SANet:通过特征分组和通道置换实现轻量型置换注意力.
卷积网络中的注意力模块,有通道注意力和空间注意力两种类型。组合两者通常能够取得更好的性能,但会导致计算量的增加。作者提出了一种置换注意力(shuffle attention)机制,实现了注意力的高效组合。
置换注意力的实现如下。首先对输入特征沿通道维度拆分为$g$组,对每一组特征平均拆分后使用并行的通道注意力和空间注意力提取特征,将所有组的特征进行集成,并通过通道置换操作进行不同通道间的交互。
具体地,模块主要由四部分组成:
- Feature Grouping:对输入特征进行分组。记输入特征$x \in \Bbb{R}^{c \times h \times w}$,沿通道维度把特征分成$g$组:$x=[x_1,…,x_g]$。对于每组特征$x_g \in \Bbb{R}^{\frac{c}{g} \times h \times w}$,沿通道再次拆分成两个分支$x_{k1},x_{k2} \in \Bbb{R}^{\frac{c}{2g} \times h \times w}$,分别计算通道注意力和空间注意力。
- Channel Attention:广泛使用的通道注意力是SENet,但其参数量较多,不利于网络的轻量化。作者采用最简单的全局平均池化GAP+缩放scale+Sigmoid组合:
- Spatial Attention:空间注意力中统计量的计算是通过GroupNorm实现的:
- Aggregation:完成上述两种注意力计算后,对其进行集成。首先通过通道组合得到$x_k’=[x_{k1}’,x_{k2}’] \in \Bbb{R}^{\frac{c}{g} \times h \times w}$,再通过通道置换shuffle实现组间通信。最终输出与输入尺寸相同的特征。
模型实现的代码如下:
class ShuffleAttention(nn.Module):
def __init__(self, channel, groups=64):
super(ShuffleAttention, self).__init__()
self.groups = groups
self.avg_pool = nn.AdaptiveAvgPool2d(1)
mid_channel = channel // (groups * 2)
self.affine1 = nn.Conv2d(mid_channel, mid_channel, 1)
self.affine2 = nn.Conv2d(mid_channel, mid_channel, 1)
self.sigmoid = nn.Sigmoid()
self.gn = nn.GroupNorm(num_groups=mid_channel,
num_channels=mid_channel)
def channel_shuffle(self, x, groups):
b, c, h, w = x.shape
x = x.reshape(b, groups, -1, h, w) # [b, g, c/g, h, w]
x = x.permute(0, 2, 1, 3, 4) # [b, c/g, g, h, w]
x = x.reshape(b, -1, h, w) # [b, c, h, w]
return x
def forward(self, x):
b, c, h, w = x.shape
x = x.reshape(b * self.groups, -1, h, w) # [bg, c/g, h, w]
x_0, x_1 = x.chunk(2, dim=1) # [bg, c/2g, h, w]
# channel attention using SE
xn = self.avg_pool(x_0) # [bg, c/2g, 1, 1]
xn = self.affine1(xn) # [bg, c/2g, 1, 1]
xn = x_0 * self.sigmoid(xn) # [bg, c/2g, h, w]
# spatial attention using Group Norm
xs = self.gn(x_1) # [bg, c/2g, h, w]
xs = self.affine2(xs) # [bg, c/2g, h, w]
xs = x_1 * self.sigmoid(xs) # [bg, c/2g, h, w]
# concatenate along channel axis
out = torch.cat([xn, xs], dim=1) # [bg, c/g, h, w]
out = out.reshape(b, -1, h, w) # [b, c, h, w]
# channel shuffle
out = self.channel_shuffle(out, 2)
return out
if __name__ == "__main__":
t = torch.ones((32, 256, 24, 24))
sa = ShuffleAttention(256)
out = sa(t)
print(out.shape)
作者在ImageNet-1k数据集上对比不同模型的准确率、参数量和FLOPS;相比其他SOTA注意力机制,所提方案具有更高精度、更低计算复杂度。
为验证SA的有效性,作者采用GradCAM对其进行可视化。对比可得出SA使得分类模型聚焦于目标信息更相关的区域,进而有效的提高分类精度。