ESRGAN:增强的图像超分辨率生成对抗网络.
作者在SRGAN的基础上进行改进,提出了增强的图像超分辨率生成对抗网络(Enhanced Super-Resolution Generative Adversarial Network, ESRGAN)。对于网络结构,作者引入了没有批量归一化的残差密集块(Residual-in-Residual Dense Block, RRDB)作为基本的网络构建单元;对于对抗损失,作者采用相对判别器的思想,让判别器预测相对真实度;对于感知损失,作者利用激活前的特征来构造损失,为亮度一致性和纹理恢复提供更强的监督。
1. 网络结构
生成器的整体结构与SRGAN相似。
为了进一步提高SRGAN的恢复图像质量,生成器结构主要有两处改进:
- 去除所有BN层;
- 用提出的RRDB块替换原始残差模块。
BN层在训练期间使用一批数据的均值和方差来归一化特征,并且在测试期间使用整个训练数据集估计均值和方差。当训练和测试数据集的统计差异很大时,BN层可能带来伪像。为了稳定训练和维持性能,作者去除了BN层,这有助于提高泛化能力并减少计算复杂度和内存使用。
RRDB块在主路径中使用密集连接,网络容量变得更高。
class DenseResidualBlock(nn.Module):
def __init__(self, filters, res_scale=0.2):
super(DenseResidualBlock, self).__init__()
self.res_scale = res_scale
def block(in_features, non_linearity=True):
layers = [nn.Conv2d(in_features, filters, 3, 1, 1, bias=True)]
if non_linearity:
layers += [nn.LeakyReLU()]
return nn.Sequential(*layers)
self.b1 = block(in_features=1 * filters)
self.b2 = block(in_features=2 * filters)
self.b3 = block(in_features=3 * filters)
self.b4 = block(in_features=4 * filters)
self.b5 = block(in_features=5 * filters, non_linearity=False)
self.blocks = [self.b1, self.b2, self.b3, self.b4, self.b5]
def forward(self, x):
inputs = x
for block in self.blocks:
out = block(inputs)
inputs = torch.cat([inputs, out], 1)
return out.mul(self.res_scale) + x
class ResidualInResidualDenseBlock(nn.Module):
def __init__(self, filters, res_scale=0.2):
super(ResidualInResidualDenseBlock, self).__init__()
self.res_scale = res_scale
self.dense_blocks = nn.Sequential(
DenseResidualBlock(filters), DenseResidualBlock(filters), DenseResidualBlock(filters)
)
def forward(self, x):
return self.dense_blocks(x).mul(self.res_scale) + x
2. 损失函数
作者采用相对判别器的思想,让判别器学习判断“一幅图像是否比另一幅图像更真实”,而不是“一幅图像是真实的还是假的”。
adversarial_loss = torch.nn.BCEWithLogitsLoss()
for epoch in range(n_epochs):
for i, imgs_hr in enumerate(dataloader):
# 构造对抗标签
valid = torch.ones(imgs_hr.shape[0], 1).requires_grad_.(False)
fake = torch.zeros(imgs_hr.shape[0], 1).requires_grad_.(False)
# 从噪声中采样生成图像
z = torch.randn(imgs_hr.shape[0], latent_dim)
gen_hr = generator(z)
# 训练判别器
optimizer_D.zero_grad()
pred_real = discriminator(imgs_hr)
pred_fake = discriminator(gen_hr.detach())
# 相对平均判别损失
real_loss = adversarial_loss(pred_real - pred_fake.mean(0, keepdim=True), valid)
fake_loss = adversarial_loss(pred_fake - pred_real.mean(0, keepdim=True), fake)
d_loss = (real_loss + fake_loss) / 2
# 更新判别器参数
d_loss.backward()
optimizer_D.step()
# 训练生成器
optimizer_G.zero_grad()
pred_real = discriminator(imgs_hr).detach()
pred_fake = discriminator(gen_hr)
g_loss = adversarial_loss(pred_fake - pred_real.mean(0, keepdim=True), valid)
g_loss.backward()
optimizer_G.step()
此外,作者将感知损失调整为在激活前而不是激活后限制特征,这将克服激活后感知损失的两个缺点。首先,激活后的特征非常稀疏,尤其是在非常深的网络之后,例如在VGG19-54(使用预训练的19层VGG网络,其中54表示在第5个最大池化层之前通过第4次卷积获得的特征)中下列图像“狒狒”的激活神经元的平均百分比仅为$11.17\%$,稀疏激活仅能提供较弱的监督,因此导致较差的性能。第二,激活后使用特征会导致与真实图像不一致的重建亮度。
from torchvision.models import vgg19
class FeatureExtractor(nn.Module):
def __init__(self):
super(FeatureExtractor, self).__init__()
vgg19_model = vgg19(pretrained=True)
self.vgg19_54 = nn.Sequential(*list(vgg19_model.features.children())[:35])
def forward(self, img):
return self.vgg19_54(img)
feature_extractor = FeatureExtractor()
feature_extractor.eval()
criterion_content = torch.nn.L1Loss()
# Content loss
gen_features = feature_extractor(gen_hr)
real_features = feature_extractor(imgs_hr)
loss_content = criterion_content(gen_features, real_features.detach())
相比于SRGAN,作者还使用了L1重构损失:
criterion_pixel = torch.nn.L1Loss()
loss_pixel = criterion_pixel(gen_hr, imgs_hr)
ESRGAN的完整pytorch实现可参考PyTorch-GAN。