GENet:在通道注意力中利用特征上下文.

GENet是对通道注意力网络SENet的改进。SENet包括Squeeze过程和Excitation过程。Squeeze过程对特征$x$沿着通道维度进行全局平均池化,Excitation过程通过两层全连接层学习通道之间的相关性。

本文作者指出,SENetSqueeze过程对通道的空间维度统计量估计是粗略的(仅考虑了均值这个一阶统计量)。GENetSqueeze过程替换为Gather过程,即对每个局部的空间位置提取一个统计量,用于捕捉特征之间的上下文信息;对应的Excite操作则用于将其进行缩放还原回原始尺寸。

统计量的提取可以通过具有较大卷积核尺寸的通道卷积实现,引入可学习的参数;空间尺寸的还原通过插值操作实现。

import torch.nn as nn
import torch.nn.functional as F

class GEModule(nn.Module):
    def __init__(self, channels, kernel_size):
        super(GEModule, self).__init__()
        self.downop = nn.Sequential(
            nn.Conv2d(channels, channels, groups=channels,
                      stride=1, kernel_size=kernel_size, padding=0,
                      bias=False,),
            nn.BatchNorm2d(channels),)
        self.mlp = nn.Sequential(
            nn.Conv2d(
                channels, channels // 16,
                kernel_size=1, padding=0, bias=False),
            nn.ReLU(),
            nn.Conv2d(
                channels // 16, channels,
                kernel_size=1, padding=0, bias=False),)

    def forward(self, x):
        out = self.downop(x)
        out = self.mlp(out)
        shape_in = x.shape[-1]
        out = F.interpolate(out, shape_in)
        out = torch.sigmoid(out)
        out = x * out
        return out

x = torch.rand((16, 256, 64, 64))
genet = GEModule(256, 32)
print(genet(x).shape) # torch.Size([16, 256, 64, 64])