CVAE: 使用深度条件生成模型学习结构化输出表示.
1. 变分自编码器
变分自编码器(VAE)是一种有向图模型,其隐变量通常选择高斯隐变量。VAE的生成过程为由先验分布$p(z)$生成一组隐变量$z$,再由生成分布$p(x|z)$生成数据$x$。一般情况下,有向图模型的参数估计由于其后验推断难以处理而具有挑战性。在随机梯度变分贝叶斯(SGVB)框架中,将对数似然的变分下界作为替代目标函数,可以有效地估计参数。变分下界为:
\[\begin{aligned} \log p(x) &= \log \Bbb{E}_{z \text{~} q(z|x)}[\frac{p(x,z)}{q(z|x)}] \geq \Bbb{E}_{z \text{~} q(z|x)}[\log \frac{p(x,z)}{q(z|x)}] \\ \text{ELBO} &= \mathbb{E}_{z \text{~} q(z|x)} [\log \frac{p(z)p(x|z)}{q(z|x)}] \\ &= \mathbb{E}_{z \text{~}q(z|x)} [\log \frac{p(z)}{q(z|x)}] + \mathbb{E}_{z \text{~} q(z|x)} [\log p(x | z)] \\ &= - KL[q(z|x)||p(z)]+\mathbb{E}_{z \text{~} q(z|x)} [\log p(x | z)] \end{aligned}\]在此框架下,引入建议分布$q(z|x)$逼近真正的后验$p(z|x)$。利用神经网络对建议分布和生成分布进行建模。
2. 条件变分自编码器
VAE是无监督训练的,如果数据有对应的标签,则可以把标签信息加进去辅助生成样本,或者通过控制某个变量来实现生成某一类样本。实现条件变分自编码器Conditional VAE的方式有很多,一种思路是把数据分布建模为条件概率分布$p(x|y)$,则对应的变分下界为:
\[\begin{aligned} \log p(x|y) &\geq \Bbb{E}_{z \text{~} q(z|x,y)}[\log \frac{p(x,z|y)}{q(z|x,y)}] \\ \text{ELBO} & = - KL[q(z|x,y)||p(z|y)]+\mathbb{E}_{z \text{~} q(z|x,y)} [\log p(x | z,y)] \end{aligned}\]此时建议分布$q(z|x,y)$的输入为样本$x$及其标签$y$,生成分布$p(x|z,y)$的输入为采样值$z$及其标签$y$。
也可以建模输出分布$p(y|x)$,即构造条件图模型,调换上式中$x$和$y$的顺序可以得到:
\[\begin{aligned} \log p(y|x) & \geq \Bbb{E}_{z \text{~} q(z|x,y)}[\log \frac{p(y,z|x)}{q(z|x,y)}] \\ \text{ELBO} &= - KL[q(z|x,y)||p(z|x)]+\mathbb{E}_{z \text{~} q(z|x,y)} [\log p(y | z,x)] \end{aligned}\]此时建议分布$q(z|x,y)$的输入仍然是样本$x$及其标签$y$;而生成分布$p(y|z,x)$的输入为采样值$z$和样本$x$,输出为其标签$y$。
3. CVAE的pytorch实现
CVAE的完整pytorch实现可参考PyTorch-VAE,下面分析模型的推理过程。
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
y = kwargs['labels'].float()
embedded_class = self.embed_class(y)
embedded_class = embedded_class.view(-1, self.img_size, self.img_size).unsqueeze(1)
embedded_input = self.embed_data(input)
x = torch.cat([embedded_input, embedded_class], dim = 1)
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
z = torch.cat([z, y], dim = 1)
return [self.decode(z), input, mu, log_var]
编码器self.encode
的输入同时包含样本与标签信息,具体做法是把onehot标签$y$通过全连接层嵌入到与图像$x$尺寸相同的特征中,连接后共同作为输入。
解码器self.decode
的输入同时包含采样与标签信息,具体做法是通过重参数化采样获得$z$后,与标签$y$连接共同作为输入。