FLASH: 基于门控注意力单元的线性Transformer.

本文提出了一种Transformer的改进模型,把自注意力运算和全连接层融合成一种门控注意力单元(gated attention unit, GAU)层,使用GAU构造了Transformer的变体FLASH-Quad,尽管仍具有二次计算复杂度,但比标准的Transformer速度更快、显存占用更低,且效果更好。作者进一步提出了一种分块混合注意力(Mixed Chunk Attention)的注意力线性化方法,从而实现具有线性计算复杂度的FLASH

1. 门控注意力单元 GAU

标准的Transformer是由自注意力(Attention)运算和全连接层(FFN)交替构成的,作者将这两层融合成一种新的门控注意力单元GAU。下面介绍GAU的设计思路。

全连接层FFN是由两层MLP构成的:

\[O=\phi(XW_u)W_o\]

其中输入特征$X \in \Bbb{R}^{n \times d}$,全连接作用于特征的每一个$d$维token,分别应用$W_u \in \Bbb{R}^{d \times e}$和的$W_o \in \Bbb{R}^{e \times d}$仿射变换,$\phi$是激活函数。在全连接层中不同token之间没有交互。

T5.1.1mT5等模型中指出,可以使用门控线性单元(gated linear unit, GLU)代替全连接层FFN

\[O=(U\odot V)W_o, \quad U=\phi_u(XW_u), \quad V=\phi_v(XW_v)\]

其中$W_u,W_v \in \Bbb{R}^{d \times e}$,$\odot$表示逐点相乘(Hadamard积)。此处$\phi_u$和$\phi_v$均为Swish激活函数。

注意到在上面的运算中不同token之间没有交互,为了引入自注意力机制,计算中引入自注意力矩阵$A \in \Bbb{R}^{n \times n}$:

\[O=(U\odot AV)W_o\]

观察上式不难发现,当$A$为单位阵时,上式就是GLU形式的全连接层;当$U$为全$1$矩阵时,上式是标准的注意力运算。因此上式可作为自注意力和全连接层的融合。

在计算注意力矩阵$A$时,使用一种计算量更小的简化形式。首先把输入特征变换为中间特征$Z \in \Bbb{R}^{n \times s}$:

\[Z = \phi_z(XW_z)\]

其中$W_z \in \Bbb{R}^{d \times s}$,对$Z$做两次仿射变换$\mathcal{Q}$和$\mathcal{K}$(乘$\gamma$加$\beta$),则注意力矩阵计算如下:

\[A=\frac{1}{n} \text{relu}^2(\frac{\mathcal{Q}(Z)\mathcal{K}(Z)^T}{\sqrt{s}})=\frac{1}{ns} \text{relu}^2(\mathcal{Q}(Z)\mathcal{K}(Z)^T)\]

此处$\text{relu}^2$表示应用relu激活函数后再平方,这个操作是通过NAS搜索出来的。

使用GAU可以替代自注意力和全连接层,其结构图和伪代码如下:

在实践中设置中间特征的维度$e=2d$,则两层GAU的参数量与一层自注意力与全连接层的参数量差不多。GAU的一些消融实验如下:

2. 分块混合注意力

基于上面介绍的GAU构造的模型称为FLASH-QuadQuad表示该模型仍具有二次复杂度。作者进一步提出了具有线性复杂度的FLASH(Fast Linear Attention with a Single Head),使用分块混合注意力实现注意力计算的线性化。

对于长度为$n$的输入序列,将其不重叠地划分为$n/c$个长度为$c$的序列块。对于第$g$个序列块$X_g \in \Bbb{R}^{c \times d}$,计算:

\[U_g=\phi_u(X_gW_u) \in \Bbb{R}^{c \times e} \\ V_g=\phi_v(X_gW_v) \in \Bbb{R}^{c \times e} \\ Z_g = \phi_z(X_gW_z) \in \Bbb{R}^{c \times s}\]

将$Z_g$通过四个简单的仿射变换得到$Q_g^{quad},K_g^{quad},Q_g^{lin},K_g^{lin}$。

$Q_g^{quad},K_g^{quad}$用于在块内计算自注意力,实现每个块的token内部交互:

\[\hat{V}_g^{quad}=\frac{1}{ns} \text{relu}^2(Q_g^{quad}{K_g^{quad}}^T)V_g\]

$Q_g^{lin},K_g^{lin}$用于计算全局的注意力,通过线性注意力实现:

\[\hat{V}_g^{lin}=\frac{1}{n}Q_g^{lin} \sum_{h=1}^{n/c} {K_h^{lin}}^TV_h\]

使用上述两种注意力共同计算输出:

\[O_g=[U_g\odot (\hat{V}_g^{quad}+\hat{V}_g^{lin})]W_o\]

作者针对这种局部注意力和全局注意力进行了消融实验。结果表明局部注意力比全局注意力更重要,且结合两者效果最好。

3. 实验分析

作者对比了标准Transformer与基于GAU的模型在三种不同的层数设置下的模型表现。结果表明GAU在速度和精度上均超过了多头自注意力机制MHSA

作者测试了FLASH模型的性能表现。结果表明,尽管FLASH-QuadTransformer都是二次复杂度,但FLASH-Quad效果更好、速度更快;在序列较长时,线性复杂度的FLASH速度更快,且仍然具有较高的精度。