Linear Transformer: 使用线性注意力实现快速自回归的Transformer.

作者提出了一种通过“线性化”降低自注意力机制的计算复杂度的方法,并构造了一种自回归的Transformer结构,能够更快地实现长句子生成等任务。

标准的Attention首先将输入序列X=[x1,...,xn]Rn×d(n个维度为d的特征向量,通常n>d)转换成查询矩阵Q,键矩阵K,值矩阵V

Q=XWqRn×d,WqRd×d K=XWkRn×d,WkRd×d V=XWvRn×d,WvRd×d

并通过下式计算自注意力,对于第i个输入xi,其输出计算为:

(softmax(QKTd)V)i=j=1neqiTkjdvjj=1neqiTkjd

上式计算中矩阵乘法QKT会引入O(n2)计算复杂度。

一般地,引入相似度函数sim(,)0,则Attention也可表示为一般形式:

Attention(Q,K,V)i=j=1nsim(qi,kj)vjj=1nsim(qi,kj)

注意到标准的Attention计算相当于选择了相似度函数:

sim(qi,kj)=eqiTkjd

若把相似度函数看作核函数,即sim(qi,kj)=ϕ(qi)Tϕ(kj),则有:

Attention(Q,K,V)i=j=1nϕ(qi)Tϕ(kj)vjj=1nϕ(qi)Tϕ(kj)=ϕ(qi)Tj=1nϕ(kj)vjTϕ(qi)Tj=1nϕ(kj)

注意到通过上述转换,将计算复杂度从O(n2)降为O(n),即循环内从两次乘法减少为一次乘法。本文选择的ϕ如下:

ϕ(x)=elu(x)+1,elu(x)={x,x>0ex1,x0

使用上述线性注意力(linear attention)构建Transformer时需要注意,由于Transformer采用语言模型的训练策略,因此需要引入mask,即mask掉未来的输入信息。实践中只需将求和j=1n替换为j=1i

Attention(Q,K,V)i=ϕ(qi)Tj=1iϕ(kj)vjTϕ(qi)Tj=1iϕ(kj)

若记Si=j=1iϕ(kj)vjTZi=j=1iϕ(kj),则:

Attention(Q,K,V)i=ϕ(qi)TSiϕ(qi)TZi Si=Si1+ϕ(ki)viT Zi=Zi1+ϕ(ki)

这种线性注意力可以通过上式递归地计算,类似于RNN。该机制的计算复杂度为O(n),但需要串行计算。作者使用该线性注意力构造了线性Transformer。随着输入序列长度n的增长,推理时间和内存占用也呈线性增长(标准的softmax呈平方增长):