DeepLab v3+: 图像语义分割中的扩张可分离卷积.
Deeplab v3+在Deeplab v3的基础上做了扩展和改进,其主要改进就是在编解码结构上使用了ASPP。Deeplab v3+可以视作是融合了语义分割两大流派的一项工作,即编解码+ASPP结构。 下图是(a)DeepLab v3和(c)DeepLab v3+的对比:
DeepLab v3+的Decoder部分使用了卷积网络中间层的特征映射:
DeepLab v3+采用了两种卷积网络结构,分别是Resnet 101和Xception,后者效果更好,其深度可分离卷积的设计使得分割网络更加高效。作者提出了一个修改的Xception模型,在深度可分离卷积上应用空洞卷积。
class ASPP_module(nn.ModuleList):
def __init__(self, in_channels, out_channels, dilation_list=[6, 12, 18, 24]):
super(ASPP_module, self).__init__()
self.dilation_list = dilation_list
for dia_rate in self.dilation_list:
self.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1 if dia_rate==1 else 3, dilation=dia_rate, padding=0 if dia_rate==1 else dia_rate),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
)
def forward(self, x):
outputs = []
for aspp_module in self:
outputs.append(aspp_module(x))
return torch.cat(outputs, 1)
class DeepLabV3Plus(nn.Module):
def __init__(self, num_classes):
super(DeepLabV3Plus, self).__init__()
self.backbone = ResNet() # 最后一个模块使用空洞卷积
self.ASPP_module = ASPP_module(512,256,[1,6,12,18])
self.low_feature = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(),
)
self.num_classes = num_classes
self.avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(512, 256, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True))
self.conv1 = nn.Sequential(
nn.Conv2d(256*5, 256, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(),
)
self.conv2 = nn.Sequential(
nn.Conv2d(256*2, self.num_classes, kernel_size=3, padding=1),
nn.BatchNorm2d(self.num_classes),
nn.ReLU(),
)
self.conv3 = nn.Sequential(
nn.Conv2d(self.num_classes, self.num_classes, kernel_size=3, padding=1),
)
def forward(self, x):
x = self.backbone(x)
low_feature = self.low_feature(x[-3])
x_1 = self.ASPP_module(x[-1])
x_2 = nn.functional.interpolate(self.avg_pool(x[-1]), size=(x[-1].size(2), x[-1].size(3)), mode='bilinear', align_corners=True)
x = torch.cat([x_1, x_2], 1)
x = self.conv1(x)
x = nn.functional.interpolate(input=x ,scale_factor=4,mode='bilinear')
x = torch.cat([x, low_feature], 1)
x = self.conv2(x)
x = nn.functional.interpolate(input=x ,scale_factor=4,mode='bilinear')
x = self.conv3(x)
return x