WAE: 使用Wasserstein距离的变分自编码器.

本文作者提出了Wasserstein自编码器(WAE),最小化模型生成分布与真实数据分布之间的Wasserstein距离,从而构造一种具有新正则化器形式的VAE。实验表明WAE具有VAE的良好特性(训练稳定、良好的隐空间结构),同时能够生成质量更好的样本。

1. 构建WAE

VAE的目标函数包括两项:生成数据分布$P_G(x | z)$与真实数据分布$P_X$的重构损失、编码器分布$Q(z|x)$与先验分布$P(z)$的差异(正则化项)。

\[\mathcal{L} = \mathbb{E}_{z \text{~} Q(z|x)} [\log P_G(x | z)]- KL[Q(z|x)||P(z)]\]

在正则化项中,VAE使不同样本$x$对应的分布$Q(z|x)$都趋近于$P(z)$(对每个样本只考虑了单次采样结果),在重构过程中可能会出现问题;而WAE使连续分布$Q(z)$趋近于$P(z)$,从而保留不同样本后验概率的差异性。

2. WAE的损失形式

(1)重构损失

重构损失衡量原输入数据的概率分布$P_X$和生成数据的概率分布$P_G(x | z)$之间的距离。在VAE中,通过假设生成模型服从固定方差的正态分布,将重构损失选定为均方误差。而在WAE中选用Wasserstein距离。标准的Wasserstein距离定义如下:

\[\mathcal{W}[P_X,P_G] = \mathop{\inf}_{\Gamma \in \mathcal{P}[X \text{~} P_X,Y \text{~} P_G]} \Bbb{E}_{(X,Y) \text{~} \Gamma} [c(x,y)]\]

其中$c(x,y)$是代价函数,$\Gamma$是联合分布。如果引入限制$Q_Z=P_Z$,则上式被松弛为:

\[\mathop{\inf}_{Q:Q_Z=P_Z} \Bbb{E}_{P_X}\Bbb{E}_{Q(z|x)} [c(x,G(z))]\]

(2)正则化项

正则化项衡量编码器分布$Q(z|x)$与先验分布$P(z)$之间的差异\(\mathcal{D}_Z(Q_Z,P_Z)\),在标准的VAE中是通过计算KL散度实现的。而在WAE中,作者设计了两种实现形式。

⚪ 对抗训练:WAE-GAN

可以使用JS散度衡量两者的差异\(\mathcal{D}_Z(Q_Z,P_Z)=\mathcal{D}_{JS}(Q_Z,P_Z)\),此时等价于对抗训练。为隐变量$z$额外引入一个判别器,用于区分$z$来自先验分布$P(z)$还是采样分布$Q(z|x)$。

⚪ 最大平均差异:WAE-MMD

也可以使用最大平均差异(maximum mean discrepancy, MMD)衡量两个分布的差异。MMD通过引入正定核$k$计算如下:

\[\text{MMD}_k(P_Z,Q_Z) = || \int_{\mathcal{Z}} k(z,\cdot)dP_Z(z)-\int_{\mathcal{Z}} k(z,\cdot)dQ_Z(z) ||_{\mathcal{H}_k}\]

当先验分布比较接近高维标准正态分布时,WAE-MMD的效果比较好。

上述两个模型的算法流程如下:

(3)总损失

WAE的总损失描述如下:

\[D_{WAE}(P_X,P_G) = \mathop{\inf}_{Q(Z|X) \in \mathcal{Q}} \Bbb{E}_{P_X}\Bbb{E}_{Q(Z|X)} [c(X,G(Z))]+\lambda \cdot \mathcal{D}_Z(Q_Z,P_Z)\]

注意到WAE放松了对编码器$Q(Z|X)$的约束,即不再强制其映射到正态分布,而是仅约束先验分布$P_Z$为正态分布。此时非随机编码器将输入确定性地映射到隐变量,其表现形式与普通的编码器类似,因此也不再依赖重参数化技巧。

WAE的完整pytorch实现可参考PyTorch-VAE,下面以WAE-MMD为例分析模型的推理过程。

WAE的前向推理过程如下,编码器直接编码隐变量$z$(而不是概率分布的参数),然后使用隐变量直接重构数据。

def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    z = self.encode(input)
    return  [self.decode(z), input, z]

损失函数如下,重构损失仍然选用均方误差,而MMD也有解析表达形式。

recons_loss =F.mse_loss(recons, input)
mmd_loss = self.compute_mmd(z, reg_weight)
loss = recons_loss + mmd_loss

MMD的解析式如下,注意先验分布$z$~$P_Z$直接指定为标准正态分布,$\tilde{z}$~$Q_Z$是编码器的输出。

\[\frac{\lambda}{n(n-1)}\sum_{l \ne j}k(z_l,z_j)+\frac{\lambda}{n(n-1)}\sum_{l \ne j}k(\tilde{z}_l,\tilde{z}_j)-\frac{2\lambda}{n^2} \sum_{l , j}k(z_l,\tilde{z}_j)\]
def compute_mmd(self, z: Tensor) -> Tensor:
    bias_corr = self.batch_size *  (self.batch_size - 1)
    reg_weight = self.reg_weight / bias_corr
    # Sample from prior (Gaussian) distribution
    prior_z = torch.randn_like(z)

    prior_z__kernel = self.compute_kernel(prior_z, prior_z)
    z__kernel = self.compute_kernel(z, z)
    priorz_z__kernel = self.compute_kernel(prior_z, z)

    mmd = reg_weight * prior_z__kernel.mean() + \
          reg_weight * z__kernel.mean() - \
          2 * reg_weight * priorz_z__kernel.mean()
    return mmd

正定核$k$可以选择不同的形式,如: