ISANet:语义分割的交错稀疏自注意力网络.
Self Attention(Non-Local)机制已经成为了语义分割工作中扩大感受野(建立远程依赖)的重要利器。虽然attention机制效果好,但是其额外的大量的计算资源消耗(GPU Memory占用高、计算复杂度高)是难以承受的。因此,一些工作开始着力于减少attention机制的计算量。本文为了减少attention机制的计算量,提出了Interlaced Sparse Self-Attention模块。
self attention机制是直接在全局上进行计算,计算每一个位置与其他所有位置的链接关系。而对于ISA结构,先进行块内的self attention计算,比如$A_1,A_2,A_3$相互之间计算注意力、$B_1,B_2,B_3$相互之间计算注意力,这样得到了Long-range的attention; 然后再进$A_1B_1$、$A_2B_2$、$A_3B_3$两两之间计算self attention,这样得到了Short-range的attention。经过这种分块后反复计算self attention的操作之后,可以间接或直接的得到任意两个位置之间的依赖关系。
标准的self attention实现如下:
class SelfAttentionBlock2D(nn.Module):
def __init__(self, in_channels, key_channels, value_channels, out_channels=None, bn_type=None):
super(SelfAttentionBlock2D, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.key_channels = key_channels
self.value_channels = value_channels
if out_channels == None:
self.out_channels = in_channels
self.f_key = nn.Sequential(
nn.Conv2d(self.in_channels, self.key_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(self.key_channels),
nn.ReLU(),
nn.Conv2d(self.key_channels, self.key_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(self.key_channels),
nn.ReLU(),
)
self.f_query = nn.Sequential(
nn.Conv2d(self.in_channels, self.key_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(self.key_channels),
nn.ReLU(),
nn.Conv2d(self.key_channels, self.key_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(self.key_channels),
nn.ReLU(),
)
self.f_value = nn.Conv2d(self.in_channels, self.value_channels, kernel_size=1, bias=False)
self.W = nn.Sequential(
nn.Conv2d(self.value_channels, self.out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(self.out_channels),
nn.ReLU(),
)
def forward(self, x):
batch_size, h, w = x.size(0), x.size(2), x.size(3)
value = self.f_value(x).view(batch_size, self.value_channels, -1)
value = value.permute(0, 2, 1)
query = self.f_query(x).view(batch_size, self.key_channels, -1)
query = query.permute(0, 2, 1)
key = self.f_key(x).view(batch_size, self.key_channels, -1)
sim_map = torch.matmul(query, key)
sim_map = (self.key_channels**-.5) * sim_map
sim_map = F.softmax(sim_map, dim=-1)
context = torch.matmul(sim_map, value)
context = context.permute(0, 2, 1).contiguous()
context = context.view(batch_size, self.value_channels, h, w)
context = self.W(context)
return context
而ISA Block实现如下:
class ISA_Block(nn.Module):
def __init__(self, in_channels, key_channels, value_channels, out_channels, down_factor=[8,8], bn_type=None):
super(ISA_Block, self).__init__()
self.out_channels = out_channels
assert isinstance(down_factor, (tuple, list)) and len(down_factor) == 2
self.down_factor = down_factor
self.long_range_sa = SelfAttentionBlock2D(in_channels, key_channels, value_channels, out_channels, bn_type=bn_type)
self.short_range_sa = SelfAttentionBlock2D(out_channels, key_channels, value_channels, out_channels, bn_type=bn_type)
def forward(self, x):
n, c, h, w = x.size()
dh, dw = self.down_factor # down_factor for h and w, respectively
out_h, out_w = math.ceil(h / dh), math.ceil(w / dw)
# pad the feature if the size is not divisible
pad_h, pad_w = out_h * dh - h, out_w * dw - w
if pad_h > 0 or pad_w > 0: # padding in both left&right sides
feats = F.pad(x, (pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2))
else:
feats = x
# long range attention
feats = feats.view(n, c, out_h, dh, out_w, dw)
feats = feats.permute(0, 3, 5, 1, 2, 4).contiguous().view(-1, c, out_h, out_w)
feats = self.long_range_sa(feats)
c = self.out_channels
# short range attention
feats = feats.view(n, dh, dw, c, out_h, out_w)
feats = feats.permute(0, 4, 5, 3, 1, 2).contiguous().view(-1, c, dh, dw)
feats = self.short_range_sa(feats)
feats = feats.view(n, out_h, out_w, c, dh, dw).permute(0, 3, 1, 4, 2, 5)
feats = feats.contiguous().view(n, c, dh * out_h, dw * out_w)
# remove padding
if pad_h > 0 or pad_w > 0:
feats = feats[:, :, pad_h//2:pad_h//2 + h, pad_w//2:pad_w//2 + w]
return feats
ISA Module是通过堆叠若干个ISA Block得到的:
class ISA_Module(nn.Module):
def __init__(self, in_channels, key_channels, value_channels, out_channels, down_factors=[[8,8]], dropout=0, bn_type=None):
super(ISA_Module, self).__init__()
assert isinstance(down_factors, (tuple, list))
self.down_factors = down_factors
self.stages = nn.ModuleList([
ISA_Block(in_channels, key_channels, value_channels, out_channels, d, bn_type) for d in down_factors
])
concat_channels = in_channels + out_channels
if len(self.down_factors) > 1:
self.up_conv = nn.Sequential(
nn.Conv2d(in_channels, len(self.down_factors) * out_channels, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(len(self.down_factors) * out_channels),
nn.ReLU(),
)
concat_channels = out_channels * len(self.down_factors) * 2
self.conv_bn = nn.Sequential(
nn.Conv2d(concat_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Dropout2d(dropout),
)
def forward(self, x):
priors = [stage(x) for stage in self.stages]
if len(self.down_factors) == 1:
context = priors[0]
else:
context = torch.cat(priors, dim=1)
x = self.up_conv(x)
# residual connection
return self.conv_bn(torch.cat([x, context], dim=1))
ISANet的实现如下:
class ISANet(nn.Module):
def __init__(self, num_classes):
super(ISANet, self).__init__()
self.ISAHead = ISA_Module(in_channels=2048, key_channels=256, value_channels=512, out_channels=512, dropout=0)
self.backbone = ResNet.resnet50(replace_stride_with_dilation=[1,2,4])
self.Conv_1 = nn.Sequential(
nn.Conv2d(512, 512, 3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True)
)
self.cls_seg = nn.Conv2d(512, num_classes, 3, padding=1)
def forward(self, x):
"""Forward function."""
output = self.backbone(x)
output = self.ISAHead(output)
output = self.Conv_1(output)
output = self.cls_seg(output)
return output