ACGAN:基于辅助分类器GAN的条件图像合成.

Auxiliary Classifier GAN (ACGAN)可以生成给定条件(标签或类别)的图像。

ACGAN的生成器接收随机噪声$z$和随机标签$c$,生成给定标签$c$时的图像$G(z|c)$;判别器$D(x)$接收图像$x$,判断图像$x$是否为真实图像(二分类)以及是否属于对应的标签$c$ (多分类)。

img_shape = (opt.channels, opt.img_size, opt.img_size)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # nn.Embedding(num_embeddings, embedding_dim)
        self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim + opt.n_classes, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        # Concatenate label embedding and image to produce input
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), *img_shape)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 128),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
        )

        # Output layers
        self.adv_layer = nn.Sequential(nn.Linear(128, 1), nn.Sigmoid())
        self.aux_layer = nn.Sequential(nn.Linear(128, opt.n_classes), nn.Softmax())

    def forward(self, img):
        out = self.model(img.view(img.size(0), -1))
        validity = self.adv_layer(out)
        label = self.aux_layer(out)
        return validity, label

ACGAN的目标函数如下:

\[\begin{aligned} \mathop{\max}_{D} & \Bbb{E}_{x \text{~} P_{data}(x)}[\log D(x)] + \Bbb{E}_{z \text{~} P_{Z}(z)}[\log(1-D(G(z|c)))] \\ & + \Bbb{E}_{c,x \text{~} P_{data}(x)}[\log D_c(x)]+ \Bbb{E}_{z \text{~} P_{Z}(z)}[\log D_c(G(z|c))] \\ \mathop{ \min}_{G} & \Bbb{E}_{z \text{~} P_{Z}(z)}[\log(D(G(z|c))] - \Bbb{E}_{z \text{~} P_{Z}(z)}[\log D_c(G(z|c))] \end{aligned}\]

ACGAN的完整pytorch实现可参考PyTorch-GAN,下面给出其损失函数的计算和参数更新过程:

# Loss functions
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()

for epoch in range(opt.n_epochs):
    for i, (real_imgs, labels) in enumerate(dataloader):
        # real_imgs.type() == torch.FloatTensor
        # labels.type() == torch.LongTensor

        # Adversarial ground truths
        valid = torch.ones(real_imgs.shape[0], 1).requires_grad_.(False)
        fake = torch.zeros(real_imgs.shape[0], 1).requires_grad_.(False)

        # Sample noise and labels as generator input
        z = torch.randn(real_imgs.shape[0], latent_dim)
        gen_labels = torch.randint(low=0, high=opt.n_classes, size=(real_imgs.shape[0]))

        # Generate a batch of images
        gen_imgs = generator(z, gen_labels)

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        # Loss for real images
        real_pred, real_aux = discriminator(real_imgs)
        d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2

        # Loss for fake images
        fake_pred, fake_aux = discriminator(gen_imgs.detach())
        d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()

        # Loss measures generator's ability to fool the discriminator
        validity, pred_label = discriminator(gen_imgs)
        g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))
        g_loss.backward()
        optimizer_G.step()