NAM:基于归一化的注意力模块.
识别不显著的特征是模型压缩的关键,然而这一点在注意力机制中却没有得到研究。这项工作提出了一种基于归一化的注意力模块NAM,通过对对注意力模块应用一个权重稀疏惩罚,抑制了具有较少显著性的特征对应的权重。
NAM对输入特征应用Batch Norm,并通过Batch Norm中可学习的尺度变换参数$\gamma$构造注意力分布。
Batch Norm的表达式为:
\[X \leftarrow \gamma \frac{X-\mu}{\sqrt{\sigma^2+\epsilon}} + \beta\]则NAM构造的注意力权重为:
\[w_i = \frac{\gamma_i}{\sum_j \gamma_j}\]import torch.nn as nn
import torch
from torch.nn import functional as F
class Channel_Att(nn.Module):
def __init__(self, channels, t=16):
super(Channel_Att, self).__init__()
self.channels = channels
self.bn = nn.BatchNorm2d(self.channels, affine=True)
def forward(self, x):
residual = x
x = self.bn(x)
weight_bn = self.bn.weight.data.abs() / torch.sum(self.bn.weight.data.abs())
x = x.permute(0, 2, 3, 1).contiguous()
x = torch.mul(weight_bn, x)
x = x.permute(0, 3, 1, 2).contiguous()
x = torch.sigmoid(x) * residual
return x