PSPNet: 金字塔场景解析网络.

PSPNet模型提出是为了解决场景解析问题,以提升场景分析中对于相似颜色、形状的物体的检测精度。在ADE20K数据集上进行实验时,主要发现有如下问题:

为解决以上问题,作者提出了金字塔池化模块(Pyramid Pooling Module, PPM)。在深层网络中,感受野的大小大致上体现了模型能获得的上下文新消息。尽管在理论上深层卷积网络的感受野已经大于图像尺寸,但是实际上会小得多。这就导致了很多网络不能充分的将上下文信息结合起来,于是作者提出了一种全局的先验方法:引入平均池化。

PPM模块并联了四个不同大小的平均池化层,将原始的特征图池化生成不同级别的特征图,经过卷积和上采样恢复到原始大小。这种操作聚合了多尺度的图像特征,融合了不同尺度和不同子区域之间的信息。最后,这个先验信息再和原始特征图进行相加,输入到最后的卷积模块完成预测。

PPM模块的实现如下:

class PPM(nn.ModuleList):
    def __init__(self, pool_sizes, in_channels, out_channels):
        super(PPM, self).__init__()
        self.pool_sizes = pool_sizes
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        for pool_size in pool_sizes:
            self.append(
                nn.Sequential(
                    nn.AdaptiveMaxPool2d(pool_size),
                    nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1),
                )
            )
            
    def forward(self, x):
        out_puts = []
        for ppm in self:
            ppm_out = nn.functional.interpolate(ppm(x), size=x.size()[-2:], mode='bilinear', align_corners=True)
            out_puts.append(ppm_out)

PSPNet的网络架构十分简单,backboneresnet网络,将原始图像下采样$8$倍成特征图,特征图输入到PPM模块,并与其输出相加,最后经过卷积和$8$倍双线性差值上采样得到结果。

此外作者还在网络中引入了辅助损失(auxiliary loss),在resnet101res4b22层引出一条FCN分支,用于计算辅助损失。论文里设置了赋值损失loss2的权重为$0.4$。

用于计算辅助损失的辅助头定义如下:

# 构建一个FCN分割头,用于计算辅助损失
class Aux_Head(nn.Module):
    def __init__(self, in_channels=1024, num_classes=3):
        super(Aux_Head, self).__init__()
        self.num_classes = num_classes
        self.in_channels = in_channels
 
        self.decode_head = nn.Sequential(
            nn.Conv2d(self.in_channels, self.in_channels//2, kernel_size=3, padding=1),
            nn.BatchNorm2d(self.in_channels//2),
            nn.ReLU(),            
            
            nn.Conv2d(self.in_channels//2, self.in_channels//4, kernel_size=3, padding=1),
            nn.BatchNorm2d(self.in_channels//4),
            nn.ReLU(),            
            
            nn.Conv2d(self.in_channels//4, self.num_classes, kernel_size=3, padding=1),
        )
        
    def forward(self, x):
        return self.decode_head(x)

PSPNet构建如下:

from torchvision.models import resnet50, resnet101
from torchvision.models._utils import IntermediateLayerGetter

class PSPHEAD(nn.Module):
    def __init__(self, in_channels, out_channels,pool_sizes = [1, 2, 3, 6],num_classes=3):
        super(PSPHEAD, self).__init__()
        self.pool_sizes = pool_sizes
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.psp_modules = PPM(self.pool_sizes, self.in_channels, self.out_channels)
        self.final = nn.Sequential(
            nn.Conv2d(self.in_channels + len(self.pool_sizes)*self.out_channels, self.out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(),
        )
        
    def forward(self, x):
        out = self.psp_modules(x)
        out.append(x)
        out = torch.cat(out, 1)
        out = self.final(out)

class Pspnet(nn.Module):
    def __init__(self, num_classes, aux_loss = True):
        super(Pspnet, self).__init__()
        self.num_classes = num_classes
        # backbone调用resnet50
        # 替换最后两个layer为dialation模式
        # 引出layer3的计算结果用于计算辅助损失。
        self.backbone = IntermediateLayerGetter(
            resnet50(pretrained=False, replace_stride_with_dilation=[False, True, True]),
            return_layers={'layer3':"aux" ,'layer4': 'stage4'}
        )
        self.aux_loss = aux_loss
        self.decoder = PSPHEAD(in_channels=2048, out_channels=512, pool_sizes = [1, 2, 3, 6], num_classes=self.num_classes)
        self.cls_seg = nn.Sequential(
            nn.Conv2d(512, self.num_classes, kernel_size=3, padding=1),
        )
        if self.aux_loss:
            self.aux_head = Aux_Head(in_channels=1024, num_classes=self.num_classes)
 
        
    def forward(self, x):
        _, _, h, w = x.size()
        feats = self.backbone(x) 
        x = self.decoder(feats["stage4"])
        x = self.cls_seg(x)
        x = nn.functional.interpolate(x, size=(h, w),mode='bilinear', align_corners=True)
 
        # 如果需要添加辅助损失
        if self.aux_loss:
            aux_output = self.aux_head(feats['aux'])
            aux_output = nn.functional.interpolate(aux_output, size=(h, w),mode='bilinear', align_corners=True)
            return {"output":x, "aux_output":aux_output}

        return {"output":x}