把生成对抗网络建模为Softmax函数.
- paper:Softmax GAN
本文作者把生成对抗网络建模为Softmax函数。具体地,若记共有个样本,包括个真实样本和个生成样本。对于每个样本,使用判别器计算logits,并通过Softmax函数进行建模:
对于判别器,希望其能正确地区分真实样本和生成样本。因此判别器将目标概率均等地分配给中的所有真实样本,而生成样本的目标概率为;则判别器学习的目标分布为:
构造交叉熵损失函数:
对于生成器,希望其生成的样本足够接近真实样本。因此生成器将概率平均分配给所有样本;则生成器学习的目标分布为一个均匀分布:
构造交叉熵损失函数:
Softmax GAN的完整目标函数如下:
Softmax GAN的完整pytorch实现可参考PyTorch-GAN,下面给出损失函数计算和参数更新过程:
for epoch in range(opt.n_epochs):
for i, real_imgs in enumerate(dataloader):
batch_size = real_imgs.shape[0]
# Adversarial ground truths
g_target = 1 / (batch_size * 2)
d_target = 1 / batch_size
z = torch.randn(batch_size, opt.latent_dim)
gen_imgs = generator(z)
d_real = discriminator(real_imgs)
d_fake = discriminator(gen_imgs)
# Partition function
Z = torch.sum(torch.exp(-d_real)) + torch.sum(torch.exp(-d_fake))
# 训练判别器
optimizer_D.zero_grad()
d_loss = d_target * torch.sum(d_real) + log(Z)
d_loss.backward(retain_graph=True)
optimizer_D.step()
# 训练生成器
optimizer_G.zero_grad()
g_loss = g_target * (torch.sum(d_real) + torch.sum(d_fake)) + log(Z)
g_loss.backward()
optimizer_G.step()
Related Issues not found
Please contact @0809zheng to initialize the comment