GSoP-Net:全局二阶池化卷积网络.

本文提出了全局二阶池化(global second-order pooling, GSoP)模块,沿着特征的通道维度捕获全局二阶统计信息,可以方便地插入到现有的网络架构中,以较小的计算开销提高网络的性能。

GSoP把输入特征$x$沿通道维度进行降维后,计算通道之间的协方差矩阵,然后通过按行卷积把协方差特征转化为一个向量,并通过全连接层(由$1 \times 1$卷积实现)构造为权重向量,并作用于输入特征。

所设计的GSoP模块可以即插即用到网络的任意位置。通过在网络的中间层中引入该模块,可以在早期对整体图像进行高阶统计建模,增强了网络的非线性建模能力。

Pytorch代码如下:

import torch.nn as nn

class GSoP(nn.Module):
    def __init__(self, in_channel, mid_channel=128):
        super(GSoP, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channel, mid_channel, 1, 1, 0),
            nn.BatchNorm2d(mid_channel),
            nn.ReLU(inplace=True),)
        # 通过组卷积实现按行卷积
        self.row_wise_conv = nn.Sequential(
            nn.Conv2d(
                mid_channel, 4*mid_channel,
                kernel_size=(mid_channel, 1),
                groups = mid_channel),
            nn.BatchNorm2d(4*mid_channel),)
        self.conv2 = nn.Sequential(
            nn.Conv2d(4*mid_channel, in_channel, 1, 1, 0),
            nn.BatchNorm2d(in_channel),
            nn.Sigmoid())
    
    def forward(self, x):
        # [B, C', H, W]
        feas = self.conv1(x) # [B, C, H, W]
        # 计算协方差矩阵
        B, C = feas.shape[0], feas.shape[1]
        for i in range(B):
            fea = feas[i].view(C, -1).permute(1, 0) # [HW, C]
            fea = fea - torch.mean(fea, axis=0) # [HW, C]
            cov = torch.matmul(fea.T, fea).unsqueeze(0) # [1, C, C]
            if i == 0:
                covs = cov
            else:
                covs = torch.cat([covs, cov], dim=0) # [B, C, C]
        covs = covs.unsqueeze(-1) # [B, C, C, 1]
        out = self.row_wise_conv(covs) # [B, 4C, 1, 1]
        out = self.conv2(out) # [B, C', 1, 1]
        return x * out
		
if __name__ == "__main__":
    t = torch.ones((32, 256, 24, 24))
    gsop = GSoP(256)
    out = gsop(t)
    print(out.shape)