StarGAN:统一的多领域图像翻译框架.

对于图像翻译任务(Image-to-Image Translation),大多数方法获得的图像输出都是单一的,例如执行斑马到马的转换;如果想要实现多种不同的转换,则需要成倍地网络的计算负担(每两个类型之间都建立转换关系)。

本文设计了StarGAN,通过一种模型同时实现多种类型的图像翻译。该方法在训练时需要对每一个图像提供对应的领域标签(其中的每一个元素指代一种类型,如人脸图像的发色与性别)。

1. StarGAN的整体结构

StarGAN由一个判别器和一个生成器构造。判别器用于判断图像是否为真实图像,若为真实图像则进一步预测其领域标签;生成器接收一张图像和给定的领域标签,生成对应领域的图像。

StarGAN的训练过程采用循环过程。将输入图像和目标标签输入生成器,产生目标域的生成图像。再将该图像和原标签输入生成器,产生原图像的重构图像。判别器则尝试区分生成图像和输入图像。

2. StarGAN的网络设计

StarGAN的生成器采用残差网络构成的编码器-解码器结构,标签$c$采用与输入图像直连的方式。

class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [
            nn.Conv2d(in_features, in_features, 3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(in_features, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_features, in_features, 3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(in_features, affine=True, track_running_stats=True),
        ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

class GeneratorResNet(nn.Module):
    def __init__(self, img_shape=(3, 128, 128), res_blocks=9, c_dim=5):
        super(GeneratorResNet, self).__init__()
        channels, img_size, _ = img_shape

        # Initial convolution block
        model = [
            nn.Conv2d(channels + c_dim, 64, 7, stride=1, padding=3, bias=False),
            nn.InstanceNorm2d(64, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
        ]

        # Downsampling
        curr_dim = 64
        for _ in range(2):
            model += [
                nn.Conv2d(curr_dim, curr_dim * 2, 4, stride=2, padding=1, bias=False),
                nn.InstanceNorm2d(curr_dim * 2, affine=True, track_running_stats=True),
                nn.ReLU(inplace=True),
            ]
            curr_dim *= 2

        # Residual blocks
        for _ in range(res_blocks):
            model += [ResidualBlock(curr_dim)]

        # Upsampling
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(curr_dim, curr_dim // 2, 4, stride=2, padding=1, bias=False),
                nn.InstanceNorm2d(curr_dim // 2, affine=True, track_running_stats=True),
                nn.ReLU(inplace=True),
            ]
            curr_dim = curr_dim // 2

        # Output layer
        model += [nn.Conv2d(curr_dim, channels, 7, stride=1, padding=3), nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, x, c):
        c = c.view(c.size(0), c.size(1), 1, 1)
        c = c.repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat((x, c), 1)
        return self.model(x)

StarGAN的判别器采用Pix2Pix提出的PatchGAN结构,输出为一个$N \times N$矩阵,其中的每个元素对应输入图像的一个子区域,用来评估该子区域的真实性。与此同时,判别器还对标签进行预测。

class Discriminator(nn.Module):
    def __init__(self, img_shape=(3, 128, 128), c_dim=5, n_strided=6):
        super(Discriminator, self).__init__()
        channels, img_size, _ = img_shape

        def discriminator_block(in_filters, out_filters):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1), nn.LeakyReLU(0.01)]
            return layers

        layers = discriminator_block(channels, 64)
        curr_dim = 64
        for _ in range(n_strided - 1):
            layers.extend(discriminator_block(curr_dim, curr_dim * 2))
            curr_dim *= 2

        self.model = nn.Sequential(*layers)

        # Output 1: PatchGAN
        self.out1 = nn.Conv2d(curr_dim, 1, 3, padding=1, bias=False)
        # Output 2: Class prediction
        kernel_size = img_size // 2 ** n_strided
        self.out2 = nn.Sequential(nn.Conv2d(curr_dim, c_dim, kernel_size, bias=False),nn.Sigmoid())

    def forward(self, img):
        feature_repr = self.model(img)
        out_adv = self.out1(feature_repr)
        out_cls = self.out2(feature_repr)
        return out_adv, out_cls.view(out_cls.size(0), -1)

3. StarGAN的目标函数

StarGAN的判别器的目标函数包括对抗损失和标签分类损失;生成器的目标函数包括对抗损失、标签分类损失和重构损失。

\[\begin{aligned} \mathop{\max}_{D} & \Bbb{E}_{x \text{~} P_{data}(x)}[\log D(x)] + \Bbb{E}_{x \text{~} P_{data}(x)}[1-\log D(G(x, y^t))] \\ &+ \Bbb{E}_{(x,y) \text{~} (P_{data}(x),P_{data}(Y))}[\log D_{y}(x)] \\ \mathop{ \min}_{G} &- \Bbb{E}_{x \text{~} P_{data}(x)}[D(G(x, y^t))]-\Bbb{E}_{x \text{~} P_{data}(x)}[\log D_{y^t}(G(x, y^t))] \\ &+ \Bbb{E}_{x \text{~} P_{data}(x)}[||x-G(G(x, y^t),y^s)||_1] \end{aligned}\]

StarGAN的完整pytorch实现可参考PyTorch-GAN,下面给出其损失函数的计算和参数更新过程:

# Losses
gan_loss = torch.nn.BCELoss()
cycle_loss = torch.nn.L1Loss()

# Initialize model
generator = GeneratorResNet(img_shape=img_shape, res_blocks=opt.residual_blocks, c_dim=c_dim)
discriminator = Discriminator(img_shape=img_shape, c_dim=c_dim)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

# Calculate output of image discriminator (PatchGAN)
patch = (1, opt.img_height // 2 ** 6, opt.img_width // 2 ** 6)

for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):
        # Adversarial ground truths
        valid = torch.ones(imgs.shape[0], *patch).requires_grad_(False)
        fake = torch.zeros(imgs.shape[0], *patch).requires_grad_(False)

        # ----------------------------------
        # forward propogation
        # ----------------------------------
        # Sample labels as generator inputs
        sampled_c = torch.randn(imgs.size(0), c_dim)
        # Generate fake batch of images
        fake_imgs = generator(imgs, sampled_c)
        # Reconstruct image        
        recov_imgs = generator(fake_imgs, labels)

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        # Real images
        real_validity, pred_cls = discriminator(imgs)
        # Fake images
        fake_validity, _ = discriminator(fake_imgs.detach())
        # Adversarial loss
        loss_D_adv = gan_loss(real_validity, valid) + gan_loss(fake_validity, fake)
        # Classification loss
        loss_D_cls = gan_loss(pred_cls, labels)
        # Total loss
        loss_D = loss_D_adv + lambda_cls * loss_D_cls

        loss_D.backward()
        optimizer_D.step()

        # -------------------------------
        #  Train Generator
        # -------------------------------
        optimizer_G.zero_grad()

        # Discriminator evaluates translated image
        fake_validity, pred_cls = discriminator(gen_imgs)
        # Adversarial loss
        loss_G_adv = gan_loss(fake_validity, valid)
        # Classification loss
        loss_G_cls = gan_loss(pred_cls, sampled_c)
        # Reconstruction loss
        loss_G_rec = criterion_cycle(recov_imgs, imgs)
        # Total loss
        loss_G = loss_G_adv + lambda_cls * loss_G_cls + lambda_rec * loss_G_rec

        loss_G.backward()
        optimizer_G.step()