通过生成对抗网络进行半监督学习.

本文使用GAN进行半监督学习。将原有的监督学习任务融合到GAN的判别器中,判别器同时实现数据真伪的判断和数据的分类;由生成器生成数据的标签是未知的,在原有类别的基础上多加一类作为生成数据的类别标签。

\[\begin{aligned} \mathop{ \min}_{G} \mathop{\max}_{D} & \Bbb{E}_{x \text{~} P_{data}(x)}[\log D_r(x) + D_c(\hat{y}=y |x)] \\ & + \Bbb{E}_{z \text{~} P(z)} [\log(1-D_r(G(z)))+D_c(\hat{y}=y' |G(z))] \end{aligned}\]

Semi-Supervised GAN的完整pytorch实现可参考PyTorch-GAN,下面给出其损失函数的计算和参数更新过程:

# Loss function
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

for epoch in range(opt.n_epochs):
    for i, (real_imgs, labels) in enumerate(dataloader):
        batch_size = real_imgs.shape[0]
        # Adversarial ground truths
        valid = torch.ones(batch_size).requires_grad_(False)
        fake = torch.zeros(batch_size).requires_grad_(False)
        fake_aux_gt = torch.empty(
            batch_size, dtype=torch.long32, requires_grad=False
            ).fill_(opt.num_classes)

        # Sample noise and labels as generator input
        z = torch.randn((batch_size, opt.latent_dim))
        # Generate a batch of images
        gen_imgs = generator(z)      

        # -----------------------
        #  Train Discriminator
        # -----------------------
        optimizer_D.zero_grad()

        # Loss for real images
        real_pred, real_aux = discriminator(real_imgs)
        d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2

        # Loss for fake images
        fake_pred, fake_aux = discriminator(gen_imgs.detach())
        d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, fake_aux_gt)) / 2

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # -------------------------------
        #  Train Generator
        # -------------------------------
        optimizer_G.zero_grad()

        # Loss measures generator's ability to fool the discriminator
        validity, _ = discriminator(gen_imgs)
        g_loss = adversarial_loss(validity, valid)
        g_loss.backward()
        optimizer_G.step()