通过VQGAN和Transformer实现高分辨率图像合成.
本文作者设计了一种两阶段的图像合成方法。第一阶段使用VQGAN对图像块进行编码,学习隐空间中离散的编码表;第二阶段使用Transformer自回归地生成隐编码,从而生成高分辨率图像。
1. VQGAN
VQGAN包括一个生成器和一个判别器。生成器采用VQ-VAE,学习图像块的量化表示;判别器采用Pix2Pix提出的PatchGAN结构,用于提高生成图像的感知质量。
生成器(VQ-VAE)将尺寸为$H \times W \times c$的输入图像$x$通过卷积网络构成的编码器,得到尺寸为$H’ \times W’ \times D$的特征映射(隐变量)$z_e(x)$。构建一个字典$E=[e_1,e_2, \cdots e_K],e_k \in \Bbb{R}^D$,也称为编码表。VQ-VAE通过最邻近搜索,将$z_e(x)$中$H’ \times W’$个$D$维向量映射为这$K$个字典向量之一:
\[z_e^{(i)}(x)\to e_k, \quad k = \mathop{\arg \min}_{j} ||z_e^{(i)}(x)- e_j||_2\]由于$e_k$是编码表$E$中的向量之一,所以它实际上等价于其index ($1,2,…,K$这$K$个整数之一),因此该过程相当于将图像编码为一个$H’ \times W’$的整数矩阵$q(z | x)$,实现了离散型编码。把$z_e(x)$的向量替换为编码表$E$中对应的向量$e_k$,就可以得到最终的尺寸为$H’ \times W’ \times D$的编码结果$z_q(x)$。
将$z_q(x)$喂入解码器,重构最终图像$p(x | z_q)$。损失函数包括重构误差:
\[|| x - p(x | z_q) ||_2^2\]还应该期望编码向量$z_e$和量化向量$z_q$足够接近。通常编码向量$z_e$相对比较自由,而量化向量$z_q$要保证重构效果,因此将$||z_e-z_q||^2_2$分解为两个损失,分别更新量化向量和编码向量:
\[|| sg[z_e] - z_q ||_2^2 + || z_e - sg[z_q] ||_2^2\]上述两项损失共同构成了生成器损失\(\mathcal{L}_{VQ}(E,G)\)。
判别器结构采用Pix2Pix提出的PatchGAN结构,把判别器设计为全卷积网络,输出为一个$N \times N$矩阵,其中的每个元素对应输入图像的一个子区域,用来评估该子区域的真实性。
\[\begin{aligned} \mathcal{L}_{GAN}(D,E,G) & = \Bbb{E}_{x \text{~} P_{data}(x)}[\log D(x)] + \Bbb{E}_{z \text{~} P_{Z}(z)}[\log(1-D(G(z)))] \end{aligned}\]VQGAN的总目标函数为:
\[\mathop{ \min}_{E,G} \mathop{\max}_{D} \mathcal{L}_{VQ}(E,G) + \lambda \mathcal{L}_{GAN}(D,E,G)\]其中损失权重设置为:
\[\lambda = \frac{\nabla_{G_L}[\mathcal{L}_{VQ}]}{\nabla_{G_L}[\mathcal{L}_{GAN}]+1e-6}\]$\nabla_{G_L}$表示求解码器最后一层关于输入的梯度。
2. 图像生成
VQGAN训练完成后,学习到图像块的编码表$E=[e_1,e_2, \cdots e_K],e_k \in \Bbb{R}^D$。从编码表中采样可以生成新的图像。作者使用一个Transformer实现自回归的采样过程。对于高分辨率图像,由于整体像素数量较大,作者设置了一个注意力窗口,自回归旨在窗口内的像素中进行:
对于条件生成,输入条件既可以是单个类别标签,也可以是另一副图像。对于另一副图像作为条件,作者额外训练了一个VQGAN用于将条件图像编码为索引向量$z_e(x)$作为输入。