Joint VAE:学习解耦的联合连续和离散表示.

1. Joint VAE

VAE的解耦模型中,一些方法把隐变量z设置为连续形式(如β-VAE中的标准正态分布),另一些方法把隐变量z设置为离散形式(如Categorical VAE中的类别均匀分布)。而本文提出的Joint VAE在隐变量中将连续和离散变量结合起来,若z是连续变量部分,c是离散变量部分,并且假设zc是相互独立的,损失函数设置为Disentangled β-VAE的形式:

Ez,c~q(z,c|x)[logp(x|z,c)]+γz|KL[q(z|x)||p(z)]Cz|+γc|KL[q(c|x)||p(c)]Cc|

⚪ 重构损失

重构损失Ez,c~q(z,c|x)[logp(x|z,c)]选用均方误差损失:

recons_loss = F.mse_loss(recons, input, reduction='mean')

⚪ 连续隐变量的正则化项

连续隐变量z的先验分布p(z)选定为标准正态分布N(0,I),而后验分布人为指定为对角正态分布N(μ,σ2),两者的KL散度KL[q(z|x)||p(z)]具有解析表达式:

KL[N(μ,σ2)||N(0,1)]=12(logσ2+μ2+σ21)

为了防止KL散度过小使得重构效果变差,控制KL散度的数值在Cz左右,且Cz随着训练轮数逐渐增大,一方面可以提高重构效果,另一方面保留模型的解耦能力。则正则化项γz|KL[q(z|x)||p(z)]Cz|表示为:

self.cont_gamma = latent_gamma # float = 30.
self.cont_min = latent_min_capacity # float = 0.
self.cont_max = latent_max_capacity # float = 25.
self.cont_iter = latent_num_iter # int = 25000

# Compute Continuous loss
# Adaptively increase the continuous capacity
cont_curr = (self.cont_max - self.cont_min) * \
            self.num_iter/ float(self.cont_iter) + self.cont_min
cont_curr = min(cont_curr, self.cont_max)

kld_cont_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(),
                                            dim=1),
                           dim=0)
cont_capacity_loss = self.cont_gamma * torch.abs(cont_curr - kld_cont_loss)

⚪ 离散隐变量的正则化项

离散隐变量c的先验分布p(c)选定为k类离散均匀分布(1/k,,1/k),而后验分布q(c|x)为类别分布(需要归一化),两者的KL散度KL[q(c|x)||p(c)]计算为:

KL[q(c|x)||p(c)]=cq(c|x)logq(c|x)q(c|x)logp(c)
self.disc_gamma = categorical_gamma # float = 30.
self.disc_min = categorical_min_capacity # float = 0.
self.disc_max = categorical_max_capacity # float = 25.
self.disc_iter = categorical_num_iter # int = 25000

# Adaptively increase the discrinimator capacity
disc_curr = (self.disc_max - self.disc_min) * \
            self.num_iter/ float(self.disc_iter) + self.disc_min
disc_curr = min(disc_curr, np.log(self.categorical_dim))

q = self.encode(input)[0]
q_p = F.softmax(q, dim=-1) # Convert the categorical codes into probabilities
eps = 1e-7

# Entropy of the logits
h1 = q_p * torch.log(q_p + eps)
# Cross entropy with the categorical distribution
h2 = q_p * np.log(1. / self.categorical_dim + eps)
kld_disc_loss = torch.mean(torch.sum(h1 - h2, dim =1), dim=0)

disc_capacity_loss = self.disc_gamma * torch.abs(disc_curr - kld_disc_loss)

Joint VAE的完整pytorch实现可参考PyTorch-VAE

2. Joint VAE的重参数化

Joint VAE涉及分别从连续分布q(z|x)和离散分布q(c|x)中采样的过程,因此需要借助重参数化技巧。

⚪ 连续变量的重参数化

连续分布q(z|x)通常选择正态分布:z~N(μθ,σθ2)。此时重参数化技巧就是“从N(μθ,σθ2)中采样z”变成“从N(0,1)中采样ϵ,然后计算ϵσθ+μθ”。此时目标函数变为:

Ez~N(μθ,σθ2)[f(z)]=Eϵ~N(0,1)[f(ϵσθ+μθ)]

Pytorch实现如下:

def reparameterize(mu, log_var):
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)
    return mu + eps * std

⚪ 离散变量的重参数化

为实现离散分布q(c|x)的重参数化,引入Gumbel Softmax方法。Gumbel Softmax方法实现从离散的类别分布中采样的过程,且采样的随机性转移到无参数的均匀分布U[0,1]上:

softmax(cilog(logϵi)τ)i=1k,ϵi~U[0,1]

其中τ为退火参数,其数值越小会使结果越接近onehot形式,对应类别分布越尖锐,然而梯度消失情况也越严重。

Pytorch实现如下:

def reparameterize(self, c: Tensor, eps:float = 1e-7) -> Tensor:
    """
    Gumbel-softmax trick to sample from Categorical Distribution
    :param c: (Tensor) Latent Codes [B x D x K]
    :return: (Tensor) [B x D]
    """
    # Sample from Gumbel
    u = torch.rand_like(c)
    g = - torch.log(- torch.log(u + eps) + eps)

    # Gumbel-Softmax sample
    s = F.softmax((c + g) / self.temp, dim=-1)
    s = s.view(-1, self.latent_dim * self.categorical_dim)
    return s