SN-GAN:在WGAN中引入谱归一化.

1. WGAN与Lipschitz约束

Wasserstein GAN中,作者采用Wasserstein距离构造了GAN的目标函数,优化目标为真实分布\(P_{data}\)和生成分布$P_G$之间的Wasserstein距离:

\[\mathop{\min}_{G} \mathop{\max}_{D, ||D||_L \leq K} \{ \Bbb{E}_{x \text{~} P_{data}(x)}[D(x)]-\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \}\]

或写作交替优化的形式:

\[\begin{aligned} θ_D &\leftarrow \mathop{\arg \max}_{\theta_D} \frac{1}{n} \sum_{i=1}^{n} { D(x^i)} - \frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \\ \theta_G &\leftarrow \mathop{\arg \min}_{\theta_G} -\frac{1}{n} \sum_{i=1}^{n} {D(G(z^i))} \end{aligned}\]

其中要求判别器$D$是$K$阶Lipschitz连续的,即应满足:

\[|| D(x_1)-D(x_2) || ≤K || x_1-x_2 ||\]

Lipschitz连续性保证了函数的输出变化相对输入变化是缓慢的。若没有该限制,优化过程可能会使函数的输出趋向正负无穷。

⚪ Lipschitz连续性

一般地,一个实值函数$f$是$K$阶Lipschitz连续的,是指存在一个实数$K\geq 0$,使得对\(\forall x_1,x_2 \in \Bbb{R}\),有:

\[|| f(x_1)-f(x_2) || ≤K || x_1-x_2 ||\]

通常一个连续可微函数满足Lipschitz连续,这是因为其微分(用$\frac{|f(x_1)-f(x_2)|}{|x_1-x_2|}$近似)是有界的。但是一个Lipschitz连续函数不一定是处处可微的,比如$f(x) = |x|$。

2. 实现Lipschitz连续性

若神经网络具有Lipschitz连续性,意味着该网络对输入扰动不敏感,具有更好的泛化性。下面讨论如何对判别器$D$施加Lipschitz约束。

假设判别器$D(x)$具有参数$W$,则Lipschitz常数$K$通常是由参数$W$决定的,此时Lipschitz约束为:

\[|| D_W(x_1)-D_W(x_2) ||\leq K(W) || x_1-x_2 ||\]

首先考虑判别器为单层全连接层$D_W(x)=\sigma(Wx)$,其中$\sigma$是激活函数,对应Lipschitz约束:

\[|| \sigma(Wx_1)-\sigma(Wx_2) || \leq K(W) || x_1-x_2 ||\]

对$\sigma(Wx)$进行Taylor展开并取一阶近似可得:

\[|| \frac{\partial \sigma}{\partial Wx} W(x_1-x_2) || \leq K(W) || x_1-x_2 ||\]

$\frac{\partial \sigma}{\partial Wx}$表示激活函数的导数。通常激活函数的导数是有界的,比如ReLU函数的导数范围是$[0,1]$;因此这一项可以被忽略。则全连接层的Lipschitz约束为:

\[|| W(x_1-x_2) || \leq K(W) || x_1-x_2 ||\]

上式对全连接层的参数$W$进行了约束。在实践中全连接网络是由全连接层组合而成,而卷积网络、循环网络等也可以表示为特殊的全连接网络,因此上述分析具有一般性。

3. 矩阵范数问题

全连接层的Lipschitz约束可以转化为一个矩阵范数问题(由向量范数诱导出来的矩阵范数,作用相当于向量的模长),定义为:

\[||W||_2 = \mathop{\max}_{x \neq 0} \frac{||Wx||}{||x||}\]

当$W$为方阵时,上述矩阵范数称为谱范数(spectral norm)。此时问题转化为:

\[|| W(x_1-x_2) || \leq ||W||_2 \cdot || x_1-x_2 ||\]

谱范数$||W||_2$等于$W^TW$的最大特征值(主特征值)的平方根;若$W$为方阵,则$||W||_2$等于$W$的最大特征值的绝对值。

⚪ 谱范数的证明

谱范数$||W||_2$的平方为:

\[||W||_2^2 = \mathop{\max}_{x \neq 0} \frac{x^TW^TWx}{x^Tx}\]

上式右端为瑞利商(Rayleigh Quotient),取值范围是:

\[\lambda_{min}≤\frac{x^TW^TWx}{x^Tx}≤\lambda_{max}\]

因此谱范数$||W||_2$的平方的取值为$W^TW$的最大特征值。

⚪ 谱范数的计算:幂迭代

$W^TW$的最大特征值可以通过幂迭代(power iteration)方法求解。

迭代格式1:

\[u \leftarrow \frac{(W^TW)u}{||(W^TW)u||}, ||W||_2^2 ≈ u^TW^TWu\]

迭代格式2:

\[v \leftarrow \frac{W^Tu}{||W^Tu||},u \leftarrow \frac{Wv}{||Wv||}, ||W||_2 ≈ u^TWv\]

其中$u,v$可以初始化为全$1$向量。下面以迭代格式1为例简单证明迭代过程收敛,记$A=W^TW$,初始化$u^{(0)}$,若$A$可对角化,则$A$的特征向量\(\{v_1 v_2 \cdots v_n\}\)构成一组完备的基,$u^{(0)}$可由这组基表示:

\[u^{(0)} = c_1v_1+c_2v_2+\cdots c_nv_n\]

先不考虑迭代中分母的归一化,则迭代过程$u \leftarrow Au$经过$t$次后为:

\[A^tu^{(0)} = c_1A^tv_1+c_2A^tv_2+\cdots c_nA^tv_n\]

注意到$Av=\lambda v$,则有:

\[A^tu^{(0)} = c_1\lambda_1^tv_1+c_2\lambda_2^tv_2+\cdots c_n\lambda_n^tv_n\]

不失一般性地假设$\lambda_1$为最大特征值,则有:

\[\frac{A^tu^{(0)}}{\lambda_1^t} = c_1v_1+c_2(\frac{\lambda_2}{\lambda_1})^tv_2+\cdots c_n(\frac{\lambda_n}{\lambda_1})^tv_n\]

注意到当$t \to \infty$时,$(\frac{\lambda_2}{\lambda_1})^t,\cdots (\frac{\lambda_n}{\lambda_1})^t \to 0$。则有:

\[\frac{A^tu^{(0)}}{\lambda_1^t} ≈ c_1v_1\]

上述结果表明当迭代次数$t$足够大时,$A^tu^{(0)}$提供了最大特征根对应的特征向量的近似方向,对其归一化后相当于单位特征向量:

\[\begin{aligned} u &= \frac{A^tu^{(0)}}{||A^tu^{(0)}||} \\ A u &≈ \lambda_1 u \end{aligned}\]

因此可求$A=W^TW$的最大特征值:

\[u^T A u ≈ \lambda_1\]

4. 谱归一化

谱归一化(Spectral Normalization)是指使用谱范数对网络参数进行归一化:

\[W \leftarrow \frac{W}{||W||_2^2}\]

根据前述分析,如果激活函数的导数是有界的,应用谱归一化约束参数后,可以精确地使网络满足Lipschiitz连续性:

\[|| D_W(x_1)-D_W(x_2) ||\leq K(W) || x_1-x_2 ||\]

5. Hinge损失

WGAN构造的目标函数为:

\[\begin{aligned} (D^*, G^*) & \leftarrow \mathop{ \min}_{G} \mathop{ \max}_{D,\|D\|_L \leq K} \Bbb{E}_{x \text{~} P_{data}(x)}[D(x)]-\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \end{aligned}\]

若将谱归一化应用到网络,则自动满足Lipschiitz约束\(\|D\|_L \leq K\),目标函数简化为:

\[\begin{aligned} (D^*, G^*) & \leftarrow \mathop{ \min}_{G} \mathop{ \max}_{D} \Bbb{E}_{x \text{~} P_{data}(x)}[D(x)]-\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \end{aligned}\]

直接优化上式容易导致数值不稳定。在没有约束的情况下对于真实样本有$D(x) \to + \infty$,对于生成样本有$D(x) \to - \infty$。

因此考虑在优化判别器时分别对真实样本和生成样本设置一个阈值,可以通过hinge loss的形式实现:

\[\begin{aligned} D^* & \leftarrow \mathop{ \max}_{D} \Bbb{E}_{x \text{~} P_{data}(x)}[\max(0,1-D(x))]-\Bbb{E}_{x \text{~} P_{G}(x)}[\max(0,1+D(x))] \end{aligned}\]

其中左式只有$1-D(x)>0$才有意义,从而约束对于真实样本$D(x) < 1$,同时$D(x)$取值尽可能大;右式只有$1+D(x)>0$才有意义,从而约束对于生成样本$D(x) > -1$,同时$D(x)$取值尽可能小。

6. SN-GAN

作者将谱归一化应用到生成对抗网络,并将损失函数修改为hinge loss形式,从而构造了spectrally normalized GAN(SN-GAN)。其目标函数如下:

\[\begin{aligned} D^* & \leftarrow \mathop{ \max}_{D} \Bbb{E}_{x \text{~} P_{data}(x)}[\max(0,1-D(x))]-\Bbb{E}_{x \text{~} P_{G}(x)}[\max(0,1+D(x))] \\ G^* & \leftarrow \mathop{ \min}_{G} -\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \end{aligned}\]

pytorch实现时,对判别器应用谱归一化的代码如下:

discriminator = Discriminator()

def add_sn(m):
        for name, layer in m.named_children():
             m.add_module(name, add_sn(layer))
        if isinstance(m, (nn.Conv2d, nn.Linear)):
             return nn.utils.spectral_norm(m)
        else:
             return m

discriminator = add_sn(discriminator)

损失函数的计算和参数更新过程如下:

for epoch in range(opt.n_epochs):
    for i, real_imgs in enumerate(dataloader):

        z = torch.randn(real_imgs.shape[0], opt.latent_dim) 
        gen_imgs = generator(z)  

        # 训练判别器
        optimizer_D.zero_grad()
        real_validaty = torch.clamp(discriminator(real_imgs), max=1)
        fake_validaty = torch.clamp(discriminator(gen_imgs.detach()), min=-1)
        d_loss = -torch.mean(real_validaty) + torch.mean(fake_validaty)
        d_loss.backward()
        optimizer_D.step()

        # 训练生成器
        if i % opt.d_iter == 0:
            optimizer_G.zero_grad()
            fake_validaty = discriminator(gen_imgs)
            g_loss = -torch.mean(fake_validaty)
            g_loss.backward()
            optimizer_G.step()