GSoP-Net:全局二阶池化卷积网络.
本文提出了全局二阶池化(global second-order pooling, GSoP)模块,沿着特征的通道维度捕获全局二阶统计信息,可以方便地插入到现有的网络架构中,以较小的计算开销提高网络的性能。
GSoP把输入特征$x$沿通道维度进行降维后,计算通道之间的协方差矩阵,然后通过按行卷积把协方差特征转化为一个向量,并通过全连接层(由$1 \times 1$卷积实现)构造为权重向量,并作用于输入特征。
所设计的GSoP模块可以即插即用到网络的任意位置。通过在网络的中间层中引入该模块,可以在早期对整体图像进行高阶统计建模,增强了网络的非线性建模能力。
Pytorch代码如下:
import torch.nn as nn
class GSoP(nn.Module):
def __init__(self, in_channel, mid_channel=128):
super(GSoP, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channel, mid_channel, 1, 1, 0),
nn.BatchNorm2d(mid_channel),
nn.ReLU(inplace=True),)
# 通过组卷积实现按行卷积
self.row_wise_conv = nn.Sequential(
nn.Conv2d(
mid_channel, 4*mid_channel,
kernel_size=(mid_channel, 1),
groups = mid_channel),
nn.BatchNorm2d(4*mid_channel),)
self.conv2 = nn.Sequential(
nn.Conv2d(4*mid_channel, in_channel, 1, 1, 0),
nn.BatchNorm2d(in_channel),
nn.Sigmoid())
def forward(self, x):
# [B, C', H, W]
feas = self.conv1(x) # [B, C, H, W]
# 计算协方差矩阵
B, C = feas.shape[0], feas.shape[1]
for i in range(B):
fea = feas[i].view(C, -1).permute(1, 0) # [HW, C]
fea = fea - torch.mean(fea, axis=0) # [HW, C]
cov = torch.matmul(fea.T, fea).unsqueeze(0) # [1, C, C]
if i == 0:
covs = cov
else:
covs = torch.cat([covs, cov], dim=0) # [B, C, C]
covs = covs.unsqueeze(-1) # [B, C, C, 1]
out = self.row_wise_conv(covs) # [B, 4C, 1, 1]
out = self.conv2(out) # [B, C', 1, 1]
return x * out
if __name__ == "__main__":
t = torch.ones((32, 256, 24, 24))
gsop = GSoP(256)
out = gsop(t)
print(out.shape)