Joint VAE:学习解耦的联合连续和离散表示.
1. Joint VAE
在VAE的解耦模型中,一些方法把隐变量设置为连续形式(如β-VAE中的标准正态分布),另一些方法把隐变量设置为离散形式(如Categorical VAE中的类别均匀分布)。而本文提出的Joint VAE在隐变量中将连续和离散变量结合起来,若是连续变量部分,是离散变量部分,并且假设和是相互独立的,损失函数设置为Disentangled β-VAE的形式:
⚪ 重构损失
重构损失选用均方误差损失:
recons_loss = F.mse_loss(recons, input, reduction='mean')
⚪ 连续隐变量的正则化项
连续隐变量的先验分布选定为标准正态分布,而后验分布人为指定为对角正态分布,两者的KL散度具有解析表达式:
为了防止KL散度过小使得重构效果变差,控制KL散度的数值在左右,且随着训练轮数逐渐增大,一方面可以提高重构效果,另一方面保留模型的解耦能力。则正则化项表示为:
self.cont_gamma = latent_gamma # float = 30.
self.cont_min = latent_min_capacity # float = 0.
self.cont_max = latent_max_capacity # float = 25.
self.cont_iter = latent_num_iter # int = 25000
# Compute Continuous loss
# Adaptively increase the continuous capacity
cont_curr = (self.cont_max - self.cont_min) * \
self.num_iter/ float(self.cont_iter) + self.cont_min
cont_curr = min(cont_curr, self.cont_max)
kld_cont_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(),
dim=1),
dim=0)
cont_capacity_loss = self.cont_gamma * torch.abs(cont_curr - kld_cont_loss)
⚪ 离散隐变量的正则化项
离散隐变量的先验分布选定为类离散均匀分布,而后验分布为类别分布(需要归一化),两者的KL散度计算为:
self.disc_gamma = categorical_gamma # float = 30.
self.disc_min = categorical_min_capacity # float = 0.
self.disc_max = categorical_max_capacity # float = 25.
self.disc_iter = categorical_num_iter # int = 25000
# Adaptively increase the discrinimator capacity
disc_curr = (self.disc_max - self.disc_min) * \
self.num_iter/ float(self.disc_iter) + self.disc_min
disc_curr = min(disc_curr, np.log(self.categorical_dim))
q = self.encode(input)[0]
q_p = F.softmax(q, dim=-1) # Convert the categorical codes into probabilities
eps = 1e-7
# Entropy of the logits
h1 = q_p * torch.log(q_p + eps)
# Cross entropy with the categorical distribution
h2 = q_p * np.log(1. / self.categorical_dim + eps)
kld_disc_loss = torch.mean(torch.sum(h1 - h2, dim =1), dim=0)
disc_capacity_loss = self.disc_gamma * torch.abs(disc_curr - kld_disc_loss)
Joint VAE的完整pytorch实现可参考PyTorch-VAE。
2. Joint VAE的重参数化
Joint VAE涉及分别从连续分布和离散分布中采样的过程,因此需要借助重参数化技巧。
⚪ 连续变量的重参数化
连续分布通常选择正态分布:。此时重参数化技巧就是“从中采样”变成“从中采样,然后计算”。此时目标函数变为:
Pytorch实现如下:
def reparameterize(mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
⚪ 离散变量的重参数化
为实现离散分布的重参数化,引入Gumbel Softmax方法。Gumbel Softmax方法实现从离散的类别分布中采样的过程,且采样的随机性转移到无参数的均匀分布上:
其中为退火参数,其数值越小会使结果越接近onehot形式,对应类别分布越尖锐,然而梯度消失情况也越严重。
Pytorch实现如下:
def reparameterize(self, c: Tensor, eps:float = 1e-7) -> Tensor:
"""
Gumbel-softmax trick to sample from Categorical Distribution
:param c: (Tensor) Latent Codes [B x D x K]
:return: (Tensor) [B x D]
"""
# Sample from Gumbel
u = torch.rand_like(c)
g = - torch.log(- torch.log(u + eps) + eps)
# Gumbel-Softmax sample
s = F.softmax((c + g) / self.temp, dim=-1)
s = s.view(-1, self.latent_dim * self.categorical_dim)
return s
Related Issues not found
Please contact @0809zheng to initialize the comment