DRAGAN:调整梯度惩罚的插值空间.
在WGAN-GP中,作者采用Wasserstein距离构造了GAN的目标函数,优化目标为真实分布\(P_{data}\)和生成分布$P_G$之间的Wasserstein距离:
\[\mathop{\min}_{G} \mathop{\max}_{D, ||D||_L \leq 1} \{ \Bbb{E}_{x \text{~} P_{data}(x)}[D(x)]-\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \}\]其中要求判别器$D$是$1$阶Lipschitz连续的。作者引入梯度惩罚项(gradient penalty)来实现Lipschitz约束。
若$D$是$1$阶Lipschitz函数,则对\(\forall x_1,x_2 \in \Bbb{R}\),有:
\[| D(x_1)-D(x_2) | ≤ | x_1-x_2 |\]或写作:
\[\nabla_x D(x) ≈ \frac{|D(x_1)-D(x_2)|}{|x_1-x_2|} \leq 1\]则WGAN-GP的目标函数为:
\[\begin{aligned} \mathop{ \max}_{D} & \Bbb{E}_{x \text{~} P_{data}(x)}[D(x)]-\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \\ & - λ \Bbb{E}_{x \text{~} \epsilon P_{data}(x) + (1-\epsilon)P_{G}(x) }[(|| \nabla_xD(x) || -1)^2] \\ \mathop{ \min}_{G}& -\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \end{aligned}\]理论上应该对$D(x)$的所有自变量取值进行计算并取平均,在实践中采用对真实样本和生成样本之间的随机插值:
$P_{penalty}$定义为从\(P_{data}\)和\(P_G\)中各抽取一个样本,再在其连线上抽取的样本。这样的操作是合理的,因为直观上,优化过程是使\(P_G\)靠近\(P_{data}\),样本点大多从这两个分布之间选取,而不是整个空间。
本文作者分析WGAN-GP的训练过程后发现,由于$D(x)$的取值范围变化比较大,导致出现梯度迅速上升的情况,从而导致模型训练不稳定。因此作者提出,梯度惩罚项只通过真实数据构造,在距离真实数据$x$足够近的一个邻域上计算梯度:
\[\begin{aligned} \mathop{ \max}_{D} & \Bbb{E}_{x \text{~} P_{data}(x)}[D(x)]-\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \\ & - λ \Bbb{E}_{x \text{~} P_{data}(x), \delta \text{~} N(0,cI) }[(|| \nabla_xD(x+\delta) || -k)^2] \\ \mathop{ \min}_{G}& -\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \end{aligned}\]实验选择$\lambda=10,k=1,c=10$。
DRAGAN的完整pytorch实现可参考PyTorch-GAN。
下面给出梯度惩罚项的计算过程。可以使用torch.autograd.grad()方法实现网络对输入变量的求导。
def compute_gradient_penalty(D, X):
"""Calculates the gradient penalty loss for DRAGAN"""
# Get random interpolation
interpolates = (X + sqrt_c * torch.randn(X.size())).requires_grad_(True)
d_interpolates = D(interpolates)
# Get gradient w.r.t. interpolates
gradients = autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(d_interpolates).requires_grad_(False),
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
下面给出参数更新过程:
for epoch in range(opt.n_epochs):
for i, real_imgs in enumerate(dataloader):
z = torch.randn(real_imgs.shape[0], opt.latent_dim)
gen_imgs = generator(z)
# 训练判别器
optimizer_D.zero_grad()
# 真实图像得分
real_validity = discriminator(real_imgs)
# 生成图像得分
gen_validity = discriminator(gen_imgs.detach())
# 梯度惩罚项
gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data)
d_loss = -torch.mean(real_validity) + torch.mean(gen_validity) + opt.lambda_gp*gradient_penalty
d_loss.backward()
optimizer_D.step()
# 训练生成器
if i % opt.d_iter == 0:
optimizer_G.zero_grad()
g_loss = -torch.mean(discriminator(gen_imgs))
g_loss.backward()
optimizer_G.step()