为轻量型网络设计的坐标注意力机制.

注意力机制能够提高卷积网络的特征表达能力。对于移动网络,通道注意力机制能够有效建模通道之间的相关性,但忽略了特征的位置信息;而更复杂的注意力会引入更多计算。作者提出了一种坐标注意力(Coordinate Attention)机制,通过沿水平或垂直方向捕捉较远距离的关系,同时保持精确的互补位置信息。该注意力机制计算简单,可以应用于现有移动网络中,且几乎不会引入额外的计算。

上图(a)为通道注意力网络SENet的结构,由于标准的卷积操作无法对通道相关性进行建模,因此显式地构造通道之间的相关性,并采用全局平均池化捕捉特征的全局信息。上图(b)CBAM网络结构,该网络采用通道和空间注意力串联的形式,需要的计算量较大。

作者提出的坐标注意力如上图(c)所示,它能够同时建模通道相关性和空间的远程依赖性。具体地,该机制由两步组成,分别是坐标信息嵌入(Coordinate Information Embedding, CIE)坐标注意力生成(Coordinate Attention Generation, CAG)

Coordinate Information Embedding

通道注意力中的全局池化能够编码空间的全局信息,但它将空间信息压缩为一个通道描述子,难以保持位置信息。为使注意力模块能够保留精确的位置信息,将全局池化拆分成两个$1D$特征编码操作。即分别沿水平与垂直坐标方向进行编码:

\[z_c^h(h) = \frac{1}{W} \sum_{0≤i≤W}^{} x_c(h,i)\] \[z_c^w(w) = \frac{1}{H} \sum_{0≤j≤H}^{} x_c(j,w)\]

Coordinate Attention Generation

分别沿水平和垂直位置获得具有精确编码信息的特征后,将其进行拼接,然后送入$1 \times 1$卷积$F_1$得到表示编码空间信息的中间特征:

\[f=\delta(F_1([z^h,z^w])), \quad f \in R^{\frac{C}{r} \times (H+W)}\]

再将中间特征$f$沿空间维度拆分成$f^h \in R^{\frac{C}{r} \times H}$和$f^w \in R^{\frac{C}{r} \times W}$,并使用$1 \times 1$卷积$F_h$和$F_w$生成注意力权重:

\[g^h = \delta(F_h(f^h))\] \[g^w = \delta(F_w(f^w))\]

最终通过注意力权重计算注意力模块的输入:

\[y_c(i,j) = x_c(i,j) \times g^h_c(i) \times g^w_c(j)\]

本文提出的注意力机制可以用于增强各种移动网络的卷积特征提取。作者使用MobileNetV2MobileNeXt为例,将坐标注意力应用于通道数更多的特征上:

通过实验,作者发现相比图像分类与目标检测任务,所提注意力机制在语义分割任务方面取得的性能提升更大。这是因为坐标注意力有助于通过精确的位置信息捕获远程依赖关系,而精确的位置信息对于语义分割等稠密预测非常重要。

Pytorch实现

class CoordAtt(nn.Module):
    def __init__(self, channels, reduction=16):
        super(CoordAtt, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Conv2d(channels, channels//reduction, 1),
            nn.BatchNorm2d(channels//reduction),
            nn.ReLU(inplace=True)
        )

        self.xfc = nn.Conv2d(channels//reduction, channels, 1)
        self.yfc = nn.Conv2d(channels//reduction, channels, 1)

    def forward(self, x):
        B, _, H, W = x.size()
        # X Avg Pool and Y Avg Pool
        xap = F.adaptive_avg_pool2d(x, (H, 1))
        yap = F.adaptive_avg_pool2d(x, (1, W))

        # Concat+Conv2d+BatchNorm+Non-linear
        mer = torch.cat([xap.transpose_(2, 3), yap], dim=3)
        fc1 = self.fc1(mer)
        
        # split
        xat, yat = torch.split(fc1, (H, W), dim=3)

        # Conv2d-Sigmoid and Conv2d-Sigmoid
        xat = torch.sigmoid(self.xfc(xat))
        yat = torch.sigmoid(self.yfc(yat))

        # Attention Multiplier
        out = x * xat * yat
        return out