SCNet:通过自校正卷积改进卷积神经网络.

本文作者提出了一种自校正(Self-calibrated)卷积(本质是多个卷积注意力组合的模块),用于替换基本的卷积结构,在不增加额外参数和计算量的情况下,可以自适应地在每个空间位置周围建立了远程空间和通道间的相互关系,达到扩增卷积感受野的目的,进而增强输出特征的多样性和区分度。

自校正卷积把输入特征沿通道维度拆分成两部分,一部分直接应用标准的卷积操作;另一部分在两个不同的尺度空间中进行卷积特征转换:原始特征空间和下采样后具有较小分辨率的隐空间。下采样过程采用步长为$4$的平均池化,由于下采样后的特征具有较大的感受野,因此在较小的特征隐空间中进行变换,然后通过线性插值进行上采样,以指导原始特征空间中的特征变换过程。

与传统的卷积相比,通过采用自校正操作允许每个空间位置收集周围的上下文信息,还可以对通道间的依赖性进行建模,在某种程度上避免了来自无关区域的某些无用信息。因此可以有效地扩大卷积层的感受野。如下图所示,自校正卷积层能够编码更大但更准确的区域,可以准确地定位目标物体。

import torch.nn as nn
import torch.nn.functional as F

class SCConv(nn.Module):
    def __init__(self, channels, pooling_r=4):
        super(SCConv, self).__init__()
        mid_channel = channels // 2
        self.k1 = nn.Sequential(
                    nn.Conv2d(mid_channel, mid_channel, kernel_size=3,
                              stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(mid_channel),
                    )
        self.k2 = nn.Sequential(
                    nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r), 
                    nn.Conv2d(mid_channel, mid_channel, kernel_size=3,
                              stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(mid_channel),
                    )
        self.k3 = nn.Sequential(
                    nn.Conv2d(mid_channel, mid_channel, kernel_size=3,
                              stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(mid_channel),
                    )
        self.k4 = nn.Sequential(
                    nn.Conv2d(mid_channel, mid_channel, kernel_size=3,
                              stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(mid_channel),
                    )

    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        x1_down = self.k2(x1)
        x1_up = F.interpolate(x1_down, x1.size()[2:])
        out = torch.sigmoid(torch.add(x1, x1_up))
        out = torch.mul(self.k3(x1), out)
        out = self.k4(out)
        y = torch.cat([out, self.k1(x2)], dim=1)
        return y
    
x = torch.rand((16, 256, 64, 64))
scnet = SCConv(256)
print(scnet(x).shape) # torch.Size([16, 256, 64, 64])