SimAM:为卷积神经网络设计的简单无参数注意力模块.
计算机视觉中现有的注意力模块通常关注通道域或空间域,这两种注意力机制与人脑中基于特征的注意力和基于空间的注意力完全对应。通道注意力是一种1D注意力,它对不同通道区别对待,对所有位置同等对待;空间注意力是一种2D注意力,它对不同位置区别对待,对所有通道同等对待。
在人类生活中,这两种机制共存,并在视觉处理过程中共同有助于信息选择。因此作者提出了一个注意力模块来进行类似的操作,以便每个神经元被分配一个唯一的权重。作者认为三维权值的计算应该很直接,同时允许模块保持一个轻量级的属性。
作者认为注意力机制的实现应当通过神经科学中的统一原则引导设计,因此提出一个基于神经科学理论的模块来解决这些问题。在视觉神经学中,那些信息量最大的神经元通常与周围神经元拥有不同的放电模式。同时一个活跃的神经元也可能抑制周围的神经元活动,这种现象被称为空间抑制。
换言之,在视觉中,表现出明显空间抑制效应的神经元应该被赋予更高的重要性,而找到这些神经元的最简单方式就是测量一个目标神经元与其他神经元之间的线性可分性。定义神经元的能量函数:
\[e_t(w_t,b_t,\mathbf{y},x_i) = (y_t-\hat{t})^2+\frac{1}{M-1}\sum_{i=1}^{M-1}(y_0-\hat{x}_i)^2\]$t$和$x_i$是输入\(X\in \mathbb{R}^{C\times H\times W}\)中单通道上的目标神经元和其他神经元。$\hat{t}=w_tt+b_t$和$\hat{x}_i=w_tx_i+b_t$是$t$和$x_i$的线性变换,$w_t$和$b_t$分别代表线性变换的权重和偏置。$i$是空间维度上的索引,$M=H\times W$代表该通道上神经元的个数。
上式中的所有量都是标量,当$y_t=\hat{t}$和所有$x_i=y_o$时取得最小值,其中,$y_t$和$y_o$是两个不同的值,简便起见,使用二值标签,即$y_t=1, y_o=-1$。
求解上式的最小值等价于求解目标神经元$t$和其他所有神经元$i$之间的线性可分性,若添加正则项,则最终的能量函数如下:
\[e_t(w_t,b_t,\mathbf{y},x_i) = \frac{1}{M-1}\sum_{i=1}^{M-1}(-1-(w_tx_i+b_t))^2+(1-(w_tt+b_t))^2+\lambda w_t^2\]公式的来源应该是SVM,将当前神经元设置为正类,其余神经元设置为负类,来衡量他们之间的差异性。
理论上,特征的每个通道拥有$M$个能量函数,逐一求解是很大的计算负担。上式最小化可以获得解析解:
\[\begin{aligned} w_t&=-\frac{2(t-\mu_t)}{(t-\mu_t)^2+2\sigma_t^2+2\lambda} \\ b_t&=-\frac{1}{2}(t-\mu_t)w_t \end{aligned}\]其中\(\mu_t=\frac{1}{M-1}\sum_{i=1}^{M-1}x_i,\sigma_t^2=\frac{1}{M-1}\sum_{i=1}^{M-1}(x_i-\mu_t)^2\),实际上就是该通道中除去目标神经元的均值和方差。
由于解析解是在单个通道上获得的,因此可以合理假设每个通道中所有像素遵循相同的分布,最小能量即为:
\[e_t^*=\frac{4(\mu^2+\lambda)}{(t-\mu)^2+2\sigma^2+2\lambda}\]能量越低,神经元$t$与周围神经元的区别越大,重要性越高。因此神经元的重要性可以通过$1/e_t^{*}$得到。
根据以往的神经学研究,哺乳动物大脑中的注意力调节通常表现为神经元反应的增益效应,因此使用放缩运算而非加法来实现加权:
\[\widetilde{X}=sigmoid(\frac{1}{E})\otimes X,\]同时sigmoid函数还可以限制$E$中的过大值,并且不会影响每个神经元的相对重要性。
# X: input feature [N, C, H, W]
# lambda: coefficient λ in Eqn
def forward (X, lambda):
# spatial size
n = X.shape[2] * X.shape[3] - 1
# square of (t - u)
d = (X - X.mean(dim=[2,3])).pow(2)
# d.sum() / n is channel variance
v = d.sum(dim=[2,3]) / n
# E_inv groups all importance of X
E_inv = d / (4 * (v + lambda)) + 0.5
# return attended features
return X * torch.sigmoid(E_inv)