SPANet:图像识别的空间金字塔注意力网络.

注意力机制通常采用全局平均池化GAP提取图像通道的特征,类似于结构正则化,能够防止过拟合。但是GAP会过度强调正则化效果,而忽略了原始特征表示和结构信息。

本文作者设计了空间金字塔注意力网络(Spatial Pyramid Attention Network, SPANet),通过横向添加空间金字塔注意力同时考虑结构正则化和结构信息。

SPANet4×42×21×1三个尺度上对输入特征图进行自适应平均池化。4×4平均池化捕捉了更多的特征表示和结构信息,1×1平均池化具有较强结构正则化的效果,2×2平均池化旨在平衡结构信息和结构正则化之间的关系。然后将三个输出特征连接并调整为一维向量以生成通道注意力分布。SPANet既能保持特征表示,又能继承全局平均池化的优点。

大多数注意力方法服从这样的的设计规则:以自身作为输入学习一个注意力图并作用于自身。作者探索了三种变体结构

class SPALayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SPALayer, self).__init__()
        self.avg_pool1 = nn.AdaptiveAvgPool2d(1)
        self.avg_pool2 = nn.AdaptiveAvgPool2d(2)
        self.avg_pool4 = nn.AdaptiveAvgPool2d(4)
        self.fc = nn.Sequential(
            nn.Linear(channel*21, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y1 = self.avg_pool1(x).view(b, c)  # like resize() in numpy
        y2 = self.avg_pool2(x).view(b, 4 * c)
        y3 = self.avg_pool4(x).view(b, 16 * c)
        y = torch.cat((y1, y2, y3), 1)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)