HFVAE:通过层级分解VAE实现结构化解耦表示.

1. 分解VAE的目标函数

VAE优化对数似然的变分下界ELBO:

logp(x)=logEq(z|x)[p(x,z)q(z|x)]Eq(z|x)[logp(x,z)q(z|x)]

ELBO又可以写作:

Eq(z|x)[logp(x,z)q(z|x)]=Eq(z|x)[logp(x,z)q(x,z)+logq(x)]

在实际中,VAE的目标定义为在有限数据点集的经验分布q(x)上每个数据点ELBO的期望:

Eq(x)[Eq(z|x)[logp(x,z)q(x,z)+logq(x)]]=Eq(x,z)[logp(x,z)q(x,z)]+Eq(x)[logq(x)]

其中Eq(x)[logq(x)]不包含可优化参数,则VAE的主要优化目标实际上为联合分布q(x,z)p(x,z)KL散度。作者对其进行进一步分解:

Eq(x,z)[logp(x,z)q(x,z)]=Eq(x,z)[logp(x,z)q(x,z)p(x)p(x)p(z)p(z)q(x)q(x)q(z)q(z)]=Eq(x,z)[logp(x,z)p(x)p(z)+logq(z)q(x)q(x,z)+logp(x)q(x)+logp(z)q(z)]=Eq(x,z)[logp(x,z)p(x)p(z)]+Eq(x,z)[logq(z)q(x)q(x,z)]+Eq(x,z)[logp(x)q(x)]+Eq(x,z)[logp(z)q(z)]=Eq(x,z)[logp(x|z)p(x)]+Eq(x,z)[logq(z)q(z|x)]+Eq(x)[logp(x)q(x)]+Eq(z)[logp(z)q(z)]=Eq(x,z)[logp(x|z)p(x)]Eq(x,z)[logq(z|x)q(z)]Eq(x)[logq(x)p(x)]Eq(z)[logq(z)p(z)]

VAEELBO最终可以分解为以下四项:

第①项和第②项增强条件分布之间的一致性。第①项Eq(x,z) [logp(x|z)/p(x)]衡量重构结果的唯一性,最大化生成每个样本x的隐变量z的可识别性;第②项Eq(x,z) [logq(z|x)/q(z)]衡量编码的唯一性,通过最小化互信息I(z,x)进行正则化,削弱隐变量z的可识别性。

第③项和第④项增强边际分布之间的一致性。第③项KL[q(x)||p(x)]匹配样本分布,等价于最大化对数似然Eq(x)[logp(x)];第④项KL[q(z)||p(z)]匹配先验分布。

第①项在实践中很难处理,因为p(x)无法直接获取。通过结合第①项和第③项能够避免这种困难:

Eq(x,z)[logp(x|z)p(x)]+Eq(x,z)[logp(x)q(x)]=Eq(x,z)[logp(x|z)q(x)]

为了研究每一项的影响,下图展示了从目标函数中去除每一项得到的结果。当去掉第③项或第④项时,可能会学习到p(x)偏离q(x)q(z)偏离p(z)的模型。去掉第①项意味着不需要每个样本x对应唯一的隐变量z。去掉第②项意味着不限制互信息I(z,x),将每个样本x映射到隐空间中的唯一区域。

2. Hierarchically Factorized VAE

作者旨在增强VAE学习特征之间的统计独立性,目标函数中的第④项KL[q(z)||p(z)]是实现特征解耦的关键,若预先指定先验分布p(z)的各维度之间是独立的,则学习到的隐变量特征分布q(z)也会倾向于特征独立。对第④项进行进一步分解:

Eq(z)[logq(z)p(z)]=Eq(z)[logq(z)p(z)dp(zd)dp(zd)dq(zd)dq(zd)]=Eq(z)[logq(z)dq(zd)+logdq(zd)dp(zd)+logdp(zd)p(z)]=Eq(z)[logp(z)dp(zd)]Eq(z)[logq(z)dq(zd)]Eq(z)[logdq(zd)dp(zd)]

其中前两项为全相关(Total Correlation)项,合计为第A项;第三项衡量隐变量的每一个边际分布。如果zd本身代表一组变量,可以继续将其分解为全相关项i和边际分布项ii,从而提供了归纳分离特征层次的机会。

原则上,可以在任何级别上继续此分解,从而构造HFVAE的目标函数++ii+α+βA+γi

Eq(x,z)[logp(x|z)p(x)]KL[q(x)||p(x)]d,eKL[q(zd,e)||p(zd,e)]αEq(x,z)[logq(z|x)q(z)]+βEq(z)[logp(z)dp(zd)logq(z)dq(zd)]+γdEq(zd)[logp(zd)ep(zd,e)logq(zd)eq(zd,e)]

对于该目标,前两项等价于重构损失,第三项衡量隐变量每个元素的KL散度,第四项通过α控制互信息I(z,x),第五项通过β控制变量组之间的全相关正则化,第六项通过γ控制组内的全相关正则化。