WGAN-div:通过Wasserstein散度构造GAN.

本文作者提出了WGAN-div,使用Wasserstein散度构造GAN的目标函数,既具有Wasserstein GAN (WGAN)的良好性质,又避免了Lipschitz约束的引入。

WGAN-div的目标函数的基本形式为:

\[\mathop{ \min}_{G} \mathop{ \max}_{D} \Bbb{E}_{x \text{~} P_{data}(x)}[D(x)]-\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] - \frac{1}{2} \Bbb{E}_{x \text{~} r(x) }[|| \nabla_xD(x) ||^2]\]

其中$r(x)$是一个非常宽松的分布。

⚪ Wasserstein散度

$p(x)$和$q(x)$之间的Wasserstein散度定义为:

\[D_{W_{k,p}}[p || q] = \mathop{ \max}_{f} \int_x p(x)f(x)dx - \int_x q(x)f(x)dx - k\int_x r(x) || \nabla_xf(x) ||^p dx\]

或写作采样形式:

\[D_{W_{k,p}}[p || q] = \mathop{ \max}_{f} \Bbb{E}_{x \text{~} p(x)}[f(x)] - \Bbb{E}_{x \text{~} q(x)}[f(x)] - k\Bbb{E}_{x \text{~} r(x)}[ || \nabla_xf(x) ||^p ]\]

其中$f(x)$是任意函数,$r(x)$是一个样本空间跟$p(x)$和$q(x)$一样的分布,$k>0, p > 1$。

Wasserstein散度具有以下性质:

⚪ WGAN-div的目标函数

WGAN-div的训练过程为:

\[\begin{aligned} D^* &\leftarrow\mathop{ \max}_{D} \Bbb{E}_{x \text{~} P_{data}(x)}[D(x)]-\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] - k\Bbb{E}_{x \text{~} r(x)}[ || \nabla_xD(x) ||^p ] \\ G^* &\leftarrow \mathop{ \min}_{G} -\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \end{aligned}\]

⚪ WGAN-div与WGAN-GP的对比

WGAN-GP的目标函数为:

\[\mathop{ \max}_{f} \Bbb{E}_{x \text{~} p(x)}[f(x)] - \Bbb{E}_{x \text{~} q(x)}[f(x)] - k\Bbb{E}_{x \text{~} r(x)}[ (|| \nabla_xf(x) ||-n)^p ]\]

取$n=1,p=2$时即为梯度惩罚项。然而上式不总是一个散度,这意味着WGAN-GP在训练判别器时并非总是在拉大两个分布的距离。

尽管WGAN-divWGAN-GP的目标函数非常类似,但前者具有理论保证,而后者只是一种经验方案。

⚪ 实验分析

作者通过超参数搜索确定了$k=2, p =6$时效果最好:

$r(x)$的选择非常宽松,作者进行了如下对比实验:

  1. 真假样本随机插值;
  2. 真样本之间随机插值,假样本之间随机插值;
  3. 真假样本混合后,随机选两个样本插值;
  4. 直接选原始的真假样本混合;
  5. 直接只选原始的假样本;
  6. 直接只选原始的真样本。

实验结果表明这几种情况下WGAN-div的表现都差不多,而WGAN-GP受到了明显的影响。

⚪ WGAN-div的pytorch实现

WGAN-div的完整pytorch实现可参考PyTorch-GAN

下面给出损失函数计算和参数更新过程。可以使用torch.autograd.grad()方法实现网络对输入变量的求导。其中$r(x)$选择原始的真假样本混合。

k, p = 2, 6
for epoch in range(opt.n_epochs):
    for i, real_imgs in enumerate(dataloader):

        # 注意真实图像和生成图像都需要梯度计算
        real_imgs.requires_grad_(True)
        z = torch.randn(real_imgs.shape[0], opt.latent_dim)
        fake_imgs = generator(z)

        # 训练判别器
        optimizer_D.zero_grad()
        # 真实图像得分
        real_validity = discriminator(real_imgs)
        # 生成图像得分
        fake_validity = discriminator(fake_imgs)

        # 计算真实样本的梯度范数
        real_grad_out = torch.ones_like(real_validity).requires_grad_(False)
        real_grad = autograd.grad(
            outputs=real_validity,inputs=real_imgs, grad_outputs=real_grad_out,
            create_graph=True, retain_graph=True, only_inputs=True
            )[0]
        real_grad_norm = real_grad.view(real_grad.size(0), -1).pow(2).sum(1) ** (p/2)

        # 计算生成样本的梯度范数 
        fake_grad_out = torch.ones_like(fake_validity).requires_grad_(False)
        fake_grad = autograd.grad(
            outputs=fake_validity,inputs=fake_imgs, grad_outputs=fake_grad_out,
            create_graph=True, retain_graph=True, only_inputs=True
            )[0]
        fake_grad_norm = fake_grad.view(fake_grad.size(0), -1).pow(2).sum(1) ** (p/2)

        # W散度惩罚项
        w_div = torch.mean(real_grad_norm+fake_grad_norm) * k / 2
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + w_div
        d_loss.backward(retain_graph=True)
        optimizer_D.step()

        # 训练生成器
        if i % opt.d_iter == 0:
            optimizer_G.zero_grad()
            g_loss = -torch.mean(discriminator(fake_imgs))
            g_loss.backward()
            optimizer_G.step()