CGAN:条件生成对抗网络.
conditional GAN可以生成给定条件(标签或类别)的图像。
CGAN的生成器接收随机噪声$z$和随机标签$y$,生成给定标签$y$时的图像$G(z|y)$:
img_shape = (opt.channels, opt.img_size, opt.img_size)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
# nn.Embedding(num_embeddings, embedding_dim)
self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(opt.latent_dim + opt.n_classes, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, noise, labels):
# Concatenate label embedding and image to produce input
gen_input = torch.cat((self.label_emb(labels), noise), -1)
img = self.model(gen_input)
img = img.view(img.size(0), *img_shape)
return img
CGAN的判别器接收图像$x$和对应的标签$y$,判断图像$x$是否为给定标签$y$时的真实图像$D(x|y)$:
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)
self.model = nn.Sequential(
nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.Dropout(0.4),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 128),
nn.Dropout(0.4),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(128, 1),
nn.Sigmoid(),
)
def forward(self, img, labels):
# Concatenate label embedding and image to produce input
d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
validity = self.model(d_in)
return validity
CGAN的目标函数如下:
\[\begin{aligned} \mathop{ \min}_{G} \mathop{\max}_{D} \Bbb{E}_{x \text{~} P_{data}(x)}[\log D(x|y)] + \Bbb{E}_{z \text{~} P_{Z}(z)}[\log(1-D(G(z|y)))] \end{aligned}\]CGAN的完整pytorch实现可参考PyTorch-GAN,下面给出其损失函数的计算和参数更新过程:
# Loss functions
adversarial_loss = torch.nn.BCELoss()
for epoch in range(opt.n_epochs):
for i, (real_imgs, labels) in enumerate(dataloader):
# real_imgs.type() == torch.FloatTensor
# labels.type() == torch.LongTensor
# Adversarial ground truths
valid = torch.ones(real_imgs.shape[0], 1).requires_grad_.(False)
fake = torch.zeros(real_imgs.shape[0], 1).requires_grad_.(False)
# Sample noise and labels as generator input
z = torch.randn(real_imgs.shape[0], latent_dim)
gen_labels = torch.randint(low=0, high=opt.n_classes, size=(real_imgs.shape[0]))
# Generate a batch of images
gen_imgs = generator(z, gen_labels)
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Loss for real images
validity_real = discriminator(real_imgs, labels)
d_real_loss = adversarial_loss(validity_real, valid)
# Loss for fake images
validity_fake = discriminator(gen_imgs.detach(), gen_labels)
d_fake_loss = adversarial_loss(validity_fake, fake)
# Total discriminator loss
d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# Loss measures generator's ability to fool the discriminator
validity = discriminator(gen_imgs, gen_labels)
g_loss = adversarial_loss(validity, valid)
g_loss.backward()
optimizer_G.step()