DeepLab v3+: 图像语义分割中的扩张可分离卷积.

Deeplab v3+Deeplab v3的基础上做了扩展和改进,其主要改进就是在编解码结构上使用了ASPPDeeplab v3+可以视作是融合了语义分割两大流派的一项工作,即编解码+ASPP结构。 下图是(a)DeepLab v3(c)DeepLab v3+的对比:

DeepLab v3+的Decoder部分使用了卷积网络中间层的特征映射:

DeepLab v3+采用了两种卷积网络结构,分别是Resnet 101Xception,后者效果更好,其深度可分离卷积的设计使得分割网络更加高效。作者提出了一个修改的Xception模型,在深度可分离卷积上应用空洞卷积。

class ASPP_module(nn.ModuleList):
    def __init__(self, in_channels, out_channels, dilation_list=[6, 12, 18, 24]):
        super(ASPP_module, self).__init__()
        self.dilation_list = dilation_list
        for dia_rate in self.dilation_list:
            self.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, kernel_size=1 if dia_rate==1 else 3, dilation=dia_rate, padding=0 if dia_rate==1 else dia_rate),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU()
                )
            )
            
    def forward(self, x):
        outputs = []
        for aspp_module in self:
            outputs.append(aspp_module(x))
        return torch.cat(outputs, 1)
 
class DeepLabV3Plus(nn.Module):
    def __init__(self, num_classes):
        super(DeepLabV3Plus, self).__init__()
        self.backbone = ResNet() # 最后一个模块使用空洞卷积
        self.ASPP_module = ASPP_module(512,256,[1,6,12,18])
        self.low_feature = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
        
        self.num_classes = num_classes
        
        self.avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 256, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True))
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(256*5, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(256*2, self.num_classes, kernel_size=3, padding=1),
            nn.BatchNorm2d(self.num_classes),
            nn.ReLU(),
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(self.num_classes, self.num_classes, kernel_size=3, padding=1),
        )    
        
    def forward(self, x):
        x = self.backbone(x)
        
        low_feature = self.low_feature(x[-3])    
        x_1 = self.ASPP_module(x[-1])
        x_2 = nn.functional.interpolate(self.avg_pool(x[-1]), size=(x[-1].size(2), x[-1].size(3)), mode='bilinear', align_corners=True) 
        x = torch.cat([x_1, x_2], 1)
        x = self.conv1(x)
        x = nn.functional.interpolate(input=x ,scale_factor=4,mode='bilinear')
        x = torch.cat([x, low_feature], 1)
        x = self.conv2(x)
        x = nn.functional.interpolate(input=x ,scale_factor=4,mode='bilinear')
        x = self.conv3(x)
        return x