DeepLab v3: 重新评估图像语义分割中的扩张卷积.

Deeplab v3开始,Deeplab系列舍弃了CRF后处理模块,提出了更加通用的、适用任何网络的分割框架。DeepLabV3的作者做了两部分的工作:

  1. 探索更深层的模型下,空洞卷积的效果。
  2. 空洞空间金字塔池化 Atrous Spatial Pyramid Pooling(ASPP)的优化。

作者将ResNet深层的模块替换为空洞卷积,获得了比较大的感受野,而且可以捕获远端的信息。其中dilation rate的设计十分重要,不当的设计会造成精度降低。

此外作者对ASPP模块做了升级。在实验中发现,dilation rate组合不当的情况下,3x3的卷积核会退化成1x1的卷积。因此作者重新调整了rate组合,从V2中的$[6, 12, 18, 24]$改进成$[1, 6, 12, 18]$;此外作者认为空洞卷积损失了一定信息,因此增加了全局平均池化来保存全局的上下文信息:

#DeepLabV3版本的ASPP
class ASPP_module(nn.ModuleList):
    def __init__(self, in_channels, out_channels, dilation_list=[1, 6, 12, 18]):
        super(ASPP_module, self).__init__()
        self.dilation_list = dilation_list
        for dia_rate in self.dilation_list:
            self.append(
                nn.Sequential(
                    nn.Conv2d(
                        in_channels, out_channels, 
                        kernel_size=1 if dia_rate==1 else 3, 
                        dilation=dia_rate, padding=0 if dia_rate==1 else dia_rate),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU()
                )
            )
            
    def forward(self, x):
        outputs = []
        for aspp_module in self:
            outputs.append(aspp_module(x))
        return torch.cat(outputs, 1)

class DeepLabV3(nn.Module):
    def __init__(self, num_classes):
        super(DeepLabV3, self).__init__()
        self.num_classes = num_classes
        self.ASPP_module = ASPP_module(512,256,dilation_list=[1,6,12,18]) 
        self.backbone = ResNet()
        self.final = nn.Sequential(
            nn.Conv2d(256*4+256, 256, kernel_size=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, self.num_classes, kernel_size=1)
        )
        self.avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1)),
            nn.Conv2d(512, 256, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True))
        
    def forward(self, x):
        x = self.backbone(x)[-1]
        x_1 = self.ASPP_module(x)
        x_2 = nn.functional.interpolate(self.avg_pool(x), size=(x.size(2), x.size(3)), mode='bilinear', align_corners=True)
        x = torch.cat([x_1, x_2], 1)
        x = nn.functional.interpolate(input=x ,scale_factor=8,mode='bilinear', align_corners=True)
        x = self.final(x)