InfoVAE:平衡变分自编码器的学习和推断过程.
1. VAE的缺点
⚪ 观测变量$x$和隐变量$z$的互信息太小
VAE的目标函数为对数似然的变分下界ELBO,可以进行如下分解:
\[\begin{aligned} \text{ELBO} &= \mathbb{E}_{z \text{~} q(z|x)} [\log p(x | z)]- KL[q(z|x)||p(z)] \\ &=\mathbb{E}_{z \text{~} q(z|x)} [\log \frac{p(z | x)p(x)}{p(z)}]-\mathbb{E}_{z \text{~} q(z|x)} [\log \frac{q(z|x)}{p(z)}] \\ &=\mathbb{E}_{z \text{~} q(z|x)} [\log \frac{p(z | x)}{q(z|x)}]+\mathbb{E}_{z \text{~} q(z|x)} [\log p(x)] \\ &= - KL[q(z|x)||p(z|x)] + \log p(x) \end{aligned}\]ELBO通常在已有数据集$p_{data}(x)$上计算:
\[\begin{aligned} \mathbb{E}_{x \text{~} p_{data}(x)} [\text{ELBO}]&=\mathbb{E}_{x \text{~} p_{data}(x)} [- KL[q(z|x)||p(z|x)] + \log p(x)] \\ &= -\mathbb{E}_{x \text{~} p_{data}(x)} [ KL[q(z|x)||p(z|x)]] + \mathbb{E}_{x \text{~} p_{data}(x)}[\log p(x)] \\ &= -\mathbb{E}_{x \text{~} p_{data}(x)} [ KL[q(z|x)||p(z|x)]]-KL[p_{data}(x)||p(x)] + \mathbb{E}_{x \text{~} p_{data}(x)}[\log p_{data}(x)] \end{aligned}\]注意到\(\mathbb{E}_{x \text{~} p_{data}(x)}[\log p_{data}(x)]\)为常数,若希望ELBO取得最大值,只需要构造一个合适的$p(x)$满足$KL[p_{data}(x)||p(x)]=0$,使得网络学习到$p(x|z)=p(x)$,即隐变量$z$和观测变量$x$独立,此时网络学习到$KL[q(z|x)||p(z|x)]=0$,模型是收敛的,但显然并不是最优解。
⚪ 隐变量的近似后验不容易逼近真实后验分布
假如隐变量$z$和观测变量$x$的互信息足够大,考虑极端的情况,即后验分布$q(z|x)$将每个观测变量$x_i$映射到不同的正态分布\(\mathcal{N}(\mu_i,\sigma_i^2)\),此时有$\mu_i \to \infty, \sigma_i \to 0$满足条件。然而$KL[q(z|x_i)||p(z|x_i)] \to \infty$。这说明即使隐变量$z$和观测变量$x$的互信息最大化,隐变量$z$的近似后验分布永远不会逼近真实后验分布。
2. InfoVAE
为了解决上述问题,作者提出了InfoVAE,在优化对数似然的过程中中显式地增加了观测变量$x$和隐变量$z$的互信息,并增加隐变量的近似后验逼近真实后验的限制条件。InfoVAE的优化目标如下:
\[\begin{aligned} \mathop{\max} & I(x,z)+ \mathbb{E}_{x \text{~} p(x)}[\log p(x)] \\ \text{s.t. } & D(q(z)||p(z))=0 \end{aligned}\]其中的目标函数可进一步写作:
\[\begin{aligned} I(x,z)+ \mathbb{E}_{x \text{~} p(x)}[\log p(x)] & = \Bbb{E}_{q(z,x)}[\log\frac{q(z,x)}{q(z)p(x)}]+ \mathbb{E}_{p(x)}[\log p(x)] \\ &= \sum_{x}\sum_{z} q(z,x) \log \frac{q(z,x)}{q(z)p(x)}+\sum_{x} p(x)\log p(x)\\ &= \sum_{x}\sum_{z} q(z,x) \log \frac{p(x|z)}{p(x)}+\sum_{x}\sum_{z} q(z,x)\log p(x) \\ &= \sum_{x}\sum_{z} q(z,x) \log p(x|z)= \sum_{x}\sum_{z} p(x)q(z|x) \log p(x|z) \\ &= \mathbb{E}_{x \text{~} p(x)}[\mathbb{E}_{z \text{~} q(z|x)}[\log p(x|z)]] \end{aligned}\]通过拉格朗日法转化为无约束问题:
\[\mathop{\max} I(x,z)+ \mathbb{E}_{x \text{~} p(x)}[\log p(x)] -\lambda \cdot D(q(z)||p(z)) \\= \mathop{\max} \mathbb{E}_{x \text{~} p(x)}[\mathbb{E}_{z \text{~} q(z|x)}[\log p(x|z)]] -\lambda \cdot D(q(z)||p(z))\]在此基础上,InfoVAE的损失函数设置为:
\[-\mathbb{E}_{x \text{~} p(x)}[\mathbb{E}_{z \text{~} q(z|x)}[\log p(x|z)]]+(1-\alpha) KL(q(z|x)||p(z)) +(\alpha+\lambda-1) \cdot D(q(z)||p(z))\]其中$D(q(z)||p(z))$是一类分布距离度量散度,可以选择KL-divergence,JS-divergence,Stein Variational Gradient和MMD等。
3. InfoVAE的pytorch实现
InfoVAE的完整pytorch实现可参考PyTorch-VAE,与标准的VAE相比主要在损失函数上有所不同。若为$D(q(z)||p(z))$选取MMD距离,则实现如下。
最大平均差异(maximum mean discrepancy, MMD)衡量两个分布的差异。MMD通过引入正定核$k$计算如下:
\[\text{MMD}_k(P_Z,Q_Z) = || \int_{\mathcal{Z}} k(z,\cdot)dP_Z(z)-\int_{\mathcal{Z}} k(z,\cdot)dQ_Z(z) ||_{\mathcal{H}_k}\]MMD的解析式如下,注意先验分布$z$~$P_Z$直接指定为标准正态分布,$\tilde{z}$~$Q_Z$是编码器的输出。
\[\frac{\lambda}{n(n-1)}\sum_{l \ne j}k(z_l,z_j)+\frac{\lambda}{n(n-1)}\sum_{l \ne j}k(\tilde{z}_l,\tilde{z}_j)-\frac{2\lambda}{n^2} \sum_{l , j}k(z_l,\tilde{z}_j)\]def compute_mmd(self, z: Tensor) -> Tensor:
bias_corr = self.batch_size * (self.batch_size - 1)
reg_weight = self.reg_weight / bias_corr
# Sample from prior (Gaussian) distribution
prior_z = torch.randn_like(z)
prior_z__kernel = self.compute_kernel(prior_z, prior_z)
z__kernel = self.compute_kernel(z, z)
priorz_z__kernel = self.compute_kernel(prior_z, z)
mmd = reg_weight * prior_z__kernel.mean() + \
reg_weight * z__kernel.mean() - \
2 * reg_weight * priorz_z__kernel.mean()
return mmd
正定核$k$可以选择不同的形式,如:
- RBF核:\(e^{-\frac{\|x_1-x_2\|^2}{\sigma}}\)
- Inverse Multi-Quadratics(IMQ)核:\(\frac{C}{C+\|x_1-x_2\|^2}\)
InfoVAE的损失函数表示如下:
recons_loss =F.mse_loss(recons, input)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
mmd_loss = self.compute_mmd(z)
loss = self.beta * recons_loss + \
(1. - self.alpha) * kld_weight * kld_loss + \
(self.alpha + self.reg_weight - 1.)/bias_corr * mmd_loss