CBAM:卷积块注意力模块.

CBAM模块通过串联使用通道注意力和空间注意力增强特征的表达能力,其中每种注意力机制使用两个一阶统计量(全局最大全局平均)。

通道注意力使用全局最大池化和全局平均池化压缩空间维度,并通过参数共享的全连接层 (由$1 \times 1$卷积实现,避免空间维度的压缩和解压) 提取通道维度的信息:

class ChannelAttention(nn.Module):
    def __init__(self, channel, ratio=4):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.sharedMLP = nn.Sequential(
            nn.Conv2d(channel, channel // ratio, 1, bias=False), 
            nn.ReLU(),
            nn.Conv2d(channel // ratio, channel, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgout = self.sharedMLP(self.avg_pool(x))
        maxout = self.sharedMLP(self.max_pool(x))
        return self.sigmoid(avgout + maxout)

空间注意力使用全局最大池化和全局平均池化压缩通道维度,并通过卷积层提取空间维度的信息:

class SpatialAttention(nn.Module):
    def __init__(self):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=3, padding=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avgout, maxout], dim=1)
        x = self.conv(x)
        return self.sigmoid(x)

CBAM模块可以即插即用到任意卷积神经网络中,作者给出了一个结合残差模块和CBAM模块的网络模块:

import torch
import torch.nn as nn

class BasicBlock(nn.Module):
    def __init__(self, channel):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(channel, channel, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channel)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channel, channel, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channel)
        self.ca = ChannelAttention(channel)
        self.sa = SpatialAttention()

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.ca(out) * out  # 广播机制
        out = self.sa(out) * out  # 广播机制
        return self.relu(out + x)
		
if __name__ == "__main__":
    t = torch.ones((32, 256, 24, 24))
    cbam = BasicBlock(256)
    out = cbam(t)
    print(out.shape)