变分判别瓶颈:通过约束信息流改进深度学习模型.
GAN中存在的问题是生成器和判别器的训练速度不好权衡。如果判别器太强,会对生成器生成的样本轻松判断成假样本,此时不能给生成器提供良好的梯度。本文受变分信息瓶颈启发,在GAN中限制判别器的容量和准确率。这一限制通过限制判别器内在表示和输入之间的互信息来达到。
把判别器进一步拆分成一个编码网络和一个判别网络。编码网络把输入图像(或生成图像)$x$编码为一个隐变量$z$,判别网络把隐变量$z$识别为真假类别$y$。
\[x \to z \to y\]若希望能尽可能地减少隐变量$z$包含的信息量,可以通过互信息$I(x,z)$衡量。互信息$I(x,z)$衡量随机变量$x$由于已知随机变量$z$而降低的不确定性,计算为:
\[I(x,z) = \mathbb{E}_{p(x,z)} \left[ \log \frac{p(x,z)}{p(x)p(z)} \right]\]最小化隐变量$z$包含的信息量,等价于最小化互信息$I(x,z)$。因此可以构造损失函数:
\[\mathcal{L} = \iint p(x,z)\log \frac{p(x,z)}{p(x)p(z)} dxdz\]通常隐变量$z$的先验分布是未知的,因此通过引入一个形式已知的分布$q(z)$来估计上述损失函数的一个上界:
\[\begin{aligned} \mathcal{L} &= \iint p(x,z)\log \frac{p(x,z)q(z)}{p(x)p(z)q(z)} dxdz \\ &= \iint p(x,z)\log \frac{p(z|x)q(z)}{p(z)q(z)} dxdz \\ &= \iint p(x,z)\log \frac{p(z|x)}{q(z)} dxdz + \iint p(x,z)\log \frac{q(z)}{p(z)} dxdz \\ &= \iint p(z|x)p(x)\log \frac{p(z|x)}{q(z)} dxdz - \iint p(x,z)\log \frac{p(z)}{q(z)} dxdz \\ &= \int p(x) \left[ \int p(z|x)\log \frac{p(z|x)}{q(z)} dz\right]dx - \int p(z)\log \frac{p(z)}{q(z)} dz \\ &= \int p(x) KL\left[ p(z|x) \mid\mid q(z)\right]dx - KL\left[ p(z) \mid\mid q(z)\right] \\ &\leq \int p(x) KL\left[ p(z|x) \mid\mid q(z)\right]dx \\ &= \mathbb{E}_{p(x)} \left[ KL\left[ p(z|x) \mid\mid q(z)\right] \right] \\ \end{aligned}\]先验分布$q(z)$指定为标准正态分布$N(0,1)$,后验分布$p(z|x)$建模为$N(\mu, \sigma^2)$;由于两个分布都是正态分布,KL散度有闭式解(closed-form solution),计算如下:
\[\begin{aligned} KL[q(z|x)||q(z)] &= KL[\mathcal{N}(\mu,\sigma^{2})||\mathcal{N}(0,1)] \\ &= \int_{}^{} \frac{1}{\sqrt{2\pi\sigma^2}}e^{-\frac{(x-\mu)^2}{2\sigma^2}} \log \frac{\frac{1}{\sqrt{2\pi\sigma^2}}e^{-\frac{(x-\mu)^2}{2\sigma^2}}}{\frac{1}{\sqrt{2\pi}}e^{-\frac{x^2}{2}}} dx \\&= \int_{}^{} \frac{1}{\sqrt{2\pi\sigma^2}}e^{-\frac{(x-\mu)^2}{2\sigma^2}} [-\frac{1}{2}\log \sigma^2 + \frac{x^2}{2}-\frac{(x-\mu)^2}{2\sigma^2}] dx \\ &= \frac{1}{2} (-\log \sigma^2 + \mu^2+\sigma^2-1) \end{aligned}\]对于判别器,在其原目标函数的基础上引入对上述互信息项的约束,希望互信息不超过$I_c$:
\[\begin{aligned}\mathop{ \min}_{D,E} & \Bbb{E}_{x \text{~} P_{data}(x)}[ \Bbb{E}_{z \text{~} E(z|x)}[-\log D(z)]] + \Bbb{E}_{x \text{~} G(x)}[\Bbb{E}_{z \text{~} E(z|x)}[-\log(1-D(z))]] \\ \text{s.t. } & \mathbb{E}_{x \text{~} P_{data}(x)} \left[ KL\left[ E(z|x) \mid\mid q(z)\right] \right] \leq I_c \end{aligned}\]写成拉格朗日函数:
\[\begin{aligned}\mathop{ \min}_{D,E} \mathop{ \max}_{\beta \geq 0} & \Bbb{E}_{x \text{~} P_{data}(x)}[ \Bbb{E}_{z \text{~} E(z|x)}[-\log D(z)]] + \Bbb{E}_{x \text{~} G(x)}[\Bbb{E}_{z \text{~} E(z|x)}[-\log(1-D(z))]] \\ & + \beta \left( \mathbb{E}_{x \text{~} P_{data}(x)} \left[ KL\left[ E(z|x) \mid\mid q(z)\right] \right] - I_c \right) \end{aligned}\]对应的更新过程:
\[\begin{aligned} D,E & \leftarrow \mathop{\arg \min}_{D,E} \mathcal{L}(D,E,\beta) \\ \beta & \leftarrow \max\left(0, \beta+\alpha_{\beta} \left( \mathbb{E}_{x \text{~} P_{data}(x)} \left[ KL\left[ E(z|x) \mid\mid q(z)\right] \right] - I_c \right)\right) \end{aligned}\]瓶颈损失函数计算为:
def _bottleneck_loss(mus, sigmas, i_c, alpha=1e-8):
"""
calculate the bottleneck loss for the given mus and sigmas
:param mus: means of the gaussian distributions
:param sigmas: stds of the gaussian distributions
:param i_c: value of bottleneck
:param alpha: small value for numerical stability
:return: loss_value: scalar tensor
"""
# add a small value to sigmas to avoid inf log
kl_divergence = (0.5 * torch.sum((mus ** 2) + (sigmas ** 2) \\
- torch.log((sigmas ** 2) + alpha) - 1, dim=1))
# calculate the bottleneck loss:
bottleneck_loss = (torch.mean(kl_divergence) - i_c)
# return the bottleneck_loss:
return bottleneck_loss
把判别器建模为编码器+解码器形式,并引入重参数化:
class ToyNet(nn.Module):
def __init__(self, K=256):
super(ToyNet, self).__init__()
self.K = K
self.encode = nn.Sequential(
nn.Linear(784, 1024),
nn.ReLU(True),
nn.Linear(1024, 1024),
nn.ReLU(True),
nn.Linear(1024, 2*self.K))
self.decode = nn.Sequential(
nn.Linear(self.K, 2))
def forward(self, x):
if x.dim() > 2 : x = x.view(x.size(0),-1)
statistics = self.encode(x)
mu = statistics[:,:self.K]
std = F.softplus(statistics[:,self.K:]-5,beta=1)
encoding = self.reparametrize_n(mu,std,num_sample)
logit = self.decode(encoding)
return (mu, std), logit
def reparametrize_n(self, mu, std):
eps = torch.randn_like(std)
return mu + eps * std