Reformer: 使用局部敏感哈希和可逆FFN实现高效Transformer.

本文提出了Reformer,将Transformer的计算复杂度从$O(N^2)$降低为$O(N\log N)$,从而使得输入序列长度$N$较长时的也具有较快的处理速度。

标准的Transformer计算复杂度较高主要有两个原因。

  1. 自注意力机制需要两两计算相似度得分,具有序列长度平方的计算复杂度;
  2. FFN层通常是两层全连接网络,其中间层特征较大,需要存储层内的连接参数(activations)用于反向传播,占据更多内存。

为此,作者使用了一些方法以降低计算复杂度:

  1. 局部敏感哈希(Local Sensitive Hashing):将注意力得分相似的key分到同一个bucket中,使得每个query只与较高相似度的几个key计算注意力得分。
  2. 可逆层(Reversible Layer):网络需要存储每一层的activations用于反向传播,导致内存消耗较大。作者将模型的特定结构设计成可逆的形式,只需要存储最后一层的输出,便可以反推出中间层的结果。
  3. 分段FFN(Chunking FFN Layers):由于全连接层的输入之间是独立的,因此将全连接层分段处理,降低空间消耗。

1. Locality Sensitive Hashing

自注意力矩阵的计算是稀疏的,即对于输入序列的每一个token的查询向量$q_i$,并不需要与所有的token对应的键向量$k$进行计算,因为经过softmax后较小的值会接近$0$。因此对于每个$q_i$只需要考虑方向上与其最接近的一些$k$即可,寻找过程是通过局部敏感哈希(Local Sensitive Hashing)实现的。

通常的Hash函数是指使用一个hash表,把特定的值映射到不同的bucket中,不同的bucket具有不同的哈希值,根据哈希值可以在$O(1)$复杂度下获得特定的值。局部敏感哈希函数是指将相近的值映射到同一个bucket中,赋予相同的哈希值;不相近的值映射到不同的bucket中。

本文使用随机投影(random projection)的方法构造哈希函数。即对于维度是$d$的输入向量$x$,将其数值标准化(即投影到单位圆上),预先设定$b$个哈希bucket(即把空间等分成$b$个区域)。构造随机矩阵$R \in \Bbb{R}^{d \times \frac{b}{2}}$,则哈希值计算为:

\[h(x) = \arg\max [xR;-xR]\]

上式即把向量$xR$和$-xR$连接后选择最大值所在的索引。 如下图所示,空间中更接近的向量往往会被划分给相同的哈希值,即落入同一个子空间中。 注意到不能保证相似的输入总在同一个bucket中,因此采用多轮局部敏感哈希函数,然后将相同bucket中的向量取并集。

作者把上述局部敏感哈希函数引入Attention的计算中,构造了局部敏感哈希注意力(LSH Attention)。把输入序列的所有token映射到不同的bucket中,对于每一个token的查询向量$q_i$,只与其在同一个hash bucket中的键向量进行注意力得分计算。

LSH Attention的计算过程如下:

  1. 设置share-QK注意力,即$k_i=\frac{q_i}{|q_i|}$,能够降低参数量,并且减小后续哈希bucket中的$q$和$k$数量不均衡问题(如上图b所示)。注意到$Q=K$会使每个token与自身得分最高,因此当bucket中有其余token时,忽略该token与自身的交互。
  2. 对调整后的键序列应用局部敏感哈希函数,为每个$k_i$赋予一个哈希值;将不同哈希值对应的不同bucket排序,对每个bucket内按照原本的位置进行排序。
  3. 注意到不同的bucket内得分数量是不同的,因此进行分组截断(chunked)。将每个bucket内的元素数量设置为$m=\frac{2n_{\text{query}}}{n_{\text{bucket}}}$,对于bucket中的每一个向量,都可以与当前bucket以及前一个bucket中具有相同哈希值的向量进行交互。

2. Reversible Transformer

Transformer中的自注意力层和全连接层在计算时需要保存层内的连接参数(activations)用于反向传播,消耗大量内存。作者设计了一种可逆的Transformer,只需要保存最后输出层的结果,通过特殊的网络结构可以反推出中间层的参数,从而减少内存的消耗。

作者将输入$x$拆分成两部分$x_1$和$x_2$,计算如下:

\[y_1 = x_1 + \text{Attention}(x_2)\] \[y_2 = x_2 + \text{FFN}(y_1)\]

则可以反向推出:

\[x_2 = y_2 - \text{FFN}(y_1)\] \[x_1 = y_1 - \text{Attention}(x_2)\]

因此只需要保存输入与输出,而不需要保存中间结果。在计算全连接层时,通常中间隐藏层的维度比较大,一次性计算需要比较大的内存占用。由于全连接层的输入是独立的,为了降低内存使用,可以进行拆分计算, 即分成若干组进行计算,通过时间成本换取空间成本:

\[y_2 = x_2 + \text{FFN}(y_1)=[x_2^{(1)} + \text{FFN}(y_1^{(1)});...;x_2^{(c)} + \text{FFN}(y_1^{(c)})]\]

3. 实验分析

实验验证了shared-QK和可逆结构对模型准确率几乎没有影响,但提高了模型速度,降低了参数量和内存占用。