通过生成对抗网络进行半监督学习.
本文使用GAN进行半监督学习。将原有的监督学习任务融合到GAN的判别器中,判别器同时实现数据真伪的判断和数据的分类;由生成器生成数据的标签是未知的,在原有类别的基础上多加一类作为生成数据的类别标签。
\[\begin{aligned} \mathop{ \min}_{G} \mathop{\max}_{D} & \Bbb{E}_{x \text{~} P_{data}(x)}[\log D_r(x) + D_c(\hat{y}=y |x)] \\ & + \Bbb{E}_{z \text{~} P(z)} [\log(1-D_r(G(z)))+D_c(\hat{y}=y' |G(z))] \end{aligned}\]Semi-Supervised GAN的完整pytorch实现可参考PyTorch-GAN,下面给出其损失函数的计算和参数更新过程:
# Loss function
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
for epoch in range(opt.n_epochs):
for i, (real_imgs, labels) in enumerate(dataloader):
batch_size = real_imgs.shape[0]
# Adversarial ground truths
valid = torch.ones(batch_size).requires_grad_(False)
fake = torch.zeros(batch_size).requires_grad_(False)
fake_aux_gt = torch.empty(
batch_size, dtype=torch.long32, requires_grad=False
).fill_(opt.num_classes)
# Sample noise and labels as generator input
z = torch.randn((batch_size, opt.latent_dim))
# Generate a batch of images
gen_imgs = generator(z)
# -----------------------
# Train Discriminator
# -----------------------
optimizer_D.zero_grad()
# Loss for real images
real_pred, real_aux = discriminator(real_imgs)
d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2
# Loss for fake images
fake_pred, fake_aux = discriminator(gen_imgs.detach())
d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, fake_aux_gt)) / 2
# Total discriminator loss
d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# -------------------------------
# Train Generator
# -------------------------------
optimizer_G.zero_grad()
# Loss measures generator's ability to fool the discriminator
validity, _ = discriminator(gen_imgs)
g_loss = adversarial_loss(validity, valid)
g_loss.backward()
optimizer_G.step()