DIP-VAE: 分离推断先验VAE.
作者通过实验发现VAE具有一定的特征解耦能力,这得益于其损失函数中存在$KL(q(z)||p(z))$项。为了学习解耦表示,推断先验(inferred prior) $q(z)$各个特征维度应该是独立的,这可以通过最小化$q(z)$与解耦的生成先验(generated prior) $p(z)$之间的距离实现,比如两者的KL散度。
VAE在MNIST等简单数据集上能够表现处特征解耦,然而对于更复杂的数据集其解耦能力较弱。主要原因包括:
- 真实的数据分布$p(x)$与建模的数据分布$p_{\theta}(x)$具有一定差距,导致$p(z)$和$p_{\theta}(z)$也有差异;
- ELBO目标的非凸性阻碍了实现全局最小值。
为了增强模型的解耦能力,作者显式地在VAE的目标函数中增加了损失项$D(q(z)||p(z))$。具体地,采用一种简单而有效的方法来匹配两个分布的矩(比如协方差),这种改进的VAE模型称为分离推断先验VAE (Disentangled Inferred Prior-VAE, DIP-VAE)。
1. DIP-VAE
由于期望$q(z)$各个维度是独立的,因此约束其不同元素之间的协方差为$0$,对角协方差为$1$:
\[D(q(z)||p(z)) = \lambda_{od} \sum_{i\ne j} [\text{Cov}_{q(z)}[z]]^2_{ij}+\lambda_{d} \sum_{i} ([\text{Cov}_{q(z)}[z]]_{ii}-1)^2\]下面讨论协方差$\text{Cov}_{q(z)}[z]$的计算。根据总协方差公式(the law of total covariance) $Var(X)=E(Var(X|Y))+Var(E(X|Y))$,有:
\[\text{Cov}_{q(z)}[z] = \Bbb{E}_{p(x)}[\text{Cov}_{q(z|x)}[z]] + \text{Cov}_{p(x)}[\Bbb{E}_{q(z|x)}[z]]\]把$q(z|x)$建模为对角正态分布\(\mathcal{N}(\mu,\Sigma^{2})\),则上式表示为:
\[\text{Cov}_{q(z)}[z] = \Bbb{E}_{p(x)}[\Sigma^{2}] + \text{Cov}_{p(x)}[\mu]\]DIP-VAE的完整pytorch实现可参考PyTorch-VAE,下面给出DIP损失的实现过程。
self.lambda_diag = lambda_diag # float = 10.
self.lambda_offdiag = lambda_offdiag # float = 5.
# DIP Loss
centered_mu = mu - mu.mean(dim=1, keepdim = True) # [B x D]
cov_mu = centered_mu.t().matmul(centered_mu).squeeze() # [D X D]
cov_z = cov_mu + torch.mean(torch.diagonal((2. * log_var).exp(), dim1 = 0), dim = 0) # [D x D]
cov_diag = torch.diag(cov_z) # [D]
cov_offdiag = cov_z - torch.diag(cov_diag) # [D x D]
dip_loss = self.lambda_offdiag * torch.sum(cov_offdiag ** 2) + \
self.lambda_diag * torch.sum((cov_diag - 1) ** 2)
2. Separated Attribute Predictability (SAP)
作者提出了一种新的度量模型解耦能力的指标,即独立属性可预测性(Separated Attribute Predictability, SAP)得分。对于隐变量的$d$个特征维度和$k$个解耦特征因子,构造一个$d×k$的得分矩阵S,其第$ij$项是仅使用第$i$个隐变量特征预测第$j$个因子的线性回归或分类得分。
得分矩阵的每一列对应一个生成因子$j$,计算每一列前两个最大得分的差值(对应于前两个预测置信度最高的特征维度),然后计算这些差值的平均值作为最终SAP得分。SAP得分较高表明每个生成因子主要受到隐变量其中一个特征维度的影响。实验结果表明SAP得分越高,生成图像的解耦效果越好。