DMNet: 语义分割的动态多尺度滤波器.
本文提出了端到端的DMNet模型,可以利用动态多尺度的过滤器对语义进行细分,相对于之前模型参数固定的方法,DMNet可以对图像的内容进行自适应的变化。 DMNet模型使用了动态卷积模块DCM,来捕获多尺度语义信息,每一个DCM模块都可以处理与输入尺寸相关的比例变化。
DCM模块的目标是自适应捕获输入图像的特定比例表示。DCM模块中的上下文感知过滤器(Context-aware filters)中嵌入了丰富的内容和高级语义信息,而且这些filters能够适应输入的图像,捕获图像内部的不同尺寸信息。输入特征$x$经过一个卷积层来减少通道数,然后经过一个AdaptiveAvgPooling(k),$k$值是自定义的量,经过卷积后生成k×k×512大小的$g_k(x)$,最后用一个Depth-wise conv将上下两个分支的特征图融合得到DCM模块的输出。
class DCMModle(nn.Module):
def __init__(self, in_channels=2048, channels=512, filter_size=1, fusion=True):
super(DCMModle, self).__init__()
self.filter_size = filter_size
self.in_channels = in_channels
self.channels = channels
self.fusion = fusion
# Global Information vector
self.reduce_Conv = nn.Conv2d(self.in_channels, self.channels, 1)
self.filter = nn.AdaptiveAvgPool2d(self.filter_size)
self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1, 0)
self.activate = nn.Sequential(nn.BatchNorm2d(self.channels),
nn.ReLU()
)
if self.fusion:
self.fusion_conv = nn.Conv2d(self.channels, self.channels, 1)
def forward(self, x):
b, c, h, w = x.shape
generted_filter = self.filter_gen_conv(self.filter(x)).view(b, self.channels, self.filter_size, self.filter_size)
x = self.reduce_Conv(x)
c = self.channels
# [1, b * c, h, w], c = self.channels
x = x.view(1, b * c, h, w)
# [b * c, 1, filter_size, filter_size]
generted_filter = generted_filter.view(b * c, 1, self.filter_size,
self.filter_size)
pad = (self.filter_size - 1) // 2
if (self.filter_size - 1) % 2 == 0:
p2d = (pad, pad, pad, pad)
else:
p2d = (pad + 1, pad, pad + 1, pad)
x = F.pad(input=x, pad=p2d, mode='constant', value=0)
# [1, b * c, h, w]
output = nn.functional.conv2d(input=x, weight=generted_filter, groups=b * c)
# [b, c, h, w]
output = output.view(b, c, h, w)
output = self.activate(output)
if self.fusion:
output = self.fusion_conv(output)
return output
class DCMModuleList(nn.ModuleList):
def __init__(self, filter_sizes = [1,2,3,6], in_channels = 2048, channels = 512):
super(DCMModuleList, self).__init__()
self.filter_sizes = filter_sizes
self.in_channels = in_channels
self.channels = channels
for filter_size in self.filter_sizes:
self.append(
DCMModle(self.in_channels, self.channels, filter_size)
)
def forward(self, x):
out = []
for DCM in self:
DCM_out = DCM(x)
out.append(DCM_out)
return out
class DMNet(nn.Module):
def __init__(self, num_classes):
super(DMNet, self).__init__()
self.num_classes = num_classes
self.backbone = ResNet.resnet50(replace_stride_with_dilation=[1,2,4])
self.in_channels = 2048
self.channels = 512
self.DMNet_pyramid = DCMModuleList(filter_sizes=[1,2,3,6], in_channels=self.in_channels, channels=self.channels)
self.conv1 = nn.Sequential(
nn.Conv2d(4*self.channels + self.in_channels, self.channels, 3, padding=1),
nn.BatchNorm2d(self.channels),
nn.ReLU()
)
self.cls_conv = nn.Conv2d(self.channels, self.num_classes, 3, padding=1)
def forward(self, x):
x = self.backbone(x)
DM_out = self.DMNet_pyramid(x)
DM_out.append(x)
x = torch.cat(DM_out, dim=1)
x = self.conv1(x)
x = Resize((8*x.shape[-2], 8*x.shape[-1]))(x)
x = self.cls_conv(x)
return x