SRResnet/SRGAN:使用生成对抗网络进行图像超分辨率.

直接使用MSE损失函数训练的超分辨率模型,在PSNRSSIM等评价指标上能够得到较高的结果,但图像细节显示依旧较差。作者利用生成对抗网络的方法得到视觉特性较好的结果。

如上图所示,在图像空间中存在若干结构相似的图像,使用MSE损失将会生成这些图像的平均值,忽略图像的细节;而使用GAN的方法则会接近某张真实图像,从而保留较好的真实表现。

本文的主要贡献在于:

1. 模型结构

(1)生成器

生成器结构参考了ResNet,输入低分辨率图像得到高分辨率图像,这一部分可作为SRResNet单独使用。

class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_features, 0.8),
            nn.PReLU(),
            nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_features, 0.8),
        )

    def forward(self, x):
        return x + self.conv_block(x)

class GeneratorResNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=16):
        super(GeneratorResNet, self).__init__()
        # First layer
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=9, stride=1, padding=4), nn.PReLU())
        # Residual blocks
        res_blocks = []
        for _ in range(n_residual_blocks):
            res_blocks.append(ResidualBlock(64))
        self.res_blocks = nn.Sequential(*res_blocks)
        # Second conv layer post residual blocks
        self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64, 0.8))
        # Upsampling layers
        upsampling = []
        for out_features in range(2):
            upsampling += [
                # nn.Upsample(scale_factor=2),
                nn.Conv2d(64, 256, 3, 1, 1),
                nn.BatchNorm2d(256),
                nn.PixelShuffle(upscale_factor=2),
                nn.PReLU(),
            ]
        self.upsampling = nn.Sequential(*upsampling)
        # Final output layer
        self.conv3 = nn.Sequential(nn.Conv2d(64, out_channels, kernel_size=9, stride=1, padding=4), nn.Tanh())

    def forward(self, x):
        out1 = self.conv1(x)
        out = self.res_blocks(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2)
        out = self.upsampling(out)
        out = self.conv3(out)
        return out

(2)判别器

判别器结构参考了VGGNet,输入真实图像和生成的高分辨率图像,对二者进行分类。

class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()
        self.input_shape = input_shape
        in_channels, in_height, in_width = self.input_shape
        patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)

        def discriminator_block(in_filters, out_filters, first_block=False):
            layers = []
            layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
            if not first_block:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        in_filters = in_channels
        for i, out_filters in enumerate([64, 128, 256, 512]):
            layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
            in_filters = out_filters
        layers.append(nn.Linear(512*patch_h*patch_w, 1024), nn.LeakyReLU(0.2, inplace=True), nn.Linear(1024, 1), nn.Tanh())
        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)

2. 损失函数

为了训练SRGAN,作者提出了感知损失函数(perceptual loss)$l^{SR}$,由内容损失函数(content loss)$l_X^{SR}$和对抗损失函数(adversarial loss)$l_{Gen}^{SR}$组成。

内容损失函数(content loss)$l_X^{SR}$基于一个预训练的VGG19网络,通过比较生成图像和真实图像的网络中特征差异进行定义。其中$Φ_{i,j}$表示VGG19网络中第$i$个池化层之前的第$j$个卷积层(在激活函数之后)的特征图。

from torchvision.models import vgg19
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:18])

    def forward(self, img):
        return self.feature_extractor(img)

feature_extractor = FeatureExtractor()
feature_extractor.eval()
criterion_content = torch.nn.L1Loss()
# Content loss
gen_features = feature_extractor(gen_hr)
real_features = feature_extractor(imgs_hr)
loss_content = criterion_content(gen_features, real_features.detach())

对抗损失函数(adversarial loss)$l_{Gen}^{SR}$试图使判别器无法正确的分类生成器获得的结果, 训练按照生成对抗网络的损失进行:

\[\mathop{\min}_{θ_G} \mathop{\max}_{θ_D} E_{I^{HR} \text{~} p_{train}(I^{HR})}[logD_{θ_D}(I^{HR})] + E_{I^{LR} \text{~} p_{G}(I^{LR})}[log(1-D_{θ_D}(G_{θ_G}(I^{LR})))]\]
generator = GeneratorResNet()
discriminator = Discriminator()
adversarial_loss = torch.nn.BCELoss()

# 定义优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

for epoch in range(n_epochs):
    for i, imgs_hr in enumerate(dataloader):
        # 构造对抗标签
        valid = torch.ones(imgs_hr.shape[0], 1).requires_grad_.(False)
        fake = torch.zeros(imgs_hr.shape[0], 1).requires_grad_.(False)

        # 从噪声中采样生成图像
        z = torch.randn(imgs_hr.shape[0], latent_dim)
        gen_hr = generator(z)

        # 训练判别器
        optimizer_D.zero_grad()
        # 计算判别器的损失
        real_loss = adversarial_loss(discriminator(imgs_hr), valid)
        fake_loss = adversarial_loss(discriminator(gen_hr.detach()), fake) # 此处不计算生成器的梯度
        d_loss = (real_loss + fake_loss) / 2
        # 更新判别器参数
        d_loss.backward()
        optimizer_D.step()

        # 训练生成器
        optimizer_G.zero_grad()
        g_loss = adversarial_loss(discriminator(gen_hr), valid)
        g_loss += loss_content
        g_loss.backward()
        optimizer_G.step()

SRGAN的完整pytorch实现可参考PyTorch-GAN

3. 实验分析

作者进行了主观打分实验,选择$26$人对不同模型的图像质量进行打分($1$至$5$分),并用得分均值作为平均意见得分(mean opinion score, MOS):

作者比较了不同模型上$4X$超分辨率得到的客观和主观评价指标: