NAM:基于归一化的注意力模块.

识别不显著的特征是模型压缩的关键,然而这一点在注意力机制中却没有得到研究。这项工作提出了一种基于归一化的注意力模块NAM,通过对对注意力模块应用一个权重稀疏惩罚,抑制了具有较少显著性的特征对应的权重。

NAM对输入特征应用Batch Norm,并通过Batch Norm中可学习的尺度变换参数$\gamma$构造注意力分布。

Batch Norm的表达式为:

\[X \leftarrow \gamma \frac{X-\mu}{\sqrt{\sigma^2+\epsilon}} + \beta\]

NAM构造的注意力权重为:

\[w_i = \frac{\gamma_i}{\sum_j \gamma_j}\]

import torch.nn as nn
import torch
from torch.nn import functional as F


class Channel_Att(nn.Module):
    def __init__(self, channels, t=16):
        super(Channel_Att, self).__init__()
        self.channels = channels
        self.bn = nn.BatchNorm2d(self.channels, affine=True)

    def forward(self, x):
        residual = x
        x = self.bn(x)
        weight_bn = self.bn.weight.data.abs() / torch.sum(self.bn.weight.data.abs())
        x = x.permute(0, 2, 3, 1).contiguous()
        x = torch.mul(weight_bn, x)
        x = x.permute(0, 3, 1, 2).contiguous()
        x = torch.sigmoid(x) * residual
        return x