PGGAN: 渐进生成高质量、多样性的图像.

本文提出了训练生成对抗网络的渐进生成(Progressive Growing)方法,通过逐渐地增大生成图像的分辨率获得更高质量的生成图像。此外作者还讨论了一种增加生成图像多样性的方法:小批量标准偏差(minibatch standard deviation)

1. Progressive Growing

渐进式的学习过程是从低分辨率图像开始生成,通过向网络中添加新的层逐步增加生成图像的分辨率。该种方法主观上允许模型首先学习图像分布的整体结构特征(低分辨率),然后逐步学习图像的细节部分(高分辨率)。

从低分辨率转换为高分辨率时,由于新加入的网络层是随机初始化的,为防止它们对已经训练过的网络层产生副作用,作者引入了渐进的学习过程(fade in)。通过线性变化($0 \to 1$)的权重$\alpha$避免网络突然崩溃。

下面给出PGGAN在生成$1024$分辨率图像时所采用的的网络结构。

以两阶段渐进生成($4 \times 4 \to 8 \times 8$)为例,PGGAN的实现过程如下:

class Generator(nn.Module):
    def __init__(self, latent_dim=512):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False)
        self.mlp = nn.Linear(latent_dim, latent_dim*4*4)
        self.b1, self.b1_rgb, _ = self._get_block(latent_dim, 512)
        self.b2, self.b2_rgb, self.p2 = self._get_block(512, 256)

    def _get_block(in_channels, out_channels):
        block = nn.Sequential([
            nn.ConvTransposed2d(in_channels, out_channels, 3, 1, 1),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.1),
            nn.ConvTransposed2d(out_channels, out_channels, 3, 1, 1),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.1),
        ])
        to_rgb = nn.Sequential([
            nn.Conv2d(out_channels, 3, 1, 1, 0),
            nn.Tanh(),
        ])
        project = nn.Conv2d(in_channels, out_channels, 1, 1, 0) # 调整通道数
        return block, to_rgb, project

    def forward(self, inputs):
        current_layer, alpha, x = inputs
        x = self.self.mlp(x)
        x = x.view(-1, self.latent_dim, 4, 4)
        x = self.b1(x)
        x_lr = self.b1_rgb(x)
        x_hr = self.b1_rgb(x)
        if current_layer >= 1:
            x = self.upsample(x)
            x_lr = self.b2_rgb(self.p2(x))
            x = self.b2(x)
            x_hr = self.b2_rgb(x)
        x = x_hr * alpha + x_lr * (1 - alpha)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.b2_rgb, self.b2 = self._get_block(256, 512)
        self.b1_rgb, self.b1 = self._get_block(512, 512)
        self.downsample = nn.MaxPool2d(2)
        self.tail = nn.Sequential([
            nn.Conv2d(512, 512, 4, 1, 0),
            nn.Flatten(),
            nn.Linear(512, 1)
            nn.Sigmoid(),
        ])

    def _get_block(in_channels, out_channels):
        from_rgb = nn.Sequential([
            nn.Conv2d(3, in_channels, 3, 1, 1),
            nn.LeakyReLU(),
        ])
        block = nn.Sequential([
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.1),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.1),
        ])
        return from_rgb, block

    def forward(self, inputs):
        current_layer, alpha, x = inputs
        x_lr = self.downsample(x)
        if current_layer >= 1:
            if current_layer == 1:
                x = self.b2_rgb(x)
            x = self.b2(x)
            x = self.downsample(x)
            if current_layer == 1:
                x_lr = self.b1_rgb(x_lr)
                x = x_hr * alpha + x_lr * (1 - alpha)
        if current_layer == 0:
            x = self.b1_rgb(x)
        x = self.b1(x)
        x = self.tail(x)
        return x

2. Minibatch Standard Deviation

为了增加生成图像的多样性,作者提出了小批量标准偏差(minibatch standard deviation)方法,该方法受Minibatch discrimination启发,通过在判别器的隐层特征中额外构造不同样本之间的数据分布特征,来显式地判断生成的图像距离是否足够的 ‘接近’。

对于判别器中的特征张量$f \in \Bbb{R}^{N \times C \times H \times W}$,计算不同数据之间的标准偏差$\sigma \in \Bbb{R}^{1 \times C \times H \times W}$,对其平均后构造一个常数$s \in \Bbb{R}$,并将其复制为张量$S \in \Bbb{R}^{N \times 1 \times H \times W}$,作为额外的特征增加到原始特征中。

def minibatch_std(x):
    batch_statistics = (
        torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
    )
    # we take the std for each example (across all channels, and pixels) then we repeat it
    # for a single channel and concatenate it with the image. In this way the discriminator
    # will get information about the variation in the batch/image
    return torch.cat([x, batch_statistics], dim=1)