WGAN:使用Wasserstein距离构造GAN.

GAN的损失函数衡量真实分布\(P_{data}(x)\)与生成分布\(P_G(x)\)之间的JS散度JS散度在两个分布不相交时没有意义。可以选择具有更平滑的值空间的分布度量指标,比如Wasserstein距离

1. Wasserstein Distance

Wasserstein距离又叫推土机距离(Earth Mover’s Distance),是指把一个概率分布$P=p(x)$变成另一个概率分布$Q=q(x)$时所需要的最小变换距离。

记从位置$x$运输到位置$y$的成本为$d(x,y)$,联合分布$\gamma(x,y)$描述了一种可行的运输方案,表示应该从位置$x$处运输多少货物到位置$y$处,才能使$p(\textbf{x})$和$q(\textbf{x})$具有相同的概率分布。在离散形势下,联合分布$\gamma(x,y)$表示为一个矩阵:

其中矩阵的每一行代表概率分布$p(\textbf{x})$的某个位置$x_p$要分配到概率分布$q(\textbf{x})$不同位置处的值;每一列代表概率分布$q(\textbf{x})$的某个位置$x_q$接收到概率分布$p(\textbf{x})$的不同位置分配的值。在该联合分布下,概率分布变换的总成本为:

\[\sum_{x_p,x_q} \gamma(x_p,x_q) d(x_p,x_q) = \Bbb{E}_{(x,y) \in \gamma(\textbf{x},\textbf{y})} [d(x,y)]\]

一般地,Wasserstein距离定义为如下最优化问题:

\[\begin{aligned} \mathcal{W}[p,q] = \mathop{\inf}_{\gamma \in \Pi[p,q]} & \int \int \gamma(x,y) d(x,y) dxdy \\ \text{s.t. } & \int \gamma(x,y) dy = p(x) \\ & \int \gamma(x,y)dx = q(y) \\ & \gamma(x,y) \geq 0 \end{aligned}\]

⚪ 为什么Wasserstein距离比JS散度或KL散度更好?

Wasserstein距离可以衡量两个概率分布之间的距离;JS散度KL散度也具有类似的功能。相比于后两者,Wasserstein距离在两个概率分布没有重叠的情况下仍然是有意义的,并且距离表示更加平滑。

假设具有以下概率分布:

\[\begin{aligned} &\forall (x,y) \in P, x=0,y\text{ ~ } U(0,1) \\ &\forall (x,y) \in Q, x=\theta, 0 \leq \theta \leq 1,y\text{ ~ } U(0,1) \end{aligned}\]

当$\theta \neq 0$时分布$P,Q$没有重叠。此时计算两个分布的KL散度、JS散度和Wasserstein距离:

\[\begin{aligned} D_{KL}[P || Q] &= \sum P \log \frac{P}{Q} = \sum_{x=0,y\text{ ~ } U(0,1)} 1 \cdot \log \frac{1}{0} = + \infty \\ D_{KL}[Q || P] &= \sum Q \log \frac{Q}{P} = \sum_{x=\theta,y\text{ ~ } U(0,1)} 1 \cdot \log \frac{1}{0} = + \infty \\ D_{JS}[P || Q] &= \frac{1}{2}D_{KL}[P || \frac{P+Q}{2}] + \frac{1}{2}D_{KL}[Q || \frac{P+Q}{2}] \\ & = \frac{1}{2}\sum_{x=0,y\text{ ~ } U(0,1)} 1 \cdot \log \frac{1}{1/2} + \frac{1}{2} \sum_{x=\theta,y\text{ ~ } U(0,1)} 1 \cdot \log \frac{1}{1/2} \\ & = \log 2 \\ \mathcal{W}[P,Q] & = |\theta| \end{aligned}\]

当$\theta = 0$时分布$P,Q$完全重叠,此时有:

\[D_{KL}[P || Q] = D_{KL}[Q || P] = D_{JS}[P || Q] = 0 \\ \mathcal{W}[P,Q] = 0 = |\theta|\]

根据上述结论,当$P,Q$没有重叠时KL散度变为无穷大;当参数$\theta$变化时JS散度的取值发生跳变;只有Wasserstein距离提供了一种平滑的分布差异测量,能够保证学习过程的稳定性。

通常GAN的目标函数被认为是最小化数据的真实分布与生成分布之间的JS散度;如果将其替换为Wasserstein距离,可能会取得更好的效果。

2. Wasserstein GAN

Wasserstein距离具有如下对偶形式

\[\mathcal{W}[p,q] = \mathop{\sup}_{f, ||f||_L \leq K} \{ \Bbb{E}_{x \text{~} p(x)} [ f(x)] -\Bbb{E}_{x \text{~}q(x)}[f(x)]\}\]

上式要求函数$f$是$K$阶Lipschitz连续的,即应满足\(\|f\|_L \leq K\)。

一般地,一个实值函数$f$是$K$阶Lipschitz连续的,是指存在一个实数$K\geq 0$,使得对\(\forall x_1,x_2 \in \Bbb{R}\),有:

\[| f(x_1)-f(x_2) | ≤K | x_1-x_2 |\]

通常一个连续可微函数满足Lipschitz连续,这是因为其微分(用$\frac{|f(x_1)-f(x_2)|}{|x_1-x_2|}$近似)是有界的。但是一个Lipschitz连续函数不一定是处处可微的,比如$f(x) = |x|$。

Lipschitz连续性保证了函数的输出变化相对输入变化是缓慢的。若没有该限制,优化过程可能会使函数的输出趋向正负无穷。

Wasserstein GAN中,把判别器$D(\cdot)$约束为Lipschitz连续函数,则优化目标为真实分布\(P_{data}\)和生成分布$P_G$之间的Wasserstein距离:

\[\mathcal{W}[P_{data},P_G] = \mathop{\sup}_{D, ||D||_L \leq K} \{ \Bbb{E}_{x \text{~} P_{data}(x)}[D(x)]-\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \}\]

此时判别器$D$不再充当区分真实数据和生成数据的二分类器,而是通过学习一个Lipschitz连续函数来计算分布的Wasserstein距离。当判别器的损失下降时,Wasserstein距离变小,生成器的输出更接近真实的数据分布。

下面讨论如何把判别器$D(\cdot)$约束为Lipschitz连续函数,即引入约束:

\[| D(x_1)-D(x_2) | ≤K | x_1-x_2 |\]

在实践中,通过weight clipping实现该约束:在每次梯度更新后,把判别器$D$的参数$w$的取值限制在$[-c,c]$之间($c$常取$0.01$)。

与标准的GAN相比,Wasserstein GAN的主要改进如下:

  1. 在判别器的更新中,对判别器的参数取值进行裁剪$[-c,c]$;
  2. 判别器不再是二分类器,而是提供了计算Wasserstein距离的近似估计;
  3. 判别器的优化器选用RMSProp,相比于基于动量的优化器(如Adam)训练更稳定。

⚪ Wasserstein GAN的pytorch实现

Wasserstein GAN的优化目标为:

\[\mathop{ \min}_{G} \mathop{\max}_{D, ||D||_L \leq K} \{ \Bbb{E}_{x \text{~} P_{data}(x)}[D(x)]-\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \}\]

或写作交替优化的形式:

\[\begin{aligned} θ_D &\leftarrow \mathop{\arg \max}_{\theta_D} \frac{1}{n} \sum_{i=1}^{n} { D(x^i)} - \frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \\ θ_D &\leftarrow\text{clip}(\theta_D,-c,c) \\ \theta_G &\leftarrow \mathop{\arg \min}_{\theta_G} -\frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \end{aligned}\]

Wasserstein GAN的完整训练流程如下:

Wasserstein GAN的完整pytorch实现可参考PyTorch-GAN,下面给出其损失函数的计算和参数更新过程:

for epoch in range(opt.n_epochs):
    for i, real_imgs in enumerate(dataloader):

        z = torch.randn(real_imgs.shape[0], opt.latent_dim) 
        gen_imgs = generator(z)  

        # 训练判别器
        optimizer_D.zero_grad()
        d_loss = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(gen_imgs.detach()))
        d_loss.backward()
        optimizer_D.step()
            
        # 裁剪判别器参数
        for p in discriminator.parameters():
            p.data.clamp_(-opt.clip_value, opt.clip_value)

        # 训练生成器
        if i % opt.d_iter == 0:
            optimizer_G.zero_grad()
            g_loss = -torch.mean(discriminator(gen_imgs))
            g_loss.backward()
            optimizer_G.step()