BiSeNet: 实时语义分割的双边分割网络.
BiSeNet模型设计初衷是提升实时语义分割的速度(105FPS在Titan XP上)和精度($68.4\%$的mIoU在Cityscapes上)。
在实时语义分割的算法中,大多数工作主要采用三种加速模型计算的方法:
- 第一是限制输入大小,通过剪裁或者调整大小来降低计算的复杂度。这也是大部分工作最初的思路,但是这种方式会丢失空间上的部分细节,尤其是边缘细节。
- 第二是减少模型的通道,把模型的通道数缩减到一定的值,比如某个阶段$2048$个通道,直接缩小到$128$。这样缩小肯定会丢失一些信息,尤其是在较浅层,信息比较集中且重要的时候,会削弱空间上的一些信息。
- 第三是删去网络后面几层,让深层网络变浅一点,这会导致模型的感受野不大,导致一些物体分割不精确。
为了提高模型的精度,很多模型都借鉴了Unet中的U型结构,通过skip-connection融合骨干网络中的分层特征,填充细节来帮助分辨率恢复。不过这种方式会引入更多的计算。
本文作者在BiSeNet中设计了一个双边结构,分别为空间路径(Spatial Path)和上下文路径(Context Path)。通过一个特征融合模块(FFM)将两个路径的特征进行融合,得到分割结果。
⚪ 空间路径 Spatial Path
很多模型试图保留输入图像的原始分辨率,用空洞卷积的方式来编码空间信息,尽量扩大感受野;还有一些方法通过空间金字塔池化或者用大卷积核来捕捉空间信息,扩大感受野。空间信息和感受野对于模型精度的影响较大,但却很难同时满足两者,毕竟还要考虑速度问题。如果使用小尺寸的图像就会丢失信息。
因此在BiSeNet中,作者设计了一个简单但有效的快速下采样的空间路径,通过3个Conv+BN+ReLU的组合层将原图快速下采样8倍(通过卷积层的步幅来调整),保留空间信息的同时,速度却不慢。
class SpatialPath(nn.Module):
def __init__(self):
super(SpatialPath, self).__init__()
self.downpath = nn.Sequential(
nn.Conv2d(3, 64, 7, stride=2, padding=3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, 3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, 3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, 1),
nn.BatchNorm2d(128),
nn.ReLU(),
)
def forward(self, x):
return self.downpath(x)
⚪ 上下文路径 Context Path
空间路径能够编码足够的空间信息,但是需要更大的感受野,因此作者设计了一个Context Path来提供上下文信息,扩大感受野。
在这个路径中,可以通过ResNet作为backbone来快速下采样到16倍和32倍,并且作者设计了一个半U的结构,也就是只使用16x和32x下采样倍率的特征图,在保留信息的同时,不增加过多的计算量。每一个特征图都通过一个Attention Refinement Module(ARM)通过通道注意力突出特征。
在32x特征图的下方,作者还设计了一个全局池化的小模块,计算一个池化后的向量,加到32x特征图的ARM输出中。
class ARM(nn.Module):
def __init__(self, in_channels, out_channels):
super(ARM, self).__init__()
self.reduce_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
self.module = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(out_channels, out_channels, 1),
nn.BatchNorm2d(out_channels),
nn.Sigmoid()
)
def forward(self, x):
return self.reduce_conv(x)*self.module(x)
class ContextPath(nn.Module):
def __init__(self, out_channels=128):
super(ContextPath, self).__init__()
self.resnet = ResNet.resnet50(replace_stride_with_dilation=[1,2,4])
self.ARM16 = ARM(256, 128)
self.ARM32 = ARM(512, 128)
self.conv_head32 = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
self.conv_head16 = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
self.conv_avg = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(512, out_channels, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
self.up32 = nn.Upsample(scale_factor=2., mode="bilinear")
self.up16 = nn.Upsample(scale_factor=2., mode="bilinear")
def forward(self, x):
feat16, feat32 = self.resnet(x)
avg = self.conv_avg(feat32)
feat32_arm = self.ARM32(feat32) + avg
feat32_up = self.up32(feat32_arm)
feat32_up = self.conv_head32(feat32_up)
feat16_arm = self.ARM16(feat16) + feat32_up
feat16_up = self.up16(feat16_arm)
feat16_up = self.conv_head16(feat16_up)
return feat16_up, feat32_up
⚪ 特征融合模块 FFM
FFM模块用于编码两个分支的特征,设计了一个类似注意力机制的融合模块,编码空间路径(低级别信息)和上下文路径(高级别信息)的输出。最后将结果上采样8倍得到原图。
class FFM(nn.Module):
def __init__(self, channels=128):
super(FFM, self).__init__()
self.fuse = nn.Sequential(
nn.Conv2d(2*channels, channels, 1),
nn.BatchNorm2d(channels),
nn.ReLU()
)
self.skip_forward = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels//4, 1),
nn.ReLU(),
nn.Conv2d(channels//4, channels, 1),
nn.Sigmoid()
)
def forward(self, SP_input, CP_input):
x = torch.cat([SP_input, CP_input], 1)
x = self.fuse(x)
identify = self.skip_forward(x)
out = torch.mul(x, identify) + x
return out
⚪ BiSeNet
BiSeNet网络的整体结构如下:
class BiSeNet(nn.Module):
def __init__(self, num_classes):
super(BiSeNet, self).__init__()
self.num_classes = num_classes
self.SpatialPath = SpatialPath()
self.ContexPath = ContextPath()
self.FFM = FFM()
self.cls_seg = nn.Sequential(
nn.Conv2d(128, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Upsample(scale_factor=8., mode="bilinear"),
nn.Conv2d(128, self.num_classes, 3, padding=1),
)
def forward(self, x):
b, c, h, w = x.size()
SP_out = self.SpatialPath(x)
CP_out16, CP_Out32 = self.ContexPath(x)
FFM_out = self.FFM(SP_out, CP_out16)
return self.cls_seg(FFM_out)