SANet:通过特征分组和通道置换实现轻量型置换注意力.

卷积网络中的注意力模块,有通道注意力和空间注意力两种类型。组合两者通常能够取得更好的性能,但会导致计算量的增加。作者提出了一种置换注意力(shuffle attention)机制,实现了注意力的高效组合。

置换注意力的实现如下。首先对输入特征沿通道维度拆分为$g$组,对每一组特征平均拆分后使用并行的通道注意力和空间注意力提取特征,将所有组的特征进行集成,并通过通道置换操作进行不同通道间的交互。

具体地,模块主要由四部分组成:

\[x_{k1}' = \text{sigmoid}(W_1\mathcal{F}_{GP}(x_{k1})+b_1) \cdot x_{k1}\] \[x_{k2}' = \text{sigmoid}(W_2GN(x_{k2})+b_2) \cdot x_{k2}\]

模型实现的代码如下:

class ShuffleAttention(nn.Module):
    def __init__(self, channel, groups=64):
        super(ShuffleAttention, self).__init__()
        self.groups = groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        mid_channel = channel // (groups * 2)
        self.affine1 = nn.Conv2d(mid_channel, mid_channel, 1)
        self.affine2 = nn.Conv2d(mid_channel, mid_channel, 1)
        self.sigmoid = nn.Sigmoid()
        self.gn = nn.GroupNorm(num_groups=mid_channel,
                               num_channels=mid_channel)

    def channel_shuffle(self, x, groups):
        b, c, h, w = x.shape
        x = x.reshape(b, groups, -1, h, w) # [b, g, c/g, h, w]
        x = x.permute(0, 2, 1, 3, 4) # [b, c/g, g, h, w]
        x = x.reshape(b, -1, h, w) # [b, c, h, w]
        return x

    def forward(self, x):
        b, c, h, w = x.shape
        x = x.reshape(b * self.groups, -1, h, w) # [bg, c/g, h, w]
        x_0, x_1 = x.chunk(2, dim=1) # [bg, c/2g, h, w]
        # channel attention using SE
        xn = self.avg_pool(x_0) # [bg, c/2g, 1, 1]
        xn = self.affine1(xn) # [bg, c/2g, 1, 1]
        xn = x_0 * self.sigmoid(xn) # [bg, c/2g, h, w]
        # spatial attention using Group Norm
        xs = self.gn(x_1) # [bg, c/2g, h, w]
        xs = self.affine2(xs) # [bg, c/2g, h, w]
        xs = x_1 * self.sigmoid(xs) # [bg, c/2g, h, w]
        # concatenate along channel axis
        out = torch.cat([xn, xs], dim=1) # [bg, c/g, h, w]
        out = out.reshape(b, -1, h, w) # [b, c, h, w]
        # channel shuffle
        out = self.channel_shuffle(out, 2)
        return out
		
if __name__ == "__main__":
    t = torch.ones((32, 256, 24, 24))
    sa = ShuffleAttention(256)
    out = sa(t)
    print(out.shape)

作者在ImageNet-1k数据集上对比不同模型的准确率、参数量和FLOPS;相比其他SOTA注意力机制,所提方案具有更高精度、更低计算复杂度。

为验证SA的有效性,作者采用GradCAM对其进行可视化。对比可得出SA使得分类模型聚焦于目标信息更相关的区域,进而有效的提高分类精度。