U-Net: 用于医学图像分割的卷积网络.

UNet最早应用于生物医学图像分割,采用一种对称的U型网络设计,主要包括下采样编码、上采样解码和跳跃连接。

UNetEncoder进行4次最大池化下采样,Decoder进行4次转置卷积上采样。每一个上采样层和下采样层之间都有一个跳跃连接,实现了不同尺度的特征融合,从而可以进行多尺度预测;4次上采样也使得分割图恢复边缘等信息更加精细;而每一层的特征融合后都会经过一系列的卷积层,以此来处理特征图中的细节。

具体来说,高层(浅层)下采样倍数小,特征图具备更加细致的图特征;底层(深层)下采样倍数大,信息经过大量浓缩,空间损失大,但有助于目标区域(分类)判断;当两种特征进行融合时,分割效果往往会非常好。

由于输入医学图像的分辨率过大,对显存占用较高,因此在预测时采用滑动窗口的预测方式。此外网络在设计时没用使用padding参数,因为使用padding会导致图像边缘的损失,但是不使用padding会导致卷积过程中图像的分辨率越来越小,使得最后上采样回去的特征图尺寸和原图不匹配。

作者采用了一种重叠的切割策略(Overlap-tile strategy),该策略允许通过重叠的方法对任意大的图像进行无缝分割。为了预测图像边界区域中的像素,通过镜像输入图像来推断缺失的上下文。这种平铺策略对于将网络应用于大型图像很重要。比如需要预测图中黄色框的信息,就将蓝色框的数据作为输入,如果蓝色框内有一部分图像缺失,就对图像做镜像处理,从而构造黄色框区域的上下文信息。

U-Net的简单实现如下:

# 编码块
class UNetEnc(nn.Module):

    def __init__(self, in_channels, out_channels, dropout=False):
        super().__init__()

        layers = [
            nn.Conv2d(in_channels, out_channels, 3, dilation=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, dilation=2),
            nn.ReLU(inplace=True),
        ]
        if dropout:
            layers += [nn.Dropout(.5)]
        layers += [nn.MaxPool2d(2, stride=2, ceil_mode=True)]

        self.down = nn.Sequential(*layers)

    def forward(self, x):
        return self.down(x)

# 解码块		
class UNetDec(nn.Module):

    def __init__(self, in_channels, features, out_channels):
        super().__init__()

        self.up = nn.Sequential(
            nn.Conv2d(in_channels, features, 3),
            nn.ReLU(inplace=True),
            nn.Conv2d(features, features, 3),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(features, out_channels, 2, stride=2),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.up(x)
# U-Net
class UNet(nn.Module):

    def __init__(self, num_classes):
        super().__init__()

        self.enc1 = UNetEnc(3, 64)
        self.enc2 = UNetEnc(64, 128)
        self.enc3 = UNetEnc(128, 256)
        self.enc4 = UNetEnc(256, 512, dropout=True)
        self.center = nn.Sequential(
            nn.Conv2d(512, 1024, 3),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, 3),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.ConvTranspose2d(1024, 512, 2, stride=2),
            nn.ReLU(inplace=True),
        )
        self.dec4 = UNetDec(1024, 512, 256)
        self.dec3 = UNetDec(512, 256, 128)
        self.dec2 = UNetDec(256, 128, 64)
        self.dec1 = nn.Sequential(
            nn.Conv2d(128, 64, 3),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3),
            nn.ReLU(inplace=True),
        )
        self.final = nn.Conv2d(64, num_classes, 1)

    # 前向传播过程
    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        center = self.center(enc4)
        # 包含了同层分辨率级联的解码块
        dec4 = self.dec4(torch.cat([
            center, F.upsample_bilinear(enc4, center.size()[2:])], 1))
        dec3 = self.dec3(torch.cat([
            dec4, F.upsample_bilinear(enc3, dec4.size()[2:])], 1))
        dec2 = self.dec2(torch.cat([
            dec3, F.upsample_bilinear(enc2, dec3.size()[2:])], 1))
        dec1 = self.dec1(torch.cat([
            dec2, F.upsample_bilinear(enc1, dec2.size()[2:])], 1))
        
        return F.upsample_bilinear(self.final(dec1), x.size()[2:])