BiSeNet: 实时语义分割的双边分割网络.

BiSeNet模型设计初衷是提升实时语义分割的速度(105FPSTitan XP上)和精度($68.4\%$的mIoUCityscapes上)。

在实时语义分割的算法中,大多数工作主要采用三种加速模型计算的方法:

  1. 第一是限制输入大小,通过剪裁或者调整大小来降低计算的复杂度。这也是大部分工作最初的思路,但是这种方式会丢失空间上的部分细节,尤其是边缘细节。
  2. 第二是减少模型的通道,把模型的通道数缩减到一定的值,比如某个阶段$2048$个通道,直接缩小到$128$。这样缩小肯定会丢失一些信息,尤其是在较浅层,信息比较集中且重要的时候,会削弱空间上的一些信息。
  3. 第三是删去网络后面几层,让深层网络变浅一点,这会导致模型的感受野不大,导致一些物体分割不精确。

为了提高模型的精度,很多模型都借鉴了Unet中的U型结构,通过skip-connection融合骨干网络中的分层特征,填充细节来帮助分辨率恢复。不过这种方式会引入更多的计算。

本文作者在BiSeNet中设计了一个双边结构,分别为空间路径(Spatial Path)上下文路径(Context Path)。通过一个特征融合模块(FFM)将两个路径的特征进行融合,得到分割结果。

⚪ 空间路径 Spatial Path

很多模型试图保留输入图像的原始分辨率,用空洞卷积的方式来编码空间信息,尽量扩大感受野;还有一些方法通过空间金字塔池化或者用大卷积核来捕捉空间信息,扩大感受野。空间信息和感受野对于模型精度的影响较大,但却很难同时满足两者,毕竟还要考虑速度问题。如果使用小尺寸的图像就会丢失信息。

因此在BiSeNet中,作者设计了一个简单但有效的快速下采样的空间路径,通过3Conv+BN+ReLU的组合层将原图快速下采样8倍(通过卷积层的步幅来调整),保留空间信息的同时,速度却不慢。

class SpatialPath(nn.Module):
    def __init__(self):
        super(SpatialPath, self).__init__()
        self.downpath = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            nn.Conv2d(64, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            nn.Conv2d(64, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            nn.Conv2d(64, 128, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )
        
    def forward(self, x):
        return self.downpath(x)

⚪ 上下文路径 Context Path

空间路径能够编码足够的空间信息,但是需要更大的感受野,因此作者设计了一个Context Path来提供上下文信息,扩大感受野。

在这个路径中,可以通过ResNet作为backbone来快速下采样到16倍和32倍,并且作者设计了一个半U的结构,也就是只使用16x32x下采样倍率的特征图,在保留信息的同时,不增加过多的计算量。每一个特征图都通过一个Attention Refinement Module(ARM)通过通道注意力突出特征。

32x特征图的下方,作者还设计了一个全局池化的小模块,计算一个池化后的向量,加到32x特征图的ARM输出中。

class ARM(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ARM, self).__init__()
        self.reduce_conv =  nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        self.module = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(out_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.reduce_conv(x)*self.module(x)
 
class ContextPath(nn.Module):
    def __init__(self, out_channels=128):
        super(ContextPath, self).__init__()
        self.resnet = ResNet.resnet50(replace_stride_with_dilation=[1,2,4])
        self.ARM16 = ARM(256, 128)
        self.ARM32 = ARM(512, 128)
        self.conv_head32 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        self.conv_head16 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        self.conv_avg = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        self.up32 = nn.Upsample(scale_factor=2., mode="bilinear")
        self.up16 = nn.Upsample(scale_factor=2., mode="bilinear")
        
    def forward(self, x):
        feat16, feat32 = self.resnet(x)
        avg = self.conv_avg(feat32)
        
        feat32_arm = self.ARM32(feat32) + avg
        feat32_up = self.up32(feat32_arm)
        feat32_up = self.conv_head32(feat32_up)
        
        feat16_arm = self.ARM16(feat16) + feat32_up
        feat16_up = self.up16(feat16_arm)
        feat16_up = self.conv_head16(feat16_up)     
        
        return feat16_up, feat32_up

⚪ 特征融合模块 FFM

FFM模块用于编码两个分支的特征,设计了一个类似注意力机制的融合模块,编码空间路径(低级别信息)和上下文路径(高级别信息)的输出。最后将结果上采样8倍得到原图。

class FFM(nn.Module):
    def __init__(self, channels=128):
        super(FFM, self).__init__()
        self.fuse = nn.Sequential(
            nn.Conv2d(2*channels, channels, 1),
            nn.BatchNorm2d(channels),
            nn.ReLU()
        )
        
        self.skip_forward = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels//4, 1),
            nn.ReLU(),
            nn.Conv2d(channels//4, channels, 1),
            nn.Sigmoid()
        )
        
    def forward(self, SP_input, CP_input):
        x = torch.cat([SP_input, CP_input], 1)
        x = self.fuse(x)
        identify = self.skip_forward(x)
        out = torch.mul(x, identify) + x
        return out

⚪ BiSeNet

BiSeNet网络的整体结构如下:

class BiSeNet(nn.Module):
    def __init__(self, num_classes):
        super(BiSeNet, self).__init__()
        self.num_classes = num_classes
        
        self.SpatialPath = SpatialPath()
        self.ContexPath = ContextPath()
        self.FFM = FFM()
        self.cls_seg = nn.Sequential(
            nn.Conv2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Upsample(scale_factor=8., mode="bilinear"),
            nn.Conv2d(128, self.num_classes, 3, padding=1),  
        )
        
    def forward(self, x):
        b, c, h, w = x.size()
        SP_out = self.SpatialPath(x)
        CP_out16, CP_Out32 = self.ContexPath(x)
        FFM_out = self.FFM(SP_out, CP_out16)
        return self.cls_seg(FFM_out)