使用Gumble-Softmax实现离散类别隐变量的重参数化.

1. 重参数化

在一些深度学习任务中,经常会处理以下带有期望形式的目标函数:

\[L_{\theta} = \Bbb{E}_{z \text{~} p_{\theta}(z)} [f(z)]\]

为了计算目标函数的梯度,需要写出$L_{\theta}$的解析表达式,即需要从$p_{\theta}(z)$中采样以计算期望。然而直接采样的过程无法计算梯度,因此引入重参数化技巧。

重参数化(reparameterization)是指把从有参数的分布$p_{\theta}(z)$中采样转换为从无参数的分布$q(\epsilon)$中采样,从而实现目标函数及其梯度的计算。

2. 连续形式的重参数化

若分布$p_{\theta}(z)$为连续形式,则目标函数写作:

\[L_{\theta} = \Bbb{E}_{z \text{~} p_{\theta}(z)} [f(z)] = \int p_{\theta}(z)f(z) dz\]

直接计算上述积分比较困难,因此通过采样实现重参数化:先从一个无参数的分布$q(\epsilon)$中采样$\epsilon$,再通过变换$z=g_{\theta}(\epsilon)$生成$z$。此时目标函数变为:

\[L_{\theta} = \Bbb{E}_{\epsilon \text{~} q(\epsilon)} [f(g_{\theta}(\epsilon))]\]

⚪ 一个例子:VAE

在变分自编码器中,$z$的分布通常选择正态分布:$z\text{~}\mathcal{N}(\mu_{\theta},\sigma_{\theta}^2)$。此时重参数化技巧就是“从$\mathcal{N}(\mu_{\theta},\sigma_{\theta}^2)$中采样$z$”变成“从$\mathcal{N}(0,1)$中采样$\epsilon$,然后计算$\epsilon \cdot \sigma_{\theta}+\mu_{\theta}$”。此时目标函数变为:

\[\Bbb{E}_{z \text{~} \mathcal{N}(\mu_{\theta},\sigma_{\theta}^2)} [f(z)] = \Bbb{E}_{\epsilon \text{~} \mathcal{N}(0,1)} [f(\epsilon \cdot \sigma_{\theta}+\mu_{\theta})]\]

变分自编码器中重参数化技巧的Pytorch实现如下:

def reparameterize(mu, log_var):
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)
    return mu + eps * std

3. 离散形式的重参数化

若分布$p_{\theta}(y)$为离散形式,则目标函数写作:

\[L_{\theta} = \Bbb{E}_{y \text{~} p_{\theta}(y)} [f(y)] = \sum_{y} p_{\theta}(y)f(y)\]

此时离散变量$y$服从$k$类多项分布:

\[p_{\theta}(y) = softmax(s_1,s_2,\cdots s_k) = \frac{1}{\sum_{i=1}^{k}e^{s_i}}(e^{s_1},e^{s_2},\cdots e^{s_k})\]

上式表示为有限项的求和形式,理论上是可以求得解析解的。然而当类别$k$过大或函数$f(y)$比较复杂时,求解仍然比较困难。

为实现离散分布$p_{\theta}(z)$的重参数化,引入Gumbel Max方法。

⚪ Gumbel Max方法

Gumbel Max方法是一种依概率采样类别的方法,假设每个类别出现的概率是$p_1,p_2,\cdots p_k$,则有:

\[\mathop{\arg \max}_{i} (\log p_i - \log (-\log \epsilon_i))_{i=1}^k, \quad \epsilon_i\text{~}U[0,1]\]

上式表明,先计算各类别概率的对数$\log p_i$,然后从均匀分布$U[0,1]$中采样$k$个随机数$\epsilon_1,\epsilon_2,\cdots \epsilon_k$,把$- \log (-\log \epsilon_i)$加到$\log p_i$上,然后选择最大值对应的类别即可。

Gumbel Max方法等价于依概率$p_1,p_2,\cdots p_k$采样一个类别,即输出类别$i$的概率恰好是$p_i$。下面不失一般性地证明输出类别$1$的概率是$p_1$。

若输出类别$1$,则$\log p_1 - \log (-\log \epsilon_1)$应该是最大的:

\[\log p_1 - \log (-\log \epsilon_1) \geq \log p_2 - \log (-\log \epsilon_2) \\ \log p_1 - \log (-\log \epsilon_1) \geq \log p_3 - \log (-\log \epsilon_3) \\ \cdots \\ \log p_1 - \log (-\log \epsilon_1) \geq \log p_k - \log (-\log \epsilon_k)\]

若第一个不等式成立则有:

\[\epsilon_2 \leq \epsilon_1^{\frac{p_2}{p_1}} \leq 1\]

由于$\epsilon_2\text{~}U[0,1]$,则$\epsilon_2 \leq \epsilon_1^{\frac{p_2}{p_1}}$的概率为$\epsilon_1^{\frac{p_2}{p_1}}$,此即第一个不等式成立的概率。若上述不等式均成立,则对应的概率为:

\[\epsilon_1^{\frac{p_2}{p_1}}\epsilon_1^{\frac{p_3}{p_1}}\cdots \epsilon_1^{\frac{p_k}{p_1}}=\epsilon_1^{\frac{p_2+p_3+\cdots p_k}{p_1}} = \epsilon_1^{\frac{1}{p_1}-1}\]

由于$\epsilon_1\text{~}U[0,1]$,则概率进一步计算为:

\[\int_{0}^{1} \epsilon_1^{\frac{1}{p_1}-1} d\epsilon_1 = p_1\epsilon_1^{\frac{1}{p_1}}|_{0}^{1} = p_1\]

Gumbel Max方法实现了从离散的类别分布中采样的过程,且采样的随机性转移到无参数的均匀分布$U[0,1]$上。然而其argmax操作仍然是不可导的,为此引入Gumbel Softmax方法。

⚪ Gumbel Softmax方法

Gumbel Softmax方法是指使用softmax函数替代onehot(argmax)操作,从而实现光滑可导的类别采样:

\[softmax (\frac{\log p_i - \log (-\log \epsilon_i)}{\tau})_{i=1}^k, \quad \epsilon_i\text{~}U[0,1]\]

其中$\tau$为退火参数,其数值越小会使结果越接近onehot形式,对应类别分布越尖锐,然而梯度消失情况也越严重。只有$\tau \to 0$时Gumbel Softmax才是类别采样Gumbel Max的等价形式。在应用Gumbel Softmax时,早期阶段可以选择比较大的$\tau$,然后慢慢退火到接近$0$的数值:

self.temp = temperature # 0.5
self.min_temp = min_temperature # 0.01
self.anneal_rate = anneal_rate # float = 3e-5
self.anneal_interval = anneal_interval # int = 100
# Anneal the temperature at regular intervals
if batch_idx % self.anneal_interval == 0 and self.training:
    self.temp = np.maximum(self.temp * np.exp(- self.anneal_rate * batch_idx),
                           self.min_temp)

在实际实现Gumbel Softmax时,可以直接将$\log p_i$换成网络输出的logits

\[softmax (\frac{s_i - \log (-\log \epsilon_i)}{\tau})_{i=1}^k, \quad \epsilon_i\text{~}U[0,1]\]

Gumbel SoftmaxPytorch实现如下:

def reparameterize(self, z: Tensor, eps:float = 1e-7) -> Tensor:
    """
    Gumbel-softmax trick to sample from Categorical Distribution
    :param z: (Tensor) Latent Codes [B x D x K]
    :return: (Tensor) [B x D]
    """
    # Sample from Gumbel
    u = torch.rand_like(z)
    g = - torch.log(- torch.log(u + eps) + eps)

    # Gumbel-Softmax sample
    s = F.softmax((z + g) / self.temp, dim=-1)
    s = s.view(-1, self.latent_dim * self.categorical_dim)
    return s

4. Gumbel Softmax的应用:Categorical VAE

Categorical VAE是指将隐变量$z$设置为离散的类别分布,常见于文本生成等任务。回顾标准的VAE目标函数:

\[\mathcal{L} = \mathbb{E}_{z \text{~} q(z|x)} [-\log p(x | z)] + KL[q(z|x)||p(z)]\]

其中重构损失仍然选用均方误差:

recons_loss =F.mse_loss(recons, input, reduction='mean')

对于正则化项,$p(z)$设置为$k$类离散均匀分布$(1/k,…,1/k)$,$q(z|x)$是对编码器的输出结果(应用softmax转化成概率分布),则KL散度计算为:

\[KL[q(z|x)||p(z)] = \sum_{z}^{} q(z|x) \log \frac{q(z|x)}{p(z)}\\= \sum_{z}^{} q(z|x) \log q(z|x)-q(z|x) \log p(z)\]
q = self.encode(input)[0]
q_p = F.softmax(q, dim=-1) # Convert the categorical codes into probabilities

# KL divergence between gumbel-softmax distribution
eps = 1e-7

# Entropy of the logits
h1 = q_p * torch.log(q_p + eps)

# Cross entropy with the categorical distribution
h2 = q_p * np.log(1. / self.categorical_dim + eps)
kld_loss = torch.mean(torch.sum(h1 - h2, dim =(1,2)), dim=0)

Categorical VAE的完整pytorch实现可参考PyTorch-VAE