VQ-VAE:向量量化的变分自编码器.

一. 研究背景

生成模型是一种无监督学习方法,旨在学习数据的分布或生成属于分布的新样本,广泛应用于图像生成任务。作者提出了一种向量量化的变分自编码器,将编码后的隐空间离散化,并采用自回归的方法在隐空间进行分布估计,实现对高清图片的生成。

1. 自编码器 AutoEncoder

自编码器(AutoEncoder, AE)模型包含两部分,编码器(Encoder)解码器(Decoder)

编码器将高维输入图像$x$压缩成低维的隐变量$z$,解码器将隐变量$z$解码重构图像$x$。通过计算重构误差并反向传播,即可更新网络参数。

从概率角度,编码器学习到给定输入$x$得到隐变量$z$的条件概率分布$P(z | x)$,解码器学习到给定隐变量$z$得到输入$x$的条件概率分布$P(x | z)$。

自编码器只能求得编码器和解码器,无法直接得到输入数据的概率分布$P(x)$,从而无法直接采样获得新的数据样本。由Bayes公式:

\[P(x | z) = \frac{P(z | x)P(x)}{P(z)}\]

要想获得输入数据的概率分布$P(x)$,需要知道隐变量$z$的概率分布$P(z)$。

2. 变分自编码器 Variational AutoEncoder

变分自编码器(Variational AutoEncoder,VAE)对隐变量$z$加上了先验知识,将其看作满足$z \text{~} N(0,1)$的先验分布;训练完成后抽样得到潜在表示$z$,使用解码器$P(x | z)$可以生成新样本。

具体实现时,在重构误差之外增加编码分布和先验分布的KL散度,是模型学习满足正态分布的$z$。优化置信下界(evidence low bound,ELBO)

\[ELBO(\theta, \phi) = E_{z \text{~} q_{\phi}(z | x)}[\log p_{\theta}(x | z)]-KL(q_{\phi}(z | x) || p(z))\]

3. 自回归 AutoRegressive

自回归(AutoRegressive)模型(如PixelCNN)是一种逐像素生成的图像生成模型,通过概率的chain rule可以表示输入数据的概率分布$P(x)$:

\[P(x) = P(x_0)P(x_1 | x_0)P(x_2 | x_0x_1) \cdot\cdot\cdot P(x_N | x_0x_1x_2 \cdot\cdot\cdot x_{N-1})\]

具体抽样过程是先从$P(x_0)$抽样$x_0$, 然后根据条件概率依次抽取剩余像素数值。当图片尺寸较大时,自回归模型的计算需求极大。

二. VQ-VAE:Vector Quantised Variational AutoEncoder

作者将变分自编码器和自回归模型的特点结合起来,提出了VQ-VAE模型。

自回归模型在图像比较大的情况下因计算需求暴增而失效,如果能将图像压缩到低维空间,在低维空间训练自回归网络,再解码回高维空间即可。

作者对编码后的隐变量$z$建立自回归模型,得到结构化的全局语义概率分布:

\[P(z) = P(z_0)P(z_1 | z_0)P(z_2 | z_0z_1) \cdot\cdot\cdot P(z_m | z_0z_1z_2 \cdot\cdot\cdot z_{m-1})\]

变分自编码器的隐变量$z$每一维度都是连续值,如果将其离散化为整数,更符合自然界的模态,不需要学习特别细节的东西。如图像中类别的概念是离散的,两个不同类别的中间值没有意义。

将$z$离散化的关键是向量量化(vector quatization,VQ)。首先建立一个字典存储一系列编码表,在这个字典中找到和隐变量最接近(比如欧氏距离最近)的一个编码,用这个编码的index来代表这个隐变量。

1. 模型结构

VQ-VAE的模型结构如上图所示。

将尺寸为$H \times W \times c$的输入图像$x$通过卷积网络构成的编码器,得到尺寸为$H’ \times W’ \times D$的特征映射(隐变量)$z_e(x)$,该特征映射的每一个空间位置是一个$D$维向量:

\[z_e(x) = Encoder(x)\]

构建一个字典$E=[e_1,e_2, \cdots e_K],e_k \in \Bbb{R}^D$,也称为编码表,通过Embedding层实现:

self.K = num_embeddings
self.D = embedding_dim

self.embedding = nn.Embedding(self.K, self.D)
self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K)

VQ-VAE通过最邻近搜索,将$z_e(x)$中$H’ \times W’$个$D$维向量映射为这$K$个字典向量之一:

\[z_e^{(i)}(x)\to e_k, \quad k = \mathop{\arg \min}_{j} ||z_e^{(i)}(x)- e_j||_2\]

由于$e_k$是编码表$E$中的向量之一,所以它实际上等价于其index ($1,2,…,K$这$K$个整数之一),因此该过程相当于将图像编码为一个$H’ \times W’$的整数矩阵$q(z | x)$,实现了离散型编码。整数矩阵也可以表示成onehot形式。

latents = latents.permute(0, 2, 3, 1).contiguous()  # [B x D x H x W] -> [B x H x W x D]
latents_shape = latents.shape
flat_latents = latents.view(-1, self.D)  # [BHW x D]

# Compute L2 distance between latents and embedding weights
dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \
       torch.sum(self.embedding.weight ** 2, dim=1) - \
       2 * torch.matmul(flat_latents, self.embedding.weight.t())  # [BHW x K]

# Get the encoding that has the min distance
encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1)  # [BHW, 1]

# Convert to one-hot encodings
device = latents.device
encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device)
encoding_one_hot.scatter_(1, encoding_inds, 1)  # [BHW x K]

把$z_e(x)$的向量替换为编码表$E$中对应的向量$e_k$,就可以得到最终的尺寸为$H’ \times W’ \times D$的编码结果$z_q(x)$。

# Quantize the latents
quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight)  # [BHW, D]
quantized_latents = quantized_latents.view(latents_shape)  # [B x H x W x D]
quantized_latents = quantized_latents.permute(0, 3, 1, 2).contiguous() # [B x D x H x W]

将$z_q(x)$喂入解码器,重构最终图像$p(x | z_q)$:

\[p(x \| z_q) = Decoder(z_q(x))\]

在实验中,设置输入图像的尺寸为$H \times W \times c = 256 \times 256 \times 3$,离散编码的尺寸为$H’ \times W’ = 32 \times 32$,字典长度为$K = 53$,字典编码维度为$D=64$。

2. 损失函数

对于自编码器,损失函数考虑重构误差:

\[①: || x - p(x | z_e) ||_2^2\]

但在VQ-VAE模型中,重构图像使用的是$z_q$而不是$z_e$,应该用下列重构误差:

\[②: || x - p(x | z_q) ||_2^2\]

但是$z_q$的构建过程无法求导(包含argmin等操作),无法直接对②式进行优化;而直接优化①式又不是直接的优化目标。

作者引入Straight-Through Estimator方法,梯度的直通估计是指前向传播的时使用目标变量(即使不可导),而反向传播时使用自己设计的梯度。因此所设计的重构误差为:

\[|| x - p(x | z_e + sg[z_q-z_e]) ||_2^2\]

其中$sg$表示stop gradient,即在反向传播时不计算其梯度,在pytorch中可以通过.detach()方法实现。因此在前向传播时损失等价于$|| x - p(x | z_q) ||_2^2$,即使用$z_q$计算损失函数;在反向传播时损失等价于$|| x - p(x | z_e) ||_2^2$,即使用$z_e$计算梯度。通过自定义函数的梯度可以同时更新编码器和解码器。

quantized_latents = latents + (quantized_latents - latents).detach()
recons = self.decode(quantized_latents)
recons_loss = F.mse_loss(recons, input)

在计算损失函数时,根据VQ-VAE的最邻近搜索设计,还应该期望编码向量$z_e$和量化向量$z_q$足够接近。通常编码向量$z_e$相对比较自由,而量化向量$z_q$要保证重构效果,因此将$||z_e-z_q||^2_2$分解为两个损失,分别更新量化向量和编码向量:

\[|| sg[z_e] - z_q ||_2^2 + || z_e - sg[z_q] ||_2^2\]

通过引入$\beta < 1$使得“让$z_q$靠近$z_e$”的损失(即$||sg[z_e]-z_q||^2_2$)具有更高的权重:

\[|| x - p(x | z_e + sg[z_q-z_e]) ||_2^2 + || sg[z_e] - z_q ||_2^2 + \beta || z_e - sg[z_q] ||_2^2\]
self.beta = beta # float = 0.25

# Compute the VQ Losses
commitment_loss = F.mse_loss(quantized_latents.detach(), latents)
embedding_loss = F.mse_loss(quantized_latents, latents.detach())

vq_loss = commitment_loss * self.beta + embedding_loss
loss = recons_loss + vq_loss

3. 建模先验分布

VQ-VAE训练完成后,可以在先验分布$p(z)$ ($H’ \times W’$的整数矩阵)上采样,进一步通过解码器生成图像。

如果独立地采样$H’ \times W’$个离散值,再通过字典映射为维度是$H’ \times W’ \times D$的$z_q(x)$,那么生成的图像在每个空间位置上都是独立的。

为了建立不同空间位置之间的联系,建立自回归模型,具体地,使用PixelCNN:$p(z_1,z_2,z_3, \cdot\cdot\cdot) = p(z_1)p(z_2 | z_1)p(z_3 | z_1z_2) \cdot\cdot\cdot$,其中联合概率$p(z_1,z_2,z_3, \cdot\cdot\cdot)$就是先验分布。通过对其采样,可以得到互相关联的$H’ \times W’$个整数。

VQ-VAE的完整pytorch实现可参考PyTorch-VAE