探索图像识别的自注意力机制.

作者提出一种将self-attention机制应用到图像识别领域的方法。作者认为使用卷积网络进行图像识别任务实际上在实现两个函数:

  1. 特征聚集(feature aggregation): 即通过卷积核在特征图上进行卷积来融合特征的过程。
  2. 特征变换(feature transformation): 在卷积完成后进行的一系列线性和非线性变换(比如全连接和激活函数。这一部分通过感知机就能很好地完成)

在以上观点的基础上,作者提出使用self-attention机制来替代卷积作为特征聚集方法。为此,作者考虑两种self-attention形式:pairwise self-attentionpatchwise self-attention。用这两种形式的self-attention机制作为网络的basic block提出SAN网络结构。与经典卷积网络ResNet进行对比,SAN网络具有更少参数和运算量,同时在ImageNet数据集上的分类精确度有较大提升。

⚪ Pairwise Self-attention

pairwise self-attention计算公式如下:

\[\mathbf y_i = \sum\limits_{j\in \mathcal{R}(i)}\alpha(\mathbf x_i,\mathbf x_j) \odot \beta(\mathbf x_j)\]

其中$\odot$表示Hadamard product(矩阵的对应位置相乘),\(\mathbf{x}\)是特征图上一点,\(\mathbf{y}\)是经过self-attention模块运算后得到的特征图上的对应点。\(\mathcal{R}(i)\)是对应位置$i$周围的局部区域(类似卷积过程中卷积核所在区域)。\(\alpha(\mathbf x_i,\mathbf x_j)\)是权重向量,\(\beta(\mathbf x_j)\)是对\(\mathbf{x}_j\)进行embedding后的结果。

从上述计算公式中可以看出,pairwise self-attention方式和卷积方式最大的区别在于权重的确定:卷积核的权重在学习完成后就是一个固定的标量,再用这个标量与特征图上一点的每个维度相乘。而在pairwise self-attention方法中,权重通过\(\alpha(\mathbf x_i,\mathbf x_j)\)计算得到,而且计算结果是一个向量,再用这个向量与\(\beta(\mathbf x_j)\)对位相乘。显然这种方式考虑到了特征在不同通道上的权重大小。

作者对\(\alpha(\mathbf x_i,\mathbf x_j)\)进行了分解:

\[\alpha(\mathbf x_i,\mathbf x_j) = \gamma(\delta(\mathbf x_i,\mathbf x_j))\]

这样做的好处在于,在尝试不同$\delta$函数的选择是就不必考虑向量的维度问题,将维度匹配问题交给$\gamma$函数解决。\(\gamma=\{Linear \to ReLU \to Linear\}\),作者尝试了五种$\delta$函数的选择:

此外作者还将位置$i$和位置$j$的坐标信息纳入\(\gamma(\delta(\mathbf x_i,\mathbf x_j))\)的计算过程之中。

⚪ Patchwise Self-attention

patchwise self-attention计算公式如下:

\[\mathbf y_i = \sum\limits_{j\in \mathcal{R}(i)}\alpha(\mathbf{x}_{\mathcal{R}_{(i)}})_j \odot \beta(\mathbf x_j)\]

其中,\(\mathbf x_{\mathcal{R}(i)}\)是\(\mathcal{R}(i)\)所在区域的特征图,\(\alpha(\mathbf{x}_{\mathcal{R}_{(i)}})\)是权重张量。可以看到,patchwise self-attentionpairwise self-attention的区别就在于patchwise self-attention中没有对\((\mathbf x_i, \mathbf x_j)\)的配对计算,而是整个区域用来计算得到一个权重张量,再用下标$j$来索引这个张量,再用这个向量与\(\beta(\mathbf x_j)\)对位相乘。

同样地,$\alpha$函数进行分解:

\[\alpha(\mathbf{x}_{\mathcal{R}_{(i)}})=\gamma(\delta(\mathbf x_{\mathcal{R}(i)}))\]

作者尝试$\delta$函数三种不同选择:

⚪ self-attention network

基于pairwise self-attentionpatchwise self-attention,作者设计了self-attentionbasic block

利用这样的block就替代了传统CNNconv+bn/relu的过程便得到了SAN网络。