IWAE:重要性加权自编码器.

1. VAE的优化Gap

变分自编码器(VAE)的优化目标是对数似然的变分下界:

\[\log p(x) = \log \Bbb{E}_{z \text{~} q(z|x)}[\frac{p(x,z)}{q(z|x)}] \geq \Bbb{E}_{z \text{~} q(z|x)}[\log \frac{p(x,z)}{q(z|x)}]\]

注意到变分下界是原优化目标的一个下界,与原目标之间存在gap。该gap可以通过分析对数似然得到解析表达式。对数似然又可以写作:

\[\begin{aligned} \log p(x) &= \int q(z|x)\log p(x)dz= \Bbb{E}_{z \text{~} q(z|x)}[\log p(x)]\\ &= \Bbb{E}_{z \text{~} q(z|x)}[\log \frac{p(x,z)}{p(z|x)}] = \Bbb{E}_{z \text{~} q(z|x)}[\log \frac{p(x,z)}{p(z|x)}\frac{q(z|x)}{q(z|x)}] \\ &= \Bbb{E}_{z \text{~} q(z|x)}[\log \frac{p(x,z)}{q(z|x)}] + \Bbb{E}_{z \text{~} q(z|x)}[\log \frac{q(z|x)}{p(z|x)}] \end{aligned}\]

因此VAE的变分下界与原目标之间存在的gap为$\Bbb{E}_{z \text{~} q(z|x)}[\log \frac{q(z|x)}{p(z|x)}]$。IWAE的出发点便是减小这一gap

2. IWAE的目标函数

IWAE的目标函数如下:

\[\log p(x) = \log \Bbb{E}_{z \text{~} q(z|x)}[\frac{p(x,z)}{q(z|x)}] = \log \Bbb{E}_{z_1,z_2,\cdots z_k \text{~} q(z|x)}[\frac{1}{k}\sum_{i=1}^{k}\frac{p(x,z_i)}{q(z_i|x)}]\]

同样可以根据Jenson不等式求其变分下界:

\[\log \Bbb{E}_{z_1,z_2,\cdots z_k \text{~} q(z|x)}[\frac{1}{k}\sum_{i=1}^{k}\frac{p(x,z_i)}{q(z_i|x)}] \geq \Bbb{E}_{z_1,z_2,\cdots z_k \text{~} q(z|x)}[\log \frac{1}{k}\sum_{i=1}^{k}\frac{p(x,z_i)}{q(z_i|x)}]\]

直观上IWAE将变分下界中的数学期望用采样平均操作进行了替代。当$k=1$时优化目标退化为标准的VAE,即VAE相当于只进行了单次采样。当$k \to \infty$时变分下界逼近原对数似然。可以证明IWAEgap要小于标准的VAE(即\(\mathcal{L}_{1}\)):

\[\log p(x) \geq \mathcal{L}_{k+1} \geq \mathcal{L}_{k} \geq \mathcal{L}_{1}\]

3. IWAE的pytorch实现

IWAE的完整pytorch实现可参考PyTorch-VAE,下面进行分析。

在标准的VAE中,采样是通过重参数化过程实现的。因此在IWAE中,对每个样本重参数化时进行$S$次采样:

def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    mu, log_var = self.encode(input)
    mu = mu.repeat(self.num_samples, 1, 1).permute(1, 0, 2) # [B x S x D]
    log_var = log_var.repeat(self.num_samples, 1, 1).permute(1, 0, 2) # [B x S x D]
    z= self.reparameterize(mu, log_var) # [B x S x D]
    eps = (z - mu) / log_var # Prior samples
    return  [self.decode(z), input, mu, log_var, z, eps]

在计算损失函数时,由于对数函数中存在求和项,因此求解比较困难。注意到损失函数与梯度的计算是等价的,因此不妨先讨论VAEIWAE的梯度之间的差异,以便于进一步将VAE的损失函数扩展为IWAE的损失函数。

VAE的优化目标的梯度计算如下(优化参数用$\theta$表示):

\[\nabla_{\theta}\Bbb{E}_{z \text{~} q_{\theta}(z|x)}[\log \frac{p_{\theta}(x,z)}{q_{\theta}(z|x)}] = \Bbb{E}_{z \text{~} q_{\theta}(z|x)}[\nabla_{\theta}\log \frac{p_{\theta}(x,z)}{q_{\theta}(z|x)}]\]

记$w=\frac{p_{\theta}(x,z)}{q_{\theta}(z|x)}$,则VAE的梯度进一步写作:

\[\Bbb{E}_{z \text{~} q_{\theta}(z|x)}[\nabla_{\theta}\log \frac{p_{\theta}(x,z)}{q_{\theta}(z|x)}] = \Bbb{E}_{z \text{~} q_{\theta}(z|x)}[\nabla_{\theta}\log w]\]

IWAE的梯度计算如下:

\[\begin{aligned} &\nabla_{\theta}\Bbb{E}_{z_1,z_2,\cdots z_k \text{~} q_{\theta}(z|x)}[\log \frac{1}{k}\sum_{i=1}^{k}\frac{p_{\theta}(x,z_i)}{q_{\theta}(z_i|x)}] = \nabla_{\theta}\Bbb{E}_{z_1,z_2,\cdots z_k \text{~} q_{\theta}(z|x)}[\log \frac{1}{k}\sum_{i=1}^{k}w_i] \\ &= \Bbb{E}_{z_1,z_2,\cdots z_k \text{~} q_{\theta}(z|x)}[\nabla_{\theta}\log \frac{1}{k}\sum_{i=1}^{k}w_i] = \Bbb{E}_{z_1,z_2,\cdots z_k \text{~} q_{\theta}(z|x)}[\frac{1}{\frac{1}{k}\sum_{j=1}^{k}w_j}\nabla_{\theta} \frac{1}{k}\sum_{i=1}^{k}w_i] \\ & = \Bbb{E}_{z_1,z_2,\cdots z_k \text{~} q_{\theta}(z|x)}[\frac{1}{\sum_{j=1}^{k}w_j}\sum_{i=1}^{k}\nabla_{\theta} w_i] = \Bbb{E}_{z_1,z_2,\cdots z_k \text{~} q_{\theta}(z|x)}[\frac{1}{\sum_{j=1}^{k}w_j}\sum_{i=1}^{k}w_i\nabla_{\theta} \log w_i] \end{aligned}\]

记\(\tilde{w}_i=\frac{w_i}{\sum_{j=1}^{k}w_j}\),则IWAE的梯度进一步写作:

\[\Bbb{E}_{z_1,z_2,\cdots z_k \text{~} q_{\theta}(z|x)}[\frac{1}{\sum_{j=1}^{k}w_j}\sum_{i=1}^{k}w_i\nabla_{\theta} \log w_i] = \Bbb{E}_{z_1,z_2,\cdots z_k \text{~} q_{\theta}(z|x)}[\sum_{i=1}^{k}\tilde{w}_i\nabla_{\theta} \log w_i]\]

对比VAEIWAE的梯度计算,两者都是根据已有采样点计算$\log w_i$。VAE只进行了一次采样,而IWAE进行了$k$次采样,并使用归一化的权重$\tilde{w}_i$对采样结果的权重进行加权平均。由于梯度和损失函数是等价的,因此IWAE的损失函数也是对不同采样结果损失的加权平均,这也是其名称重要性加权自编码器(Importance Weighted Autoencoder, IWAE)的又来。

注意到对数权重$\log w=\log \frac{p_{\theta}(x,z)}{q_{\theta}(z|x)}$,恰好为VAE的目标函数,因此不同采样点的实际权重$\tilde{w}_i$可以通过对其损失函数的softmax得到。

log_p_x_z = ((recons - input) ** 2).flatten(2).mean(-1) # [B x S]
kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=2) # [B x S]

# Get importance weights
log_weight = (log_p_x_z + kld_weight * kld_loss)
weight = F.softmax(log_weight, dim = -1)

loss = torch.mean(torch.sum(weight * log_weight, dim=-1), dim = 0)