MIWAE:紧凑的变分下界阻碍推理网络训练.

IWAE提供了比标准VAE更紧凑的证据下界(ELBOs),尽管这有利于生成网络(概率解码器)的梯度更新,但不利于推理网络(概率编码器)的更新。本文作者提出了三种新的算法:部分重要性加权自编码器(PIWAE)、多重重要性加权自编码器(MIWAE)和组合重要性加权自编码器(CIWAE),每一种方法都比IWAE具有更好的结果。

1. VAE和IWAE的ELBO

变分自编码器(VAE)的变分下界为:

logp(x)=logEz~q(z|x)[p(x,z)q(z|x)]Ez~q(z|x)[logp(x,z)q(z|x)]logp(x)=logEz~q(z|x)[p(x,z)q(z|x)]Ez~q(z|x)[logp(x,z)q(z|x)]

IWAE变分下界为:

logp(x)=logEz~q(z|x)[p(x,z)q(z|x)]=≥Ez1,z2,zk~q(z|x)[log1kki=1p(x,zi)q(zi|x)]logp(x)=logEz~q(z|x)[p(x,z)q(z|x)]=Ez1,z2,zk~q(z|x)[log1kki=1p(x,zi)q(zi|x)]

可以证明IWAE变分下界更接近原优化目标:

logp(x)Lk+1LkL1logp(x)Lk+1LkL1

2. 紧凑的变分下界

分析可知VAE的变分下界与原目标之间存在的gapEz~q(z|x)[logq(z|x)/p(z|x)]=KL(q(z|x)||p(z|x))Ez~q(z|x)[logq(z|x)/p(z|x)]=KL(q(z|x)||p(z|x))。更紧凑的变分下界意味着KL(q(z|x)||p(z|x))0KL(q(z|x)||p(z|x))0,此时VAE优化目标中正则化项KL[q(z|x)||p(z)]KL[q(z|x)||p(z)]被放宽,模型将重点关注重构损失Ez~q(zx)[logp(xz)]Ez~q(zx)[logp(xz)]。因此生成网络(解码器)会被进一步优化,而推理网络(编码器)的质量会下降(两者的优化目标是冲突的)。

对于网络的优化参数θθ,定义信噪比为参数梯度的均值与标准差之比:

SNR(θ)=B[(θ)]σ[(θ)]SNR(θ)=B[(θ)]σ[(θ)]

作者绘制了推理网络和生成网络的信噪比图像,观察得到当提高采样数量时,VAE推理网络和生成网络的信噪比均提高;然而IWAE的推理网络信噪比下降。

3. PIWAE, MIWAE 和 CIWAE

IWAE是通过K次采样对损失函数和优化梯度进行一次估计,若总计进行了M次估计,则可以证明推理网络的信噪比服从O(M/K)而生成网络的信噪比服从O(MK)。作者发现,通过设置不同的KM,能够同时增大推理网络和生成网络的信噪比,从而提高模型的表现。

⚪ MIWAE:multiply importance weighted autoencoder

MIWAE通过引入M>1同时增大两个网络的信噪比,其目标函数如下:

1MMm=1Ezm,1,zm,2,zm,K~q(zm|x)[log1KKk=1p(x,zm,k)q(zm,k|x)]

⚪ CIWAE:combination importance weighted autoencoder

CIWAE将优化目标构造为VAEIWAE变分下界的凸组合:

ELBOCIWAE=βELBOVAE+(1β)ELBOIWAE

⚪ PIWAE: partially importance weighted autoencoder

PIWAE是指在训练推理网络qϕ(z|x)时使用VAE变分下界,在训练生成网络pθ(x|z)时使用IWAE变分下界

ϕ=argmaxϕELBOVAEθ=argmaxθELBOIWAE

4. MIWAE的pytorch实现

MIWAE的完整pytorch实现可参考PyTorch-VAE,下面进行分析。

在标准的VAE中,采样是通过重参数化过程实现的。因此在MIWAE中,对每个样本重参数化时进行S次采样,并构造M次估计:

def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    mu, log_var = self.encode(input)
    mu = mu.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D]
    log_var = log_var.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D]
    z = self.reparameterize(mu, log_var) # [B x M x S x D]
    eps = (z - mu) / log_var # Prior samples
    return  [self.decode(z), input, mu, log_var, z, eps]