Joint VAE:学习解耦的联合连续和离散表示.

1. Joint VAE

VAE的解耦模型中,一些方法把隐变量$z$设置为连续形式(如β-VAE中的标准正态分布),另一些方法把隐变量$z$设置为离散形式(如Categorical VAE中的类别均匀分布)。而本文提出的Joint VAE在隐变量中将连续和离散变量结合起来,若$z$是连续变量部分,$c$是离散变量部分,并且假设$z$和$c$是相互独立的,损失函数设置为Disentangled β-VAE的形式:

\[\mathbb{E}_{z,c \text{~} q(z,c|x)} [-\log p(x|z,c)]+\gamma_z \cdot |KL[q(z|x)||p(z)]-C_z|+\gamma_c \cdot |KL[q(c|x)||p(c)]-C_c|\]

⚪ 重构损失

重构损失$\mathbb{E}_{z,c \text{~} q(z,c|x)} [-\log p(x|z,c)]$选用均方误差损失:

recons_loss = F.mse_loss(recons, input, reduction='mean')

⚪ 连续隐变量的正则化项

连续隐变量$z$的先验分布$p(z)$选定为标准正态分布\(\mathcal{N}(0,I)\),而后验分布人为指定为对角正态分布\(\mathcal{N}(\mu,\sigma^2)\),两者的KL散度$KL[q(z|x)||p(z)]$具有解析表达式:

\[KL[\mathcal{N}(\mu,\sigma^{2})||\mathcal{N}(0,1)] = \frac{1}{2} (-\log \sigma^2 + \mu^2+\sigma^2-1)\]

为了防止KL散度过小使得重构效果变差,控制KL散度的数值在$C_z$左右,且$C_z$随着训练轮数逐渐增大,一方面可以提高重构效果,另一方面保留模型的解耦能力。则正则化项$\gamma_z \cdot |KL[q(z|x)||p(z)]-C_z|$表示为:

self.cont_gamma = latent_gamma # float = 30.
self.cont_min = latent_min_capacity # float = 0.
self.cont_max = latent_max_capacity # float = 25.
self.cont_iter = latent_num_iter # int = 25000

# Compute Continuous loss
# Adaptively increase the continuous capacity
cont_curr = (self.cont_max - self.cont_min) * \
            self.num_iter/ float(self.cont_iter) + self.cont_min
cont_curr = min(cont_curr, self.cont_max)

kld_cont_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(),
                                            dim=1),
                           dim=0)
cont_capacity_loss = self.cont_gamma * torch.abs(cont_curr - kld_cont_loss)

⚪ 离散隐变量的正则化项

离散隐变量$c$的先验分布$p(c)$选定为$k$类离散均匀分布$(1/k,…,1/k)$,而后验分布$q(c|x)$为类别分布(需要归一化),两者的KL散度$KL[q(c|x)||p(c)]$计算为:

\[KL[q(c|x)||p(c)] = \sum_{c}^{} q(c|x) \log q(c|x)-q(c|x) \log p(c)\]
self.disc_gamma = categorical_gamma # float = 30.
self.disc_min = categorical_min_capacity # float = 0.
self.disc_max = categorical_max_capacity # float = 25.
self.disc_iter = categorical_num_iter # int = 25000

# Adaptively increase the discrinimator capacity
disc_curr = (self.disc_max - self.disc_min) * \
            self.num_iter/ float(self.disc_iter) + self.disc_min
disc_curr = min(disc_curr, np.log(self.categorical_dim))

q = self.encode(input)[0]
q_p = F.softmax(q, dim=-1) # Convert the categorical codes into probabilities
eps = 1e-7

# Entropy of the logits
h1 = q_p * torch.log(q_p + eps)
# Cross entropy with the categorical distribution
h2 = q_p * np.log(1. / self.categorical_dim + eps)
kld_disc_loss = torch.mean(torch.sum(h1 - h2, dim =1), dim=0)

disc_capacity_loss = self.disc_gamma * torch.abs(disc_curr - kld_disc_loss)

Joint VAE的完整pytorch实现可参考PyTorch-VAE

2. Joint VAE的重参数化

Joint VAE涉及分别从连续分布$q(z|x)$和离散分布$q(c|x)$中采样的过程,因此需要借助重参数化技巧。

⚪ 连续变量的重参数化

连续分布$q(z|x)$通常选择正态分布:$z\text{~}\mathcal{N}(\mu_{\theta},\sigma_{\theta}^2)$。此时重参数化技巧就是“从$\mathcal{N}(\mu_{\theta},\sigma_{\theta}^2)$中采样$z$”变成“从$\mathcal{N}(0,1)$中采样$\epsilon$,然后计算$\epsilon \cdot \sigma_{\theta}+\mu_{\theta}$”。此时目标函数变为:

\[\Bbb{E}_{z \text{~} \mathcal{N}(\mu_{\theta},\sigma_{\theta}^2)} [f(z)] = \Bbb{E}_{\epsilon \text{~} \mathcal{N}(0,1)} [f(\epsilon \cdot \sigma_{\theta}+\mu_{\theta})]\]

Pytorch实现如下:

def reparameterize(mu, log_var):
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)
    return mu + eps * std

⚪ 离散变量的重参数化

为实现离散分布$q(c|x)$的重参数化,引入Gumbel Softmax方法。Gumbel Softmax方法实现从离散的类别分布中采样的过程,且采样的随机性转移到无参数的均匀分布$U[0,1]$上:

\[softmax (\frac{c_i - \log (-\log \epsilon_i)}{\tau})_{i=1}^k, \quad \epsilon_i\text{~}U[0,1]\]

其中$\tau$为退火参数,其数值越小会使结果越接近onehot形式,对应类别分布越尖锐,然而梯度消失情况也越严重。

Pytorch实现如下:

def reparameterize(self, c: Tensor, eps:float = 1e-7) -> Tensor:
    """
    Gumbel-softmax trick to sample from Categorical Distribution
    :param c: (Tensor) Latent Codes [B x D x K]
    :return: (Tensor) [B x D]
    """
    # Sample from Gumbel
    u = torch.rand_like(c)
    g = - torch.log(- torch.log(u + eps) + eps)

    # Gumbel-Softmax sample
    s = F.softmax((c + g) / self.temp, dim=-1)
    s = s.view(-1, self.latent_dim * self.categorical_dim)
    return s