为多标签分类设计的简单有效的残差注意力.

多标签图像识别一直是计算机视觉中一项非常具有挑战性的实际应用任务。现有的一些多标签的方法多诉诸于复杂的空间注意力模型,这些模型往往难以优化。作者提出了一个简单并且容易训练的特定类别残差注意力(class-specific residual attention, CSRA)模块。该模块充分利用每个物体类别的空间注意力,并且取得了较高的精度。同时计算成本也几乎可以忽略不计。

作者的动机来源于下面的四行代码。在许多不同的预训练模型和数据集上,即使没有任何额外的训练,只用4行代码也可以提高多标签识别的准确率。作者证明了这个不同空间区域间的最大值池化的操作其实是一个类别特定的注意力操作。

# x:feature tensor,output of CNN backbone
# x’s size: (B, d, H, W)
# y_raw: by applying classifier(’FC’) to ’x’
# y_raw’s size: (B, C, HxW)
# C: number of classes
y_raw = FC(x).flatten(2) # (B, d, H, W) => (B, C, H, W) => (B, C, HxW)
y_avg = torch.mean(y_raw, dim=2) # (B, C, HxW) => (B, C)
y_max = torch.max(y_raw, dim=2)[0] # (B, C, HxW) => (B, C)
score = y_avg + Lambda * y_max

方法的总体结构如图所示。首先将图片通过一个卷积主干网络得到特征图$x$,然后经过一个$1 \times 1$的卷积将$d \times h \times w$的特征图变为$c \times h \times w$,其中$c$是类的个数。然后将得到的特征送入多头CSRA模块,得到输出\(\hat{y}_{T_i}\),将所有的\(\hat{y}_{T_i}\)相加,得到最终的输出\(\hat{y}_o=\sum_i\hat{y}_{T_i}\)。

CSRA模块的结构如下图所示。以一个$2048 \times 7 \times 7$的特征输入举例。将这个特征图分为$x_1,…,x_{49},x_i \in R^{2048}$, 将其送入一个$1 \times 1$卷积,卷积的尺寸为$m_i \in R^{2048}$,$i$表示第$i$个类。然后将输出分别进行空间池化和平均池化,得到两个输出$a^i$和$g$。将$g$和$\lambda a^i$相加得到$f^i$,并得到最终的输出$y^i=m_i^T\cdot f^i$。

\[g = \frac{1}{49}\sum_{k=1}^{49} x_k\] \[\begin{aligned} s_k^i & = \frac{\exp(Tx_k^Tm_i)}{\sum_{j=1}^{49}\exp(Tx_j^Tm_i)} \\ a^i &= \sum_{k=1}^{49} s_k^i x_k \end{aligned}\]

其中$T>0$是一个控制参数。下面分析最终的输出$y^i$:

\[\begin{aligned} y^i&=m_i^T\cdot f^i =m_i^T\cdot(g+\lambda a^i)\\ & = \frac{1}{49}\sum_{k=1}^{49} x_k^Tm_i+\lambda \sum_{k=1}^{49} \frac{\exp(Tx_k^Tm_i)}{\sum_{j=1}^{49}\exp(Tx_j^Tm_i)} x_k^Tm_i \end{aligned}\]

$x_k^Tm_i$表示特征第$k$个空间位置对第$i$类的分类分数。上式第一项表示考虑所有空间位置对分类分数的平均影响;第二项当$T \to \infty$时为$\max_k x_k^Tm_i$,表示考虑空间位置对分类分数的最大影响。

也可以分析特征向量:

\[\begin{aligned} f^i &= g+\lambda a^i\\ & = \frac{1}{49}\sum_{k=1}^{49} x_k+\lambda \sum_{k=1}^{49} s_k^i x_k \\ & = (1+\lambda) \sum_{k=1}^{49} \frac{\frac{1}{49}+\lambda s_k^i}{1+\lambda}x_k \end{aligned}\]

注意到\(\sum_{k=1}^{49} \frac{1}{49}=1, \sum_{k=1}^{49} s_k^i=1\),因此CSRA模块的第$i$类特征向量是不同空间位置的特征$x_k$的归一化加权组合,其中前者独立于类别和位置,后者依赖于类别$i$和位置$k$。

控制参数$T$的值很难调整,不同的类可能需要不同的取值。因此作者采用了多头注意力机制。虽然每个分支的$T$值不同,但对所有分支都共享同一个$\lambda$。记$H$为注意力头的个数。一般的对$T$的取值如下:

\[\begin{aligned} & H=1: T=1 \text { or } T=\infty \\ & H=2: T_1=1 \text { and } T_2=\infty \\ & H=4: T_{1: 3}=1,2,4 \text { and } T_4=\infty \\ & H=6: T_{1: 5}=1,2,3,4,5 \text { and } T_6=\infty \\ & H=8: T_{1: 7}=1,2,3,4,5,6,7 \text { and } T_8=\infty \end{aligned}\]

同时为了提高收敛速度,作者将每个分类器的权重向量进行了归一化处理,即:

\[m_i \leftarrow \frac{m_i}{||m_i||}\]
class CSRA(nn.Module):
     def __init__(self, input_dim, num_classes, T, lam):
         super(CSRA, self).__init__()
         self.T = T      # 控制参数 T       
         self.lam = lam  # Lambda                        
         self.head = nn.Conv2d(input_dim, num_classes, 1, bias=False)
         self.softmax = nn.Softmax(dim=2)
 
     def forward(self, x):
         # x (B d H W)
         # normalize classifier
         # score (B C HxW)
         score = self.head(x) / torch.norm(self.head.weight, dim=1, 
                                           keepdim=True).transpose(0,1)
         score = score.flatten(2)
         base_logit = torch.mean(score, dim=2)
 
         if self.T == 99: # max-pooling
             att_logit = torch.max(score, dim=2)[0]
         else:
             score_soft = self.softmax(score * self.T)
             att_logit = torch.sum(score * score_soft, dim=2)
 
         return base_logit + self.lam * att_logit
     
     
class MHA(nn.Module):  # 多头注意力模块
     temp_settings = {  # 初始化控制参数 T
         1: [1],
         2: [1, 99],
         4: [1, 2, 4, 99],
         6: [1, 2, 3, 4, 5, 99],
         8: [1, 2, 3, 4, 5, 6, 7, 99]
     }
 
     def __init__(self, num_heads, lam, input_dim, num_classes):
         super(MHA, self).__init__()
         self.temp_list = self.temp_settings[num_heads]
         self.multi_head = nn.ModuleList([
             CSRA(input_dim, num_classes, self.temp_list[i], lam)
             for i in range(num_heads)
         ])
 
     def forward(self, x):
         logit = 0.
         for head in self.multi_head:
             logit += head(x)
         return logit