Variational Autoencoder.
本文目录:
- 变分自编码器之“自编码器”:概率编码器与概率解码器
- 变分自编码器之“变分”:优化目标与重参数化
- 变分自编码器的各种变体
1. 变分自编码器之“自编码器”:概率编码器与概率解码器
变分自编码器(Variational Autoencoder,VAE)是一种深度生成模型,旨在学习已有数据集{x1,x2,...,xn}的概率分布p(x),并从数据分布中采样生成新的数据ˆx~p(x)。由于已有数据(也称观测数据, observed data)的概率分布形式是未知的,VAE把输入数据编码到隐空间(latent space)中,构造隐变量(latent variable)的概率分布p(z),从隐变量中采样并重构新的数据,整个过程与自编码器类似。VAE的概率模型如下:
p(x)=∑zp(x,z)=∑zp(x|z)p(z)如果人为指定隐变量的概率分布p(z)形式,则可以从中采样并通过解码器p(x|z)(通常用神经网络拟合)生成新的数据。然而注意到此时隐变量的概率分布p(z)与输入数据无关,即给定一个输入数据xn,从p(z)随机采样并重构为ˆxn,将无法保证xn与ˆxn的对应性!此时生成模型常用的优化指标Distance(xn,ˆxn)等也无法使用。
在VAE中并不是直接指定隐变量的概率分布p(z)形式,而是为每个输入数据xn指定一个后验分布q(z|xn)(通常为标准正态分布),则从该后验分布中采样并重构的ˆxn对应于xn。VAE指定后验分布q(z|xn)为标准正态分布N(0,I),则隐变量分布p(z)实际上也会是标准正态分布N(0,I):
p(z)=∑xq(z|x)p(x)=∑xN(0,I)p(x)=N(0,I)∑xp(x)=N(0,I)VAE使用编码器(通常用神经网络)拟合后验分布q(z|xn)的均值μn和方差σ2n(其数量由每批次样本量决定),通过训练使其接近标准正态分布N(0,I)。实际上后验分布不可能完全精确地被拟合为标准正态分布,因为这样会使得q(z|xn)完全独立于输入数据xn,从而使得重构效果极差。VAE的训练过程中隐含地存在着对抗的过程,最终使得q(z|xn)保留一定的输入数据xn信息,并且对输入数据也具有一定的重构效果。
VAE的整体结构如下图所示。从给定数据中学习后验分布q(z|x)的均值μn和方差σ2n的过程称为推断(inference),实现该过程的结构被称作概率编码器(probabilistic encoder)。从后验分布的采样结果中重构数据的过程p(x|z)称为生成(generation),实现该过程的结构被称作概率解码器(probabilistic decoder)。
⚪讨论:后验分布可以选取其他分布吗?
理论上,后验分布q(z|xn)可以选取任意可行的概率分布形式。然而从后续讨论中会发现对后验分布的约束是通过KL散度实现的,KL散度对于概率为0的点会发散,选择概率密度全局非负的标准正态分布N(0,I)不会出现这种问题,且具有可以计算梯度的简洁的解析解。此外,由于服从正态分布的独立随机变量的和仍然是正态分布,因此隐空间中任意两点间的线性插值也是有意义的,并且可以通过线性插值获得一系列生成结果的展示。
⚪讨论:VAE的Bayesian解释
自编码器AE将观测数据x编码为特征向量z,每一个特征向量对应特征空间中的一个离散点,所有特征向量的分布是无序、散乱的,并且无法保证不存在特征向量的空间点能够重构出真实样本。VAE是AE的Bayesian形式,将特征向量看作随机变量,使其能够覆盖特征空间中的一片区域。进一步通过强迫所有数据的特征向量服从多维正态分布,从而解耦特征维度,使得特征的空间分布有序、规整。
2. 变分自编码器之“变分”:优化目标与重参数化
VAE是一种隐变量模型p(x,z),其优化目标为最大化观测数据的对数似然logp(x)=log∑zp(x,z)。该问题是不可解的,因此采用变分推断求解。变分推断的核心思想是引入一个新的分布q(z|x)作为后验分布p(z|x)的近似,从而构造对数似然logp(x)的置信下界ELBO(也称变分下界, variational lower bound),通过最大化ELBO来代替最大化logp(x)。采用Jensen不等式可以快速推导ELBO的表达式:
logp(x)=log∑zp(x,z)=log∑zp(x,z)q(z|x)q(z|x)=logEz~q(z|x)[p(x,z)q(z|x)]≥Ez~q(z|x)[logp(x,z)q(z|x)]上式表明变分下界ELBO是原优化目标logp(x)的一个下界,两者的差距可以通过对logp(x)的另一种写法获得:
logp(x)=∑zq(z|x)logp(x)=Ez~q(z|x)[logp(x)]=Ez~q(z|x)[logp(x,z)p(z|x)]=Ez~q(z|x)[logp(x,z)p(z|x)q(z|x)q(z|x)]=Ez~q(z|x)[logp(x,z)q(z|x)]+Ez~q(z|x)[logq(z|x)p(z|x)]因此VAE的变分下界与原目标之间存在的gap为Ez~q(z|x)[logq(z|x)p(z|x)]=KL(q(z|x)||p(z|x))。让gap为0的条件是q(z|x)=p(z|x),即找到一个与真实后验分布p(z|x)相同的分布q(z|x)。然而q(z|x)通常假设为较为简单的分布形式(如正态分布),不能拟合足够复杂的分布。因此VAE通常只是一个近似模型,优化的是代理(surrogate)目标,生成的图像比较模糊。
在VAE中,最大化ELBO等价于最小化如下损失函数:
L=−Ez~q(z|x)[logp(x,z)q(z|x)]=−Ez~q(z|x)[logp(x|z)p(z)q(z|x)]=−Ez~q(z|x)[logp(x|z)]−Ez~q(z|x)[logp(z)q(z|x)]=Ez~q(z|x)[−logp(x|z)]+KL[q(z|x)||p(z)]直观上损失函数可以分成两部分:其中Ez~q(z|x)[−logp(x|z)]表示生成模型p(x|z)的重构损失,KL[q(z|x)||p(z)]表示后验分布q(z|x)的正则化项(KL损失)。这两个损失并不是独立的,因为重构损失很小表明p(x|z)置信度较大,即解码器重构比较准确,则编码器q(z|x)不会太随机(即应和x相关性较高),此时KL损失不会小;另一方面KL损失很小表明编码器q(z|x)随机性较高(即和x无关),此时重构损失不可能小。因此VAE的损失隐含着对抗的过程,在优化过程中总损失减小才对应模型的收敛。下面分别讨论这两种损失的具体形式。
(1) 后验分布q(z|x)的正则化项
损失KL[q(z|x)||p(z)]衡量后验分布q(z|x)和先验分布p(z)之间的KL散度。q(z|x)优化的目标是趋近标准正态分布,此时p(z)指定为标准正态分布z~N(0,I)。q(z|x)通过神经网络进行拟合(即概率编码器),其形式人为指定为多维对角正态分布 N(μ,σ2)。
由于两个分布都是正态分布,KL散度有闭式解(closed-form solution),计算如下:
KL[q(z|x)||p(z)]=KL[N(μ,σ2)||N(0,1)]=∫1√2πσ2e−(x−μ)22σ2log1√2πσ2e−(x−μ)22σ21√2πe−x22dx=∫1√2πσ2e−(x−μ)22σ2[−12logσ2+x22−(x−μ)22σ2]dx=12(−logσ2+μ2+σ2−1)在实际实现时拟合logσ2而不是σ2,因为σ2总是非负的,需要增加激活函数进行限制;而logσ2的取值是任意的。KL损失的Pytorch实现如下:
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
(2) 生成模型p(x|z)的重构损失
重构损失Ez~q(z|x)[−logp(x|z)]表示把观测数据映射到隐空间后再重构为原数据的过程。其中生成模型p(x|z)也是通过神经网络进行拟合的(即概率解码器),根据所处理数据的类型不同,p(x|z)应选择不同的分布形式。
⚪二值数据:伯努利分布
如果观测数据x为二值数据(如二值图像),则生成模型p(x|z)建模为伯努利分布:
p(x|z)={ρ(z),x=11−ρ(z),x=0=(ρ(z))x(1−ρ(z))1−x使用神经网络拟合参数ρ:
ρ=argmaxρlogp(x|z)=argminρ−xlogρ(z)−(1−x)log(1−ρ(z))上式表示交叉熵损失函数,且ρ(z)需要经过sigmoid等函数压缩到[0,1]。
⚪一般数据:正态分布
对于一般的观测数据x,将生成模型p(x|z)建模为具有固定方差σ20的正态分布:
p(x|z)=1√2πσ20e−(x−μ)22σ20使用神经网络拟合参数μ:
μ=argmaxμlogp(x|z)=argminμ(x−μ)22σ20上式表示均方误差(MSE),Pytorch实现如下:
recons_loss = F.mse_loss(recons, input, reduction = 'sum')
注意reduction
参数可选'sum'
和'mean'
,应该使用'sum'
,这使得损失函数计算与原式保持一致。笔者在实现时曾选用'mean'
,导致即使训练损失有下降,也只能生成噪声图片,推测是因为取平均使重构损失误差占比过小,无法正常训练。
(3) 重参数化技巧
VAE的损失函数如下:
L=Ez~q(z|x)[−logp(x|z)]+KL[q(z|x)||p(z)]其中期望Ez~q(z‖x)[⋅]表示从从q(z|x)中采样z的过程。由于采样过程是不可导的,不能直接参与梯度传播,因此引入重参数化(reparameterization)技巧。
已经假设z~q(z|x)服从N(μ,σ2),则ϵ=z−μσ服从标准正态分布N(0,I)。因此有如下关系:
z=μ+σ⋅ϵ则从N(0,I)中采样ϵ,再经过参数变换构造z,可使得采样操作不用参与梯度下降,从而实现模型端到端的训练。
在实现时对于每个样本只进行一次采样,采样的充分性是通过足够多的批量样本与训练轮数来保证的。则损失函数也可写作:
L=−logp(x|z)+KL[q(z|x)||p(z)],z~q(z|x)重参数化技巧的Pytorch实现如下:
def reparameterize(mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
⚪讨论:VAE的另一种建模方式
对于一批已有样本,代表一个真实但形式未知的概率分布˜p(x),可以构建一个带参数ϕ的后验分布qϕ(z|x),从而组成联合分布q(x,z)=˜p(x)qϕ(z|x)。如果人为定义一个先验分布p(z),并构建一个带参数θ的生成分布pθ(x|z),则可以构造另一个联合分布p(x,z)=p(z)pθ(x|z)。VAE的目的是联合分布q(x,z),p(x,z)足够接近,因此最小化两者之间的KL散度:
KL(q(x,z)||p(x,z))=Eq(x,z)[logq(x,z)p(x,z)]=E˜p(x)[Eqϕ(z|x)[log˜p(x)qϕ(z|x)p(z)pθ(x|z)]]=E˜p(x)[Eqϕ(z|x)[−logpθ(x|z)]+Eqϕ(z|x)[logqϕ(z|x)p(z)]+Eqϕ(z|x)[log˜p(x)]]=E˜p(x)[Eqϕ(z|x)[−logpθ(x|z)]+KL(qϕ(z|x)||p(z))+Const.]⚪讨论:KL散度与互信息
上述联合分布的KL散度也可写作:
KL(q(x,z)||p(x,z))=E˜p(x)[Eqϕ(z|x)[log˜p(x)qϕ(z|x)p(z)pθ(x|z)]]=E˜p(x)[Eqϕ(z|x)[logqϕ(z|x)p(z)]]−E˜p(x)[Eqϕ(z|x)[logpθ(x|z)˜p(x)]]=E˜p(x)[KL(qϕ(z|x)||p(z))]−E˜p(x)[Eqϕ(z|x)[logpθ(x,z)˜p(x)p(z)]]其中第一项为隐变量z的后验分布与先验分布之间的KL散度,第二项为观测变量x与为隐变量z之间的点互信息。因此VAE的优化目标也可以解释为最小化隐变量KL散度的同时最大化隐变量与观测变量的互信息。
3. 变分自编码器的各种变体
VAE的损失函数共涉及三个不同的概率分布:由概率编码器表示的后验分布q(z|x)、隐变量的先验分布p(z)以及由概率解码器表示的生成分布p(x|z)。对VAE的各种改进可以落脚于对这些概率分布的改进:
- 后验分布q(z|x):后验分布为模型引入了正则化;一种改进思路是通过调整后验分布的正则化项增强模型的解耦能力(如β-VAE, Disentangled β-VAE, InfoVAE, DIP-VAE, FactorVAE, β-TCVAE, HFVAE)。
- 先验分布p(z):先验分布描绘了隐变量分布的隐空间;一种改进思路是通过引入标签实现半监督学习(如CVAE, CMMA);一种改进思路是通过对隐变量离散化实现聚类或分层特征表示(如Categorical VAE, Joint VAE, VQ-VAE, VQ-VAE-2, FSQ);一种改进思路是更换隐变量的概率分布形式(如Hyperspherical VAE, TD-VAE, f-VAE, NVAE)。
- 生成分布p(x|z):生成分布代表模型的数据重构能力;一种改进思路是将均方误差损失替换为其他损失(如EL-VAE, DFCVAE, LogCosh VAE)。
- 改进整体损失函数:也有方法通过调整整体损失改进模型,如紧凑变分下界(如IWAE, MIWAE)或引入Wasserstein距离(如WAE, SWAE)。
- 改进模型结构:如BN-VAE通过引入BatchNorm缓解KL散度消失问题(指较强的解码器允许训练时KL散度项KL[q(z|x)||p(z)]=0);引入对抗训练(如AAE, VAE-GAN)。
方法 | 损失函数 |
---|---|
VAE | Ez~q(z|x)[−logp(x|z)]+KL[q(z|x)||p(z)] |
CVAE 引入条件 |
Ez~q(z|x,y)[−logp(x|z,y)]+KL[q(z|x,y)||p(z|y)] |
CMMA 隐变量z由标签y决定 |
Ez~q(z|x,y)[−logp(x|z)]+KL[q(z|x,y)||p(z|y)] |
β-VAE 特征解耦 |
Ez~q(z|x)[−logp(x|z)]+β⋅KL[q(z|x)||p(z)] |
Disentangled β-VAE 特征解耦 |
Ez~q(z|x)[−logp(x|z)]+γ⋅|KL[q(z|x)||p(z)]−C| |
InfoVAE 特征解耦 |
Ez~q(z|x)[−logp(x|z)]+(1−α)⋅KL[q(z|x)||p(z)]+(α+λ−1)⋅DZ(q(z),p(z)) |
DIP-VAE 分离推断先验 |
Ez~q(z|x)[−logp(x|z)]+KL[q(z|x)||p(z)]+λod∑i≠j[Covq(z)[z]]2ij+λd∑i([Covq(z)[z]]ii−1)2 |
FactorVAE 特征解耦 |
Ez~q(z|x)[−logp(x|z)]+KL[q(z|x)||p(z)]+γKL(q(z)||∏jq(zj)) |
β-TCVAE 分离全相关项 |
Ez~q(z|x)[−logp(x|z)]+αKL(q(z,x)||q(z)p(x))+βKL(q(z)||∏jq(zj))+γ∑jKL(q(zj)||p(zj)) |
HFVAE 隐变量特征分组 |
Ez~q(z|x)[−logp(x|z)]+∑iKL[q(zi)||p(zi)]+αKL(q(z,x)||q(z)p(x))+βEq(z)[logq(z)∏gq(zg)−logp(z)∏gp(zg)]+γ∑gEq(zg)[logq(zg)∏jq(zgj)−logp(zg)∏jp(zgj)] |
Categorical VAE 离散隐变量: Gumbel Softmax |
Ez~q(c|x)[−logp(x|c)]+KL[q(c|x)||p(c)] |
Joint VAE 连续+离散隐变量 |
Ez,c~q(z,c|x)[−logp(x|z,c)]+γz⋅|KL[q(z|x)||p(z)]−Cz|+γc⋅|KL[q(c|x)||p(c)]−Cc| |
VQ-VAE 向量量化隐变量 |
Ez~q(z|x)[−logp(x|ze+sg[zq−ze])]+||sg[ze]−zq||22+β||ze−sg[zq]||22 |
VQ-VAE-2 VQ-VAE分层 |
同上,z=zbottom,ztop |
FSQ 有限标量量化 |
Ez~q(z|x)[−logp(x|z+sg[round(⌊L2⌋tanh(z))−z])] |
Hyperspherical VAE 引入vMF分布 |
Ez~q(z|x)[−logp(x|z)],p(z)~Cd,κeκ<μ(x),z> |
TD-VAE Markov链状态空间 |
E(zt1,zt2)~q[−logpD(xt2|zt2)−logpB(zt1|bt1)−logpT(zt2|zt1)+logpB(zt2|bt2)+logqS(zt1|zt2,bt1,bt2)] |
f-VAE 引入流模型 |
Eu~q(u)[−logp(x|Fx(u))−log|det[∂Fx(u)∂u]|]+KL[q(u)||p(Fx(u))] |
NVAE 引入自回归高斯模型 |
Ez~q(z|x,y)[−logp(x|z,y)]+KL[q(z|x,y)||p(z|y)]p(z)=L∏l=1p(zl|z<l),q(z|x)=L∏l=1q(zl|z<l,x) |
EL-VAE 引入MS-SSIM损失 |
IM(x,ˆx)αM∏Mj=1Cj(x,ˆx)βjSj(x,ˆx)γj+KL[q(z|x)||p(z)] |
DFCVAE 引入特征感知损失 |
α∑Ll=112ClHlWl∑Clc=1∑Hlh=1∑Wlw=1(Φ(x)lc,h,w−Φ(ˆx)lc,h,w)2+β⋅KL[q(z|x)||p(z)] |
LogCosh VAE 引入log cosh损失 |
1alog(ea(x−ˆx)+e−a(x−ˆx)2)+β⋅KL[q(z|x)||p(z)] |
IWAE 紧凑变分下界 |
Ez1,z2,⋯zK~q(z|x)[log1K∑Kk=1p(x,zk)q(zk|x)] |
MIWAE 紧凑变分下界 |
1M∑Mm=1Ezm,1,zm,2,⋯zm,K~q(zm|x)[log1K∑Kk=1p(x,zm,k)q(zm,k|x)] |
WAE 引入Wasserstein距离 |
Ex~p(z)Ez~q(z|x)[c(x,p(x|z))]+λ⋅DZ(q(z),p(z)) |
SWAE 引入Sliced-Wasserstein距离 |
Ex~p(z)Ez~q(z|x)[c(x,p(x|z))]+∫Sd−1W[Rq(z)(⋅;θ),Rp(z)(⋅;θ)]dθ |
⚪ 参考文献
- Auto-Encoding Variational Bayes:(arXiv1312)VAE的原始论文。
- From Autoencoder to Beta-VAE:Blog by Lilian Weng.
- 变分自编码器(一):原来是这么一回事:Blog by 苏剑林.
- Recent Advances in Autoencoder-Based Representation Learning:(arXiv1812)一篇VAE综述。
- PyTorch-VAE: A Collection of Variational Autoencoders (VAE) in PyTorch.:(github)VAE的PyTorch实现。
- Learning Structured Output Representation using Deep Conditional Generative Models:(NeurIPS2015)CVAE: 使用深度条件生成模型学习结构化输出表示。
- Importance Weighted Autoencoders:(arXiv1509)IWAE:重要性加权自编码器。
- Learning to Generate Images with Perceptual Similarity Metrics:(arXiv1511)使用多尺度结构相似性度量MS-SSIM学习图像生成。
- Adversarial Autoencoders:(arXiv1511)AAE:对抗自编码器。
- Autoencoding beyond pixels using a learned similarity metric:(arXiv1512)VAE-GAN:结合VAE和GAN。
- Variational methods for Conditional Multimodal Learning: Generating Human Faces from Attributes:(arXiv1603)CMMA: 条件多模态学习的变分方法。
- Deep Feature Consistent Variational Autoencoder:(arXiv1610)DFCVAE:使用特征感知损失约束深度特征一致性。
- Categorical Reparameterization with Gumbel-Softmax:(arXiv1611)使用Gumble-Softmax实现离散类别隐变量的重参数化。
- β-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework:(ICLR1704)β-VAE:学习变分自编码器隐空间的解耦表示。
- InfoVAE: Balancing Learning and Inference in Variational Autoencoders:(arXiv1706)InfoVAE:平衡变分自编码器的学习和推断过程。
- Variational Inference of Disentangled Latent Concepts from Unlabeled Observations:(arXiv1711)DIP-VAE: 分离推断先验VAE。
- Wasserstein Auto-Encoders:(arXiv1711)WAE: 使用Wasserstein距离的变分自编码器。
- Neural Discrete Representation Learning:(arXiv1711)VQ-VAE:向量量化的变分自编码器。
- Tighter Variational Bounds are Not Necessarily Better:(arXiv1802)MIWAE:紧凑的变分下界阻碍推理网络训练。
- Disentangling by Factorising:(arXiv1802)FactorVAE:通过分解特征表示的分布进行解耦。
- Isolating Sources of Disentanglement in Variational Autoencoders:(arXiv1802)β-TCVAE: 分离VAE解耦源中的全相关项。
- Understanding disentangling in β-VAE:(arXiv1804)使用信息瓶颈解释β-VAE的解耦表示能力。
- Structured Disentangled Representations:(arXiv1804)HFVAE:通过层级分解VAE实现结构化解耦表示。
- Learning Disentangled Joint Continuous and Discrete Representations:(arXiv1804)Joint VAE:学习解耦的联合连续和离散表示。
- Sliced-Wasserstein Autoencoder: An Embarrassingly Simple Generative Model:(arXiv1804)SWAE:引入Sliced-Wasserstein距离构造VAE。
- Hyperspherical Variational Auto-Encoders:(arXiv1804)Hyperspherical VAE: 为隐变量引入vMF分布。
- Temporal Difference Variational Auto-Encoder:(arXiv1806)TD-VAE: 时间差分变分自编码器。
- f-VAEs: Improve VAEs with Conditional Flows:(arXiv1809)f-VAE: 基于流的变分自编码器。
- Log Hyperbolic Cosine Loss Improves Variational Auto-Encoder:(OpenReview2018)使用对数双曲余弦损失改进变分自编码器。
- Generating Diverse High-Fidelity Images with VQ-VAE-2:(arXiv1906)VQ-VAE-2:改进VQ-VAE生成高保真度图像。
- A Batch Normalized Inference Network Keeps the KL Vanishing Away:(arXiv2004)BN-VAE: 通过批量归一化缓解KL散度消失问题。
- NVAE: A Deep Hierarchical Variational Autoencoder:(arXiv2007)Nouveau VAE: 深度层次变分自编码器。
- Finite Scalar Quantization: VQ-VAE Made Simple:(arXiv2309)有限标量量化:简化向量量化的变分自编码器。
Related Issues not found
Please contact @0809zheng to initialize the comment