Mamba:具有选择性状态空间的线性时间序列建模.

0. TL;DR

Mamba是一种新型的序列建模方法,它通过引入选择性状态空间模型(SSMs)实现了线性时间复杂度的序列建模,并在多种模态(语言、音频、基因组学)上达到了与Transformer相当甚至更好的性能。Mamba的核心创新在于选择性机制,它允许模型根据输入动态调整状态空间参数,从而在保持高效的同时具备强大的表达能力。

1. 背景介绍

随着深度学习的发展,Transformer架构在各种序列建模任务中取得了巨大成功。然而,其核心的自注意力机制带来了计算效率和内存占用方面的挑战,尤其是在处理长序列时。许多研究致力于开发更高效的替代方案,如线性注意力、门控卷积和循环模型,以及结构化状态空间模型(SSMs)。尽管这些方法在效率上有所提升,但在性能上往往不如自注意力机制,特别是在处理语言等信息密集的模态时。

Mamba的提出旨在解决这一问题,它通过引入选择性状态空间模型,在保持线性时间复杂度的同时,实现了与Transformer相当的性能。选择性机制允许模型根据输入动态调整状态空间参数,从而在处理长序列时更加高效,同时具备更强的表达能力。

2. Mamba 模型

Mamba是选择性结构化状态空间模型(selective SSMS6),具有以下特点:

(1)选择性机制

SSMs可以表示为线性时不变系统,通过状态矩阵$A$、输入矩阵$B$和输出矩阵$C$来参数化。

\[\begin{aligned} \mathbf{x}^\prime(t) &= A \mathbf{x}(t) + B \mathbf{u}(t) \\ \mathbf{y}(t) &= C \mathbf{x}(t) \\ \end{aligned}\]

或进行基于零阶保持的离散化:

\[\begin{aligned} \overline{A} &= e^{\Delta \mathbf{A}} \\ \overline{B} &= (e^{\Delta A}-I) A^{-1} B \\ \overline{C} &= C \\ \end{aligned}\]

此时对于不同时刻的输入$\mathbf{u}(t)$,矩阵$A,B,C$都是固定的,导致SSMs对输入具有静态性,在内容感知等任务上表现不佳。

对于输入尺寸为$(B,L,D)$的序列$\mathbf{u}$,SSMs参数在$B,L$维度上是共享的,在$D$维度上是独立的,即对所有数据的所有序列token的每个特征维度分别使用一个SSM进行建模。此时矩阵$A,B,C$的形状为:

Mamba引入了选择性机制,它允许模型根据输入动态调整状态空间参数。具体来说,模型的参数(离散化步长$∆$、矩阵$B,C$)不再是固定的,而是根据输入序列的每个时间步动态生成。这种选择性机制使得模型能够根据输入内容有选择地传播或忽略信息,从而在处理长序列时更加高效。

注意到矩阵$A$保持不变,这是因为我们希望状态本身保持静态,但其受输入的影响是动态的。而矩阵$B,C$在$D$维度上是共享的,在$B,L$维度上是独立的,即对每个数据的每个序列token所有特征维度分别使用一个SSM进行建模。

由于输入$\mathbf{u}(t)$的尺寸为$(B,L,D)$,因此$∆,B,C$的构造直接通过线性层实现:

\[\begin{aligned} B &= \text{Linear}_N(\mathbf{u}(t)) \\ C &= \text{Linear}_N(\mathbf{u}(t)) \\ \Delta &= \text{Linear}_D(\mathbf{u}(t)) \\ \end{aligned}\]

(2)硬件感知扫描

对于SSMs,构造下列形式的卷积核$\overline{K}$,即可将序列运算转化为卷积运算:

\[\begin{aligned} \overline{K} &= (C\overline{B}, C\overline{A} \overline{B},...,C \overline{A}^k \overline{B}, ...) \\ y &= \overline{K} * u \end{aligned}\]

Mamba中,由于矩阵$B,C$随输入变化而改变,因此不能直接转换为卷积运算。为了在现代硬件(如GPU)上高效实现选择性状态空间模型,Mamba采用了一种硬件感知算法。该算法通过内核融合、并行扫描和重新计算等技术,优化了模型的计算和内存使用。

硬件感知扫描算法通过以下步骤实现:

① 内核融合

GPU 的一个缺点是其体积小但效率高的 SRAM 与体积大但效率稍低的 DRAM 之间的传输 (IO) 速度有限。在 SRAMDRAM 之间频繁复制信息会成为瓶颈。

Mamba 通过内核融合来实现限制从 DRAMSRAM 以及从 SRAMDRAM 的次数,这使得模型可以避免写入中间结果并持续执行计算直到完成。Mamba把离散化步骤、选择性扫描算法和矩阵$C$乘法融合到一个内核中。

② 并行扫描

SSMs的扫描操作是指每个状态$x_t$是前一个状态$x_{t-1}$乘以$A$加上当前输入$u_t$乘以$B$,计算复杂度为$O(n)$。Mamba通过并行扫描算法,使用$t$个执行任务的处理器提高计算效率,计算复杂度降低为$O(n/t)$。并行扫描算法是指首先并行计算中间值,再通过扫描获取更新的状态。

③ 重新计算

中间状态在前向传播中没有被保存,但对于反向传播计算梯度来说却是必需的。因此在反向传播期间重新计算了这些中间状态。该操作的成本比从相对较慢的 DRAM 读取所有中间状态的成本要低得多。

(3)Mamba模块

完整的Mamba模块如图所示,首先通过线性投影来扩展输入嵌入,然后在Selective SSM之前应用卷积来建立token之间的局部上下文关系。

3. 实验分析

作者讨论了Mamba在两个任务上的优势:

在上述两个任务上,Mamba都取得了最好的表现:

语言建模使用The Pile数据集,包含多种文本数据。Mamba在语言建模任务上表现出色,其性能与Transformer相当,甚至在某些情况下超过了Transformer。具体来说:

DNA建模使用HG38数据集,包含人类基因组序列。MambaDNA建模任务上也表现出色,特别是在处理长序列时。具体来说:

音频建模使用YouTubeMix数据集,包含钢琴音乐波形。Mamba在音频建模任务上同样表现出色,特别是在处理长序列时。具体来说: