GENet:在通道注意力中利用特征上下文.
GENet是对通道注意力网络SENet的改进。SENet包括Squeeze过程和Excitation过程。Squeeze过程对特征$x$沿着通道维度进行全局平均池化,Excitation过程通过两层全连接层学习通道之间的相关性。
本文作者指出,SENet的Squeeze过程对通道的空间维度统计量估计是粗略的(仅考虑了均值这个一阶统计量)。GENet把Squeeze过程替换为Gather过程,即对每个局部的空间位置提取一个统计量,用于捕捉特征之间的上下文信息;对应的Excite操作则用于将其进行缩放还原回原始尺寸。
统计量的提取可以通过具有较大卷积核尺寸的通道卷积实现,引入可学习的参数;空间尺寸的还原通过插值操作实现。
import torch.nn as nn
import torch.nn.functional as F
class GEModule(nn.Module):
def __init__(self, channels, kernel_size):
super(GEModule, self).__init__()
self.downop = nn.Sequential(
nn.Conv2d(channels, channels, groups=channels,
stride=1, kernel_size=kernel_size, padding=0,
bias=False,),
nn.BatchNorm2d(channels),)
self.mlp = nn.Sequential(
nn.Conv2d(
channels, channels // 16,
kernel_size=1, padding=0, bias=False),
nn.ReLU(),
nn.Conv2d(
channels // 16, channels,
kernel_size=1, padding=0, bias=False),)
def forward(self, x):
out = self.downop(x)
out = self.mlp(out)
shape_in = x.shape[-1]
out = F.interpolate(out, shape_in)
out = torch.sigmoid(out)
out = x * out
return out
x = torch.rand((16, 256, 64, 64))
genet = GEModule(256, 32)
print(genet(x).shape) # torch.Size([16, 256, 64, 64])