GCNet:结合非局部神经网络和通道注意力.

捕获视觉场景中的全局依赖能提高分割任务的效果。在CNN网络中远程依赖的建立(等同于感受野的扩增)主要依靠堆叠卷积层来实现,但是这种方法效率低且难以优化,因为长距离位置之间的信息难以传递,而且卷积层的堆叠可能会导致卷积核退化的问题。为了解决这个问题,Non-Local net通过自注意力机制来建立远程依赖,对于每一个查询(query),计算该query位置与全局所有位置(key)的关系来建立注意力图(attention map),然后将注意力图与value进行加权汇总,生成最终的输出。

Non-Local net建立远程关系的能力十分优秀,但是十分巨大的计算量成为了其进一步应用的缺陷。作者认为,Non-Local Block(图a)对每个特征点都计算一次attention map,计算量大,但所计算出的attention map几乎是相同的。作者提出了简化版的NL Block(图b),只计算一次attention map

简化版的NL Block简化了query运算,querykey计算时权重共享,也就是querykey等同,这里就减少了计算query的过程,作者可视化了两种结构的attention map,发现效果差不多。因此后续的工作都基于Simplified Non-Local结构。

在此基础上,作者参考SENet网络(图c)做了如下改进,最终得到Global Context Block(图d):

Global Context BlockPytorch代码如下:

class GlobalContextBlock(nn.Module):
    def __init__(self, in_channels, scale = 16):
        super(GlobalContextBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = self.in_channels//scale

        self.Conv_key = nn.Conv2d(self.in_channels, 1, 1)
        self.SoftMax = nn.Softmax(dim=1)

        self.Conv_value = nn.Sequential(
            nn.Conv2d(self.in_channels, self.out_channels, 1),
            nn.LayerNorm([self.out_channels, 1, 1]),
            nn.ReLU(),
            nn.Conv2d(self.out_channels, self.in_channels, 1),
        )

    def forward(self, x):
        b, c, h, w = x.size()
        # key -> [b, 1, H, W] -> [b, 1, H*W] ->  [b, H*W, 1]
        key = self.SoftMax(self.Conv_key(x).view(b, 1, -1).permute(0, 2, 1).view(b, -1, 1).contiguous())
        query = x.view(b, c, h*w)
        # [b, c, h*w] * [b, H*W, 1]
        concate_QK = torch.matmul(query, key)
        concate_QK = concate_QK.view(b, c, 1, 1).contiguous()
        value = self.Conv_value(concate_QK)
        out = x + value
        return out

if __name__ == "__main__":
    x = torch.randn((2, 1024, 24, 24))
    GCBlock = GlobalContextBlock(in_channels=1024)
    out = GCBlock(x)
    print("GCBlock output.shape:", out.shape)

Global Context Block应用到图像分类模型的例子:

class GCNet(nn.Module):
    def __init__(self, num_classes):
        super(GCNet, self).__init__()
        self.gc_block = GlobalContextBlock(in_channels=2048, scale = 16)
        self.backbone = ResNet.resnet50(replace_stride_with_dilation=[1,2,4])
        self.Conv_1 = nn.Sequential(
            nn.Conv2d(2048, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True)
        )
        self.cls_seg = nn.Conv2d(512, num_classes, 3, padding=1)

    def forward(self, x):
        """Forward function."""
        output = self.backbone(x)
        output = self.gc_block(output)
        output = self.Conv_1(output)
        output = self.cls_seg(output)
        return output

if __name__ == "__main__":
    x = torch.randn((2, 3, 224, 224))
    model = GCNet(num_classes=2)
    out = model(x)
    print("GCNet output.shape:", out.shape)