CBAM:卷积块注意力模块.
CBAM模块通过串联使用通道注意力和空间注意力增强特征的表达能力,其中每种注意力机制使用两个一阶统计量(全局最大和全局平均)。
通道注意力使用全局最大池化和全局平均池化压缩空间维度,并通过参数共享的全连接层 (由$1 \times 1$卷积实现,避免空间维度的压缩和解压) 提取通道维度的信息:
class ChannelAttention(nn.Module):
def __init__(self, channel, ratio=4):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.sharedMLP = nn.Sequential(
nn.Conv2d(channel, channel // ratio, 1, bias=False),
nn.ReLU(),
nn.Conv2d(channel // ratio, channel, 1, bias=False))
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avgout = self.sharedMLP(self.avg_pool(x))
maxout = self.sharedMLP(self.max_pool(x))
return self.sigmoid(avgout + maxout)
空间注意力使用全局最大池化和全局平均池化压缩通道维度,并通过卷积层提取空间维度的信息:
class SpatialAttention(nn.Module):
def __init__(self):
super(SpatialAttention, self).__init__()
self.conv = nn.Conv2d(2, 1, kernel_size=3, padding=1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avgout = torch.mean(x, dim=1, keepdim=True)
maxout, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avgout, maxout], dim=1)
x = self.conv(x)
return self.sigmoid(x)
CBAM模块可以即插即用到任意卷积神经网络中,作者给出了一个结合残差模块和CBAM模块的网络模块:
import torch
import torch.nn as nn
class BasicBlock(nn.Module):
def __init__(self, channel):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(channel, channel, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channel)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(channel, channel, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channel)
self.ca = ChannelAttention(channel)
self.sa = SpatialAttention()
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.ca(out) * out # 广播机制
out = self.sa(out) * out # 广播机制
return self.relu(out + x)
if __name__ == "__main__":
t = torch.ones((32, 256, 24, 24))
cbam = BasicBlock(256)
out = cbam(t)
print(out.shape)