WGAN-GP:在WGAN中引入梯度惩罚.

1. WGAN与Lipschitz约束

Wasserstein GAN中,作者采用Wasserstein距离构造了GAN的目标函数,优化目标为真实分布\(P_{data}\)和生成分布$P_G$之间的Wasserstein距离:

\[\mathop{\min}_{G} \mathop{\max}_{D, ||D||_L \leq K} \{ \Bbb{E}_{x \text{~} P_{data}(x)}[D(x)]-\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \}\]

或写作交替优化的形式:

\[\begin{aligned} θ_D &\leftarrow \mathop{\arg \max}_{\theta_D} \frac{1}{n} \sum_{i=1}^{n} { D(x^i)} - \frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \\ \theta_G &\leftarrow \mathop{\arg \min}_{\theta_G} -\frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \end{aligned}\]

其中要求判别器$D$是$K$阶Lipschitz连续的,即应满足:

\[| D(x_1)-D(x_2) | ≤K | x_1-x_2 |\]

Lipschitz连续性保证了函数的输出变化相对输入变化是缓慢的。若没有该限制,优化过程可能会使函数的输出趋向正负无穷。

⚪ Lipschitz连续性

一般地,一个实值函数$f$是$K$阶Lipschitz连续的,是指存在一个实数$K\geq 0$,使得对\(\forall x_1,x_2 \in \Bbb{R}\),有:

\[| f(x_1)-f(x_2) | ≤K | x_1-x_2 |\]

通常一个连续可微函数满足Lipschitz连续,这是因为其微分(用$\frac{|f(x_1)-f(x_2)|}{|x_1-x_2|}$近似)是有界的。但是一个Lipschitz连续函数不一定是处处可微的,比如$f(x) = |x|$。

⚪ 实现Lipschitz连续性

在实践中把判别器$D(\cdot)$约束为Lipschitz连续函数是比较困难的。

WGAN中,通过weight clipping实现该约束:在每次梯度更新后,把判别器$D$的参数$θ_D$的取值限制在$[-c,c]$之间($c$常取$0.01$):

\[\begin{aligned} θ_D &\leftarrow\text{clip}(\theta_D,-c,c) \end{aligned}\]

然而该做法也有一些问题。若$c$值取得太大,则模型训练容易不稳定,收敛速度慢;若$c$值取得太小,则容易造成梯度消失。

2. WGAN-GP

本文作者提出引入梯度惩罚项(gradient penalty)来实现Lipschitz约束。

若$D$是$1$阶Lipschitz函数,则对\(\forall x_1,x_2 \in \Bbb{R}\),有:

\[| D(x_1)-D(x_2) | ≤ | x_1-x_2 |\]

将其作为惩罚项引入判别器的目标函数:

\[\begin{aligned} θ_D \leftarrow \mathop{\arg \max}_{\theta_D}& \frac{1}{n} \sum_{i=1}^{n} { D(x^i)} - \frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \\&- \frac{\lambda}{n} \sum_{i,j}^{} \max(\frac{| D(x_i)-D(x_j) |}{| x_i-x_j |},1) \end{aligned}\]

或写作:

\[\begin{aligned} θ_D \leftarrow \mathop{\arg \max}_{\theta_D} &\frac{1}{n} \sum_{i=1}^{n} { D(x^i)} - \frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))}\\& - \frac{\lambda}{n} \sum_{i,j}^{} (\frac{| D(x_i)-D(x_j) |}{| x_i-x_j |}-1)^2 \end{aligned}\]

上式引入的差分形式的惩罚项计算量比较大;因此考虑将其替换为梯度形式;即约束$D$在任意位置的梯度的模小于等于$1$。

\[\begin{aligned} θ_D \leftarrow \mathop{\arg \max}_{\theta_D}& \frac{1}{n} \sum_{i=1}^{n} { D(x^i)} - \frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \\&- \lambda \max(||\frac{\partial D(x)}{\partial x}||,1) \end{aligned}\]

或写作:

\[\begin{aligned} θ_D \leftarrow \mathop{\arg \max}_{\theta_D} &\frac{1}{n} \sum_{i=1}^{n} { D(x^i)} - \frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \\&- \lambda (||\frac{\partial D(x)}{\partial x}||-1)^2 \end{aligned}\]

理论上应该对$D(x)$的所有自变量取值进行计算并取平均,在实践中采用对真实样本和生成样本之间的随机插值:

\(P_{penalty}\)定义为从\(P_{data}\)和\(P_G\)中各抽取一个样本,再在其连线上抽取的样本。这样的操作是合理的,因为直观上,优化过程是使\(P_G\)靠近\(P_{data}\),样本点大多从这两个分布之间选取,而不是整个空间。

当真实样本的类别数比较多时,梯度惩罚的效果比较差。这是因为线性插值的梯度惩罚只能保证在一小块数据空间上满足,当类别数比较多时,不同类别之间进行插值往往会落在满足Lipschitz约束的空间之外。

最终使用梯度惩罚实现Lipschitz约束的Wasserstein Generative Adversarial Nets - Gradient Penalty (WGAN-GP)的判别器目标为:

\[\begin{aligned} θ_D \leftarrow & \mathop{\arg \max}_{\theta_D} \frac{1}{n} \sum_{i=1}^{n} { D(x^i)} - \frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \\& - \frac{\lambda}{n} \sum_{i=1}^{n} \max(||\frac{\partial D(x)}{\partial x}||_{x = \epsilon_ix^i+(1-\epsilon_i)G(z^i)},1) \end{aligned}\]

或写作:

\[\begin{aligned} θ_D \leftarrow & \mathop{\arg \max}_{\theta_D} \frac{1}{n} \sum_{i=1}^{n} { D(x^i)} - \frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \\ &- \frac{\lambda}{n} \sum_{i=1}^{n} (||\frac{\partial D(x)}{\partial x}||_{x = \epsilon_ix^i+(1-\epsilon_i)G(z^i)}-1)^2 \end{aligned}\]

其中$\epsilon_i$是从$U[0,1]$中采样的随机数。

WGAN-GP的完整pytorch实现可参考PyTorch-GAN

下面给出梯度惩罚项的计算过程。可以使用torch.autograd.grad()方法实现网络对输入变量的求导。

def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    epsilon = torch.rand(real_samples.size(0), 1, 1, 1).requires_grad_(False)
    # Get random interpolation between real and fake samples
    interpolates = (epsilon * real_samples + ((1 - epsilon) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates).requires_grad_(False),
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

下面给出参数更新过程:

for epoch in range(opt.n_epochs):
    for i, real_imgs in enumerate(dataloader):

        z = torch.randn(real_imgs.shape[0], opt.latent_dim) 
        gen_imgs = generator(z)            

        # 训练判别器
        optimizer_D.zero_grad()
        # 真实图像得分
        real_validity = discriminator(real_imgs)
        # 生成图像得分
        gen_validity = discriminator(gen_imgs.detach())
        # 梯度惩罚项
        gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, gen_imgs.data)
        d_loss = -torch.mean(real_validity) + torch.mean(gen_validity) + opt.lambda_gp*gradient_penalty
        d_loss.backward()
        optimizer_D.step()

        # 训练生成器
        if i % opt.d_iter == 0:
           optimizer_G.zero_grad()
            g_loss = -torch.mean(discriminator(gen_imgs))
            g_loss.backward()
            optimizer_G.step()