WAE: 使用Wasserstein距离的变分自编码器.
本文作者提出了Wasserstein自编码器(WAE),最小化模型生成分布与真实数据分布之间的Wasserstein距离,从而构造一种具有新正则化器形式的VAE。实验表明WAE具有VAE的良好特性(训练稳定、良好的隐空间结构),同时能够生成质量更好的样本。
1. 构建WAE
VAE的目标函数包括两项:生成数据分布$P_G(x | z)$与真实数据分布$P_X$的重构损失、编码器分布$Q(z|x)$与先验分布$P(z)$的差异(正则化项)。
\[\mathcal{L} = \mathbb{E}_{z \text{~} Q(z|x)} [\log P_G(x | z)]- KL[Q(z|x)||P(z)]\]在正则化项中,VAE使不同样本$x$对应的分布$Q(z|x)$都趋近于$P(z)$(对每个样本只考虑了单次采样结果),在重构过程中可能会出现问题;而WAE使连续分布$Q(z)$趋近于$P(z)$,从而保留不同样本后验概率的差异性。
2. WAE的损失形式
(1)重构损失
重构损失衡量原输入数据的概率分布$P_X$和生成数据的概率分布$P_G(x | z)$之间的距离。在VAE中,通过假设生成模型服从固定方差的正态分布,将重构损失选定为均方误差。而在WAE中选用Wasserstein距离。标准的Wasserstein距离定义如下:
\[\mathcal{W}[P_X,P_G] = \mathop{\inf}_{\Gamma \in \mathcal{P}[X \text{~} P_X,Y \text{~} P_G]} \Bbb{E}_{(X,Y) \text{~} \Gamma} [c(x,y)]\]其中$c(x,y)$是代价函数,$\Gamma$是联合分布。如果引入限制$Q_Z=P_Z$,则上式被松弛为:
\[\mathop{\inf}_{Q:Q_Z=P_Z} \Bbb{E}_{P_X}\Bbb{E}_{Q(z|x)} [c(x,G(z))]\](2)正则化项
正则化项衡量编码器分布$Q(z|x)$与先验分布$P(z)$之间的差异\(\mathcal{D}_Z(Q_Z,P_Z)\),在标准的VAE中是通过计算KL散度实现的。而在WAE中,作者设计了两种实现形式。
⚪ 对抗训练:WAE-GAN
可以使用JS散度衡量两者的差异\(\mathcal{D}_Z(Q_Z,P_Z)=\mathcal{D}_{JS}(Q_Z,P_Z)\),此时等价于对抗训练。为隐变量$z$额外引入一个判别器,用于区分$z$来自先验分布$P(z)$还是采样分布$Q(z|x)$。
⚪ 最大平均差异:WAE-MMD
也可以使用最大平均差异(maximum mean discrepancy, MMD)衡量两个分布的差异。MMD通过引入正定核$k$计算如下:
\[\text{MMD}_k(P_Z,Q_Z) = || \int_{\mathcal{Z}} k(z,\cdot)dP_Z(z)-\int_{\mathcal{Z}} k(z,\cdot)dQ_Z(z) ||_{\mathcal{H}_k}\]当先验分布比较接近高维标准正态分布时,WAE-MMD的效果比较好。
上述两个模型的算法流程如下:
(3)总损失
WAE的总损失描述如下:
\[D_{WAE}(P_X,P_G) = \mathop{\inf}_{Q(Z|X) \in \mathcal{Q}} \Bbb{E}_{P_X}\Bbb{E}_{Q(Z|X)} [c(X,G(Z))]+\lambda \cdot \mathcal{D}_Z(Q_Z,P_Z)\]注意到WAE放松了对编码器$Q(Z|X)$的约束,即不再强制其映射到正态分布,而是仅约束先验分布$P_Z$为正态分布。此时非随机编码器将输入确定性地映射到隐变量,其表现形式与普通的编码器类似,因此也不再依赖重参数化技巧。
WAE的完整pytorch实现可参考PyTorch-VAE,下面以WAE-MMD为例分析模型的推理过程。
WAE的前向推理过程如下,编码器直接编码隐变量$z$(而不是概率分布的参数),然后使用隐变量直接重构数据。
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
z = self.encode(input)
return [self.decode(z), input, z]
损失函数如下,重构损失仍然选用均方误差,而MMD也有解析表达形式。
recons_loss =F.mse_loss(recons, input)
mmd_loss = self.compute_mmd(z, reg_weight)
loss = recons_loss + mmd_loss
MMD的解析式如下,注意先验分布$z$~$P_Z$直接指定为标准正态分布,$\tilde{z}$~$Q_Z$是编码器的输出。
\[\frac{\lambda}{n(n-1)}\sum_{l \ne j}k(z_l,z_j)+\frac{\lambda}{n(n-1)}\sum_{l \ne j}k(\tilde{z}_l,\tilde{z}_j)-\frac{2\lambda}{n^2} \sum_{l , j}k(z_l,\tilde{z}_j)\]def compute_mmd(self, z: Tensor) -> Tensor:
bias_corr = self.batch_size * (self.batch_size - 1)
reg_weight = self.reg_weight / bias_corr
# Sample from prior (Gaussian) distribution
prior_z = torch.randn_like(z)
prior_z__kernel = self.compute_kernel(prior_z, prior_z)
z__kernel = self.compute_kernel(z, z)
priorz_z__kernel = self.compute_kernel(prior_z, z)
mmd = reg_weight * prior_z__kernel.mean() + \
reg_weight * z__kernel.mean() - \
2 * reg_weight * priorz_z__kernel.mean()
return mmd
正定核$k$可以选择不同的形式,如:
- RBF核:\(e^{-\frac{\|x_1-x_2\|^2}{\sigma}}\)
- Inverse Multi-Quadratics(IMQ)核:\(\frac{C}{C+\|x_1-x_2\|^2}\)