MIWAE:紧凑的变分下界阻碍推理网络训练.

IWAE提供了比标准VAE更紧凑的证据下界(ELBOs),尽管这有利于生成网络(概率解码器)的梯度更新,但不利于推理网络(概率编码器)的更新。本文作者提出了三种新的算法:部分重要性加权自编码器(PIWAE)、多重重要性加权自编码器(MIWAE)和组合重要性加权自编码器(CIWAE),每一种方法都比IWAE具有更好的结果。

1. VAE和IWAE的ELBO

变分自编码器(VAE)的变分下界为:

\[\log p(x) = \log \Bbb{E}_{z \text{~} q(z|x)}[\frac{p(x,z)}{q(z|x)}] \geq \Bbb{E}_{z \text{~} q(z|x)}[\log \frac{p(x,z)}{q(z|x)}]\]

IWAE变分下界为:

\[\log p(x) = \log \Bbb{E}_{z \text{~} q(z|x)}[\frac{p(x,z)}{q(z|x)}] = \geq \Bbb{E}_{z_1,z_2,\cdots z_k \text{~} q(z|x)}[\log \frac{1}{k}\sum_{i=1}^{k}\frac{p(x,z_i)}{q(z_i|x)}]\]

可以证明IWAE变分下界更接近原优化目标:

\[\log p(x) \geq \mathcal{L}_{k+1} \geq \mathcal{L}_{k} \geq \mathcal{L}_{1}\]

2. 紧凑的变分下界

分析可知VAE的变分下界与原目标之间存在的gap为$\Bbb{E}_{z \text{~} q(z|x)}[\log q(z|x)/p(z|x)]=KL(q(z|x)||p(z|x))$。更紧凑的变分下界意味着$KL(q(z|x)||p(z|x))≈0$,此时VAE优化目标中正则化项$KL[q(z|x)||p(z)]$被放宽,模型将重点关注重构损失\(\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x \| z)]\)。因此生成网络(解码器)会被进一步优化,而推理网络(编码器)的质量会下降(两者的优化目标是冲突的)。

对于网络的优化参数$\theta$,定义信噪比为参数梯度的均值与标准差之比:

\[SNR(\theta) = \frac{\Bbb{B}[\nabla (\theta)]}{\sigma[\nabla (\theta)]}\]

作者绘制了推理网络和生成网络的信噪比图像,观察得到当提高采样数量时,VAE推理网络和生成网络的信噪比均提高;然而IWAE的推理网络信噪比下降。

3. PIWAE, MIWAE 和 CIWAE

IWAE是通过$K$次采样对损失函数和优化梯度进行一次估计,若总计进行了$M$次估计,则可以证明推理网络的信噪比服从$O(\sqrt{M/K})$而生成网络的信噪比服从$O(\sqrt{MK})$。作者发现,通过设置不同的$K$和$M$,能够同时增大推理网络和生成网络的信噪比,从而提高模型的表现。

⚪ MIWAE:multiply importance weighted autoencoder

MIWAE通过引入$M>1$同时增大两个网络的信噪比,其目标函数如下:

\[\frac{1}{M}\sum_{m=1}^{M} \Bbb{E}_{z_{m,1},z_{m,2},\cdots z_{m,K} \text{~} q(z_{m}|x)}[\log \frac{1}{K}\sum_{k=1}^{K}\frac{p(x,z_{m,k})}{q(z_{m,k}|x)}]\]

⚪ CIWAE:combination importance weighted autoencoder

CIWAE将优化目标构造为VAEIWAE变分下界的凸组合:

\[ELBO_{CIWAE} = \beta ELBO_{VAE} + (1-\beta) ELBO_{IWAE}\]

⚪ PIWAE: partially importance weighted autoencoder

PIWAE是指在训练推理网络$q_{\phi}(z|x)$时使用VAE变分下界,在训练生成网络$p_{\theta}(x|z)$时使用IWAE变分下界

\[\phi^* = \mathop{\arg \max}_{\phi} ELBO_{VAE} \\ \theta^* = \mathop{\arg \max}_{\theta} ELBO_{IWAE}\]

4. MIWAE的pytorch实现

MIWAE的完整pytorch实现可参考PyTorch-VAE,下面进行分析。

在标准的VAE中,采样是通过重参数化过程实现的。因此在MIWAE中,对每个样本重参数化时进行$S$次采样,并构造$M$次估计:

def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    mu, log_var = self.encode(input)
    mu = mu.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D]
    log_var = log_var.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D]
    z = self.reparameterize(mu, log_var) # [B x M x S x D]
    eps = (z - mu) / log_var # Prior samples
    return  [self.decode(z), input, mu, log_var, z, eps]