把生成对抗网络建模为Softmax函数.
- paper:Softmax GAN
本文作者把生成对抗网络建模为Softmax函数。具体地,若记共有$|P|=|P_{data}|+|P_G|$个样本,包括$|P_{data}|$个真实样本和$|P_G|$个生成样本。对于每个样本$x$,使用判别器$D(x)$计算logits,并通过Softmax函数进行建模:
\[P(x) = \frac{e^{-D(x)}}{\sum_x e^{-D(x)}} = \frac{e^{-D(x)}}{Z_P}\]对于判别器$D(x)$,希望其能正确地区分真实样本和生成样本。因此判别器将目标概率均等地分配给$|P|$中的所有真实样本,而生成样本的目标概率为$0$;则判别器学习的目标分布为:
\[T(x) = \begin{cases} \frac{1}{|P_{data}|}, & \text{if } x \in P_{data}(x) \\ 0, & \text{if } x \in P_G(x) \end{cases}\]构造交叉熵损失函数:
\[\begin{aligned} L_D &= - \Bbb{E}_{x \text{~} P(x)} [T(x) \log P(x)] \\ &= - \Bbb{E}_{x \text{~} P(x)} [T(x) \log \frac{e^{-D(x)}}{Z_P}] \\ &=- \Bbb{E}_{x \text{~} P_{data}(x)} [\frac{1}{|P_{data}|} \log \frac{e^{-D(x)}}{Z_P}] \\ &= \frac{1}{|P_{data}|}\Bbb{E}_{x \text{~} P_{data}(x)} [ D(x)] + \log Z_P \end{aligned}\]对于生成器$G(x)$,希望其生成的样本足够接近真实样本。因此生成器将概率平均分配给所有样本;则生成器学习的目标分布为一个均匀分布:
\[T(x) = \frac{1}{|P|}= \frac{1}{|P_{data}|+|P_{G}|}\]构造交叉熵损失函数:
\[\begin{aligned} L_G &= - \Bbb{E}_{x \text{~} P(x)} [T(x) \log P(x)] \\ &= - \Bbb{E}_{x \text{~} P(x)} [\frac{1}{|P_{data}|+|P_{G}|} \log \frac{e^{-D(x)}}{Z_P}] \\ &= \frac{1}{|P_{data}|+|P_{G}|}\Bbb{E}_{x \text{~} P(x)} [ D(x)] + \log Z_P \\ &= \frac{1}{|P_{data}|+|P_{G}|}(\Bbb{E}_{x \text{~} P_{data}(x)} [ D(x)]+\Bbb{E}_{x \text{~} P_G(x)} [ D(x)] )+ \log Z_P \end{aligned}\]Softmax GAN的完整目标函数如下:
\[\begin{aligned} & \mathop{ \min}_{D} \frac{1}{|P_{data}|}\Bbb{E}_{x \text{~} P_{data}(x)} [ D(x)] + \log Z_P \\ & \mathop{ \min}_{G} \frac{1}{|P_{data}|+|P_{G}|}(\Bbb{E}_{x \text{~} P_{data}(x)} [ D(x)]+\Bbb{E}_{x \text{~} P_G(x)} [ D(x)] )+ \log Z_P \end{aligned}\]Softmax GAN的完整pytorch实现可参考PyTorch-GAN,下面给出损失函数计算和参数更新过程:
for epoch in range(opt.n_epochs):
for i, real_imgs in enumerate(dataloader):
batch_size = real_imgs.shape[0]
# Adversarial ground truths
g_target = 1 / (batch_size * 2)
d_target = 1 / batch_size
z = torch.randn(batch_size, opt.latent_dim)
gen_imgs = generator(z)
d_real = discriminator(real_imgs)
d_fake = discriminator(gen_imgs)
# Partition function
Z = torch.sum(torch.exp(-d_real)) + torch.sum(torch.exp(-d_fake))
# 训练判别器
optimizer_D.zero_grad()
d_loss = d_target * torch.sum(d_real) + log(Z)
d_loss.backward(retain_graph=True)
optimizer_D.step()
# 训练生成器
optimizer_G.zero_grad()
g_loss = g_target * (torch.sum(d_real) + torch.sum(d_fake)) + log(Z)
g_loss.backward()
optimizer_G.step()