PixArt-Σ:4K文本到图像生成的扩散Transformer的由弱到强的训练.
0. TL; DR
本文介绍了一种名为PixArt-Σ的文本到图像(T2I)扩散模型,它能够在4K分辨率下直接生成高质量图像。PixArt-Σ基于PixArt-α模型,通过“从弱到强”的训练策略,利用高质量数据和高效的Token压缩技术,显著提升了生成图像的保真度和对文本提示的对齐能力。
该模型仅使用0.6B参数,相比现有的T2I扩散模型(如SDXL的2.6B参数和SD Cascade的5.1B参数)更小,同时在图像质量和语义对齐方面表现出色。此外,PixArt-Σ能够直接生成4K分辨率的图像,无需后处理,为电影和游戏等行业的高质量视觉内容创作提供了有力支持。
1. 背景介绍
近年来,文本到图像(T2I)生成模型取得了显著进展,如DALL·E 3、Midjourney和Stable Diffusion等模型,它们能够生成逼真的图像,对图像编辑、视频生成和3D资产创建等下游应用产生了深远影响。然而,开发一个顶级的T2I模型需要大量的计算资源,例如从头开始训练Stable Diffusion v1.5需要大约6000个A100 GPU天,这对资源有限的研究人员构成了重大障碍,阻碍了AIGC社区的创新。因此,如何在有限资源下高效地提升T2I模型的性能成为了一个重要的研究方向。
2. PixArt-Σ 模型
PixArt-Σ的核心思想是通过“从弱到强”的训练策略,从PixArt-α的“较弱”基线模型逐步演进为“更强”的模型。具体来说,这一过程包括以下几个方面:
(1)高质量训练数据
PixArt-Σ采用了比PixArt-α更高质量的图像数据,这些数据具有更高的分辨率(超过1K,包括约2.3M的4K分辨率图像)和更丰富的艺术风格。同时,为了提供更精确和详细的描述,PixArt-Σ使用了更强大的图像描述生成器Share-Captioner,替换了PixArt-α中使用的LLaVA,并将文本编码器的Token长度从120扩展到300,以增强模型对文本和视觉概念之间对齐的能力。
(2)高效的Token压缩
为了应对4K超高分辨率图像生成带来的计算挑战,PixArt-Σ基于Diffusion Transformer(DiT)架构,通过引入KV Token压缩技术,显著提高了模型在处理长序列Token时的效率。具体来说,PixArt-Σ在Transformer的深层(14-27层)引入了KV Token压缩,通过组卷积将2×2的Token压缩为一个Token,从而减少了计算复杂度。
此外,PixArt-Σ还采用了“Conv Avg Init”初始化策略,通过将卷积核的权重初始化为平均操作符,使得模型在初始状态下能够产生粗略的结果,加速了微调过程。这种设计有效地减少了训练和推理时间,对于4K图像生成的训练和推理时间减少了约34%。
(3)训练细节
PixArt-Σ的训练过程包括以下几个阶段:
- VAE适应:将PixArt-α的VAE替换为SDXL的VAE,并继续微调扩散模型。这一过程仅需5个V100 GPU天。
- 文本-图像对齐:使用高质量数据集进行微调,以提高模型对文本和图像对齐的能力。这一过程需要50个V100 GPU天。
- 高分辨率微调:从低分辨率模型(如512px)微调到高分辨率模型(如1024px),并引入KV Token压缩。这一过程通过“PE Interpolation”技巧初始化高分辨率模型的位置嵌入,显著提高了高分辨率模型的初始状态,加快了微调过程。
3. 实验分析
PixArt-Σ在图像质量和语义对齐方面表现出色。通过与现有的T2I模型(如PixArt-α、SDXL和Stable Cascade)进行比较,PixArt-Σ在FID和CLIP-Score等指标上均取得了更好的结果。此外,PixArt-Σ还能够直接生成4K分辨率的图像,无需后处理,这为电影和游戏等行业的高质量视觉内容创作提供了有力支持。
为了评估PixArt-Σ的性能,作者进行了人类偏好研究和AI偏好研究。在人类偏好研究中,PixArt-Σ在图像质量和对文本提示的遵循能力方面优于其他六种T2I生成器。在AI偏好研究中,使用GPT-4 Vision作为评估器,PixArt-Σ同样表现出色,证明了其在图像质量和语义对齐方面的优势。
作者还进行了消融研究,以评估不同KV Token压缩设计对生成性能的影响。实验结果表明,将KV Token压缩应用于Transformer的深层(14-27层)能够取得最佳性能。此外,使用“Conv 2×2”方法进行Token压缩的效果优于其他方法,如随机丢弃和平均池化。在不同分辨率下,KV Token压缩对图像质量(FID)有轻微影响,但对语义对齐(CLIP-Score)没有影响。尽管随着压缩比的增加,图像质量略有下降,但KV Token压缩显著提高了训练和推理的速度,特别是在高分辨率图像生成中。