ESRGAN:增强的图像超分辨率生成对抗网络.

作者在SRGAN的基础上进行改进,提出了增强的图像超分辨率生成对抗网络(Enhanced Super-Resolution Generative Adversarial Network, ESRGAN)。对于网络结构,作者引入了没有批量归一化的残差密集块(Residual-in-Residual Dense Block, RRDB)作为基本的网络构建单元;对于对抗损失,作者采用相对判别器的思想,让判别器预测相对真实度;对于感知损失,作者利用激活前的特征来构造损失,为亮度一致性和纹理恢复提供更强的监督。

1. 网络结构

生成器的整体结构与SRGAN相似。

为了进一步提高SRGAN的恢复图像质量,生成器结构主要有两处改进:

  1. 去除所有BN层;
  2. 用提出的RRDB块替换原始残差模块。

BN层在训练期间使用一批数据的均值和方差来归一化特征,并且在测试期间使用整个训练数据集估计均值和方差。当训练和测试数据集的统计差异很大时,BN层可能带来伪像。为了稳定训练和维持性能,作者去除了BN层,这有助于提高泛化能力并减少计算复杂度和内存使用。

RRDB块在主路径中使用密集连接,网络容量变得更高。

class DenseResidualBlock(nn.Module):
    def __init__(self, filters, res_scale=0.2):
        super(DenseResidualBlock, self).__init__()
        self.res_scale = res_scale

        def block(in_features, non_linearity=True):
            layers = [nn.Conv2d(in_features, filters, 3, 1, 1, bias=True)]
            if non_linearity:
                layers += [nn.LeakyReLU()]
            return nn.Sequential(*layers)

        self.b1 = block(in_features=1 * filters)
        self.b2 = block(in_features=2 * filters)
        self.b3 = block(in_features=3 * filters)
        self.b4 = block(in_features=4 * filters)
        self.b5 = block(in_features=5 * filters, non_linearity=False)
        self.blocks = [self.b1, self.b2, self.b3, self.b4, self.b5]

    def forward(self, x):
        inputs = x
        for block in self.blocks:
            out = block(inputs)
            inputs = torch.cat([inputs, out], 1)
        return out.mul(self.res_scale) + x


class ResidualInResidualDenseBlock(nn.Module):
    def __init__(self, filters, res_scale=0.2):
        super(ResidualInResidualDenseBlock, self).__init__()
        self.res_scale = res_scale
        self.dense_blocks = nn.Sequential(
            DenseResidualBlock(filters), DenseResidualBlock(filters), DenseResidualBlock(filters)
        )

    def forward(self, x):
        return self.dense_blocks(x).mul(self.res_scale) + x

2. 损失函数

作者采用相对判别器的思想,让判别器学习判断“一幅图像是否比另一幅图像更真实”,而不是“一幅图像是真实的还是假的”。

adversarial_loss = torch.nn.BCEWithLogitsLoss()

for epoch in range(n_epochs):
    for i, imgs_hr in enumerate(dataloader):
        # 构造对抗标签
        valid = torch.ones(imgs_hr.shape[0], 1).requires_grad_.(False)
        fake = torch.zeros(imgs_hr.shape[0], 1).requires_grad_.(False)

        # 从噪声中采样生成图像
        z = torch.randn(imgs_hr.shape[0], latent_dim)
        gen_hr = generator(z)

        # 训练判别器
        optimizer_D.zero_grad()
        pred_real = discriminator(imgs_hr)
        pred_fake = discriminator(gen_hr.detach())
        # 相对平均判别损失
        real_loss = adversarial_loss(pred_real - pred_fake.mean(0, keepdim=True), valid)
        fake_loss = adversarial_loss(pred_fake - pred_real.mean(0, keepdim=True), fake)
        d_loss = (real_loss + fake_loss) / 2
        # 更新判别器参数
        d_loss.backward()
        optimizer_D.step()

        # 训练生成器
        optimizer_G.zero_grad()
        pred_real = discriminator(imgs_hr).detach()
        pred_fake = discriminator(gen_hr)
        g_loss = adversarial_loss(pred_fake - pred_real.mean(0, keepdim=True), valid)
        g_loss.backward()
        optimizer_G.step()

此外,作者将感知损失调整为在激活前而不是激活后限制特征,这将克服激活后感知损失的两个缺点。首先,激活后的特征非常稀疏,尤其是在非常深的网络之后,例如在VGG19-54(使用预训练的19VGG网络,其中54表示在第5个最大池化层之前通过第4次卷积获得的特征)中下列图像“狒狒”的激活神经元的平均百分比仅为$11.17\%$,稀疏激活仅能提供较弱的监督,因此导致较差的性能。第二,激活后使用特征会导致与真实图像不一致的重建亮度。

from torchvision.models import vgg19
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        self.vgg19_54 = nn.Sequential(*list(vgg19_model.features.children())[:35])

    def forward(self, img):
        return self.vgg19_54(img)

feature_extractor = FeatureExtractor()
feature_extractor.eval()
criterion_content = torch.nn.L1Loss()
# Content loss
gen_features = feature_extractor(gen_hr)
real_features = feature_extractor(imgs_hr)
loss_content = criterion_content(gen_features, real_features.detach())

相比于SRGAN,作者还使用了L1重构损失:

criterion_pixel = torch.nn.L1Loss()
loss_pixel = criterion_pixel(gen_hr, imgs_hr)

ESRGAN的完整pytorch实现可参考PyTorch-GAN