MIWAE:紧凑的变分下界阻碍推理网络训练.
IWAE提供了比标准VAE更紧凑的证据下界(ELBOs),尽管这有利于生成网络(概率解码器)的梯度更新,但不利于推理网络(概率编码器)的更新。本文作者提出了三种新的算法:部分重要性加权自编码器(PIWAE)、多重重要性加权自编码器(MIWAE)和组合重要性加权自编码器(CIWAE),每一种方法都比IWAE具有更好的结果。
1. VAE和IWAE的ELBO
变分自编码器(VAE)的变分下界为:
logp(x)=logEz~q(z|x)[p(x,z)q(z|x)]≥Ez~q(z|x)[logp(x,z)q(z|x)]logp(x)=logEz~q(z|x)[p(x,z)q(z|x)]≥Ez~q(z|x)[logp(x,z)q(z|x)]IWAE的变分下界为:
logp(x)=logEz~q(z|x)[p(x,z)q(z|x)]=≥Ez1,z2,⋯zk~q(z|x)[log1kk∑i=1p(x,zi)q(zi|x)]logp(x)=logEz~q(z|x)[p(x,z)q(z|x)]=≥Ez1,z2,⋯zk~q(z|x)[log1kk∑i=1p(x,zi)q(zi|x)]可以证明IWAE的变分下界更接近原优化目标:
logp(x)≥Lk+1≥Lk≥L1logp(x)≥Lk+1≥Lk≥L12. 紧凑的变分下界
分析可知VAE的变分下界与原目标之间存在的gap为Ez~q(z|x)[logq(z|x)/p(z|x)]=KL(q(z|x)||p(z|x))Ez~q(z|x)[logq(z|x)/p(z|x)]=KL(q(z|x)||p(z|x))。更紧凑的变分下界意味着KL(q(z|x)||p(z|x))≈0KL(q(z|x)||p(z|x))≈0,此时VAE优化目标中正则化项KL[q(z|x)||p(z)]KL[q(z|x)||p(z)]被放宽,模型将重点关注重构损失Ez~q(z‖x)[−logp(x‖z)]Ez~q(z∥x)[−logp(x∥z)]。因此生成网络(解码器)会被进一步优化,而推理网络(编码器)的质量会下降(两者的优化目标是冲突的)。
对于网络的优化参数θθ,定义信噪比为参数梯度的均值与标准差之比:
SNR(θ)=B[∇(θ)]σ[∇(θ)]SNR(θ)=B[∇(θ)]σ[∇(θ)]作者绘制了推理网络和生成网络的信噪比图像,观察得到当提高采样数量时,VAE推理网络和生成网络的信噪比均提高;然而IWAE的推理网络信噪比下降。
3. PIWAE, MIWAE 和 CIWAE
IWAE是通过K次采样对损失函数和优化梯度进行一次估计,若总计进行了M次估计,则可以证明推理网络的信噪比服从O(√M/K)而生成网络的信噪比服从O(√MK)。作者发现,通过设置不同的K和M,能够同时增大推理网络和生成网络的信噪比,从而提高模型的表现。
⚪ MIWAE:multiply importance weighted autoencoder
MIWAE通过引入M>1同时增大两个网络的信噪比,其目标函数如下:
1MM∑m=1Ezm,1,zm,2,⋯zm,K~q(zm|x)[log1KK∑k=1p(x,zm,k)q(zm,k|x)]⚪ CIWAE:combination importance weighted autoencoder
CIWAE将优化目标构造为VAE和IWAE的变分下界的凸组合:
ELBOCIWAE=βELBOVAE+(1−β)ELBOIWAE⚪ PIWAE: partially importance weighted autoencoder
PIWAE是指在训练推理网络qϕ(z|x)时使用VAE的变分下界,在训练生成网络pθ(x|z)时使用IWAE的变分下界。
ϕ∗=argmaxϕELBOVAEθ∗=argmaxθELBOIWAE4. MIWAE的pytorch实现
MIWAE的完整pytorch实现可参考PyTorch-VAE,下面进行分析。
在标准的VAE中,采样是通过重参数化过程实现的。因此在MIWAE中,对每个样本重参数化时进行S次采样,并构造M次估计:
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var = self.encode(input)
mu = mu.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D]
log_var = log_var.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D]
z = self.reparameterize(mu, log_var) # [B x M x S x D]
eps = (z - mu) / log_var # Prior samples
return [self.decode(z), input, mu, log_var, z, eps]
Related Issues not found
Please contact @0809zheng to initialize the comment