H3:使用状态空间模型进行语言建模.
0. TL;DR
H3是一种基于状态空间模型(SSMs)的高效语言建模方法,通过引入新的层设计(H3层)和硬件感知算法(FlashConv),在保持线性时间复杂度的同时,实现了与Transformer相当甚至更好的性能。H3在合成语言任务和实际语言建模任务上均表现出色,特别是在长序列处理和硬件效率方面具有显著优势。
1. 背景介绍
随着深度学习的发展,Transformer架构在语言建模任务中取得了巨大成功。然而,其核心的自注意力机制在处理长序列时面临计算效率和内存占用的挑战。状态空间模型(SSMs)作为一种高效的序列建模方法,虽然在某些领域表现出色,但在语言建模上仍不及Transformer。H3的提出旨在通过新的层设计和硬件优化,缩小SSMs与Transformer在语言建模上的差距。
2. H3 模型
(1)H3 层
SSMs可以表示为线性时不变系统,通过状态矩阵$A$、输入矩阵$B$和输出矩阵$C$来参数化。
\[\begin{aligned} \mathbf{x}^\prime(t) &= A \mathbf{x}(t) + B \mathbf{u}(t) \\ \mathbf{y}(t) &= C \mathbf{x}(t) \\ \end{aligned}\]或进行基于零阶保持的离散化:
\[\begin{aligned} \overline{A} &= e^{\Delta \mathbf{A}} \\ \overline{B} &= (e^{\Delta A}-I) A^{-1} B \\ \overline{C} &= C \\ \end{aligned}\]H3 层通过结合两个 SSM(移位 SSM 和对角 SSM)以及输入投影的乘法交互,解决了 SSMs 在语言建模中的两大短板:回忆序列中的早期标记和比较不同位置的标记。
移位 SSM 使用移位矩阵 $A:A_{ij}=\begin{cases}1, & i-1=j \ 0, & \text{otherwise}\end{cases}$,将状态向量的每个元素向下移动一位,从而实现对之前输入的记忆:
\[x_t = Ax_{t-1} + Bu_t\]对角 SSM 使用HiPPO初始化的对角矩阵 $A$,允许模型在整个序列上记忆状态:
\[x_t = Ax_{t-1} + Bu_t\]H3 层通过乘法交互,使得模型能够对序列中的不同位置进行比较,H3 层的输出可以表示为:
\[O_t=Q_t⊙SSM_{diag}(SSM_{shift}(K)⊙V)\](2)FlashConv 层
对于SSMs,构造下列形式的卷积核$\overline{K}$,即可将序列运算转化为卷积运算:
\[\begin{aligned} \overline{K} &= (C\overline{B}, C\overline{A} \overline{B},...,C \overline{A}^k \overline{B}, ...) \\ y &= \overline{K} * u \end{aligned}\]FlashConv 是一种用于提高 SSMs 在现代硬件上效率的算法,通过融合 FFT、逐点乘法和逆 FFT 操作,减少了内存读写,并利用块 FFT 算法提高了计算效率。
块 FFT 算法通过将 FFT 分解为多个块,并利用矩阵乘法单元进行计算,提高了计算效率。具体来说,块 FFT 算法可以表示为:
\[FFT(u)=P(IN_2⊗FN_1)P^TD(IN_1⊗FN_2)P\]其中,$P$ 是排列矩阵,$IN_i$ 和 $FN_i$ 分别是大小为 $i$ 的单位矩阵和 DFT 矩阵,$D$ 是对角矩阵。
对于长序列,FlashConv 通过状态传递算法将序列分割为多个块,逐块处理。具体来说,状态传递算法可以表示为:
\[y^{(c)}=M_{xy}x^{(c−1)}_{N^\prime}+BlockFFTConv(f,u^{(c)})+Du^{(c)}\]其中,$M_{xy}$是状态传递矩阵,$f$ 是滤波器,$u(c)$ 是当前块的输入。
3. 实验分析
H3 在合成语言任务上表现出色,能够有效回忆和比较序列中的关键信息。具体来说,H3 在归纳头任务和关联回忆任务上均取得了较高的准确率。
H3 在实际语言建模任务上也表现出色,特别是在处理长序列时。具体来说,H3 在 OpenWebText 和 The Pile 数据集上的困惑度均低于 Transformer。
H3 在零样本和少样本学习任务上也表现出色,能够在 SuperGLUE 基准上取得与 Transformer 相当甚至更好的性能。
H3 在硬件效率方面具有显著优势,特别是在长序列处理时。具体来说,H3 在 A100 GPU 上的训练速度比 Transformer 快 5.8 倍。