Linear Transformer: 使用线性注意力实现快速自回归的Transformer.

作者提出了一种通过“线性化”降低自注意力机制的计算复杂度的方法,并构造了一种自回归的Transformer结构,能够更快地实现长句子生成等任务。

标准的Attention首先将输入序列\(X=[x_1,...,x_n] \in \Bbb{R}^{n×d}\)($n$个维度为$d$的特征向量,通常$n>d$)转换成查询矩阵$Q$,键矩阵$K$,值矩阵$V$:

\[Q = XW_q \in \Bbb{R}^{n×d}, \quad W_q \in \Bbb{R}^{d×d}\] \[K = XW_k \in \Bbb{R}^{n×d}, \quad W_k \in \Bbb{R}^{d×d}\] \[V = XW_v \in \Bbb{R}^{n×d}, \quad W_v \in \Bbb{R}^{d×d}\]

并通过下式计算自注意力,对于第$i$个输入$x_i$,其输出计算为:

\[(\text{softmax}(\frac{QK^T}{\sqrt{d}})V)_i=\frac{\sum_{j=1}^{n}e^{\frac{q_i^Tk_j}{\sqrt{d}}}v_j}{\sum_{j=1}^{n}e^{\frac{q_i^Tk_j}{\sqrt{d}}}}\]

上式计算中矩阵乘法$QK^T$会引入$O(n^2)$计算复杂度。

一般地,引入相似度函数$\text{sim}(\cdot,\cdot)≥0$,则Attention也可表示为一般形式:

\[\text{Attention}(Q,K,V)_i=\frac{\sum_{j=1}^{n}\text{sim}(q_i,k_j)v_j}{\sum_{j=1}^{n}\text{sim}(q_i,k_j)}\]

注意到标准的Attention计算相当于选择了相似度函数:

\[\text{sim}(q_i,k_j) = e^{\frac{q_i^Tk_j}{\sqrt{d}}}\]

若把相似度函数看作核函数,即$\text{sim}(q_i,k_j)=\phi(q_i)^T\phi(k_j)$,则有:

\[\text{Attention}(Q,K,V)_i=\frac{\sum_{j=1}^{n}\phi(q_i)^T\phi(k_j)v_j}{\sum_{j=1}^{n}\phi(q_i)^T\phi(k_j)}=\frac{\phi(q_i)^T\sum_{j=1}^{n}\phi(k_j)v_j^T}{\phi(q_i)^T\sum_{j=1}^{n}\phi(k_j)}\]

注意到通过上述转换,将计算复杂度从$O(n^2)$降为$O(n)$,即循环内从两次乘法减少为一次乘法。本文选择的$\phi$如下:

\[\phi(x)=\text{elu}(x)+1 , \quad \text{elu}(x)=\begin{cases} x, & x>0 \\ e^x-1, & x≤0 \end{cases}\]

使用上述线性注意力(linear attention)构建Transformer时需要注意,由于Transformer采用语言模型的训练策略,因此需要引入mask,即mask掉未来的输入信息。实践中只需将求和$\sum_{j=1}^{n}$替换为$\sum_{j=1}^{i}$:

\[\text{Attention}(Q,K,V)_i=\frac{\phi(q_i)^T\sum_{j=1}^{i}\phi(k_j)v_j^T}{\phi(q_i)^T\sum_{j=1}^{i}\phi(k_j)}\]

若记$S_i=\sum_{j=1}^{i}\phi(k_j)v_j^T$,$Z_i=\sum_{j=1}^{i}\phi(k_j)$,则:

\[\text{Attention}(Q,K,V)_i=\frac{\phi(q_i)^TS_i}{\phi(q_i)^TZ_i}\] \[S_i=S_{i-1}+\phi(k_i)v_i^T\] \[Z_i=Z_{i-1}+\phi(k_i)\]

这种线性注意力可以通过上式递归地计算,类似于RNN。该机制的计算复杂度为$O(n)$,但需要串行计算。作者使用该线性注意力构造了线性Transformer。随着输入序列长度$n$的增长,推理时间和内存占用也呈线性增长(标准的softmax呈平方增长):