GraN-GAN:在WGAN中引入分段线性的梯度归一化.

1. WGAN与Lipschitz约束

Wasserstein GAN中,作者采用Wasserstein距离构造了GAN的目标函数,优化目标为真实分布\(P_{data}\)和生成分布$P_G$之间的Wasserstein距离:

\[\mathop{\min}_{G} \mathop{\max}_{D, ||D||_L \leq K} \{ \Bbb{E}_{x \text{~} P_{data}(x)}[D(x)]-\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \}\]

或写作交替优化的形式:

\[\begin{aligned} θ_D &\leftarrow \mathop{\arg \max}_{\theta_D} \frac{1}{n} \sum_{i=1}^{n} { D(x^i)} - \frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \\ \theta_G &\leftarrow \mathop{\arg \min}_{\theta_G} -\frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \end{aligned}\]

其中要求判别器$D$是$K$阶Lipschitz连续的,即应满足:

\[| D(x_1)-D(x_2) | ≤K | x_1-x_2 |\]

Lipschitz连续性保证了函数的输出变化相对输入变化是缓慢的。若没有该限制,优化过程可能会使函数的输出趋向正负无穷。

⚪ Lipschitz连续性

一般地,一个实值函数$f$是$K$阶Lipschitz连续的,是指存在一个实数$K\geq 0$,使得对\(\forall x_1,x_2 \in \Bbb{R}\),有:

\[| f(x_1)-f(x_2) | ≤K | x_1-x_2 |\]

通常一个连续可微函数满足Lipschitz连续,这是因为其微分(用$\frac{|f(x_1)-f(x_2)|}{|x_1-x_2|}$近似)是有界的。但是一个Lipschitz连续函数不一定是处处可微的,比如$f(x) = |x|$。

⚪ 实现Lipschitz连续性

为判别器引入Lipschitz约束的方法主要有两种。第一种是施加硬约束,即通过约束参数使得网络每一层的Lipschitz常数都是有界的,则总Lipschitz常数也是有界的,这类方法包括权重裁剪、谱归一化。

这些方法强制网络的每一层都满足Lipschiitz约束,从而把网络限制为所有满足Lipschiitz约束的函数中的一小簇函数。事实上考虑到如果网络有些层不满足Lipschiitz约束,另一些层满足更强的Lipschiitz约束,则网络整体仍然满足Lipschiitz约束。这类方法无法顾及这种情况。

第二种是施加软约束,即选择Lipschitz约束的一个充分条件(通常是网络对输入的梯度),并在目标函数中添加相关的惩罚项。

2. 梯度归一化 gradient normalization

若判别器$D$是$1$阶Lipschitz函数,则对\(\forall x_1,x_2 \in \Bbb{R}\),有:

\[| D(x_1)-D(x_2) | ≤ | x_1-x_2 |\]

上式的一个充分条件是:

\[||\nabla_x D(x)|| \leq 1\]

如果将判别器$D$变换为$\hat{D}$,使得其自动满足\(\|\nabla_x \hat{D}(x)\| \leq 1\),则实现了Lipschitz约束的引入。

不妨取:

\[\hat{D}(x) = \frac{D(x)}{||\nabla_x D(x)||}\]

注意到网络通常用ReLULeakyReLU作为激活函数,此时$D(x)$实际上是一个“分段线性函数”,除边界之外$D(x)$在局部的连续区域内是一个线性函数,因此$\nabla_x D(x)$是一个常向量。此时有:

\[||\nabla_x \hat{D}(x)|| = ||\nabla_x \frac{D(x)}{||\nabla_x D(x)||}|| = ||\frac{\nabla_x D(x)}{||\nabla_x D(x)||}|| = 1\]

3. GraN-GAN

上式可能会出现分母为零的情况,GraN-GAN设计了如下变换:

\[\hat{D}(x) = \frac{D(x) \cdot ||\nabla_x D(x)||}{||\nabla_x D(x)||^2+\epsilon}\]

GraN-GAN中梯度归一化的pytorch实现如下:

def grad_normlize(D, img):
    """Calculates the gradient normalization"""
    img.requires_grad_(True)
    out = D(img)
    grad_out=torch.ones_like(out).requires_grad_(False),
    # Get gradient w.r.t. img
    gradients = autograd.grad(
        outputs=out,
        inputs=img,
        grad_outputs=grad_out,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    grad_norm = gradients.view(gradients.size(0), -1).pow(2).sum(1) ** (1/2)
    return (out * grad_norm) / (grad_norm**2 + epsilon)

下面给出参数更新过程:

for epoch in range(opt.n_epochs):
    for i, real_imgs in enumerate(dataloader):

        z = torch.randn(real_imgs.shape[0], opt.latent_dim) 
        gen_imgs = generator(z)

        # 训练判别器
        optimizer_D.zero_grad()
        # 真实图像得分
        real_validity = grad_normlize(discriminator,real_imgs)
        # 生成图像得分
        gen_validity = grad_normlize(discriminator,gen_imgs.detach())
        d_loss = -torch.mean(real_validity) + torch.mean(gen_validity)
        d_loss.backward()
        optimizer_D.step()

        # 训练生成器
        if i % opt.d_iter == 0:
            optimizer_G.zero_grad()
            # 生成图像得分
            gen_validity = grad_normlize(discriminator,gen_imgs)
            g_loss = -torch.mean(gen_validity)
            g_loss.backward()
            optimizer_G.step()