GN-GAN:在WGAN中引入梯度归一化.
1. WGAN与Lipschitz约束
在Wasserstein GAN中,作者采用Wasserstein距离构造了GAN的目标函数,优化目标为真实分布\(P_{data}\)和生成分布$P_G$之间的Wasserstein距离:
\[\mathop{\min}_{G} \mathop{\max}_{D, ||D||_L \leq K} \{ \Bbb{E}_{x \text{~} P_{data}(x)}[D(x)]-\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \}\]或写作交替优化的形式:
\[\begin{aligned} θ_D &\leftarrow \mathop{\arg \max}_{\theta_D} \frac{1}{n} \sum_{i=1}^{n} { D(x^i)} - \frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \\ \theta_G &\leftarrow \mathop{\arg \min}_{\theta_G} -\frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \end{aligned}\]其中要求判别器$D$是$K$阶Lipschitz连续的,即应满足:
\[| D(x_1)-D(x_2) | ≤K | x_1-x_2 |\]Lipschitz连续性保证了函数的输出变化相对输入变化是缓慢的。若没有该限制,优化过程可能会使函数的输出趋向正负无穷。
⚪ Lipschitz连续性
一般地,一个实值函数$f$是$K$阶Lipschitz连续的,是指存在一个实数$K\geq 0$,使得对\(\forall x_1,x_2 \in \Bbb{R}\),有:
\[| f(x_1)-f(x_2) | ≤K | x_1-x_2 |\]通常一个连续可微函数满足Lipschitz连续,这是因为其微分(用$\frac{|f(x_1)-f(x_2)|}{|x_1-x_2|}$近似)是有界的。但是一个Lipschitz连续函数不一定是处处可微的,比如$f(x) = |x|$。
⚪ 实现Lipschitz连续性
为判别器引入Lipschitz约束的方法主要有两种。第一种是施加硬约束,即通过约束参数使得网络每一层的Lipschitz常数都是有界的,则总Lipschitz常数也是有界的,这类方法包括权重裁剪、谱归一化。
这些方法强制网络的每一层都满足Lipschiitz约束,从而把网络限制为所有满足Lipschiitz约束的函数中的一小簇函数。事实上考虑到如果网络有些层不满足Lipschiitz约束,另一些层满足更强的Lipschiitz约束,则网络整体仍然满足Lipschiitz约束。这类方法无法顾及这种情况。
第二种是施加软约束,即选择Lipschitz约束的一个充分条件(通常是网络对输入的梯度),并在目标函数中添加相关的惩罚项。
2. 梯度归一化 gradient normalization
若判别器$D$是$1$阶Lipschitz函数,则对\(\forall x_1,x_2 \in \Bbb{R}\),有:
\[| D(x_1)-D(x_2) | ≤ | x_1-x_2 |\]上式的一个充分条件是:
\[||\nabla_x D(x)|| \leq 1\]如果将判别器$D$变换为$\hat{D}$,使得其自动满足\(\|\nabla_x \hat{D}(x)\| \leq 1\),则实现了Lipschitz约束的引入。
不妨取:
\[\hat{D}(x) = \frac{D(x)}{||\nabla_x D(x)||}\]注意到网络通常用ReLU或LeakyReLU作为激活函数,此时$D(x)$实际上是一个“分段线性函数”,除边界之外$D(x)$在局部的连续区域内是一个线性函数,因此$\nabla_x D(x)$是一个常向量。此时有:
\[||\nabla_x \hat{D}(x)|| = ||\nabla_x \frac{D(x)}{||\nabla_x D(x)||}|| = ||\frac{\nabla_x D(x)}{||\nabla_x D(x)||}|| = 1\]3. GN-GAN
上式可能会出现分母为零的情况,GN-GAN将$|D(x)|$引入分母,同时也保证了函数的有界性:
\[\hat{D}(x) = \frac{D(x)}{||\nabla_x D(x)||+|D(x)|} \in [-1,1]\]GN-GAN中梯度归一化的pytorch实现如下:
def grad_normlize(D, img):
"""Calculates the gradient normalization"""
img.requires_grad_(True)
out = D(img)
grad_out=torch.ones_like(out).requires_grad_(False),
# Get gradient w.r.t. img
gradients = autograd.grad(
outputs=out,
inputs=img,
grad_outputs=grad_out,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
grad_norm = gradients.view(gradients.size(0), -1).pow(2).sum(1) ** (1/2)
return out / (grad_norm + torch.abs(out))
下面给出参数更新过程:
for epoch in range(opt.n_epochs):
for i, real_imgs in enumerate(dataloader):
z = torch.randn(real_imgs.shape[0], opt.latent_dim)
gen_imgs = generator(z)
# 训练判别器
optimizer_D.zero_grad()
# 真实图像得分
real_validity = grad_normlize(discriminator,real_imgs)
# 生成图像得分
gen_validity = grad_normlize(discriminator,gen_imgs.detach())
d_loss = -torch.mean(real_validity) + torch.mean(gen_validity)
d_loss.backward()
optimizer_D.step()
# 训练生成器
if i % opt.d_iter == 0:
optimizer_G.zero_grad()
# 生成图像得分
gen_validity = grad_normlize(discriminator,gen_imgs)
g_loss = -torch.mean(gen_validity )
g_loss.backward()
optimizer_G.step()