GCNet:结合非局部神经网络和通道注意力.
捕获视觉场景中的全局依赖能提高分割任务的效果。在CNN网络中远程依赖的建立(等同于感受野的扩增)主要依靠堆叠卷积层来实现,但是这种方法效率低且难以优化,因为长距离位置之间的信息难以传递,而且卷积层的堆叠可能会导致卷积核退化的问题。为了解决这个问题,Non-Local net通过自注意力机制来建立远程依赖,对于每一个查询(query),计算该query位置与全局所有位置(key)的关系来建立注意力图(attention map),然后将注意力图与value进行加权汇总,生成最终的输出。
Non-Local net建立远程关系的能力十分优秀,但是十分巨大的计算量成为了其进一步应用的缺陷。作者认为,Non-Local Block(图a)对每个特征点都计算一次attention map,计算量大,但所计算出的attention map几乎是相同的。作者提出了简化版的NL Block(图b),只计算一次attention map:
简化版的NL Block简化了query运算,query和key计算时权重共享,也就是query和key等同,这里就减少了计算query的过程,作者可视化了两种结构的attention map,发现效果差不多。因此后续的工作都基于Simplified Non-Local结构。
在此基础上,作者参考SENet网络(图c)做了如下改进,最终得到Global Context Block(图d):
- 忽略了计算value的卷积$W_v$。
- 将Non-Local Block右下角的$1 \times 1$卷积替换为bottleneck层,降低参数量(通常设下采样率$r=16$)。
- bottleneck层增加了网络优化的难度,因此引入Layer Norm。
Global Context Block的Pytorch代码如下:
class GlobalContextBlock(nn.Module):
def __init__(self, in_channels, scale = 16):
super(GlobalContextBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = self.in_channels//scale
self.Conv_key = nn.Conv2d(self.in_channels, 1, 1)
self.SoftMax = nn.Softmax(dim=1)
self.Conv_value = nn.Sequential(
nn.Conv2d(self.in_channels, self.out_channels, 1),
nn.LayerNorm([self.out_channels, 1, 1]),
nn.ReLU(),
nn.Conv2d(self.out_channels, self.in_channels, 1),
)
def forward(self, x):
b, c, h, w = x.size()
# key -> [b, 1, H, W] -> [b, 1, H*W] -> [b, H*W, 1]
key = self.SoftMax(self.Conv_key(x).view(b, 1, -1).permute(0, 2, 1).view(b, -1, 1).contiguous())
query = x.view(b, c, h*w)
# [b, c, h*w] * [b, H*W, 1]
concate_QK = torch.matmul(query, key)
concate_QK = concate_QK.view(b, c, 1, 1).contiguous()
value = self.Conv_value(concate_QK)
out = x + value
return out
if __name__ == "__main__":
x = torch.randn((2, 1024, 24, 24))
GCBlock = GlobalContextBlock(in_channels=1024)
out = GCBlock(x)
print("GCBlock output.shape:", out.shape)
把Global Context Block应用到图像分类模型的例子:
class GCNet(nn.Module):
def __init__(self, num_classes):
super(GCNet, self).__init__()
self.gc_block = GlobalContextBlock(in_channels=2048, scale = 16)
self.backbone = ResNet.resnet50(replace_stride_with_dilation=[1,2,4])
self.Conv_1 = nn.Sequential(
nn.Conv2d(2048, 512, 3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True)
)
self.cls_seg = nn.Conv2d(512, num_classes, 3, padding=1)
def forward(self, x):
"""Forward function."""
output = self.backbone(x)
output = self.gc_block(output)
output = self.Conv_1(output)
output = self.cls_seg(output)
return output
if __name__ == "__main__":
x = torch.randn((2, 3, 224, 224))
model = GCNet(num_classes=2)
out = model(x)
print("GCNet output.shape:", out.shape)