UNet++:用于医学图像分割的巢型UNet.

1. 图像分割网络的发展

自编码器结构在2006年被Hinton提出并发表于Nature,用于解决图像压缩和去噪任务。这种编码器-解码器的设计思路随后被引入图像分割任务中,并取得很好的结果。全卷积网络(FCN)是最早提出的图像分割网络之一,该网络采用卷积的编码器-解码器结构,并引入了跳跃连接。UNet同年被提出;FCNUNet的主要区别在于:

  1. FCN的解码器只有一层转置卷积层,而UNet是完全对称的结构;
  2. FCN的跳跃连接采用求和(summation)操作,而UNet的跳跃连接采用叠加(concatenation)操作。

这类编码器-解码器结构的图像分割网络,主要由三部分构成:

  1. 编码器
  2. 解码器
  3. 跳跃连接

(1)编码器

编码器相当于特征提取器,执行下采样操作。编码器需要从图像中提取好的特征表示,并具有较快的收敛速度;但是不合适的编码器可能会在较简单的任务中过拟合。

图像分割任务中引入下采样的主要原因包括:

下采样操作的主要缺陷:

不同的网络选用的特征提取结构不同(相当于backbone不同):

此外,不同的网络选用的下采样操作也不同:

(2)解码器

解码器执行上采样操作,其主要贡献包括:

上采样操作的主要缺陷:

此外,不同的网络选用的上采样操作也不同:

(3)跳跃连接

跳跃连接最早在ResNet中使用,其重要性在于:

不同的网络选用的跳跃连接也不同:

2. 对UNet的层数讨论

原始的UNet设计了$4$层下采样层,不同深度的UNet在不同的数据集上往往会有不同的表现。如下图设计了$4$种不同深度的UNet,测试其在两种图像分类任务上的表现。在电子显微镜图像上$4$层网络取得最好的效果,而在细胞图像上$3$层网络取得最好的效果。

作者受此启发,设计了能够同时利用浅层特征和深层特征的UNet++网络。其基本思想是将不同层数的UNet网络叠加到同一个网络中,让网络自己学习不同深度的特征的重要性。且这些网络共享编码器(特征提取器)。最初的模型结构设计如下:

3. UNet++的改进

上述网络结构存在一些问题。该网络结构是不能被训练的,原因在于不会有任何梯度经过下图的红色区域,因为它和损失函数在反向传播时是断开的。

为了解决这一问题,有两种可行的思路:

  1. 将较长的跳跃连接换成较短的跳跃连接
  2. 深度监督(deep supervision)

通过结合这两种思路,不仅整合了不同层次的特征,获得精度的提升;而且通过灵活的网络结构配合深度监督,在可接受的精度范围内大幅度的缩减参数量。

(1)Change skip connection

这个思路可以参考论文Deep Layer Aggregation。将上述结构中的较长的跳跃连接全部换成短的跳跃连接,既能够整合不同层次的特征,又能够实现整体的梯度更新。

但是长连接仍然是有必要的,它能够联系输入图像的很多信息,有助于还原下采样所带来的信息损失。作者给出了一种综合长连接和短连接的方案,即论文中给出的UNet++结构:

为了探究额外的参数量对性能的影响,作者又设计了一个和UNet++参数量差不多的wide UNet结构(UNet++:$9.04$M,wide UNet:$9.13$M)。实验证明单纯把网络变宽,增加参数量对性能的提升并不大,UNet++使用参数是高效的。

(2)Deep supervision

深度监督是另一种解决反向传播时中间部分无法进行梯度更新的方法。其实现过程是在每一层级的上采样特征后增加一个$1 \times 1$的卷积,相当于对每个层级的子网络进行集成;每个子网络的输出便可以作为图像分割的结果:

该操作最大的好处是使得模型能够进行剪枝。注意到在测试阶段,只进行前向传播,此时更高层级的网络对低层级的网络没有影响,可以被剪掉;但是在训练阶段,会进行前向传播和反向传播,高层级的网络需要向低层级的网络回传梯度。通过对整个网络进行训练,测试时可以剪枝掉高层网络以获得更小的模型。由于剪掉的部分在训练时的反向传播中是有贡献的,因此会使得低层级网络的表现更好。比如实验发现第三个输出的效果和第四个输出效果差不多时,就可以直接删去图中棕色部分,实现剪枝。

下面是不同子模型在四种医学图像分割数据集上的性能表现。从图中可以看出,对于大多数比较简单的分割问题,不需要非常深的网络就可以达到不错的精度;对于比较难的数据集,网络越深其分割性能是在不断上升的。

4. UNet++的实现

import torch
import torch.nn as nn

class ContinusParalleConv(nn.Module):
    # 一个连续的卷积模块,包含BatchNorm 在前 和 在后 两种模式
    def __init__(self, in_channels, out_channels, pre_Batch_Norm = True):
        super(ContinusParalleConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
 
        if pre_Batch_Norm:
          self.Conv_forward = nn.Sequential(
            nn.BatchNorm2d(self.in_channels),
            nn.ReLU(),
            nn.Conv2d(self.in_channels, self.out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1))
 
        else:
          self.Conv_forward = nn.Sequential(
            nn.Conv2d(self.in_channels, self.out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU())
 
    def forward(self, x):
        x = self.Conv_forward(x)
        return x
 
class UnetPlusPlus(nn.Module):
    def __init__(self, num_classes, deep_supervision=False):
        super(UnetPlusPlus, self).__init__()
        self.num_classes = num_classes
        self.deep_supervision = deep_supervision
        self.filters = [64, 128, 256, 512, 1024]
        
        self.CONV3_1 = ContinusParalleConv(512*2, 512, pre_Batch_Norm = True)
 
        self.CONV2_2 = ContinusParalleConv(256*3, 256, pre_Batch_Norm = True)
        self.CONV2_1 = ContinusParalleConv(256*2, 256, pre_Batch_Norm = True)
 
        self.CONV1_1 = ContinusParalleConv(128*2, 128, pre_Batch_Norm = True)
        self.CONV1_2 = ContinusParalleConv(128*3, 128, pre_Batch_Norm = True)
        self.CONV1_3 = ContinusParalleConv(128*4, 128, pre_Batch_Norm = True)
 
        self.CONV0_1 = ContinusParalleConv(64*2, 64, pre_Batch_Norm = True)
        self.CONV0_2 = ContinusParalleConv(64*3, 64, pre_Batch_Norm = True)
        self.CONV0_3 = ContinusParalleConv(64*4, 64, pre_Batch_Norm = True)
        self.CONV0_4 = ContinusParalleConv(64*5, 64, pre_Batch_Norm = True)
 
 
        self.stage_0 = ContinusParalleConv(3, 64, pre_Batch_Norm = False)
        self.stage_1 = ContinusParalleConv(64, 128, pre_Batch_Norm = False)
        self.stage_2 = ContinusParalleConv(128, 256, pre_Batch_Norm = False)
        self.stage_3 = ContinusParalleConv(256, 512, pre_Batch_Norm = False)
        self.stage_4 = ContinusParalleConv(512, 1024, pre_Batch_Norm = False)
 
        self.pool = nn.MaxPool2d(2)
    
        self.upsample_3_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1) 
 
        self.upsample_2_1 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1) 
        self.upsample_2_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1) 
 
        self.upsample_1_1 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1) 
        self.upsample_1_2 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1) 
        self.upsample_1_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1) 
 
        self.upsample_0_1 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1) 
        self.upsample_0_2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1) 
        self.upsample_0_3 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1) 
        self.upsample_0_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1) 
 
        
        # 分割头
        self.final_super_0_1 = nn.Sequential(
          nn.BatchNorm2d(64),
          nn.ReLU(),
          nn.Conv2d(64, self.num_classes, 3, padding=1),
        )        
        self.final_super_0_2 = nn.Sequential(
          nn.BatchNorm2d(64),
          nn.ReLU(),
          nn.Conv2d(64, self.num_classes, 3, padding=1),
        )        
        self.final_super_0_3 = nn.Sequential(
          nn.BatchNorm2d(64),
          nn.ReLU(),
          nn.Conv2d(64, self.num_classes, 3, padding=1),
        )        
        self.final_super_0_4 = nn.Sequential(
          nn.BatchNorm2d(64),
          nn.ReLU(),
          nn.Conv2d(64, self.num_classes, 3, padding=1),
        )        
 
        
    def forward(self, x):
        x_0_0 = self.stage_0(x)
        x_1_0 = self.stage_1(self.pool(x_0_0))
        x_2_0 = self.stage_2(self.pool(x_1_0))
        x_3_0 = self.stage_3(self.pool(x_2_0))
        x_4_0 = self.stage_4(self.pool(x_3_0))
        
        x_0_1 = torch.cat([self.upsample_0_1(x_1_0) , x_0_0], 1)
        x_0_1 =  self.CONV0_1(x_0_1)
        
        x_1_1 = torch.cat([self.upsample_1_1(x_2_0), x_1_0], 1)
        x_1_1 = self.CONV1_1(x_1_1)
        
        x_2_1 = torch.cat([self.upsample_2_1(x_3_0), x_2_0], 1)
        x_2_1 = self.CONV2_1(x_2_1)
        
        x_3_1 = torch.cat([self.upsample_3_1(x_4_0), x_3_0], 1)
        x_3_1 = self.CONV3_1(x_3_1)
 
        x_2_2 = torch.cat([self.upsample_2_2(x_3_1), x_2_0, x_2_1], 1)
        x_2_2 = self.CONV2_2(x_2_2)
        
        x_1_2 = torch.cat([self.upsample_1_2(x_2_1), x_1_0, x_1_1], 1)
        x_1_2 = self.CONV1_2(x_1_2)
        
        x_1_3 = torch.cat([self.upsample_1_3(x_2_2), x_1_0, x_1_1, x_1_2], 1)
        x_1_3 = self.CONV1_3(x_1_3)
 
        x_0_2 = torch.cat([self.upsample_0_2(x_1_1), x_0_0, x_0_1], 1)
        x_0_2 = self.CONV0_2(x_0_2)
        
        x_0_3 = torch.cat([self.upsample_0_3(x_1_2), x_0_0, x_0_1, x_0_2], 1)
        x_0_3 = self.CONV0_3(x_0_3)
        
        x_0_4 = torch.cat([self.upsample_0_4(x_1_3), x_0_0, x_0_1, x_0_2, x_0_3], 1)
        x_0_4 = self.CONV0_4(x_0_4)
    
    
        if self.deep_supervision:
            out_put1 = self.final_super_0_1(x_0_1)
            out_put2 = self.final_super_0_2(x_0_2)
            out_put3 = self.final_super_0_3(x_0_3)
            out_put4 = self.final_super_0_4(x_0_4)
            return [out_put1, out_put2, out_put3, out_put4]
        else:
            return self.final_super_0_4(x_0_4)
 
 
if __name__ == "__main__":
    print("deep_supervision: False")
    deep_supervision = False
    device = torch.device('cpu')
    inputs = torch.randn((1, 3, 224, 224)).to(device)
    model = UnetPlusPlus(num_classes=3, deep_supervision=deep_supervision).to(device)
    outputs = model(inputs)
    print(outputs.shape)    
    
    print("deep_supervision: True")
    deep_supervision = True
    model = UnetPlusPlus(num_classes=3, deep_supervision=deep_supervision).to(device)
    outputs = model(inputs)
    for out in outputs:
      print(out.shape)