为轻量型网络设计的坐标注意力机制.
- paper:Coordinate Attention for Efficient Mobile Network Design
- arXiv:link
注意力机制能够提高卷积网络的特征表达能力。对于移动网络,通道注意力机制能够有效建模通道之间的相关性,但忽略了特征的位置信息;而更复杂的注意力会引入更多计算。作者提出了一种坐标注意力(Coordinate Attention)机制,通过沿水平或垂直方向捕捉较远距离的关系,同时保持精确的互补位置信息。该注意力机制计算简单,可以应用于现有移动网络中,且几乎不会引入额外的计算。
上图(a)为通道注意力网络SENet的结构,由于标准的卷积操作无法对通道相关性进行建模,因此显式地构造通道之间的相关性,并采用全局平均池化捕捉特征的全局信息。上图(b)为CBAM网络结构,该网络采用通道和空间注意力串联的形式,需要的计算量较大。
作者提出的坐标注意力如上图(c)所示,它能够同时建模通道相关性和空间的远程依赖性。具体地,该机制由两步组成,分别是坐标信息嵌入(Coordinate Information Embedding, CIE)和坐标注意力生成(Coordinate Attention Generation, CAG)。
Coordinate Information Embedding
通道注意力中的全局池化能够编码空间的全局信息,但它将空间信息压缩为一个通道描述子,难以保持位置信息。为使注意力模块能够保留精确的位置信息,将全局池化拆分成两个$1D$特征编码操作。即分别沿水平与垂直坐标方向进行编码:
\[z_c^h(h) = \frac{1}{W} \sum_{0≤i≤W}^{} x_c(h,i)\] \[z_c^w(w) = \frac{1}{H} \sum_{0≤j≤H}^{} x_c(j,w)\]Coordinate Attention Generation
分别沿水平和垂直位置获得具有精确编码信息的特征后,将其进行拼接,然后送入$1 \times 1$卷积$F_1$得到表示编码空间信息的中间特征:
\[f=\delta(F_1([z^h,z^w])), \quad f \in R^{\frac{C}{r} \times (H+W)}\]再将中间特征$f$沿空间维度拆分成$f^h \in R^{\frac{C}{r} \times H}$和$f^w \in R^{\frac{C}{r} \times W}$,并使用$1 \times 1$卷积$F_h$和$F_w$生成注意力权重:
\[g^h = \delta(F_h(f^h))\] \[g^w = \delta(F_w(f^w))\]最终通过注意力权重计算注意力模块的输入:
\[y_c(i,j) = x_c(i,j) \times g^h_c(i) \times g^w_c(j)\]本文提出的注意力机制可以用于增强各种移动网络的卷积特征提取。作者使用MobileNetV2和MobileNeXt为例,将坐标注意力应用于通道数更多的特征上:
通过实验,作者发现相比图像分类与目标检测任务,所提注意力机制在语义分割任务方面取得的性能提升更大。这是因为坐标注意力有助于通过精确的位置信息捕获远程依赖关系,而精确的位置信息对于语义分割等稠密预测非常重要。
Pytorch实现
class CoordAtt(nn.Module):
def __init__(self, channels, reduction=16):
super(CoordAtt, self).__init__()
self.fc1 = nn.Sequential(
nn.Conv2d(channels, channels//reduction, 1),
nn.BatchNorm2d(channels//reduction),
nn.ReLU(inplace=True)
)
self.xfc = nn.Conv2d(channels//reduction, channels, 1)
self.yfc = nn.Conv2d(channels//reduction, channels, 1)
def forward(self, x):
B, _, H, W = x.size()
# X Avg Pool and Y Avg Pool
xap = F.adaptive_avg_pool2d(x, (H, 1))
yap = F.adaptive_avg_pool2d(x, (1, W))
# Concat+Conv2d+BatchNorm+Non-linear
mer = torch.cat([xap.transpose_(2, 3), yap], dim=3)
fc1 = self.fc1(mer)
# split
xat, yat = torch.split(fc1, (H, W), dim=3)
# Conv2d-Sigmoid and Conv2d-Sigmoid
xat = torch.sigmoid(self.xfc(xat))
yat = torch.sigmoid(self.yfc(yat))
# Attention Multiplier
out = x * xat * yat
return out