Mamba:具有选择性状态空间的线性时间序列建模.
0. TL;DR
Mamba是一种新型的序列建模方法,它通过引入选择性状态空间模型(SSMs)实现了线性时间复杂度的序列建模,并在多种模态(语言、音频、基因组学)上达到了与Transformer相当甚至更好的性能。Mamba的核心创新在于选择性机制,它允许模型根据输入动态调整状态空间参数,从而在保持高效的同时具备强大的表达能力。
1. 背景介绍
随着深度学习的发展,Transformer架构在各种序列建模任务中取得了巨大成功。然而,其核心的自注意力机制带来了计算效率和内存占用方面的挑战,尤其是在处理长序列时。许多研究致力于开发更高效的替代方案,如线性注意力、门控卷积和循环模型,以及结构化状态空间模型(SSMs)。尽管这些方法在效率上有所提升,但在性能上往往不如自注意力机制,特别是在处理语言等信息密集的模态时。
Mamba的提出旨在解决这一问题,它通过引入选择性状态空间模型,在保持线性时间复杂度的同时,实现了与Transformer相当的性能。选择性机制允许模型根据输入动态调整状态空间参数,从而在处理长序列时更加高效,同时具备更强的表达能力。
2. Mamba 模型
Mamba是选择性结构化状态空间模型(selective SSM 或 S6),具有以下特点:
- 选择性机制:允许模型根据输入动态调整状态空间参数(离散化步长$∆$、矩阵$B,C$),这些参数是根据输入序列的每个时间步动态生成,且在批量维度和序列维度上是独立的。
- 硬件感知扫描:通过内核融合、并行扫描和重新计算,优化了模型的计算和内存使用。
(1)选择性机制
SSMs可以表示为线性时不变系统,通过状态矩阵$A$、输入矩阵$B$和输出矩阵$C$来参数化。
\[\begin{aligned} \mathbf{x}^\prime(t) &= A \mathbf{x}(t) + B \mathbf{u}(t) \\ \mathbf{y}(t) &= C \mathbf{x}(t) \\ \end{aligned}\]或进行基于零阶保持的离散化:
\[\begin{aligned} \overline{A} &= e^{\Delta \mathbf{A}} \\ \overline{B} &= (e^{\Delta A}-I) A^{-1} B \\ \overline{C} &= C \\ \end{aligned}\]此时对于不同时刻的输入$\mathbf{u}(t)$,矩阵$A,B,C$都是固定的,导致SSMs对输入具有静态性,在内容感知等任务上表现不佳。
对于输入尺寸为$(B,L,D)$的序列$\mathbf{u}$,SSMs参数在$B,L$维度上是共享的,在$D$维度上是独立的,即对所有数据的所有序列token的每个特征维度分别使用一个SSM进行建模。此时矩阵$A,B,C$的形状为:
Mamba引入了选择性机制,它允许模型根据输入动态调整状态空间参数。具体来说,模型的参数(离散化步长$∆$、矩阵$B,C$)不再是固定的,而是根据输入序列的每个时间步动态生成。这种选择性机制使得模型能够根据输入内容有选择地传播或忽略信息,从而在处理长序列时更加高效。
注意到矩阵$A$保持不变,这是因为我们希望状态本身保持静态,但其受输入的影响是动态的。而矩阵$B,C$在$D$维度上是共享的,在$B,L$维度上是独立的,即对每个数据的每个序列token所有特征维度分别使用一个SSM进行建模。
由于输入$\mathbf{u}(t)$的尺寸为$(B,L,D)$,因此$∆,B,C$的构造直接通过线性层实现:
\[\begin{aligned} B &= \text{Linear}_N(\mathbf{u}(t)) \\ C &= \text{Linear}_N(\mathbf{u}(t)) \\ \Delta &= \text{Linear}_D(\mathbf{u}(t)) \\ \end{aligned}\](2)硬件感知扫描
对于SSMs,构造下列形式的卷积核$\overline{K}$,即可将序列运算转化为卷积运算:
\[\begin{aligned} \overline{K} &= (C\overline{B}, C\overline{A} \overline{B},...,C \overline{A}^k \overline{B}, ...) \\ y &= \overline{K} * u \end{aligned}\]在Mamba中,由于矩阵$B,C$随输入变化而改变,因此不能直接转换为卷积运算。为了在现代硬件(如GPU)上高效实现选择性状态空间模型,Mamba采用了一种硬件感知算法。该算法通过内核融合、并行扫描和重新计算等技术,优化了模型的计算和内存使用。
硬件感知扫描算法通过以下步骤实现:
- 内核融合:将状态空间模型的前向传播、扫描和输出投影融合成一个内核,减少内存I/O操作。
- 并行扫描:利用并行扫描算法,高效地处理序列数据。
- 重新计算:在反向传播时重新计算中间状态,而不是存储它们,从而减少内存占用。
① 内核融合
GPU 的一个缺点是其体积小但效率高的 SRAM 与体积大但效率稍低的 DRAM 之间的传输 (IO) 速度有限。在 SRAM 和 DRAM 之间频繁复制信息会成为瓶颈。
Mamba 通过内核融合来实现限制从 DRAM 到 SRAM 以及从 SRAM 到 DRAM 的次数,这使得模型可以避免写入中间结果并持续执行计算直到完成。Mamba把离散化步骤、选择性扫描算法和矩阵$C$乘法融合到一个内核中。
② 并行扫描
SSMs的扫描操作是指每个状态$x_t$是前一个状态$x_{t-1}$乘以$A$加上当前输入$u_t$乘以$B$,计算复杂度为$O(n)$。Mamba通过并行扫描算法,使用$t$个执行任务的处理器提高计算效率,计算复杂度降低为$O(n/t)$。并行扫描算法是指首先并行计算中间值,再通过扫描获取更新的状态。
③ 重新计算
中间状态在前向传播中没有被保存,但对于反向传播计算梯度来说却是必需的。因此在反向传播期间重新计算了这些中间状态。该操作的成本比从相对较慢的 DRAM 读取所有中间状态的成本要低得多。
(3)Mamba模块
完整的Mamba模块如图所示,首先通过线性投影来扩展输入嵌入,然后在Selective SSM之前应用卷积来建立token之间的局部上下文关系。
3. 实验分析
作者讨论了Mamba在两个任务上的优势:
- 选择性复制任务(Selective Copying)在输入之间具有随机间距,需要内容感知推理,在内容上能够灵活地选择记忆或忽略输入。
- 归纳头任务(Induction Head)是联想回忆的一个例子,需要根据上下文检索答案。
在上述两个任务上,Mamba都取得了最好的表现:
语言建模使用The Pile数据集,包含多种文本数据。Mamba在语言建模任务上表现出色,其性能与Transformer相当,甚至在某些情况下超过了Transformer。具体来说:
- Mamba-3B模型在预训练困惑度和下游任务评估上均优于相同大小的Transformer模型,并且与两倍大小的Transformer模型相当。
- Mamba的生成吞吐量是Transformer的5倍,这得益于其线性时间复杂度和无需缓存的特性。
DNA建模使用HG38数据集,包含人类基因组序列。Mamba在DNA建模任务上也表现出色,特别是在处理长序列时。具体来说:
- Mamba模型在预训练困惑度上随着序列长度的增加而持续改进,而其他模型如HyenaDNA则在长序列上表现不佳。
- 在物种分类任务上,Mamba模型的准确率随着序列长度的增加而显著提高,表明其能够有效利用长序列信息。
音频建模使用YouTubeMix数据集,包含钢琴音乐波形。Mamba在音频建模任务上同样表现出色,特别是在处理长序列时。具体来说:
- Mamba模型在预训练的BPB指标上优于其他模型,并且随着序列长度的增加而持续改进。
- 在语音生成任务上,Mamba模型的生成质量显著优于其他模型,特别是在长序列生成时。