β-TCVAE: 分离VAE解耦源中的全相关项.

本文作者对VAE的证据下界ELBO进行分解,发现其中存在隐变量的全相关(Total Correlation)项。根据这一项作者设计了β-TCVAE(Total Correlation Variational Autoencoder),能够学习特征的解耦表示而不需要引入额外的超参数。作者进一步提出一种模型解耦程度的度量方法,称为互信息间隙(Mutual Information Gap, MIG)。作者通过实验发现全相关项和解耦表示之间存在很强的相关性。

1. 分解ELBO

VAE优化对数似然的变分下界:

\[\begin{aligned} \log p(x) &= \log \Bbb{E}_{z \text{~} q(z|x)}[\frac{p(x,z)}{q(z|x)}] \geq \Bbb{E}_{z \text{~} q(z|x)}[\log \frac{p(x,z)}{q(z|x)}] \\ \text{ELBO} &= - KL[q(z|x)||p(z)]+\mathbb{E}_{z \text{~} q(z|x)} [\log p(x | z)] \end{aligned}\]

作者对ELBO中的KL散度进行分解。首先对每个训练样本指定唯一的索引$n$,并且定义一个在$[1,N]$上均匀的随机变量$p(n)$与训练样本相关联,表示每个样本被选择的概率相同。分解过程如下:

\[\begin{aligned} KL[q(z|x)||p(z)] &= \frac{1}{N}\sum_{n=1}^{N}KL[q(z|x_n)||p(z)] = \Bbb{E}_{p(n)} [KL[q(z|n)||p(z)]] \\ &= \Bbb{E}_{p(n)} [\Bbb{E}_{q(z|n)}[\log \frac{q(z|n)}{p(z)}]] = \Bbb{E}_{p(n)} [\Bbb{E}_{q(z|n)}[\log q(z|n)-\log p(z)]] \\ &= \Bbb{E}_{p(n)} [\Bbb{E}_{q(z|n)}[\log q(z|n)-\log p(z) +\log q(z) - \log q(z) +\log \prod_{j}q(z_j)-\log \prod_{j}q(z_j)]] \\ & = \Bbb{E}_{p(n)} [\Bbb{E}_{q(z|n)}[\log\frac{q(z|n)}{q(z)}+\log\frac{q(z)}{\prod_{j}q(z_j)}+\log\frac{\prod_{j}q(z_j)}{p(z)} ]] \\ &= \Bbb{E}_{q(z,n)}[\log\frac{q(z|n)}{q(z)}]+\Bbb{E}_{q(z)}[\log\frac{q(z)}{\prod_{j}q(z_j)}]+\Bbb{E}_{q(z)}[\log\frac{\prod_{j}q(z_j)}{\prod_{j}p(z_j)} ]\\& = \Bbb{E}_{q(z,n)}[\log\frac{q(z|n)p(n)}{q(z)p(n)}]+\Bbb{E}_{q(z)}[\log\frac{q(z)}{\prod_{j}q(z_j)}]+\sum_{j}\Bbb{E}_{q(z)}[\log\frac{q(z_j)}{p(z_j)} ] \\ &= \Bbb{E}_{q(z,n)}[\log\frac{q(z,n)}{q(z)p(n)}]+\Bbb{E}_{q(z)}[\log\frac{q(z)}{\prod_{j}q(z_j)}]+\sum_{j}\Bbb{E}_{q(z_j)}[\log\frac{q(z_j)}{p(z_j)} ] \\& = KL(q(z,n)||q(z)p(n)) + KL(q(z)||\prod_{j}q(z_j)) + \sum_{j}KL(q(z_j)||p(z_j)) \end{aligned}\]

分解式的第一项$KL(q(z,n)||q(z)p(n))$是索引-编码的互信息项(Index-Code MI),衡量索引分布$p(n)$和隐变量分布$q(z)$之间的互信息。通常认为较高的互信息能够获得更好的解耦效果,因此有些研究者去掉对这一项的惩罚。然而也有研究者认为对该项进行惩罚能够获得更紧凑的解耦表示。

分解式的第二项$KL(q(z)||\prod_{j}q(z_j))$是全相关(Total Correlation)项,衡量隐变量的不同维度之间的相互依赖程度。这一项惩罚使得模型在数据分布中寻找统计独立的因子,提高模型的解耦能力,是β-VAE成功的关键。

分解式的第三项$\sum_{j}KL(q(z_j)||p(z_j))$是每个维度上的KL散度,防止隐变量的某个维度偏离其相应的正态先验太远。

2. β-TCVAE及其PyTorch实现

根据ELBO的分解项可以构造β-TCVAE模型:

\[\begin{aligned} \text{ELBO} = &\mathbb{E}_{z \text{~} q(z|x)} [\log p(x | z)] - \alpha KL(q(z,n)||q(z)p(n)) \\&-\beta KL(q(z)||\prod_{j}q(z_j)) -\gamma \sum_{j}KL(q(z_j)||p(z_j)) \end{aligned}\]

消融实验表明,仅调节$\beta$也能获得较好的解耦表现,因此通常设置$\alpha=\gamma=1$。

β-TCVAE的完整pytorch实现可参考PyTorch-VAE,与标准VAE的主要区别在于损失函数的实现不同。在计算损失函数中的KL散度时,需要计算形如$\log p$的表达式,其中概率$p$通常取正态分布,因此定义计算正态概率密度的对数函数:

def log_density_gaussian(self, x: Tensor, mu: Tensor, logvar: Tensor):
    """
    Computes the log pdf of the Gaussian with parameters mu and logvar at x
    :param x: (Tensor) Point at whichGaussian PDF is to be evaluated
    :param mu: (Tensor) Mean of the Gaussian distribution
    :param logvar: (Tensor) Log variance of the Gaussian distribution
    :return:
    """
    norm = - 0.5 * (math.log(2 * math.pi) + logvar)
    log_density = norm - 0.5 * ((x - mu) ** 2 * torch.exp(-logvar))
    return log_density

⚪计算$\mathbb{E}_{z \text{~} q(z|x)} [\log p(x | z)]$

采用与标准VAE相同的均方误差损失:

recons_loss = F.mse_loss(recons, input, reduction='sum')

⚪计算$KL(q(z,n)||q(z)p(n))$

注意到:

\[KL(q(z,n)||q(z)p(n))=\Bbb{E}_{p(n)} [\Bbb{E}_{q(z|n)}[\log q(z|n)-\log q(z)]]\]

其中$\log q(z|n)$可以直接计算:

log_q_zx = self.log_density_gaussian(z, mu, log_var).sum(dim = 1)

而$q(z)=\Bbb{E}_{p(n)}[q(z|n)]$依赖于整个数据集,因此不能直接计算得到。作者采用重要性采样的思路,随机采样$M$个数据,对$q(z)$的对数项进行批量加权采样:

\[\Bbb{E}_{q(z)} [\log q(z)] ≈ \frac{1}{M} \sum_{i=1}^{M} [\log \frac{1}{NM} \sum_{j=1}^{M} q(z(n_i)|n_j)]\]

其中$z(n_i)$是从$q(z|n_i)$中的采样。尽管上述估计是有偏的,该估计不需要引入额外的超参数。在实现时首先计算重要性权重(的对数):

dataset_size = (1 / kwargs['M_N']) * batch_size # dataset size
strat_weight = (dataset_size - batch_size + 1) / (dataset_size * (batch_size - 1))
importance_weights = torch.Tensor(batch_size, batch_size).fill_(1 / (batch_size -1)).to(input.device)
importance_weights.view(-1)[::batch_size] = 1 / dataset_size
importance_weights.view(-1)[1::batch_size] = strat_weight
importance_weights[batch_size - 2, 0] = strat_weight
log_importance_weights = importance_weights.log()

将重要性权重应用于采样结果(对应对数形式的相加),即可得到$\log q(z)$的采样估计:

batch_size, latent_dim = z.shape
mat_log_q_z = self.log_density_gaussian(z.view(batch_size, 1, latent_dim),
                                        mu.view(1, batch_size, latent_dim),
                                        log_var.view(1, batch_size, latent_dim))
mat_log_q_z += log_importance_weights.view(batch_size, batch_size, 1)
log_q_z = torch.logsumexp(mat_log_q_z.sum(2), dim=1, keepdim=False)

则损失函数$KL(q(z,n)||q(z)p(n))$计算为:

mi_loss  = (log_q_zx - log_q_z).mean()

⚪计算$KL(q(z)||\prod_{j}q(z_j))$

注意到:

\[KL(q(z)||\prod_{j}q(z_j))=\Bbb{E}_{p(n)} [\Bbb{E}_{q(z|n)}[\log q(z)-\log \prod_{j}q(z_j)]]\]

其中$\log q(z)$已求得,则$\log \prod_{j}q(z_j)$可以用类似的方式计算:

log_prod_q_z = torch.logsumexp(mat_log_q_z, dim=1, keepdim=False).sum(1)

则损失函数$KL(q(z)||\prod_{j}q(z_j))$计算为:

tc_loss = (log_q_z - log_prod_q_z).mean()

⚪计算$\sum_{j}KL(q(z_j)||p(z_j))$

注意到:

\[\sum_{j}KL(q(z_j)||p(z_j))=\Bbb{E}_{p(n)} [\Bbb{E}_{q(z|n)}[\prod_{j}q(z_j)-p(z)]]\]

其中$\log \prod_{j}q(z_j)$已求得,而$\log p(z)$建模为标准正态分布:

zeros = torch.zeros_like(z)
log_p_z = self.log_density_gaussian(z, zeros, zeros).sum(dim = 1)

则损失函数$\sum_{j}KL(q(z_j)||p(z_j))$计算为:

kld_loss = (log_prod_q_z - log_p_z).mean()

3. 互信息间隔 Mutual Information Gap

作者提出了一种衡量模型解耦程度的度量方法:互信息间隔(Mutual Information Gap, MIG)。该方法的出发点是隐变量的某一维度$z_j$和用于描述物体特征的真实因子$v_k$之间的经验互信息可以用联合分布$q(z_j,v_k)=\sum_{n=1}^{N}p(v_k)p(n|v_k)q(z_j|n)$衡量。两者的互信息计算为:

\[I_n(z_j;v_k) = \Bbb{E}_{q(z_j,v_k)}[\log \sum_{n \in \mathcal{X}_{v_k}}q(z_j|n)p(n|v_k)]+H(z_j)\]

有时真实因子$v_k$可能与隐变量的多个维度相关,此时只考虑具有最大互信息的维度,因此在计算时减去第二大的互信息值:

\[\frac{1}{K} \sum_{k=1}^{K} \frac{1}{H(v_k)} (I_n(z_{j^{(k)}};v_k)-\mathop{\max}_{j \ne j^{(k)}}I_n(z_j;v_k))\]

4. 实验分析

作者在两个数据集上对模型的解耦表示能力进行测试。dSprites是一个二维形状数据集,包含6种尺度因子、40种旋转因子、32种X和Y位置因子。3D Faces是一个三维人脸数据集,包含21种方位向因子、11种俯仰向因子和11种光照因子。

下图表明,当逐渐提高$\beta$值时,全相关的惩罚程度变大,模型的解耦能力增强,但也阻碍了数据重构种获取更多有用的信息。

作者分析了隐变量种不同因子之间分布的相关性和独立性,具体地模拟了两个因子的四种不同分布。结果表明β-TCVAE学习到特征的解耦表示。

最后作者展示了一些生成结果,结果表明β-TCVAE能够学习到数据集的更多特征因子。