把生成对抗网络建模为Softmax函数.

本文作者把生成对抗网络建模为Softmax函数。具体地,若记共有|P|=|Pdata|+|PG|个样本,包括|Pdata|个真实样本和|PG|个生成样本。对于每个样本x,使用判别器D(x)计算logits,并通过Softmax函数进行建模:

P(x)=eD(x)xeD(x)=eD(x)ZP

对于判别器D(x),希望其能正确地区分真实样本和生成样本。因此判别器将目标概率均等地分配给|P|中的所有真实样本,而生成样本的目标概率为0;则判别器学习的目标分布为:

T(x)={1|Pdata|,if xPdata(x)0,if xPG(x)

构造交叉熵损失函数:

LD=Ex~P(x)[T(x)logP(x)]=Ex~P(x)[T(x)logeD(x)ZP]=Ex~Pdata(x)[1|Pdata|logeD(x)ZP]=1|Pdata|Ex~Pdata(x)[D(x)]+logZP

对于生成器G(x),希望其生成的样本足够接近真实样本。因此生成器将概率平均分配给所有样本;则生成器学习的目标分布为一个均匀分布:

T(x)=1|P|=1|Pdata|+|PG|

构造交叉熵损失函数:

LG=Ex~P(x)[T(x)logP(x)]=Ex~P(x)[1|Pdata|+|PG|logeD(x)ZP]=1|Pdata|+|PG|Ex~P(x)[D(x)]+logZP=1|Pdata|+|PG|(Ex~Pdata(x)[D(x)]+Ex~PG(x)[D(x)])+logZP

Softmax GAN的完整目标函数如下:

minD1|Pdata|Ex~Pdata(x)[D(x)]+logZPminG1|Pdata|+|PG|(Ex~Pdata(x)[D(x)]+Ex~PG(x)[D(x)])+logZP

Softmax GAN的完整pytorch实现可参考PyTorch-GAN,下面给出损失函数计算和参数更新过程:

for epoch in range(opt.n_epochs):
    for i, real_imgs in enumerate(dataloader):
        batch_size = real_imgs.shape[0]
        # Adversarial ground truths
        g_target = 1 / (batch_size * 2)
        d_target = 1 / batch_size

        z = torch.randn(batch_size, opt.latent_dim) 
        gen_imgs = generator(z)  

        d_real = discriminator(real_imgs)
        d_fake = discriminator(gen_imgs)                  

        # Partition function
        Z = torch.sum(torch.exp(-d_real)) + torch.sum(torch.exp(-d_fake))

        # 训练判别器
        optimizer_D.zero_grad()
        d_loss = d_target * torch.sum(d_real) + log(Z)
        d_loss.backward(retain_graph=True)
        optimizer_D.step()

        # 训练生成器
        optimizer_G.zero_grad()
        g_loss = g_target * (torch.sum(d_real) + torch.sum(d_fake)) + log(Z)
        g_loss.backward()
        optimizer_G.step()