变分判别瓶颈:通过约束信息流改进深度学习模型.

GAN中存在的问题是生成器和判别器的训练速度不好权衡。如果判别器太强,会对生成器生成的样本轻松判断成假样本,此时不能给生成器提供良好的梯度。本文受变分信息瓶颈启发,在GAN中限制判别器的容量和准确率。这一限制通过限制判别器内在表示和输入之间的互信息来达到。

把判别器进一步拆分成一个编码网络和一个判别网络。编码网络把输入图像(或生成图像)$x$编码为一个隐变量$z$,判别网络把隐变量$z$识别为真假类别$y$。

\[x \to z \to y\]

若希望能尽可能地减少隐变量$z$包含的信息量,可以通过互信息$I(x,z)$衡量。互信息$I(x,z)$衡量随机变量$x$由于已知随机变量$z$而降低的不确定性,计算为:

\[I(x,z) = \mathbb{E}_{p(x,z)} \left[ \log \frac{p(x,z)}{p(x)p(z)} \right]\]

最小化隐变量$z$包含的信息量,等价于最小化互信息$I(x,z)$。因此可以构造损失函数:

\[\mathcal{L} = \iint p(x,z)\log \frac{p(x,z)}{p(x)p(z)} dxdz\]

通常隐变量$z$的先验分布是未知的,因此通过引入一个形式已知的分布$q(z)$来估计上述损失函数的一个上界:

\[\begin{aligned} \mathcal{L} &= \iint p(x,z)\log \frac{p(x,z)q(z)}{p(x)p(z)q(z)} dxdz \\ &= \iint p(x,z)\log \frac{p(z|x)q(z)}{p(z)q(z)} dxdz \\ &= \iint p(x,z)\log \frac{p(z|x)}{q(z)} dxdz + \iint p(x,z)\log \frac{q(z)}{p(z)} dxdz \\ &= \iint p(z|x)p(x)\log \frac{p(z|x)}{q(z)} dxdz - \iint p(x,z)\log \frac{p(z)}{q(z)} dxdz \\ &= \int p(x) \left[ \int p(z|x)\log \frac{p(z|x)}{q(z)} dz\right]dx - \int p(z)\log \frac{p(z)}{q(z)} dz \\ &= \int p(x) KL\left[ p(z|x) \mid\mid q(z)\right]dx - KL\left[ p(z) \mid\mid q(z)\right] \\ &\leq \int p(x) KL\left[ p(z|x) \mid\mid q(z)\right]dx \\ &= \mathbb{E}_{p(x)} \left[ KL\left[ p(z|x) \mid\mid q(z)\right] \right] \\ \end{aligned}\]

先验分布$q(z)$指定为标准正态分布$N(0,1)$,后验分布$p(z|x)$建模为$N(\mu, \sigma^2)$;由于两个分布都是正态分布,KL散度有闭式解(closed-form solution),计算如下:

\[\begin{aligned} KL[q(z|x)||q(z)] &= KL[\mathcal{N}(\mu,\sigma^{2})||\mathcal{N}(0,1)] \\ &= \int_{}^{} \frac{1}{\sqrt{2\pi\sigma^2}}e^{-\frac{(x-\mu)^2}{2\sigma^2}} \log \frac{\frac{1}{\sqrt{2\pi\sigma^2}}e^{-\frac{(x-\mu)^2}{2\sigma^2}}}{\frac{1}{\sqrt{2\pi}}e^{-\frac{x^2}{2}}} dx \\&= \int_{}^{} \frac{1}{\sqrt{2\pi\sigma^2}}e^{-\frac{(x-\mu)^2}{2\sigma^2}} [-\frac{1}{2}\log \sigma^2 + \frac{x^2}{2}-\frac{(x-\mu)^2}{2\sigma^2}] dx \\ &= \frac{1}{2} (-\log \sigma^2 + \mu^2+\sigma^2-1) \end{aligned}\]

对于判别器,在其原目标函数的基础上引入对上述互信息项的约束,希望互信息不超过$I_c$:

\[\begin{aligned}\mathop{ \min}_{D,E} & \Bbb{E}_{x \text{~} P_{data}(x)}[ \Bbb{E}_{z \text{~} E(z|x)}[-\log D(z)]] + \Bbb{E}_{x \text{~} G(x)}[\Bbb{E}_{z \text{~} E(z|x)}[-\log(1-D(z))]] \\ \text{s.t. } & \mathbb{E}_{x \text{~} P_{data}(x)} \left[ KL\left[ E(z|x) \mid\mid q(z)\right] \right] \leq I_c \end{aligned}\]

写成拉格朗日函数:

\[\begin{aligned}\mathop{ \min}_{D,E} \mathop{ \max}_{\beta \geq 0} & \Bbb{E}_{x \text{~} P_{data}(x)}[ \Bbb{E}_{z \text{~} E(z|x)}[-\log D(z)]] + \Bbb{E}_{x \text{~} G(x)}[\Bbb{E}_{z \text{~} E(z|x)}[-\log(1-D(z))]] \\ & + \beta \left( \mathbb{E}_{x \text{~} P_{data}(x)} \left[ KL\left[ E(z|x) \mid\mid q(z)\right] \right] - I_c \right) \end{aligned}\]

对应的更新过程:

\[\begin{aligned} D,E & \leftarrow \mathop{\arg \min}_{D,E} \mathcal{L}(D,E,\beta) \\ \beta & \leftarrow \max\left(0, \beta+\alpha_{\beta} \left( \mathbb{E}_{x \text{~} P_{data}(x)} \left[ KL\left[ E(z|x) \mid\mid q(z)\right] \right] - I_c \right)\right) \end{aligned}\]

瓶颈损失函数计算为:

def _bottleneck_loss(mus, sigmas, i_c, alpha=1e-8):
        """
        calculate the bottleneck loss for the given mus and sigmas
        :param mus: means of the gaussian distributions
        :param sigmas: stds of the gaussian distributions
        :param i_c: value of bottleneck
        :param alpha: small value for numerical stability
        :return: loss_value: scalar tensor
        """
        # add a small value to sigmas to avoid inf log
        kl_divergence = (0.5 * torch.sum((mus ** 2) + (sigmas ** 2) \\
            - torch.log((sigmas ** 2) + alpha) - 1, dim=1))

        # calculate the bottleneck loss:
        bottleneck_loss = (torch.mean(kl_divergence) - i_c)

        # return the bottleneck_loss:
        return bottleneck_loss

把判别器建模为编码器+解码器形式,并引入重参数化:

class ToyNet(nn.Module):
    def __init__(self, K=256):
        super(ToyNet, self).__init__()
        self.K = K

        self.encode = nn.Sequential(
            nn.Linear(784, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 2*self.K))

        self.decode = nn.Sequential(
                nn.Linear(self.K, 2))

    def forward(self, x):
        if x.dim() > 2 : x = x.view(x.size(0),-1)

        statistics = self.encode(x)
        mu = statistics[:,:self.K]
        std = F.softplus(statistics[:,self.K:]-5,beta=1)

        encoding = self.reparametrize_n(mu,std,num_sample)
        logit = self.decode(encoding)
        return (mu, std), logit

    def reparametrize_n(self, mu, std):
        eps = torch.randn_like(std)
        return mu + eps * std