UPerNet: 场景理解的统一感知解析.

人类在识别物体上往往是通过多角度多层次的观察来得出物体类别的。而图像分割领域的Multi-task learning的数据集较少,同时制作也较为困难,因为对于不同任务的数据标签是异质的。比如对于场景分析的ADE20K数据集来说,所有注释都是像素级别的对象;而对于描述纹理信息的数据集DTD(Describe Texture Dataset),标注都是图像级别的。这成为了数据集建立的瓶颈所在。

为了解决缺乏Multi-task数据集的问题,作者构建Broadly and Densely Labeled Dataset (Broden)来统一了ADE20KPascal-ContextPascal-PartOpenSurfaces、和Describable Textures Dataset (DTD)这几个数据集。这些数据集中包含了各种场景、对象、对象的部分组成件和材料。作者对类别不均衡问题做了进一步处理,包括删除出现次数少于$50$张图像的类别、删除像素数少于$50000$的类别。总之,作者构建了一个十分宏大的Multi-task数据集,总共62,262张图像。

UPerNet做了一个Multi-task learning的任务示范,创建了一个多任务的数据集。合理设计了UPerNet的主干部分和检测头部分用于不同任务的分类。UPerNet的模型设计总体基于FPN(Feature Pyramid Network)PPM(Pyramid Pooling Module)。作者为每一个task设计了不同的检测头。

下面实现UPerNet的语义分割部分:

# Encoder 采用ResNet,返回每个模块的输出特征
# Decoder = FPN+PPM

class PPM(nn.ModuleList):
    def __init__(self, pool_sizes, in_channels, out_channels):
        super(PPM, self).__init__()
        for pool_size in pool_sizes:
            self.append(
                nn.Sequential(
                    nn.AdaptiveMaxPool2d(pool_size),
                    nn.Conv2d(in_channels, out_channels, kernel_size=1),
                )
            )     
            
    def forward(self, x):
        out_puts = []
        for ppm in self:
            ppm_out = nn.functional.interpolate(ppm(x), size=(x.size(2), x.size(3)), mode='bilinear', align_corners=True)
            out_puts.append(ppm_out)
        return out_puts


class PPMHEAD(nn.Module):
    def __init__(self, in_channels, out_channels, pool_sizes = [1, 2, 3, 6]):
        super(PPMHEAD, self).__init__()
        self.psp_modules = PPM(pool_sizes, in_channels, out_channels)
        self.final = nn.Sequential(
            nn.Conv2d(in_channels + len(pool_sizes)*out_channels, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
        
    def forward(self, x):
        out = self.psp_modules(x)
        out.append(x)
        out = torch.cat(out, 1)
        out = self.final(out)
        return out


class FPNHEAD(nn.Module):
    def __init__(self, channels=2048, out_channels=256):
        super(FPNHEAD, self).__init__()
        self.PPMHead = PPMHEAD(in_channels=channels, out_channels=out_channels)
        
        self.Conv_fuse1 = nn.Sequential(
            nn.Conv2d(channels//2, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        self.Conv_fuse1_ = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        self.Conv_fuse2 = nn.Sequential(
            nn.Conv2d(channels//4, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )    
        self.Conv_fuse2_ = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        
        self.Conv_fuse3 = nn.Sequential(
            nn.Conv2d(channels//8, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ) 
        self.Conv_fuse3_ = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    
        self.fuse_all = nn.Sequential(
            nn.Conv2d(out_channels*4, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        
        self.conv_x1 = nn.Conv2d(out_channels, out_channels, 1)
 
    def forward(self, input_fpn):
        # b, 512, 7, 7
        x1 = self.PPMHead(input_fpn[-1])
 
        x = nn.functional.interpolate(x1, size=(x1.size(2)*2, x1.size(3)*2),mode='bilinear', align_corners=True)
        x = self.conv_x1(x) + self.Conv_fuse1(input_fpn[-2])
        x2 = self.Conv_fuse1_(x)
        
        x = nn.functional.interpolate(x2, size=(x2.size(2)*2, x2.size(3)*2),mode='bilinear', align_corners=True)
        x = x + self.Conv_fuse2(input_fpn[-3])
        x3 = self.Conv_fuse2_(x)  
 
        x = nn.functional.interpolate(x3, size=(x3.size(2)*2, x3.size(3)*2),mode='bilinear', align_corners=True)
        x = x + self.Conv_fuse3(input_fpn[-4])
        x4 = self.Conv_fuse3_(x)
 
        x1 = F.interpolate(x1, x4.size()[-2:],mode='bilinear', align_corners=True)
        x2 = F.interpolate(x2, x4.size()[-2:],mode='bilinear', align_corners=True)
        x3 = F.interpolate(x3, x4.size()[-2:],mode='bilinear', align_corners=True)
 
        x = self.fuse_all(torch.cat([x1, x2, x3, x4], 1))
        return x


class UPerNet(nn.Module):
    def __init__(self, num_classes):
        super(UPerNet, self).__init__()
        self.backbone = ResNet.resnet50(replace_stride_with_dilation=[1,2,4])
        self.decoder = FPNHEAD()
        self.cls_seg = nn.Sequential(
            nn.Conv2d(256, num_classes, kernel_size=3, padding=1),
        )
        
    def forward(self, x):
        input_fpn = self.backbone(x) 
        x = self.decoder(input_fpn)
        
        x = nn.functional.interpolate(x, size=(x.size(2)*4, x.size(3)*4),mode='bilinear', align_corners=True)
        x = self.cls_seg(x)
        return x