DMNet: 语义分割的动态多尺度滤波器.

本文提出了端到端的DMNet模型,可以利用动态多尺度的过滤器对语义进行细分,相对于之前模型参数固定的方法,DMNet可以对图像的内容进行自适应的变化。 DMNet模型使用了动态卷积模块DCM,来捕获多尺度语义信息,每一个DCM模块都可以处理与输入尺寸相关的比例变化。

DCM模块的目标是自适应捕获输入图像的特定比例表示。DCM模块中的上下文感知过滤器(Context-aware filters)中嵌入了丰富的内容和高级语义信息,而且这些filters能够适应输入的图像,捕获图像内部的不同尺寸信息。输入特征$x$经过一个卷积层来减少通道数,然后经过一个AdaptiveAvgPooling(k),$k$值是自定义的量,经过卷积后生成k×k×512大小的$g_k(x)$,最后用一个Depth-wise conv将上下两个分支的特征图融合得到DCM模块的输出。

class DCMModle(nn.Module):
    def __init__(self, in_channels=2048, channels=512, filter_size=1, fusion=True):
        super(DCMModle, self).__init__()
        self.filter_size = filter_size
        self.in_channels = in_channels
        self.channels = channels
        self.fusion = fusion
        
        # Global Information vector
        self.reduce_Conv = nn.Conv2d(self.in_channels, self.channels, 1)
        self.filter = nn.AdaptiveAvgPool2d(self.filter_size)
        
        self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1, 0)
        
        self.activate = nn.Sequential(nn.BatchNorm2d(self.channels),
                                     nn.ReLU()
                                     )
        if self.fusion:
            self.fusion_conv = nn.Conv2d(self.channels, self.channels, 1)
 
    def forward(self, x):
        b, c, h, w = x.shape
        generted_filter = self.filter_gen_conv(self.filter(x)).view(b, self.channels, self.filter_size, self.filter_size)
        x = self.reduce_Conv(x)
        
        c = self.channels
        # [1, b * c, h, w], c = self.channels
        x = x.view(1, b * c, h, w)
        # [b * c, 1, filter_size, filter_size]
        generted_filter = generted_filter.view(b * c, 1, self.filter_size,
                                               self.filter_size)
        
        pad = (self.filter_size - 1) // 2
        if (self.filter_size - 1) % 2 == 0:
            p2d = (pad, pad, pad, pad)
        else:
            p2d = (pad + 1, pad, pad + 1, pad)
        x = F.pad(input=x, pad=p2d, mode='constant', value=0)
        
        # [1, b * c, h, w]
        output = nn.functional.conv2d(input=x, weight=generted_filter, groups=b * c)
        # [b, c, h, w]
        output = output.view(b, c, h, w)
        output = self.activate(output)
        if self.fusion:
            output = self.fusion_conv(output)
        return output
    
    
class DCMModuleList(nn.ModuleList):
    def __init__(self, filter_sizes = [1,2,3,6], in_channels = 2048, channels = 512):
        super(DCMModuleList, self).__init__()
        self.filter_sizes = filter_sizes
        self.in_channels = in_channels
        self.channels = channels
        
        for filter_size in self.filter_sizes:
            self.append(
                DCMModle(self.in_channels, self.channels, filter_size)
            )
            
    def forward(self, x):
        out = []
        for DCM in self:
            DCM_out = DCM(x)
            out.append(DCM_out)
        return out
    

class DMNet(nn.Module):
    def __init__(self, num_classes):
        super(DMNet, self).__init__()
        self.num_classes = num_classes
        self.backbone = ResNet.resnet50(replace_stride_with_dilation=[1,2,4])
        self.in_channels = 2048
        self.channels = 512
        self.DMNet_pyramid = DCMModuleList(filter_sizes=[1,2,3,6], in_channels=self.in_channels, channels=self.channels)
        self.conv1 = nn.Sequential(
            nn.Conv2d(4*self.channels + self.in_channels, self.channels, 3, padding=1),
            nn.BatchNorm2d(self.channels),
            nn.ReLU()
        )
        self.cls_conv = nn.Conv2d(self.channels, self.num_classes, 3, padding=1)
        
    def forward(self, x):
        x = self.backbone(x)
        DM_out = self.DMNet_pyramid(x)
        DM_out.append(x)
        x = torch.cat(DM_out, dim=1)
        x = self.conv1(x)
        x = Resize((8*x.shape[-2], 8*x.shape[-1]))(x)
        x = self.cls_conv(x)
        return x