TransGAN:用Transformer实现GAN.

作者提出了一个用Transformer构建的生成对抗网络:TransGAN。传统的GAN使用卷积网络作为基本结构,作者使用自注意力机制取代了卷积,认为卷积可能并不是GAN所必需的。

网络结构

TransGAN是由生成器判别器组成的。生成器用于生成图像,判别器用于判别图像的真实性。

生成器

如果直接将原始图像的每个像素看作一个token,逐像素地生成图像,则即使是较低分辨率的图像(如32×32)也是一个长序列(1024),会引入巨大的计算量。为了避免过大的计算开销,作者使用分段式设计迭代地增加输入序列长度,进而提升图像分辨率。

在每个阶段中,堆叠多个Transformer编码器模块,每个编码器的输入和输出序列长度不变。具体地,生成器接收随机噪声作为输入,并通过一个多层感知机生成长度为H×W×C1D向量序列。该向量序列可看作是一个尺寸为H×W2D特征图,其中每个像素点都是一个Ctoken。与位置编码结合后,输入生成器。为了生成更高分辨率的图像,作者在每个阶段后使用了由ReshapePixel Shuffle组成的上采样模块。该模块先将尺寸为(HW×C)1D token序列输入变换到尺寸为(H×W×C)2D特征图,再使用Pixel Shuffle对其进行上采样,得到尺寸为(2H×2W×C4)2D特征图,再变换回尺寸为(4HW×C4)1D token序列。通过在多个阶段重复上述操作,以降低通道数为代价实现了具有较小计算量的分辨率增加。

判别器

判别器用于判断输入图像是真实的还是合成的,因此不需要关注每个像素位置,可以在语义上分辨图像。将输入图像划分成若干patch(文中分解为8×8),每个patch通过线性Flatten层转化为token序列,再在头部增加类别token后,通过若干Transformer编码器模块进行分类。

实验分析

作者通过实验发现TransGAN在图像生成任务中的表现仅次于StyleGAN v2,取得了不错的效果。