SPANet:图像识别的空间金字塔注意力网络.
注意力机制通常采用全局平均池化GAP提取图像通道的特征,类似于结构正则化,能够防止过拟合。但是GAP会过度强调正则化效果,而忽略了原始特征表示和结构信息。
本文作者设计了空间金字塔注意力网络(Spatial Pyramid Attention Network, SPANet),通过横向添加空间金字塔注意力同时考虑结构正则化和结构信息。
SPANet在4×4、2×2和1×1三个尺度上对输入特征图进行自适应平均池化。4×4平均池化捕捉了更多的特征表示和结构信息,1×1平均池化具有较强结构正则化的效果,2×2平均池化旨在平衡结构信息和结构正则化之间的关系。然后将三个输出特征连接并调整为一维向量以生成通道注意力分布。SPANet既能保持特征表示,又能继承全局平均池化的优点。
大多数注意力方法服从这样的的设计规则:以自身作为输入学习一个注意力图并作用于自身。作者探索了三种变体结构
- SPANet-A使用与传统注意力路径连接类似的模式。
- SPANet-B确保注意力路径独立于原始卷积路径,使注意力路径能够学习更广义的权重。虽然两条路径彼此独立,但并非完全无关,因为注意力路径和卷积路径是联合训练的。
- SPANet-C的设计是考虑到两个分支之间的通道不匹配问题。
class SPALayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SPALayer, self).__init__()
self.avg_pool1 = nn.AdaptiveAvgPool2d(1)
self.avg_pool2 = nn.AdaptiveAvgPool2d(2)
self.avg_pool4 = nn.AdaptiveAvgPool2d(4)
self.fc = nn.Sequential(
nn.Linear(channel*21, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y1 = self.avg_pool1(x).view(b, c) # like resize() in numpy
y2 = self.avg_pool2(x).view(b, 4 * c)
y3 = self.avg_pool4(x).view(b, 16 * c)
y = torch.cat((y1, y2, y3), 1)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)