WGAN-GP:在WGAN中引入梯度惩罚.
1. WGAN与Lipschitz约束
在Wasserstein GAN中,作者采用Wasserstein距离构造了GAN的目标函数,优化目标为真实分布\(P_{data}\)和生成分布$P_G$之间的Wasserstein距离:
\[\mathop{\min}_{G} \mathop{\max}_{D, ||D||_L \leq K} \{ \Bbb{E}_{x \text{~} P_{data}(x)}[D(x)]-\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \}\]或写作交替优化的形式:
\[\begin{aligned} θ_D &\leftarrow \mathop{\arg \max}_{\theta_D} \frac{1}{n} \sum_{i=1}^{n} { D(x^i)} - \frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \\ \theta_G &\leftarrow \mathop{\arg \min}_{\theta_G} -\frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \end{aligned}\]其中要求判别器$D$是$K$阶Lipschitz连续的,即应满足:
\[| D(x_1)-D(x_2) | ≤K | x_1-x_2 |\]Lipschitz连续性保证了函数的输出变化相对输入变化是缓慢的。若没有该限制,优化过程可能会使函数的输出趋向正负无穷。
⚪ Lipschitz连续性
一般地,一个实值函数$f$是$K$阶Lipschitz连续的,是指存在一个实数$K\geq 0$,使得对\(\forall x_1,x_2 \in \Bbb{R}\),有:
\[| f(x_1)-f(x_2) | ≤K | x_1-x_2 |\]通常一个连续可微函数满足Lipschitz连续,这是因为其微分(用$\frac{|f(x_1)-f(x_2)|}{|x_1-x_2|}$近似)是有界的。但是一个Lipschitz连续函数不一定是处处可微的,比如$f(x) = |x|$。
⚪ 实现Lipschitz连续性
在实践中把判别器$D(\cdot)$约束为Lipschitz连续函数是比较困难的。
在WGAN中,通过weight clipping实现该约束:在每次梯度更新后,把判别器$D$的参数$θ_D$的取值限制在$[-c,c]$之间($c$常取$0.01$):
\[\begin{aligned} θ_D &\leftarrow\text{clip}(\theta_D,-c,c) \end{aligned}\]然而该做法也有一些问题。若$c$值取得太大,则模型训练容易不稳定,收敛速度慢;若$c$值取得太小,则容易造成梯度消失。
2. WGAN-GP
本文作者提出引入梯度惩罚项(gradient penalty)来实现Lipschitz约束。
若$D$是$1$阶Lipschitz函数,则对\(\forall x_1,x_2 \in \Bbb{R}\),有:
\[| D(x_1)-D(x_2) | ≤ | x_1-x_2 |\]将其作为惩罚项引入判别器的目标函数:
\[\begin{aligned} θ_D \leftarrow \mathop{\arg \max}_{\theta_D}& \frac{1}{n} \sum_{i=1}^{n} { D(x^i)} - \frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \\&- \frac{\lambda}{n} \sum_{i,j}^{} \max(\frac{| D(x_i)-D(x_j) |}{| x_i-x_j |},1) \end{aligned}\]或写作:
\[\begin{aligned} θ_D \leftarrow \mathop{\arg \max}_{\theta_D} &\frac{1}{n} \sum_{i=1}^{n} { D(x^i)} - \frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))}\\& - \frac{\lambda}{n} \sum_{i,j}^{} (\frac{| D(x_i)-D(x_j) |}{| x_i-x_j |}-1)^2 \end{aligned}\]上式引入的差分形式的惩罚项计算量比较大;因此考虑将其替换为梯度形式;即约束$D$在任意位置的梯度的模小于等于$1$。
\[\begin{aligned} θ_D \leftarrow \mathop{\arg \max}_{\theta_D}& \frac{1}{n} \sum_{i=1}^{n} { D(x^i)} - \frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \\&- \lambda \max(||\frac{\partial D(x)}{\partial x}||,1) \end{aligned}\]或写作:
\[\begin{aligned} θ_D \leftarrow \mathop{\arg \max}_{\theta_D} &\frac{1}{n} \sum_{i=1}^{n} { D(x^i)} - \frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \\&- \lambda (||\frac{\partial D(x)}{\partial x}||-1)^2 \end{aligned}\]理论上应该对$D(x)$的所有自变量取值进行计算并取平均,在实践中采用对真实样本和生成样本之间的随机插值:
\(P_{penalty}\)定义为从\(P_{data}\)和\(P_G\)中各抽取一个样本,再在其连线上抽取的样本。这样的操作是合理的,因为直观上,优化过程是使\(P_G\)靠近\(P_{data}\),样本点大多从这两个分布之间选取,而不是整个空间。
当真实样本的类别数比较多时,梯度惩罚的效果比较差。这是因为线性插值的梯度惩罚只能保证在一小块数据空间上满足,当类别数比较多时,不同类别之间进行插值往往会落在满足Lipschitz约束的空间之外。
最终使用梯度惩罚实现Lipschitz约束的Wasserstein Generative Adversarial Nets - Gradient Penalty (WGAN-GP)的判别器目标为:
\[\begin{aligned} θ_D \leftarrow & \mathop{\arg \max}_{\theta_D} \frac{1}{n} \sum_{i=1}^{n} { D(x^i)} - \frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \\& - \frac{\lambda}{n} \sum_{i=1}^{n} \max(||\frac{\partial D(x)}{\partial x}||_{x = \epsilon_ix^i+(1-\epsilon_i)G(z^i)},1) \end{aligned}\]或写作:
\[\begin{aligned} θ_D \leftarrow & \mathop{\arg \max}_{\theta_D} \frac{1}{n} \sum_{i=1}^{n} { D(x^i)} - \frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \\ &- \frac{\lambda}{n} \sum_{i=1}^{n} (||\frac{\partial D(x)}{\partial x}||_{x = \epsilon_ix^i+(1-\epsilon_i)G(z^i)}-1)^2 \end{aligned}\]其中$\epsilon_i$是从$U[0,1]$中采样的随机数。
WGAN-GP的完整pytorch实现可参考PyTorch-GAN。
下面给出梯度惩罚项的计算过程。可以使用torch.autograd.grad()方法实现网络对输入变量的求导。
def compute_gradient_penalty(D, real_samples, fake_samples):
"""Calculates the gradient penalty loss for WGAN GP"""
# Random weight term for interpolation between real and fake samples
epsilon = torch.rand(real_samples.size(0), 1, 1, 1).requires_grad_(False)
# Get random interpolation between real and fake samples
interpolates = (epsilon * real_samples + ((1 - epsilon) * fake_samples)).requires_grad_(True)
d_interpolates = D(interpolates)
# Get gradient w.r.t. interpolates
gradients = autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(d_interpolates).requires_grad_(False),
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
下面给出参数更新过程:
for epoch in range(opt.n_epochs):
for i, real_imgs in enumerate(dataloader):
z = torch.randn(real_imgs.shape[0], opt.latent_dim)
gen_imgs = generator(z)
# 训练判别器
optimizer_D.zero_grad()
# 真实图像得分
real_validity = discriminator(real_imgs)
# 生成图像得分
gen_validity = discriminator(gen_imgs.detach())
# 梯度惩罚项
gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, gen_imgs.data)
d_loss = -torch.mean(real_validity) + torch.mean(gen_validity) + opt.lambda_gp*gradient_penalty
d_loss.backward()
optimizer_D.step()
# 训练生成器
if i % opt.d_iter == 0:
optimizer_G.zero_grad()
g_loss = -torch.mean(discriminator(gen_imgs))
g_loss.backward()
optimizer_G.step()