BAM:瓶颈注意力模块.

BAM模块通过并联使用通道注意力和空间注意力增强特征的表达能力。

其中通道注意力模块对输入特征沿着通道维度计算一阶统计量(全局平均池化),然后通过带有瓶颈层的全连接层学习通道之间的相关性。

class ChannelGate(nn.Module):
    def __init__(self, channel, reduction_ratio=16, num_layers=1):
        super(ChannelGate, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.gate_c = nn.Sequential(
            nn.Linear(channel, channel//reduction_ratio),
            nn.BatchNorm1d(channel//reduction_ratio),
            nn.ReLU(inplace=True),
            nn.Linear(channel//reduction_ratio,channel),
        )

    def forward(self, x):
        avg_pool = self.avgpool(x).squeeze(-1).squeeze(-1) # [b, c]
        channel_attn = self.gate_c(avg_pool) # [b, c]
        return channel_attn.unsqueeze(-1).unsqueeze(-1).expand_as(x) # [b, c, h, w]

空间注意力模块首先应用$1 \times 1$卷积压缩通道维度,然后使用两个空洞率为$4$的$3 \times 3$空洞卷积提取多感受野特征,并最终应用$1 \times 1$卷积把通道数压缩为$1$。

class SpatialGate(nn.Module):
    def __init__(self,
                 channel,
                 reduction_ratio=16,
                 dilation_conv_num=2,
                 dilation_val=4):
        super(SpatialGate, self).__init__()
        reduced_c = channel // reduction_ratio
        self.gate_s = nn.Sequential()

        self.gate_s.add_module(
            'gate_s_conv_reduce0',
            nn.Conv2d(channel, reduced_c, kernel_size=1))
        self.gate_s.add_module('gate_s_bn_reduce0',
                               nn.BatchNorm2d(reduced_c))
        self.gate_s.add_module('gate_s_relu_reduce0', nn.ReLU())

        # 进行多个空洞卷积,丰富感受野
        for i in range(dilation_conv_num):
            self.gate_s.add_module(
                'gate_s_conv_di_%d' % i,
                nn.Conv2d(reduced_c, reduced_c,
                          kernel_size=3,
                          padding=dilation_val,
                          dilation=dilation_val))
            self.gate_s.add_module(
                'gate_s_bn_di_%d' % i,
                nn.BatchNorm2d(reduced_c))
            self.gate_s.add_module('gate_s_relu_di_%d' % i, nn.ReLU())

        self.gate_s.add_module(
            'gate_s_conv_final',
            nn.Conv2d(reduced_c, 1, kernel_size=1))

    def forward(self, x):
        return self.gate_s(x).expand_as(x)

BAM模块并联空间注意力和通道注意力生成的注意力图,使用Sigmoid进行归一化后与输入特征相乘,并通过残差连接构造输出特征。

import torch
import torch.nn as nn
import torch.nn.functional as F

class BAM(nn.Module):
    def __init__(self, channel):
        super(BAM, self).__init__()
        self.channel_att = ChannelGate(channel)
        self.spatial_att = SpatialGate(channel)

    def forward(self, x):
        att = torch.sigmoid(self.channel_att(x) + self.spatial_att(x))
        return att * x + x

BAM模块可以即插即用到任意卷积神经网络中,作者把该模块放到下采样层(池化层)之前,相当于放置在网络的瓶颈层处,因此称为瓶颈注意力模块。