BAM:瓶颈注意力模块.
BAM模块通过并联使用通道注意力和空间注意力增强特征的表达能力。
其中通道注意力模块对输入特征沿着通道维度计算一阶统计量(全局平均池化),然后通过带有瓶颈层的全连接层学习通道之间的相关性。
class ChannelGate(nn.Module):
def __init__(self, channel, reduction_ratio=16, num_layers=1):
super(ChannelGate, self).__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.gate_c = nn.Sequential(
nn.Linear(channel, channel//reduction_ratio),
nn.BatchNorm1d(channel//reduction_ratio),
nn.ReLU(inplace=True),
nn.Linear(channel//reduction_ratio,channel),
)
def forward(self, x):
avg_pool = self.avgpool(x).squeeze(-1).squeeze(-1) # [b, c]
channel_attn = self.gate_c(avg_pool) # [b, c]
return channel_attn.unsqueeze(-1).unsqueeze(-1).expand_as(x) # [b, c, h, w]
空间注意力模块首先应用$1 \times 1$卷积压缩通道维度,然后使用两个空洞率为$4$的$3 \times 3$空洞卷积提取多感受野特征,并最终应用$1 \times 1$卷积把通道数压缩为$1$。
class SpatialGate(nn.Module):
def __init__(self,
channel,
reduction_ratio=16,
dilation_conv_num=2,
dilation_val=4):
super(SpatialGate, self).__init__()
reduced_c = channel // reduction_ratio
self.gate_s = nn.Sequential()
self.gate_s.add_module(
'gate_s_conv_reduce0',
nn.Conv2d(channel, reduced_c, kernel_size=1))
self.gate_s.add_module('gate_s_bn_reduce0',
nn.BatchNorm2d(reduced_c))
self.gate_s.add_module('gate_s_relu_reduce0', nn.ReLU())
# 进行多个空洞卷积,丰富感受野
for i in range(dilation_conv_num):
self.gate_s.add_module(
'gate_s_conv_di_%d' % i,
nn.Conv2d(reduced_c, reduced_c,
kernel_size=3,
padding=dilation_val,
dilation=dilation_val))
self.gate_s.add_module(
'gate_s_bn_di_%d' % i,
nn.BatchNorm2d(reduced_c))
self.gate_s.add_module('gate_s_relu_di_%d' % i, nn.ReLU())
self.gate_s.add_module(
'gate_s_conv_final',
nn.Conv2d(reduced_c, 1, kernel_size=1))
def forward(self, x):
return self.gate_s(x).expand_as(x)
BAM模块并联空间注意力和通道注意力生成的注意力图,使用Sigmoid进行归一化后与输入特征相乘,并通过残差连接构造输出特征。
import torch
import torch.nn as nn
import torch.nn.functional as F
class BAM(nn.Module):
def __init__(self, channel):
super(BAM, self).__init__()
self.channel_att = ChannelGate(channel)
self.spatial_att = SpatialGate(channel)
def forward(self, x):
att = torch.sigmoid(self.channel_att(x) + self.spatial_att(x))
return att * x + x
BAM模块可以即插即用到任意卷积神经网络中,作者把该模块放到下采样层(池化层)之前,相当于放置在网络的瓶颈层处,因此称为瓶颈注意力模块。