U-Net: 用于医学图像分割的卷积网络.
UNet最早应用于生物医学图像分割,采用一种对称的U型网络设计,主要包括下采样编码、上采样解码和跳跃连接。
UNet的Encoder进行4次最大池化下采样,Decoder进行4次转置卷积上采样。每一个上采样层和下采样层之间都有一个跳跃连接,实现了不同尺度的特征融合,从而可以进行多尺度预测;4次上采样也使得分割图恢复边缘等信息更加精细;而每一层的特征融合后都会经过一系列的卷积层,以此来处理特征图中的细节。
具体来说,高层(浅层)下采样倍数小,特征图具备更加细致的图特征;底层(深层)下采样倍数大,信息经过大量浓缩,空间损失大,但有助于目标区域(分类)判断;当两种特征进行融合时,分割效果往往会非常好。
由于输入医学图像的分辨率过大,对显存占用较高,因此在预测时采用滑动窗口的预测方式。此外网络在设计时没用使用padding参数,因为使用padding会导致图像边缘的损失,但是不使用padding会导致卷积过程中图像的分辨率越来越小,使得最后上采样回去的特征图尺寸和原图不匹配。
作者采用了一种重叠的切割策略(Overlap-tile strategy),该策略允许通过重叠的方法对任意大的图像进行无缝分割。为了预测图像边界区域中的像素,通过镜像输入图像来推断缺失的上下文。这种平铺策略对于将网络应用于大型图像很重要。比如需要预测图中黄色框的信息,就将蓝色框的数据作为输入,如果蓝色框内有一部分图像缺失,就对图像做镜像处理,从而构造黄色框区域的上下文信息。
U-Net的简单实现如下:
# 编码块
class UNetEnc(nn.Module):
def __init__(self, in_channels, out_channels, dropout=False):
super().__init__()
layers = [
nn.Conv2d(in_channels, out_channels, 3, dilation=2),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, dilation=2),
nn.ReLU(inplace=True),
]
if dropout:
layers += [nn.Dropout(.5)]
layers += [nn.MaxPool2d(2, stride=2, ceil_mode=True)]
self.down = nn.Sequential(*layers)
def forward(self, x):
return self.down(x)
# 解码块
class UNetDec(nn.Module):
def __init__(self, in_channels, features, out_channels):
super().__init__()
self.up = nn.Sequential(
nn.Conv2d(in_channels, features, 3),
nn.ReLU(inplace=True),
nn.Conv2d(features, features, 3),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(features, out_channels, 2, stride=2),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.up(x)
# U-Net
class UNet(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.enc1 = UNetEnc(3, 64)
self.enc2 = UNetEnc(64, 128)
self.enc3 = UNetEnc(128, 256)
self.enc4 = UNetEnc(256, 512, dropout=True)
self.center = nn.Sequential(
nn.Conv2d(512, 1024, 3),
nn.ReLU(inplace=True),
nn.Conv2d(1024, 1024, 3),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.ConvTranspose2d(1024, 512, 2, stride=2),
nn.ReLU(inplace=True),
)
self.dec4 = UNetDec(1024, 512, 256)
self.dec3 = UNetDec(512, 256, 128)
self.dec2 = UNetDec(256, 128, 64)
self.dec1 = nn.Sequential(
nn.Conv2d(128, 64, 3),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 3),
nn.ReLU(inplace=True),
)
self.final = nn.Conv2d(64, num_classes, 1)
# 前向传播过程
def forward(self, x):
enc1 = self.enc1(x)
enc2 = self.enc2(enc1)
enc3 = self.enc3(enc2)
enc4 = self.enc4(enc3)
center = self.center(enc4)
# 包含了同层分辨率级联的解码块
dec4 = self.dec4(torch.cat([
center, F.upsample_bilinear(enc4, center.size()[2:])], 1))
dec3 = self.dec3(torch.cat([
dec4, F.upsample_bilinear(enc3, dec4.size()[2:])], 1))
dec2 = self.dec2(torch.cat([
dec3, F.upsample_bilinear(enc2, dec3.size()[2:])], 1))
dec1 = self.dec1(torch.cat([
dec2, F.upsample_bilinear(enc1, dec2.size()[2:])], 1))
return F.upsample_bilinear(self.final(dec1), x.size()[2:])