APCNet: 语义分割的自适应金字塔上下文网络.

目前的图像分割模型存在以下问题:

本文提出了端到端的APCNet模型,能够进行全局信息的融合,以及提升多尺度检测效果。APCNet的金字塔层由若干个ACM(Adaptive Context Module)模块构成,每一个ACM模块接收一个scale参数$s$,来确定区域大小。

ACM模块计算每个局部位置的上下文向量,并将这个向量加权到特征图上,实现聚合上下文信息的作用。ACM由两个分支构成,分别是GLA分支和Aggregate分支。

GLA分支中,backbone输出的特征图记为$X$,$X$先经过一个1x1Conv来得到一个特征映射$x$,通过一个空间全局池化,将$x$映射成一个全局信息向量(Global Information)$g(X)$。随后将$x$和$g(X)$相加后通过一个1x1ConvSigmoid激活,生成一个GLA向量$\alpha^s$,将这个向量reshape后得到上下文向量。

Aggregate分支中, 特征图$X$通过AdaptivePooling(size=s)Conv(kernel size=1x1)reshape后得到的形状为$s^2\times 512$大小的$y^s$,与GLA分支中的GLA向量$\alpha^s$进行矩阵乘法,生成$hw\times 512$的结果,完成初步的特征融合。最后reshape成原始大小,并与GLA部分的残差相加,最终输出总的融合结果。

class ACMModle(nn.Module):
    def __init__(self, in_channels=2048, channels=512, pool_scale=1, fusion=True):
        super(ACMModle, self).__init__()
        self.pool_scale = pool_scale
        self.in_channels = in_channels
        self.channels = channels
        self.fusion = fusion
        
        # Global Information vector
        self.reduce_Conv = nn.Conv2d(self.in_channels, self.channels, 1)
        self.reduce_Pool_Conv = nn.Conv2d(self.in_channels, self.channels, 1)
        
        self.residual_conv = nn.Conv2d(self.channels, self.channels, 1)
        self.global_info = nn.Conv2d(self.channels, self.channels, 1)
        self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0)
        
        if self.fusion:
            self.fusion_conv = nn.Conv2d(self.channels, self.channels, 1)
 
    def forward(self, X):
        batch_size, c, h, w = X.shape

        x = self.reduce_Conv(X)
        GI = self.global_info(F.adaptive_avg_pool2d(x, 1))
        GI = torchvision.transforms.Resize(x.shape[2:])(GI)
        Affinity_matrix = self.gla(x + GI).permute(0, 2, 3, 1).reshape(batch_size, -1, self.pool_scale**2)
        Affinity_matrix = F.sigmoid(Affinity_matrix)
        
        pooled_x = F.adaptive_avg_pool2d(X, self.pool_scale)
        pooled_x = pooled_x.view(batch_size, -1, self.pool_scale**2).permute(0, 2, 1).contiguous()

        MatrixProduct = torch.matmul(Affinity_matrix, pooled_x)
        MatrixProduct = MatrixProduct.permute(0, 2, 1).contiguous()
        MatrixProduct = MatrixProduct.view(batch_size, c, h, w)
        MatrixProduct = self.residual_conv(MatrixProduct)
        Z_out = F.relu(MatrixProduct + x)
        
        if self.fusion:
            Z_out = self.fusion_conv(Z_out)
        return Z_out
    
    
class ACMModuleList(nn.ModuleList):
    def __init__(self, pool_scales = [1,2,3,6], in_channels = 2048, channels = 512):
        super(ACMModuleList, self).__init__()
        self.pool_scales = pool_scales
        self.in_channels = in_channels
        self.channels = channels
        
        for pool_scale in pool_scales:
            self.append(
                ACMModle(in_channels, channels, pool_scale)
            )
            
    def forward(self, x):
        out = []
        for ACM in self:
            ACM_out = ACM(x)
            out.append(ACM_out)
        return out
    

class APCNet(nn.Module):
    def __init__(self, num_classes):
        super(APCNet, self).__init__()
        self.num_classes = num_classes
        self.backbone = ResNet.resnet50(replace_stride_with_dilation=[1,2,4])
        self.in_channels = 2048
        self.channels = 512
        self.ACM_pyramid = ACMModuleList(pool_scales=[1,2,3,6], in_channels=self.in_channels, channels=self.channels)
        self.conv1 = nn.Sequential(
            nn.Conv2d(4*self.channels + self.in_channels, self.channels, 3, padding=1),
            nn.BatchNorm2d(self.channels),
            nn.ReLU()
        )
        self.cls_conv = nn.Conv2d(self.channels, self.num_classes, 3, padding=1)
        
    def forward(self, x):
        x = self.backbone(x)
        ACM_out = self.ACM_pyramid(x)
        ACM_out.append(x)
        x = torch.cat(ACM_out, dim=1)
        x = self.conv1(x)
        x = Resize((8*x.shape[-2], 8*x.shape[-1]))(x)
        x = self.cls_conv(x)
        return x