SWAE:引入Sliced-Wasserstein距离构造VAE.

1. WAE:Wasserstein Autoencoder

Wasserstein自编码器(WAE)的设计思路是约束原输入数据的概率分布$P_X$和生成数据的概率分布$P_G(x | z)$之间的Wasserstein距离:

\[\mathcal{W}[P_X,P_G] = \mathop{\inf}_{\Gamma \in \mathcal{P}[X\text{~}P_X,Y\text{~}P_G]} \Bbb{E}_{(X,Y)\text{~}\Gamma} [c(X,Y)]\]

其中$c(X,Y)$是代价函数,$\Gamma$是联合分布。如果引入限制$Q_Z=P_Z$,则上式被松弛为:

\[\mathop{\inf}_{Q:Q_Z=P_Z} \Bbb{E}_{P_X}\Bbb{E}_{Q(z|x)} [c(X,G(Z))]\]

WAE的总损失描述如下:

\[D_{WAE}(P_X,P_G) = \mathop{\inf}_{Q(Z|X) \in \mathcal{Q}} \Bbb{E}_{P_X}\Bbb{E}_{Q(Z|X)} [c(X,G(Z))]+\lambda \cdot \mathcal{D}_Z(Q_Z,P_Z)\]

注意到WAE放松了对编码器$Q(Z|X)$的约束,即不再强制其映射到正态分布,而是仅约束先验分布$P_Z$为正态分布。此时非随机编码器将输入确定性地映射到隐变量,其表现形式与普通的编码器类似,因此也不再依赖重参数化技巧。

2. SWAE:Sliced-Wasserstein Autoencoder

WAE中,优化的出发点是输入数据$P_X$和生成数据$P_G$之间的Wasserstein距离\(\mathcal{W}[P_X,P_G]\),而先验约束\(\mathcal{D}_Z(Q_Z,P_Z)\)是约束条件的松弛项。作者提出了改进的WAESliced-Wasserstein Autoencoder(SWAE),其损失函数如下:

\[\mathop{\arg \min}_{P_X,P_G} \mathcal{W}^+[P_X,P_G]+\lambda \cdot SW_c(Q_Z,P_Z)\]

⚪ 重构损失

其中\(\mathcal{W}^+\)是Wasserstein距离\(\mathcal{W}\)的一个上界:

\[\mathcal{W}[P_X,P_G] \leq W^+[P_X,P_G] = \Bbb{E}_{P_X}(c(X,G(Q(Z|X)))) \\ = \int_{X} c(X,G(Q(Z|X)))P_XdX ≈ \frac{1}{N} \sum_{n=1}^{N} c(X_n,G(Q(Z|X_n)))\]

在实际实现时,重构损失选用均方误差和绝对值误差的和:

recons_loss_l2 = F.mse_loss(recons, input)
recons_loss_l1 = F.l1_loss(recons, input)
recons_loss = recons_loss_l2 + recons_loss_l1

⚪ Sliced-Wasserstein距离

$SW_c$是指Sliced-Wasserstein距离。由于$Q_Z,P_Z$没有显式的关系,因此无法直接构造两者的Wasserstein距离。Sliced-Wasserstein距离是指把高维分布投影到低维空间中再计算Wasserstein距离。

对于一维分布函数,Wasserstein距离有闭式解:

\[\mathcal{W}[P_X,P_Y] = \int_{0}^{1} c(P_X^{-1}(\tau),P_Y^{-1}(\tau))d \tau\]

Sliced-Wasserstein距离定义如下:

\[SW_c(P_X,P_Y) = \int_{\Bbb{S}^{d-1}} \mathcal{W}[\mathcal{R}_{P_X}(\cdot ;\theta),\mathcal{R}_{P_Y}(\cdot ;\theta)] d\theta\]

其中\(\Bbb{S}^{d-1}\)是$d$维空间中的单位球面,\(\mathcal{R}_{P_X}\)和\(\mathcal{R}_{P_Y}\)是一维分布函数,表示以$\theta$为参数对高维分布$P_X,P_Y$进行投影:

\[\mathcal{R}_{P_X}(t ;\theta) = \int_{X} P_X(x)\delta(t-\theta\cdot x)dx, \forall \theta \in \Bbb{S}^{d-1},\forall t \in \Bbb{R}\]

在实践中,计算Sliced-Wasserstein距离包括两个步骤:计算编码$Q_Z$和先验$P_Z$的随机映射和计算两个映射之间的Wasserstein距离。

z = self.encode(input) # [N  x D]
def compute_swd(self,
                z: Tensor,
                p: float=2.) -> Tensor:
    """
    Computes the Sliced Wasserstein Distance (SWD) - which consists of
    randomly projecting the encoded and prior vectors and computing
    their Wasserstein distance along those projections.
    :param z: Latent samples # [N  x D]
    :param p: Value for the p^th Wasserstein distance
    :return:
    """
    prior_z = torch.randn_like(z) # [N x D]
    device = z.device

    self.latent_dim = latent_dim # int=D
    self.num_projections = num_projections # int=50
    proj_matrix = self.get_random_projections(self.latent_dim,
                                              num_samples=self.num_projections).transpose(0,1).to(device)

    latent_projections = z.matmul(proj_matrix) # [N x S]
    prior_projections = prior_z.matmul(proj_matrix) # [N x S]

    # The Wasserstein distance is computed by sorting the two projections
    # across the batches and computing their element-wise l2 distance
    w_dist = torch.sort(latent_projections.t(), dim=1)[0] - \
             torch.sort(prior_projections.t(), dim=1)[0]
    w_dist = w_dist.pow(p)
    return w_dist.mean()

其中低维空间的Wasserstein距离是通过对两个投影随机变量各个特征维度之间最大值差异的L2距离实现的。而随机投影过程从正态分布或Cauchy分布中采样构造:

from torch import distributions as dist

def get_random_projections(self, latent_dim: int, num_samples: int) -> Tensor:
    """
    Returns random samples from latent distribution's (Gaussian)
    unit sphere for projecting the encoded samples and the
    distribution samples.
    :param latent_dim: (Int) Dimensionality of the latent space (D)
    :param num_samples: (Int) Number of samples required (S)
    :return: Random projections from the latent unit sphere
    """
    if self.proj_dist == 'normal':
        rand_samples = torch.randn(num_samples, latent_dim)
    elif self.proj_dist == 'cauchy':
        rand_samples = dist.Cauchy(torch.tensor([0.0]),
                                   torch.tensor([1.0])).sample((num_samples, latent_dim)).squeeze()
    else:
        raise ValueError('Unknown projection distribution.')

    rand_proj = rand_samples / rand_samples.norm(dim=1).view(-1,1)
    return rand_proj # [S x D]

SWAE的完整pytorch实现可参考PyTorch-VAE