Linear Transformer: 使用线性注意力实现快速自回归的Transformer.
- paper:Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
- arXiv:link
作者提出了一种通过“线性化”降低自注意力机制的计算复杂度的方法,并构造了一种自回归的Transformer结构,能够更快地实现长句子生成等任务。
标准的Attention首先将输入序列X=[x1,...,xn]∈Rn×d(n个维度为d的特征向量,通常n>d)转换成查询矩阵Q,键矩阵K,值矩阵V:
Q=XWq∈Rn×d,Wq∈Rd×d
K=XWk∈Rn×d,Wk∈Rd×d
V=XWv∈Rn×d,Wv∈Rd×d
并通过下式计算自注意力,对于第i个输入xi,其输出计算为:
(softmax(QKT√d)V)i=∑nj=1eqTikj√dvj∑nj=1eqTikj√d
上式计算中矩阵乘法QKT会引入O(n2)计算复杂度。
一般地,引入相似度函数sim(⋅,⋅)≥0,则Attention也可表示为一般形式:
Attention(Q,K,V)i=∑nj=1sim(qi,kj)vj∑nj=1sim(qi,kj)
注意到标准的Attention计算相当于选择了相似度函数:
sim(qi,kj)=eqTikj√d
若把相似度函数看作核函数,即sim(qi,kj)=ϕ(qi)Tϕ(kj),则有:
Attention(Q,K,V)i=∑nj=1ϕ(qi)Tϕ(kj)vj∑nj=1ϕ(qi)Tϕ(kj)=ϕ(qi)T∑nj=1ϕ(kj)vTjϕ(qi)T∑nj=1ϕ(kj)
注意到通过上述转换,将计算复杂度从O(n2)降为O(n),即循环内从两次乘法减少为一次乘法。本文选择的ϕ如下:
ϕ(x)=elu(x)+1,elu(x)={x,x>0ex−1,x≤0
使用上述线性注意力(linear attention)构建Transformer时需要注意,由于Transformer采用语言模型的训练策略,因此需要引入mask,即mask掉未来的输入信息。实践中只需将求和∑nj=1替换为∑ij=1:
Attention(Q,K,V)i=ϕ(qi)T∑ij=1ϕ(kj)vTjϕ(qi)T∑ij=1ϕ(kj)
若记Si=∑ij=1ϕ(kj)vTj,Zi=∑ij=1ϕ(kj),则:
Attention(Q,K,V)i=ϕ(qi)TSiϕ(qi)TZi
Si=Si−1+ϕ(ki)vTi
Zi=Zi−1+ϕ(ki)
这种线性注意力可以通过上式递归地计算,类似于RNN。该机制的计算复杂度为O(n),但需要串行计算。作者使用该线性注意力构造了线性Transformer。随着输入序列长度n的增长,推理时间和内存占用也呈线性增长(标准的softmax呈平方增长):

Related Issues not found
Please contact @0809zheng to initialize the comment