使用信息瓶颈解释β-VAE的解耦表示能力.
1. 使用信息瓶颈解释β-VAE
作者使用信息论中的信息瓶颈(information bottleneck)来解释β-VAE的解耦能力。
信息瓶颈是指,如果想用由随机变量$X$表示的随机变量$Z$来预测随机变量$Y$,则应尽可能最大化$Z$与$Y$的互信息,同时减少$X$与$Z$的互信息。即令$Z$尽可能保留$X$中与预测$Y$相关的信息,并提出其中对预测$Y$无用的信息,表示如下:
\[\mathop{\max} [I(Z;Y)- \beta I(X;Z)]\]上式对比β-VAE的目标函数:
\[L(\theta,\phi; x,z, \beta) = \mathbb{E}_{q_{\phi}(z|x)} [\log p_{\theta}(x | z)] - \beta D_{KL}(q_{\phi}(z|x)||p_{\theta}(z))\]该目标函数可用信息瓶颈的思想来解释。第一项重构误差衡量编码得到的$Z$预测$X$的能力,$Z$包含$X$的信息越多,预测能力就越强。第二项正则化项拉近变分后验概率$q_{\phi}(z|x)$与先验概率$p_{\theta}(z)$的距离,减少$Z$包含$X$的信息,引入了信息瓶颈。
作者进一步分析了β-VAE能够得到更加“平滑”的表示的原因。如下图所示,增大后验概率的方差与移动其均值都能使其与先验概率(标准正态分布)的KL散度减小。这将使对于不同的输入样本,后验概率之间的重叠更大。如图中数据点$\tilde{x}$的编码表示$z$处于$q(z_1|x_1)$和$q(z_2|x_2)$的重叠区域中,则容易造成样本的混淆。为了增大对数似然,最好的编码方式就是让相近的数据点对应的编码分布也是相近的。在这种编码方式下,数据空间$X$的微小变化只会引起隐空间$Z$的微小变化,从而得到更平滑的表示。
2. 实验与分析
作者比较了VAE和β-VAE的解耦能力。在一个黑色背景下放置一个白球的Toy数据集上,生成位置不同的图像。第一排是真实图像,第二排是模型的重构图像,之后每一排按照隐变量$z$不同维度与标准正态分布的KL距离从大到小进行排列,并从左到右逐渐地变化这个维度,并保持其他维度不变,生成一系列图像。
从上图可以看出,β-VAE只用$2$个维度表示小球的变化(上下位移、左右位移),并且这种变化是平滑的;而VAE用了更多维度表示其变化,并且均匀变化$X$会导致$Z$的不均匀变化。
作者分析了β-VAE为什么能够得到解耦表示,原因如下:
- β-VAE能够学习到更平滑的表示,在该平滑性假设下,如果要引入新的变化factor,最优的方式是用一个新的维度编码该factor,并且不会影响到其他维度;
- 若只是由于上面这一点,不能避免学习到的隐空间是解耦隐空间的一个旋转表示(多个维度共同决定一个factor),但人为假设后验分布$q_{\phi}(z|x)$为对角协方差的正态分布,因此不同维度之间是独立的,避免了这种旋转表示。
作者认为,如果一开始$\beta$值很大,再逐渐地减小其值,既不会降低模型的解耦能力,又可以逐渐降低重构误差。因此在训练过程中,可以逐渐减小KL项来改进β-VAE:
\[L(\theta,\phi; x,z, C) = \mathbb{E}_{q_{\phi}(z|x)} [\log p_{\theta}(x | z)] - \gamma | D_{KL}(q_{\phi}(z|x)||p_{\theta}(z)) - C|\]在训练中会逐渐增大$C$的值。
3. 所提β-VAE的pytorch实现
本文所提β-VAE的完整pytorch实现可参考PyTorch-VAE,与标准的β-VAE主要区别在损失函数上,逐渐放宽对后验分布与先验分布的KL散度限制:
recons_loss = F.mse_loss(recons, input)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
self.C_max = max_capacity # 25
self.C_stop_iter = Capacity_max_iter #1e5
C = torch.clamp(self.C_max/self.C_stop_iter * self.num_iter, 0, self.C_max)
loss = recons_loss + self.gamma * (kld_loss - C).abs()