DFANet: 实时语义分割的深度特征聚合.
DFANet是一种用于实时语义分割的深度特征聚合算法,主要实现目标是提出一个能够在资源受限条件下快速实现分割的模型。与其他轻量级的模型相比,DFANet拥有更少的计算量和更高的分割精度:
作者首先回顾了图像分割模型的发展规律,提出以下几个观点:
- U-Net网络中利用高分辨率的特征图来帮助恢复上采样时花费了大量的时间。
- 如果单纯的减小输入图像的尺寸来减少时间消耗,就容易导致失去一些重要的边界特征、小物体的细节特征等。
- 如果使用浅层的网络,则会导致网络提取特征的能力不足。
- 有一些模型采用多分辨率分支结构来融合空间信息和上下文信息(图a),但是这些分支在高分辨率图像上的反复处理也会大幅限制速度,而且这些分支之间往往相互独立,限制了模型的学习能力。
- 语义分割网络的backbone多为Resnet,为了实现网络轻量化,使用带有深度可分离卷积的Xception是一个不错的选择。
- 主流的语义分割结构采用金字塔结构SPP(图b),虽然能够富集上下文特征,但是计算量十分大。
- 很多语义分割结构中特征重用的思想具有启发性(图c)。
- SENet的通道注意力启发了EncNet,作者也采用了EncNet中的上下文编码模块来对通道做通道注意力机制。
基于以上观点,作者以修改过的Xception为backbone网络,设计了一种多分支的框架DFANet来融合空间细节和上下文信息。
DFANet采用编码器-解码器结构。
在Encoder结构中,分为3个分支,每个分支都包含了三个encode模块和一个实现通道注意力的fc attention模块,而每一个模块的输出都和下一个分支中的输入相融合。三个分支的模块输出和输入互相融合之后,每个分支的enc2输出和fc attention输出都跳跃连接到Decoder结构。
在Decoder结构中,主要接受了每个分支的enc2模块输出和fc attention模块输出,其中,三个enc2模块输出相加后经过一个卷积再和三个fc attention模块的输出相加,经过上采样后得到分割结果。
这种层级之间的特征融合,能够将低级的特征和空间信息传递到最后的语义理解中,通过多层次的特征融合和通道注意力机制,帮助各个阶段的特征完善,以提升分割精度。
DFANet编码器中的每个分支构建如下:
class depthwise_separable_conv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,padding=1, dilation=1):
super(depthwise_separable_conv, self).__init__()
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels,bias=False)
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
def forward(self, x):
out = self.depthwise(x)
out = self.pointwise(out)
return out
class XceptionABlock(nn.Module):
"""
Base Block for XceptionA mentioned in DFANet paper.
"""
def __init__(self, in_channels, out_channels, stride=1):
super(XceptionABlock, self).__init__()
self.conv1 = nn.Sequential(
depthwise_separable_conv(in_channels, out_channels //4, stride=stride),
nn.BatchNorm2d(out_channels // 4),
nn.ReLU(),
)
self.conv2 = nn.Sequential(
depthwise_separable_conv(out_channels //4, out_channels //4),
nn.BatchNorm2d(out_channels // 4),
nn.ReLU(),
)
self.conv3 = nn.Sequential(
depthwise_separable_conv(out_channels //4, out_channels),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False)
def forward(self, x):
residual = self.conv1(x)
residual = self.conv2(residual)
residual = self.conv3(residual)
identity = self.skip(x)
return residual + identity
class enc(nn.Module):
"""
encoder block
"""
def __init__(self, in_channels, out_channels, stride=2, num_repeat=3):
super(enc, self).__init__()
stacks = [XceptionABlock(in_channels, out_channels, stride=2)]
for x in range(num_repeat - 1):
stacks.append(XceptionABlock(out_channels, out_channels))
self.build = nn.Sequential(*stacks)
def forward(self, x):
x = self.build(x)
return x
class ChannelAttention(nn.Module):
"""
channel attention module
"""
def __init__(self, in_channels, out_channels):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, 1000, bias=False),
nn.ReLU(),
nn.Linear(1000, out_channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b,c,_,_ = x.size()
y = self.avg_pool(x).view(b,c)
y = self.fc(y).view(b,c,1,1)
return x*y.expand_as(x)
class SubBranch(nn.Module):
"""
create 3 Sub Branches in DFANet
channel_cfg: the chnnels of each enc stage
branch_index: the index of each sub branch
"""
def __init__(self, channel_cfg, branch_index):
super(SubBranch, self).__init__()
self.enc2 = enc(channel_cfg[0], 48, num_repeat=3)
self.enc3 = enc(channel_cfg[1],96,num_repeat=6)
self.enc4 = enc(channel_cfg[2],192,num_repeat=3)
self.fc_atten = ChannelAttention(192, 192)
self.branch_index = branch_index
def forward(self,x0,*args):
out0=self.enc2(x0)
if self.branch_index in [1,2]:
out1 = self.enc3(torch.cat([out0,args[0]],1))
out2 = self.enc4(torch.cat([out1,args[1]],1))
else:
out1 = self.enc3(out0)
out2 = self.enc4(out1)
out3 = self.fc_atten(out2)
return [out0, out1, out2, out3]
DFANet的整体结构构建如下:
class DFA_Encoder(nn.Module):
def __init__(self, channel_cfg):
super(DFA_Encoder, self).__init__()
self.conv1=nn.Sequential(
nn.Conv2d(in_channels=3,out_channels=8,kernel_size=3,stride=2,padding=1,bias=False),
nn.BatchNorm2d(num_features=8),
nn.ReLU()
)
self.branch0=SubBranch(channel_cfg[0],branch_index=0)
self.branch1=SubBranch(channel_cfg[1],branch_index=1)
self.branch2=SubBranch(channel_cfg[2],branch_index=2)
def forward(self, x):
x = self.conv1(x)
x0,x1,x2,x5=self.branch0(x)
x3=F.interpolate(x5,x0.size()[2:],mode='bilinear',align_corners=True)
x1,x2,x3,x6=self.branch1(torch.cat([x0,x3],1),x1,x2)
x4=F.interpolate(x6,x1.size()[2:],mode='bilinear',align_corners=True)
x2,x3,x4,x7=self.branch2(torch.cat([x1,x4],1),x2,x3)
return [x0,x1,x2,x5,x6,x7]
class DFA_Decoder(nn.Module):
def __init__(self, decode_channels, num_classes):
super(DFA_Decoder, self).__init__()
self.conv0 = nn.Sequential(
nn.Conv2d(in_channels=48, out_channels=decode_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(decode_channels),
nn.ReLU()
)
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=48, out_channels=decode_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(decode_channels),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=48, out_channels=decode_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(decode_channels),
nn.ReLU()
)
self.conv3 = nn.Sequential(
nn.Conv2d(in_channels=192, out_channels=decode_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(decode_channels),
nn.ReLU()
)
self.conv4 = nn.Sequential(
nn.Conv2d(in_channels=192, out_channels=decode_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(decode_channels),
nn.ReLU()
)
self.conv5 = nn.Sequential(
nn.Conv2d(in_channels=192, out_channels=decode_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(decode_channels),
nn.ReLU()
)
self.conv_add = nn.Sequential(
nn.Conv2d(in_channels=decode_channels, out_channels=decode_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(decode_channels),
nn.ReLU()
)
self.conv_cls = nn.Conv2d(in_channels=decode_channels, out_channels=num_classes, kernel_size=3, padding=1, bias=False)
def forward(self, x0, x1, x2, x3, x4, x5):
x0 = self.conv0(x0)
x1 = F.interpolate(self.conv1(x1),x0.size()[2:],mode='bilinear',align_corners=True)
x2 = F.interpolate(self.conv2(x2),x0.size()[2:],mode='bilinear',align_corners=True)
x3 = F.interpolate(self.conv3(x3),x0.size()[2:],mode='bilinear',align_corners=True)
x4 = F.interpolate(self.conv5(x4),x0.size()[2:],mode='bilinear',align_corners=True)
x5 = F.interpolate(self.conv5(x5),x0.size()[2:],mode='bilinear',align_corners=True)
x_shallow = self.conv_add(x0+x1+x2)
x = self.conv_cls(x_shallow+x3+x4+x5)
x = F.interpolate(x,scale_factor=4,mode='bilinear',align_corners=True)
return x
class DFANet(nn.Module):
def __init__(self,channel_cfg,decoder_channel=64,num_classes=33):
"""
ch_cfg=[[8,48,96],
[240,144,288],
[240,144,288]]
"""
super(DFANet,self).__init__()
self.encoder=DFA_Encoder(channel_cfg)
self.decoder=DFA_Decoder(decoder_channel,num_classes)
def forward(self,x):
x0,x1,x2,x3,x4,x5=self.encoder(x)
x=self.decoder(x0,x1,x2,x3,x4,x5)
return x