使用Transformer实现可扩展的扩散模型.

0. TL; DR

这篇论文提出了一种新的基于Transformer架构的扩散模型(Diffusion Transformers,DiTs),用于图像生成任务。DiTsImageNet数据集上取得了最先进的图像质量,通过将传统的U-Net骨干网络替换为Transformer,实现了更好的可扩展性和性能。研究发现,随着模型计算复杂度(以Gflops衡量)的增加,DiTs的性能持续提升,表明Transformer架构在扩散模型中具有良好的扩展性。

1. 背景介绍

近年来,扩散模型在图像生成领域取得了显著进展,成为与生成对抗网络(GANs)相媲美的技术。传统的扩散模型通常采用U-Net作为骨干网络,尽管有效,但其扩展性和适应性受到限制。与此同时,Transformer架构在自然语言处理和计算机视觉等领域展现了强大的性能和可扩展性。然而,Transformer在扩散模型中的应用尚未得到充分探索。这篇论文旨在填补这一空白,通过将Transformer引入扩散模型,探索其在图像生成任务中的潜力和优势。

2. DiT 结构

DiTs基于Vision Transformers(ViTs)的设计,将输入图像分解为图像块序列,并通过Transformer块进行处理。模型的核心组件包括:

(1)Patchify层

对于$3\times 256\times 256$的图片,隐变量$z$的维度是$32 \times 32 \times 4$。Patchify把隐变量拆分成patch,并通过线性嵌入转换为$T\times d$的tokenstoken的数量由 Patch 的大小 $p$ 决定,两者满足$T=(I/p)^2$的关系。

(2)Transformer

上述输入的 tokensTransformer 进行处理。除了噪声图像输入之外,模型还会接收额外的条件信息,比如噪声时间步长、类标签等。作者探索了4种不同类型的 Transformer Block,以不同的方式处理条件输入。

在最后一个 DiT Block 之后需要将序列解码为输出噪声的均值和方差预测结果。这两个输出的形状与原始的空间输入一致,因此使用标准的线性解码器,然后将解码的 tokens 重新排列到其原始空间布局中,得到预测的噪声和方差。

(3)模型缩放

作者在缩放 DiT 时从下面几个维度进行考虑:网络深度、特征维度、注意力头数,从而设计出4种不同尺寸的 DiT 模型:

3. 实验分析

通过训练12种不同配置的DiT模型,发现随着模型大小的增加和图像块大小的减小($p=8,4,2$,即增加输入序列长度),FID持续降低,表明模型性能显著提升。

模型的GflopsFID之间存在强烈的负相关关系,即增加模型的计算复杂度是提升性能的关键。即使在保持模型参数数量大致不变的情况下,通过增加输入序列长度(减小图像块大小)来提升Gflops,也能显著提高生成质量。

在四种条件机制中,adaLN-Zero块在训练的各个阶段均优于交叉注意力和上下文条件机制,同时计算效率最高。它通过将每个DiT块初始化为恒等函数,加速了训练过程并提高了模型性能。

实验表明,增加采样步骤的计算量无法弥补模型计算量的不足。较大的DiT模型即使使用较少的采样步骤,也能生成更高质量的图像,证明了模型规模在生成性能中的重要性。