SWAE:引入Sliced-Wasserstein距离构造VAE.
1. WAE:Wasserstein Autoencoder
Wasserstein自编码器(WAE)的设计思路是约束原输入数据的概率分布$P_X$和生成数据的概率分布$P_G(x | z)$之间的Wasserstein距离:
\[\mathcal{W}[P_X,P_G] = \mathop{\inf}_{\Gamma \in \mathcal{P}[X\text{~}P_X,Y\text{~}P_G]} \Bbb{E}_{(X,Y)\text{~}\Gamma} [c(X,Y)]\]其中$c(X,Y)$是代价函数,$\Gamma$是联合分布。如果引入限制$Q_Z=P_Z$,则上式被松弛为:
\[\mathop{\inf}_{Q:Q_Z=P_Z} \Bbb{E}_{P_X}\Bbb{E}_{Q(z|x)} [c(X,G(Z))]\]则WAE的总损失描述如下:
\[D_{WAE}(P_X,P_G) = \mathop{\inf}_{Q(Z|X) \in \mathcal{Q}} \Bbb{E}_{P_X}\Bbb{E}_{Q(Z|X)} [c(X,G(Z))]+\lambda \cdot \mathcal{D}_Z(Q_Z,P_Z)\]注意到WAE放松了对编码器$Q(Z|X)$的约束,即不再强制其映射到正态分布,而是仅约束先验分布$P_Z$为正态分布。此时非随机编码器将输入确定性地映射到隐变量,其表现形式与普通的编码器类似,因此也不再依赖重参数化技巧。
2. SWAE:Sliced-Wasserstein Autoencoder
在WAE中,优化的出发点是输入数据$P_X$和生成数据$P_G$之间的Wasserstein距离\(\mathcal{W}[P_X,P_G]\),而先验约束\(\mathcal{D}_Z(Q_Z,P_Z)\)是约束条件的松弛项。作者提出了改进的WAE:Sliced-Wasserstein Autoencoder(SWAE),其损失函数如下:
\[\mathop{\arg \min}_{P_X,P_G} \mathcal{W}^+[P_X,P_G]+\lambda \cdot SW_c(Q_Z,P_Z)\]⚪ 重构损失
其中\(\mathcal{W}^+\)是Wasserstein距离\(\mathcal{W}\)的一个上界:
\[\mathcal{W}[P_X,P_G] \leq W^+[P_X,P_G] = \Bbb{E}_{P_X}(c(X,G(Q(Z|X)))) \\ = \int_{X} c(X,G(Q(Z|X)))P_XdX ≈ \frac{1}{N} \sum_{n=1}^{N} c(X_n,G(Q(Z|X_n)))\]在实际实现时,重构损失选用均方误差和绝对值误差的和:
recons_loss_l2 = F.mse_loss(recons, input)
recons_loss_l1 = F.l1_loss(recons, input)
recons_loss = recons_loss_l2 + recons_loss_l1
⚪ Sliced-Wasserstein距离
$SW_c$是指Sliced-Wasserstein距离。由于$Q_Z,P_Z$没有显式的关系,因此无法直接构造两者的Wasserstein距离。Sliced-Wasserstein距离是指把高维分布投影到低维空间中再计算Wasserstein距离。
对于一维分布函数,Wasserstein距离有闭式解:
\[\mathcal{W}[P_X,P_Y] = \int_{0}^{1} c(P_X^{-1}(\tau),P_Y^{-1}(\tau))d \tau\]则Sliced-Wasserstein距离定义如下:
\[SW_c(P_X,P_Y) = \int_{\Bbb{S}^{d-1}} \mathcal{W}[\mathcal{R}_{P_X}(\cdot ;\theta),\mathcal{R}_{P_Y}(\cdot ;\theta)] d\theta\]其中\(\Bbb{S}^{d-1}\)是$d$维空间中的单位球面,\(\mathcal{R}_{P_X}\)和\(\mathcal{R}_{P_Y}\)是一维分布函数,表示以$\theta$为参数对高维分布$P_X,P_Y$进行投影:
\[\mathcal{R}_{P_X}(t ;\theta) = \int_{X} P_X(x)\delta(t-\theta\cdot x)dx, \forall \theta \in \Bbb{S}^{d-1},\forall t \in \Bbb{R}\]在实践中,计算Sliced-Wasserstein距离包括两个步骤:计算编码$Q_Z$和先验$P_Z$的随机映射和计算两个映射之间的Wasserstein距离。
z = self.encode(input) # [N x D]
def compute_swd(self,
z: Tensor,
p: float=2.) -> Tensor:
"""
Computes the Sliced Wasserstein Distance (SWD) - which consists of
randomly projecting the encoded and prior vectors and computing
their Wasserstein distance along those projections.
:param z: Latent samples # [N x D]
:param p: Value for the p^th Wasserstein distance
:return:
"""
prior_z = torch.randn_like(z) # [N x D]
device = z.device
self.latent_dim = latent_dim # int=D
self.num_projections = num_projections # int=50
proj_matrix = self.get_random_projections(self.latent_dim,
num_samples=self.num_projections).transpose(0,1).to(device)
latent_projections = z.matmul(proj_matrix) # [N x S]
prior_projections = prior_z.matmul(proj_matrix) # [N x S]
# The Wasserstein distance is computed by sorting the two projections
# across the batches and computing their element-wise l2 distance
w_dist = torch.sort(latent_projections.t(), dim=1)[0] - \
torch.sort(prior_projections.t(), dim=1)[0]
w_dist = w_dist.pow(p)
return w_dist.mean()
其中低维空间的Wasserstein距离是通过对两个投影随机变量各个特征维度之间最大值差异的L2距离实现的。而随机投影过程从正态分布或Cauchy分布中采样构造:
from torch import distributions as dist
def get_random_projections(self, latent_dim: int, num_samples: int) -> Tensor:
"""
Returns random samples from latent distribution's (Gaussian)
unit sphere for projecting the encoded samples and the
distribution samples.
:param latent_dim: (Int) Dimensionality of the latent space (D)
:param num_samples: (Int) Number of samples required (S)
:return: Random projections from the latent unit sphere
"""
if self.proj_dist == 'normal':
rand_samples = torch.randn(num_samples, latent_dim)
elif self.proj_dist == 'cauchy':
rand_samples = dist.Cauchy(torch.tensor([0.0]),
torch.tensor([1.0])).sample((num_samples, latent_dim)).squeeze()
else:
raise ValueError('Unknown projection distribution.')
rand_proj = rand_samples / rand_samples.norm(dim=1).view(-1,1)
return rand_proj # [S x D]
SWAE的完整pytorch实现可参考PyTorch-VAE。