State Space Model.
1. 状态空间模型的建模
状态空间包含完整描述系统的最小变量数,这些变量称为状态向量。状态空间模型(State Space Model, SSM)是用于描述这些状态向量的模型,并根据额外的输入预测它们的下一个状态。
SSM是一种描述动态系统行为的数学模型,它使用一组一阶微分方程(连续时间系统)或差分方程(离散时间系统)来表示系统内部状态的演化,这组方程被称为状态方程;同时用另一组方程来描述系统状态和输出之间的关系,这组方程被称为观测方程(也称为输出方程)。
(1)连续形式的状态空间模型
状态方程描述了系统内部状态$h(t)$随时间和输入的演化:
\[h^\prime(t) = Ah(t) + Bx(t) + \omega(t)\]- $h(t)$是系统状态变量,包含所有必要的变量来描述系统在任意时刻的状态,是一个$n$维向量。
- $h’(t)$是$h(t)$关于时间的微分,表示系统状态的变化率。
- $A$是系统矩阵,描述了系统状态之间的关系,以及无控制输入时它们如何随时间自然演化,是一个$n \times n$矩阵。
- $x(t)$是控制输入向量,表示外部输入或控制信号的影响,是一个$m≤ n$维向量。
- $B$是输入矩阵,描述了控制输入如何影响系统状态,是一个$n \times m$矩阵。
- $\omega(t)$是过程噪声,代表系统内部的不确定性,一般假设为高斯噪声,可省略。
观测方程描述了系统的输出$y(t)$如何依赖于系统状态和控制输入:
\[y(t) = Ch(t) + Dx(t) + e(t)\]- $y(t)$是输出向量或观测向量,包含所有的测量或观测到的变量,是一个$p≤n$维向量。
- $C$是输出矩阵,描述了系统状态如何影响输出,是一个$p \times n$矩阵。
- $D$是直达传递矩阵,表示控制输入直接对输出的影响,是一个$p \times m$矩阵(在很多实际系统中,这个矩阵通常是零或者不显著)。
- $e(t)$是测量噪声,代表了测量过程中的不确定性或误差,一般假设为高斯噪声,可省略。
SSM将一个$m$维输入向量$x(t)$映射为一个$n$维隐状态变量$h(t)$,然后再映射为一个$p$维输出向量$y(t)$。其中参数$A,B,C,D$可以人为指定或设置为可学习参数。
(2)离散形式的状态空间模型
在实际场景中,输入$\mathbf{x}(t)$通常是离散的:$x_0,x_1,…,x_t,…$,比如文本输入,若希望用状态空间模型来实时记忆这些离散点,需要对状态空间模型离散化。
连续形式的状态空间模型通常表示为线性常微分方程的形式:
\[\dot{\mathbf{h}}(t) = \mathbf{A}(t) \mathbf{h}(t) + \mathbf{B}(t) \mathbf{x}(t) \\ \mathbf{y}(t) = \mathbf{C}(t) \mathbf{h}(t) + \mathbf{D}(t) \mathbf{x}(t)\]其中$\mathbf{h}(t)$是状态向量,$\mathbf{x}(t)$是输入向量,$\mathbf{y}(t)$是输出向量,$\mathbf{A}(t),\mathbf{B}(t),\mathbf{C}(t),\mathbf{D}(t)$是系统矩阵。 离散化过程涉及以下步骤:
① 采样时间的选择
确定一个合适的采样时间间隔$\Delta$,这个时间间隔要基于系统动态特性以及所需的控制精度(通常设置为可学习参数)。
② 离散化系统矩阵
使用适当的数值方法将连续系统矩阵$\mathbf{A}(t),\mathbf{B}(t)$转换为离散矩阵$\overline{A},\overline{B}$。
方法一:零阶保持
\[\begin{aligned} \overline{A} &= e^{\Delta \mathbf{A}} \\ \overline{B} &= (e^{\Delta A}-I) A^{-1} B \end{aligned}\]下面进行推导。零阶保持 (Zero-order hold)技术是指每次收到离散信号时都会保持其值,直到收到新的离散信号。通过该技术,离散输入信号可以被连续化。
常微分方程$\dot{h}(t) = A(t) h(t) + B(t) x(t)$的通解为:
\[h(t) = e^{A(t-t_0)} h(t_0) + \int_{t_0}^t e^{A(t-\tau) } B x(\tau) d\tau\]令$t_0=t,t=t+\Delta$,根据零阶保持有:
\[\begin{aligned} h_{t+\Delta} &= e^{A\Delta} h_t + \int_{t}^{t+\Delta} e^{A(t-\tau) } B x(\tau) d\tau \\ &= e^{A\Delta} h_t + \int_{t}^{t+\Delta} e^{A(t-\tau) } B d\tau x_t \\ \end{aligned}\]则有:
\[\begin{aligned} \overline{A} &= e^{\Delta \mathbf{A}} \\ \overline{B} &= \int_0^{\Delta} e^{A\tau } Bd\tau = (e^{\Delta A}-I) A^{-1} B \end{aligned}\]方法二:双线性方法
\[\begin{aligned} \overline{A} &= (I-\Delta A/2)^{-1}(I+\Delta A/2) \\ \overline{B} &= (I-\Delta A/2)^{-1}\Delta B \end{aligned}\]下面进行推导。设定离散化步长$\Delta$,将输入$\mathbf{x}(t)$表示为$x(t)=x_k,t\in[k,k+\Delta)$,即通过阶梯函数把离散输入$\mathbf{x}(t)$连续化。基于此,对连续形式的状态空间模型ODE两端积分:
\[\begin{aligned} h(t+\Delta)-h(t) &= A\int_t^{t+\Delta}h(s)ds + B \int_t^{t+\Delta}x(s)ds \\ \end{aligned}\]若假设在$[t,t+\Delta)$区间内$h(s)$近似等于$(h_t+h_{t+\Delta})/2$,则得到双线性格式:
\[\begin{aligned} h_{t+\Delta}-h_t &= \frac{1}{2}\Delta A(h_t+h_{t+\Delta}) + \Delta B x_t \\ &\downarrow \\ h_{t+\Delta} &= (I-\Delta A/2)^{-1}\left((I+\Delta A/2)h_t + \Delta B x_t\right) \\ &\downarrow \\ h_{t+\Delta} &= (I-\Delta A/2)^{-1}(I+\Delta A/2)h_t + (I-\Delta A/2)^{-1}\Delta B x_t \\ \end{aligned}\]③ 离散化状态和输出方程
离散时间的状态方程和观测方程可以表示为:
\[\begin{aligned} \mathbf{h}[k] &= \overline{A} \mathbf{h}[k-1] + \overline{B} \mathbf{x}[k] \\ \mathbf{y}[k] &= \overline{C} \mathbf{h}[k] + \overline{D} \mathbf{x}[k] \end{aligned}\]其中$\overline{C},\overline{D}$通常可以直接等同于连续时间模型中的$\mathbf{C},\mathbf{D}$。
2. 状态空间模型的实现
(1)状态空间模型的循环表示
状态空间模型的简化离散表示为:
\[\begin{aligned} \mathbf{h}[k] &= \overline{A} \mathbf{h}[k-1] + \overline{B} \mathbf{x}[k] \\ \mathbf{y}[k] &= C \mathbf{h}[k] \end{aligned}\]对于每个时间步,计算当前输入$x_k$如何影响之前的状态$h_{k-1}$,然后预测输出$y_k$。形式上SSM与循环神经网络RNN类似,主要区别在于SSM直接采用了线性变换,而没有使用激活函数进行非线性化;因此SSM也被称为线性RNN。
(2)状态空间模型的卷积表示
状态空间模型计算效率最大的提升在于其序列运算可以卷积化。从0时刻开始向后推导状态空间模型几个时刻后的输出,可得到如下的形式:
\[\begin{aligned} \mathbf{h}_0 &= \overline{B} \mathbf{x}_0 \\ \mathbf{y}_0 &= C \mathbf{h}_0= C\overline{B} \mathbf{x}_0 \\ \mathbf{h}_1 &= \overline{A} \mathbf{h}_0 + \overline{B} \mathbf{x}_1 = \overline{A} \overline{B} \mathbf{x}_0 + \overline{B} \mathbf{x}_1 \\ \mathbf{y}_1 &= C \mathbf{h}_1 = C\overline{A} \overline{B} \mathbf{x}_0 + C\overline{B} \mathbf{x}_1 \\ \mathbf{h}_2 &= \overline{A} \mathbf{h}_1 + \overline{B} \mathbf{x}_2 = \overline{A} (\overline{A} \overline{B} \mathbf{x}_0 + \overline{B} \mathbf{x}_1) + \overline{B} \mathbf{x}_2= \overline{A}^2 \overline{B} \mathbf{x}_0 +\overline{A} \overline{B} \mathbf{x}_1 + \overline{B} \mathbf{x}_2 \\ \mathbf{y}_2 &= C \mathbf{h}_2 = C\overline{A}^2 \overline{B} \mathbf{x}_0 +C\overline{A} \overline{B} \mathbf{x}_1 + C\overline{B} \mathbf{x}_2 \\ & \cdots \\ \mathbf{h}_k &= \overline{A}^k \overline{B} \mathbf{x}_0 +\overline{A}^{k-1} \overline{B} \mathbf{x}_1 + \cdots +\overline{A} \overline{B} \mathbf{x}_{k-1} + \overline{B} \mathbf{x}_k \\ \mathbf{y}_k &= C \overline{A}^k \overline{B} \mathbf{x}_0 +C\overline{A}^{k-1} \overline{B} \mathbf{x}_1 + \cdots +C\overline{A} \overline{B} \mathbf{x}_{k-1} + C\overline{B} \mathbf{x}_k \\ \end{aligned}\]形式上,构造下列形式的卷积核,即可将序列运算转化为卷积运算:
\[\begin{aligned} \overline{K} &= (C\overline{B}, C\overline{A} \overline{B},...,C \overline{A}^k \overline{B}, ...) \\ \mathbf{y} &= \mathbf{x} * \overline{K} \end{aligned}\]卷积运算在计算机中可进行并行训练,这大大加速了状态空间模型的训练速度。而RNN由于在输出时使用了激活函数,因此无法进行并行训练,这也是状态空间模型对比RNN的一大优势所在。然而由于内核大小固定,它们的推理速度不如 RNN 那样快且不受限制。
(3)状态空间模型的训练与推理
根据前面的讨论,循环形式的SSM具有高效的推理速度,但无法并行化训练;而卷积形式的SSM可以并行化训练,但无法进行无限长度上下文的推理。因此在实践中可以使用循环 SSM 进行有效推理,并使用卷积 SSM 进行可并行训练。
3. 深度学习中的状态空间模型
深度学习中的状态空间模型包括:
- 处理序列的SSM:如HiPPO, LSSL, S4, S5, DSS, S4D, H3, Mamba, Mamba-2, MoE-Mamba, RTF。
- 处理图像的SSM:Vim, VMamba, MambaOut, MambaR。
(1)处理序列的SSM
⚪ HiPPO
- (arXiv2008)HiPPO: Recurrent Memory with Optimal Polynomial Projections
HiPPO模型的出发点是用勒让德多项式构造的标准正交基函数的线性组合对输入$x(t)$进行实时逼近,此时$h(t)$为当前时刻的拟合系数。基于此导出了两种形式的SSM:
- HiPPO-LegT(Translated Legendre):只保留最邻近窗口$[T-w,T]$的输入信息。
- HiPPO-LegS(Scaled Legendre):保留所有历史输入信息,且随时间呈多项式衰减。
⚪ LSSL
- (arXiv2110)Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers
LSSL指出线性状态空间模型可以在连续时间形式、循环形式和卷积形式之间进行转换:
- 连续时间形式:
- 循环形式:
- 卷积形式:
⚪ S4
- (arXiv2111)Efficiently Modeling Long Sequences with Structured State Spaces
S4提出了一种高效计算SSM的卷积形式的方法:首先将输入序列$x(t)$和卷积核$\overline{K}$的 FFT 相乘,然后应用逆 FFT 来有效地计算卷积的输出。S4把矩阵$\overline{A}$分解为复数空间下的对角低秩矩阵:$\overline{A}=\Lambda-PQ^T$,基于此通过以下三个步骤克服了速度瓶颈:
- 不直接计算卷积核$\overline{K}$,而是计算卷积核$\overline{K}$的截断生成函数$\hat{K}$,把矩阵幂运算转化为矩阵逆运算。 \(\hat{K}_L(z;\overline{A},\overline{B},\overline{C}) = \overline{C} (I-\overline{A}^L z^L)(1-\overline{A} z)^{-1} \overline{B}\)
- 假设$A$为对角矩阵$\Lambda$,则矩阵逆运算等价于柯西核形式的点积运算。 \(\begin{aligned} \hat{K}_L(z;\overline{A},\overline{B},\overline{C}) &= c(z) \cdot k_{z,\Lambda}(\tilde{C},B) \\ \end{aligned}\)
- 通过引入低秩项对假设进行修正,结果为多个柯西核形式的点积运算。 \(\begin{aligned} \hat{K}_L(z;\overline{A},\overline{B},\overline{C}) &= c(z) \left[k_{z,\Lambda}(\tilde{C},B) -k_{z,\Lambda}(\tilde{C},P)\left(I+k_{z,\Lambda}(Q^T,P)\right)^{-1}k_{z,\Lambda}(Q^T,B)\right] \\ \end{aligned}\)
⚪ S5
- (arXiv2208)Simplified State Space Layers for Sequence Modeling
S5把S4的SISO-SSM形式调整为MIMO-SSM形式,通过以下两个步骤来生成输出序列:
- SSM计算:给定输入序列$u_{1:L} \in R^{L×H}$,MIMO SSM通过隐状态$x_k \in R^P$生成SSM输出(或预激活)$y_{1:L} \in R^{L×H}$。
- 非线性激活:将SSM输出通过非线性激活函数(如GELU)处理,生成最终的层输出。
⚪ DSS
- (arXiv2203)Diagonal State Spaces are as Effective as Structured State Spaces
DSS通过假设对角状态空间来简化卷积核$\overline{K}$的计算。
\[\begin{aligned} \overline{K} &= (C\overline{B}, C\overline{A} \overline{B},...,C \overline{A}^k \overline{B}, ...) \end{aligned}\]假设$A$可对角化:$A=V\Lambda V^{-1}$,特征值为$\lambda_1,…,\lambda_N$,记$(CV)(V^{-1}B)=\tilde{w}$,$P,P_{i,k}=\lambda_i k \Delta$,DSS提出了两种形式的卷积核$\overline{K}$计算:
- DSS-EXP: \(K = \tilde{w} \cdot \Lambda^{-1} (e^{\Lambda \Delta}-I) \cdot \text{elementwise-exp}(P)\)
- DSS-SOFTMAX: \(K = w \cdot \Lambda^{-1} \cdot \text{row-softmax}(P)\)
⚪ S4D
- (arXiv2206)On the Parameterization and Initialization of Diagonal State Space Models
S4D指出,当$A$是对角矩阵时,卷积核$K$可以通过Vandermonde矩阵乘法来高效计算:
\[\begin{aligned} \overline{K} &= (C\overline{B}, C\overline{A} \overline{B},...,C \overline{A}^k \overline{B}, ...) \\ &= \begin{bmatrix}\overline{B}_0C_0 & \cdots & \overline{B}_{N-1}C_{N-1} \end{bmatrix} \begin{bmatrix} 1 & \overline{A}_0 & \overline{A}_0^2 & \cdots & \overline{A}_0^{N-1} \\ 1 & \overline{A}_1 & \overline{A}_1^2 & \cdots & \overline{A}_1^{N-1} \\ \cdots & \cdots & \cdots & \cdots & \cdots \\ 1 & \overline{A}_{N-1} & \overline{A}_{N-1}^2 & \cdots & \overline{A}_{N-1}^{N-1} \end{bmatrix} \\ &= \overline{B}^T \cdot C \cdot \mathcal{V}_L(\overline{A}) \end{aligned}\]⚪ H3
- (arXiv2212)Hungry Hungry Hippos: Towards Language Modeling with State Space Models
H3为状态空间模型引入新的层设计(H3层)和硬件感知算法(FlashConv):
- H3层结合两个 SSM(移位 SSM 和对角 SSM)以及输入投影的乘法交互:移位 SSM 使用移位矩阵将状态向量的每个元素向下移动一位,从而实现对之前输入的记忆;对角 SSM 使用HiPPO初始化的对角矩阵 $A$,允许模型在整个序列上记忆状态;乘法交互使得模型能够对序列中的不同位置进行比较。
- FlashConv通过融合 FFT、逐点乘法和逆 FFT 操作,减少了内存读写,并利用块 FFT 算法提高了计算效率。
⚪ Mamba
- (arXiv2312)Mamba: Linear-Time Sequence Modeling with Selective State Spaces
Mamba是选择性结构化状态空间模型(selective SSM 或 S6),具有以下特点:
- 选择性机制:允许模型根据输入动态调整状态空间参数(离散化步长$∆$、矩阵$B,C$),这些参数是根据输入序列的每个时间步动态生成,且在批量维度和序列维度上是独立的。
- 硬件感知扫描:通过内核融合、并行扫描和重新计算,优化了模型的计算和内存使用。
⚪ Mamba-2
- (arXiv2405)Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality
Mamba-2相比于Mamba的主要改进在于:
- 状态转移矩阵$A$采用标量乘以恒等结构;
- $A,X,B,C$矩阵是并行生成的;
- 额外引入了归一化层以缓解训练不稳定。
⚪ MoE-Mamba
- (arXiv2401)MoE-Mamba: Efficient Selective State Space Models with Mixture of Experts
MoE-Mamba通过在Mamba层之间交错插入MoE层,实现了无条件处理和条件处理的分离:Mamba层负责高效地将整个序列上下文集成到内部表示中;MoE层应用最相关的专家来处理每个token。
⚪ RTF
- (arXiv2405)State-Free Inference of State-Space Models: The Transfer Function Approach
RTF指出卷积核$\overline{K}$的生成函数可以表示为一个有理函数,从而在生成函数空间对矩阵$\overline{A},\overline{B},\overline{C}$进行参数化,并通过离散傅立叶变换来加速。
(2)处理图像的SSM
⚪ Vim
- (arXiv2401)Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model
Vision Mamba(Vim)通过引入双向SSM块(第一个块以前向处理视觉序列,第二个块以后向处理视觉序列),实现了对视觉数据的全局上下文建模,同时克服了传统SSM模型在视觉任务中的局限性。
⚪ VMamba
- (arXiv2401)VMamba: Visual State Space Model
VMamba将Mamba的并行选择扫描改进为二维选择扫描(SS2D),这是一种为空间域遍历量身定制的四向扫描机制,促进选择性SSM在处理视觉数据方面的扩展。
给定输入数据,SS2D首先沿四条不同的遍历路径(分别从二维矩阵的左上角向右$\rightarrow$、左上角向下$↓$、右下角向左$←$、右下角向上$↑$出发)进行选择性扫描,使用单独的S6模块并行处理每个序列,随后重新整形并合并结果序列以形成输出。
⚪ MambaOut
- (arXiv2405)MambaOut: Do We Really Need Mamba for Vision?
本文构建了 MambaOut 模型(移除了 SSM 的 Mamba 块),通过实验证明:
- 在图像分类任务中,由于不满足长序列和自回归特性,Mamba 并非必要;
- 而在目标检测和语义分割任务中,由于符合长序列特性,Mamba 仍具潜力。
⚪ Mamba-R
- (arXiv2405)Mamba-R: Vision Mamba ALSO Needs Registers
为解决 Vision Mamba 的特征伪影问题(在低信息背景区域中出现大量高范数token),Mamba® 架构在输入序列中均匀插入寄存器token,并在最终预测时将寄存器token的输出通过线性层降维后拼接成一个全局表示。