TransGAN:用Transformer实现GAN.
- paper:TransGAN: Two Transformers Can Make One Strong GAN
- arXiv:link
作者提出了一个用Transformer构建的生成对抗网络:TransGAN。传统的GAN使用卷积网络作为基本结构,作者使用自注意力机制取代了卷积,认为卷积可能并不是GAN所必需的。
网络结构
TransGAN是由生成器和判别器组成的。生成器用于生成图像,判别器用于判别图像的真实性。
生成器
如果直接将原始图像的每个像素看作一个token,逐像素地生成图像,则即使是较低分辨率的图像(如$32 \times 32$)也是一个长序列($1024$),会引入巨大的计算量。为了避免过大的计算开销,作者使用分段式设计迭代地增加输入序列长度,进而提升图像分辨率。
在每个阶段中,堆叠多个Transformer编码器模块,每个编码器的输入和输出序列长度不变。具体地,生成器接收随机噪声作为输入,并通过一个多层感知机生成长度为$H \times W \times C$的1D向量序列。该向量序列可看作是一个尺寸为$H \times W$的2D特征图,其中每个像素点都是一个$C$维token。与位置编码结合后,输入生成器。为了生成更高分辨率的图像,作者在每个阶段后使用了由Reshape和Pixel Shuffle组成的上采样模块。该模块先将尺寸为$(HW \times C)$的1D token序列输入变换到尺寸为$(H \times W \times C)$的2D特征图,再使用Pixel Shuffle对其进行上采样,得到尺寸为$(2H \times 2W \times \frac{C}{4})$的2D特征图,再变换回尺寸为$(4HW \times \frac{C}{4})$的1D token序列。通过在多个阶段重复上述操作,以降低通道数为代价实现了具有较小计算量的分辨率增加。
判别器
判别器用于判断输入图像是真实的还是合成的,因此不需要关注每个像素位置,可以在语义上分辨图像。将输入图像划分成若干patch(文中分解为$8 \times 8$),每个patch通过线性Flatten层转化为token序列,再在头部增加类别token后,通过若干Transformer编码器模块进行分类。
实验分析
作者通过实验发现TransGAN在图像生成任务中的表现仅次于StyleGAN v2,取得了不错的效果。