DRAGAN:调整梯度惩罚的插值空间.

WGAN-GP中,作者采用Wasserstein距离构造了GAN的目标函数,优化目标为真实分布\(P_{data}\)和生成分布$P_G$之间的Wasserstein距离:

\[\mathop{\min}_{G} \mathop{\max}_{D, ||D||_L \leq 1} \{ \Bbb{E}_{x \text{~} P_{data}(x)}[D(x)]-\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \}\]

其中要求判别器$D$是$1$阶Lipschitz连续的。作者引入梯度惩罚项(gradient penalty)来实现Lipschitz约束。

若$D$是$1$阶Lipschitz函数,则对\(\forall x_1,x_2 \in \Bbb{R}\),有:

\[| D(x_1)-D(x_2) | ≤ | x_1-x_2 |\]

或写作:

\[\nabla_x D(x) ≈ \frac{|D(x_1)-D(x_2)|}{|x_1-x_2|} \leq 1\]

WGAN-GP的目标函数为:

\[\begin{aligned} \mathop{ \max}_{D} & \Bbb{E}_{x \text{~} P_{data}(x)}[D(x)]-\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \\ & - λ \Bbb{E}_{x \text{~} \epsilon P_{data}(x) + (1-\epsilon)P_{G}(x) }[(|| \nabla_xD(x) || -1)^2] \\ \mathop{ \min}_{G}& -\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \end{aligned}\]

理论上应该对$D(x)$的所有自变量取值进行计算并取平均,在实践中采用对真实样本和生成样本之间的随机插值:

$P_{penalty}$定义为从\(P_{data}\)和\(P_G\)中各抽取一个样本,再在其连线上抽取的样本。这样的操作是合理的,因为直观上,优化过程是使\(P_G\)靠近\(P_{data}\),样本点大多从这两个分布之间选取,而不是整个空间。

本文作者分析WGAN-GP的训练过程后发现,由于$D(x)$的取值范围变化比较大,导致出现梯度迅速上升的情况,从而导致模型训练不稳定。因此作者提出,梯度惩罚项只通过真实数据构造,在距离真实数据$x$足够近的一个邻域上计算梯度:

\[\begin{aligned} \mathop{ \max}_{D} & \Bbb{E}_{x \text{~} P_{data}(x)}[D(x)]-\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \\ & - λ \Bbb{E}_{x \text{~} P_{data}(x), \delta \text{~} N(0,cI) }[(|| \nabla_xD(x+\delta) || -k)^2] \\ \mathop{ \min}_{G}& -\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \end{aligned}\]

实验选择$\lambda=10,k=1,c=10$。

DRAGAN的完整pytorch实现可参考PyTorch-GAN

下面给出梯度惩罚项的计算过程。可以使用torch.autograd.grad()方法实现网络对输入变量的求导。

def compute_gradient_penalty(D, X):
    """Calculates the gradient penalty loss for DRAGAN"""
    # Get random interpolation
    interpolates = (X + sqrt_c * torch.randn(X.size())).requires_grad_(True)
    d_interpolates = D(interpolates)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates).requires_grad_(False),
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

下面给出参数更新过程:

for epoch in range(opt.n_epochs):
    for i, real_imgs in enumerate(dataloader):

        z = torch.randn(real_imgs.shape[0], opt.latent_dim) 
        gen_imgs = generator(z)            

        # 训练判别器
        optimizer_D.zero_grad()
        # 真实图像得分
        real_validity = discriminator(real_imgs)
        # 生成图像得分
        gen_validity = discriminator(gen_imgs.detach())
        # 梯度惩罚项
        gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data)
        d_loss = -torch.mean(real_validity) + torch.mean(gen_validity) + opt.lambda_gp*gradient_penalty
        d_loss.backward()
        optimizer_D.step()

        # 训练生成器
        if i % opt.d_iter == 0:
           optimizer_G.zero_grad()
            g_loss = -torch.mean(discriminator(gen_imgs))
            g_loss.backward()
            optimizer_G.step()