通过结构化状态空间高效建模长序列.
本文提出了序列的结构化状态空间(Structured State Space for Sequences, S4),利用HiPPO的结果构造了状态空间模型,并从新的视角探讨了高效的计算和训练方式,在大量长序列建模任务上验证了它的有效性。
S4使用的序列建模框架,是HiPPO-LegS形式的线性ODE系统:
\[\begin{aligned} \mathbf{x}^\prime(t) &= A \mathbf{x}(t) + B \mathbf{u}(t) \\ \mathbf{y}(t) &= C \mathbf{x}(t) \\ A_{n,k} &= -\begin{cases} \sqrt{(2n+1)(2k+1)}, & k < n \\ n+1, & k = n\\ 0, & k > n \end{cases} \\ \end{aligned}\]其中$\mathbf{x}(t)$是状态向量,$\mathbf{u}(t)$是输入向量,$\mathbf{y}(t)$是输出向量。
在实际场景中,输入$\mathbf{u}(t)$通常是离散的:$u_0,u_1,…,u_t,…$,比如文本输入,若希望用上述状态空间模型来实时记忆这些离散点,需要对常微分方程离散化。设定离散化步长$\Delta $,将输入$\mathbf{u}(t)$表示为$u(t)=u_k,t\in[k\Delta ,(k+1)\Delta )$,即通过阶梯函数把离散输入$\mathbf{u}(t)$连续化。基于此,对上述ODE两端积分:
\[\begin{aligned} x(t+\Delta )-x(t) &= A_{n,k}\int_t^{t+\Delta }x(s)ds + B_n \int_t^{t+\Delta }\mathbf{u}(s)ds \\ &= A_{n,k}\int_t^{t+\Delta }x(s)ds + \Delta B_n u_k \\ \end{aligned}\]其中$t=k\Delta$。若假设在$[t,t+\Delta )$区间内$x(s)$近似等于$(x(t)+x(t+\Delta ))/2$,则得到双线性格式:
\[\begin{aligned} x(t+\Delta )-x(t) &= \frac{1}{2}\Delta A_{n,k}(x(t)+x(t+\Delta )) + \Delta B_n u_k \\ &\downarrow \\ x(t+\Delta ) &= (I-\Delta A_{n,k}/2)^{-1}\left((I+\Delta A_{n,k}/2)x(t) + \Delta B_n u_k\right) \\ \end{aligned}\]上式可离散化为:
\[\begin{aligned} x_{k+1} &= (I-\Delta A/2)^{-1}(I+\Delta A/2)x_k + (I-\Delta A/2)^{-1}\Delta B u_k \\ &= \overline{A}x_k + \overline{B} u_k \\ \end{aligned}\]从0时刻开始向后推导几个时刻后的输出,可得到如下的形式:
\[\begin{aligned} x_0 &= B u_0 \\ y_0 &= C x_0= CB x_0 \\ x_1 &= A x_0 + B u_1 = A B u_0 + B u_1 \\ y_1 &= C x_1 = CA B x_0 + CB x_1 \\ x_2 &= A x_1 + B u_2 = A (A B u_0 + B u_1) + B u_2= A^2 B u_0 +A B u_1 + B u_2 \\ y_1 &= C x_2 = CA^2 B u_0 +CA B u_1 + CB u_2 \\ & \cdots \\ x_k &= A^k B u_0 +A^{k-1} B u_1 + \cdots +A B u_{k-1} + B u_k \\ y_k &= C A^k B u_0 +CA^{k-1} B u_1 + \cdots +CA B u_{k-1} + CB u_k \\ \end{aligned}\]形式上,构造下列形式的卷积核$\overline{K}$,即可将序列运算转化为卷积运算:
\[\begin{aligned} \overline{K} &= (C\overline{B}, C\overline{A} \overline{B},...,C \overline{A}^k \overline{B}, ...) \\ y &= \overline{K} * u \end{aligned}\]def K_conv(Ab, Bb, Cb, L):
return np.array(
[(Cb @ matrix_power(Ab, l) @ Bb).reshape() for l in range(L)]
)
可以使用快速傅里叶变换(FFT)的卷积定理来快速计算该卷积运算:首先将输入序列的 FFT 相乘,然后应用逆 FFT 来有效地计算卷积的输出。要将此定理用于上述非循环卷积,需要用零填充等长的输入序列,然后丢弃输出序列的填充部分。
def causal_convolution(u, K):
assert K.shape[0] == u.shape[0]
ud = np.fft.rfft(np.pad(u, (0, K.shape[0])))
Kd = np.fft.rfft(np.pad(K, (0, u.shape[0])))
out = ud * Kd
return np.fft.irfft(out)[: u.shape[0]]
S4把$B,C$视为可训练参数,因此需要高效计算$\overline{A}^k$。S4把矩阵$\overline{A}$假设为复数空间下的对角低秩矩阵(Diagonal Plus Low-Rank, DPLR):$\overline{A}=\Lambda-PQ^T$,其中$P,Q\in \mathbb{R}^{N\times 1}$。基于此通过以下三个步骤克服了速度瓶颈:
- 不直接计算卷积核$\overline{K}$,而是计算卷积核$\overline{K}$的截断生成函数,把矩阵幂运算转化为矩阵逆运算。
- 假设$A$为对角矩阵$\Lambda$,矩阵逆运算等价于柯西核形式的点积运算。
- 低秩项可以用Woodbury恒等式来修正。
(1)SSM生成函数
卷积核$\overline{K}$的截断生成函数是将前$L$个分量当成幂级数$z$的系数来构建幂级数:
\[\hat{K}_L(z;\overline{A},\overline{B},\overline{C}) \in \mathbb{C} := \sum_{i=0}^{L-1} \overline{C}\overline{A}^i\overline{B} z^i\]def K_gen_simple(Ab, Bb, Cb, L):
K = K_conv(Ab, Bb, Cb, L)
def gen(z):
return np.sum(K * (z ** np.arange(L)))
return gen
生成函数把卷积核$\overline{K}$从时域转换到频域,该变换也叫做Z变换。一旦知道离散序列的 z 变换,就可以从单位根处的 z 变换求值($z \in \Omega = {\exp(2\pi \frac{k}{L}):k\in[L]}$)中获得滤波器的离散傅里叶变换,进而可以应用逆傅里叶变换来恢复卷积核$\overline{K}$。
def conv_from_gen(gen, L):
# Evaluate at roots of unity
# Generating function is (-)z-transform, so we evaluate at (-)root
Omega_L = np.exp((-2j * np.pi) * (np.arange(L) / L))
atRoots = np.vectorize(gen)(Omega_L)
# Inverse FFT
out = np.fft.ifft(atRoots, L).reshape(L)
return out.real
并且生成函数把矩阵幂运算替换成了矩阵逆运算。注意到:
\[\left(\sum_{i=0}^{L-1} \overline{A}^i z^i \right)(1-\overline{A} z) = \sum_{i=0}^{L-1} \overline{A}^i z^i - \sum_{i=1}^{L} \overline{A}^i z^i = I-\overline{A}^L z^L\]因此卷积核$\overline{K}$的截断生成函数计算为:
\[\hat{K}_L(z;\overline{A},\overline{B},\overline{C}) = \sum_{i=0}^{L-1} \overline{C}\overline{A}^i\overline{B} z^i = \overline{C} (I-\overline{A}^L z^L)(1-\overline{A} z)^{-1} \overline{B}\]注意到$z^L\propto \exp(2\pi k) = 1$,为简化表示合并常数项$\tilde{C}=\overline{C} (I-\overline{A}^L)$,该公式计算卷积核$\overline{K}$的生成函数时不需要显式调用卷积核$\overline{K}$。
from numpy.linalg import inv
from numpy.linalg import matrix_power
def K_gen_inverse(Ab, Bb, Cb, L):
I = np.eye(Ab.shape[0])
Ab_L = matrix_power(Ab, L)
Ct = Cb @ (I - Ab_L)
return lambda z: (Ct.conj() @ inv(I - Ab * z) @ Bb).reshape(-1)
(2)对角化
根据SSM矩阵的连续-离散转换关系:
\[\begin{aligned} \overline{A} &= (I-\Delta A/2)^{-1}(I+\Delta A/2) \\ \overline{B} &= (I-\Delta A/2)^{-1}\Delta B \\ \end{aligned}\]卷积核$\overline{K}$的截断生成函数可以表示为:
\[\begin{aligned} \hat{K}_L(z;\overline{A},\overline{B},\overline{C}) &= \tilde{C} (1-\overline{A} z)^{-1} \overline{B} \\ &= \tilde{C} \left(1-(I-\Delta A/2)^{-1}(I+\Delta A/2) z\right)^{-1} (I-\Delta A/2)^{-1}\Delta B \\ &= \tilde{C} \left((I-\Delta A/2)^{-1}(I-\Delta A/2)-(I-\Delta A/2)^{-1}(I+\Delta A/2) z\right)^{-1} (I-\Delta A/2)^{-1}\Delta B \\ &= \tilde{C} \left((I-\Delta A/2)^{-1}(I-\Delta A/2-(I+\Delta A/2) z)\right)^{-1} (I-\Delta A/2)^{-1}\Delta B \\ &= \tilde{C} \left(1-z-(1+z)\Delta A/2\right)^{-1} \Delta B \\ &= \frac{2}{1+z} \tilde{C} \left(\frac{2(1-z)}{\Delta(1+z)}-A\right)^{-1} B \\ \end{aligned}\]把矩阵$A$假设为对角矩阵:$A=\Lambda$,则上述生成函数可以写作如下形式:
\[\begin{aligned} \hat{K}_L(z;\overline{A},\overline{B},\overline{C}) &= \frac{2}{1+z} \tilde{C} \left(\frac{2(1-z)}{\Delta(1+z)}-\Lambda\right)^{-1} B \\ &= \frac{2}{1+z} \sum_i\frac{\tilde{C}_iB_i}{\frac{2(1-z)}{\Delta(1+z)}-\Lambda_i} \\ &:= c(z) \sum_i\frac{\tilde{C}_iB_i}{g(z)-\Lambda_i} = c(z) \cdot k_{z,\Lambda}(\tilde{C},B) \\ \end{aligned}\]其中$c(z)=\frac{2}{1+z},g(z)=\frac{2(1-z)}{\Delta(1+z)}$是$z$的函数。上式表明生成函数计算中的矩阵逆可以进一步表示为柯西核$k_{z,\Lambda}(\tilde{C},B)$形式的点积计算。
(3)低秩项修正
为了修正把矩阵$A$假设为对角矩阵$A=\Lambda$的假设,增加由$P,Q\in \mathbb{R}^{N\times 1}$构造的低秩部分:
\[A=\Lambda-PQ^T\]此时卷积核$\overline{K}$的截断生成函数可以表示为:
\[\begin{aligned} \hat{K}_L(z;\overline{A},\overline{B},\overline{C}) &= \frac{2}{1+z} \tilde{C} \left(\frac{2(1-z)}{\Delta(1+z)}-A\right)^{-1} B \\ &= \frac{2}{1+z} \tilde{C} \left(\frac{2(1-z)}{\Delta(1+z)}-\Lambda+PQ^T\right)^{-1} B \\ \end{aligned}\]根据Woodbury恒等式:
\[\begin{aligned} \left(A+UV^T\right)^{-1}&=\left(A(I+A^{-1}UV^T)\right)^{-1} \\ &=\left(I+A^{-1}UV^T\right)^{-1}A^{-1} \\ &=\sum_{n=0}^{\infty}(-1)^n(A^{-1}UV^T)^n A^{-1} \\ &=\sum_{n=0}^{\infty}(-1)^nA^{-1}U(V^TA^{-1}U)^{n-1}V^T A^{-1} \\ &=A^{-1}\left(\sum_{n=0}^{\infty}(-1)^nU(V^TA^{-1}U)^{n-1}V^T\right) A^{-1} \\ &=A^{-1}\left(A+U\sum_{n=0}^{\infty}(-1)^n(V^TA^{-1}U)^nV^T\right) A^{-1} \\ &=A^{-1}\left(A+U(I+V^TA^{-1}U)V^T\right) A^{-1} \\ &=A^{-1}-A^{-1}U\left(I+V^TA^{-1}U\right)^{-1}V^TA^{-1} \\ \end{aligned}\]且记$R^{-1}(z)=\frac{2(1-z)}{\Delta(1+z)}-\Lambda$,上式进一步化简为:
\[\begin{aligned} \hat{K}_L(z;\overline{A},\overline{B},\overline{C}) &= \frac{2}{1+z} \tilde{C} \left(\frac{2(1-z)}{\Delta(1+z)}-\Lambda+PQ^T\right)^{-1} B \\ &= \frac{2}{1+z} \tilde{C} \left(R^{-1}(z)+PQ^T\right)^{-1} B \\ &= \frac{2}{1+z} \tilde{C} \left(R(z)-R(z)P\left(I+Q^TR(z)P\right)^{-1}Q^TR(z)\right) B \\ \end{aligned}\]引入柯西核$k_{z,\Lambda}(A,B)=AB/(\frac{2(1-z)}{\Delta(1+z)}-\Lambda)=AR(z)B$的形式,上式进一步化简为四个点积运算:
\[\begin{aligned} \hat{K}_L(z;\overline{A},\overline{B},\overline{C}) &= \frac{2}{1+z} \tilde{C} \left(R(z)-R(z)P\left(I+Q^TR(z)P\right)^{-1}Q^TR(z)\right) B \\ &= c(z) \left[k_{z,\Lambda}(\tilde{C},B) -k_{z,\Lambda}(\tilde{C},P)\left(I+k_{z,\Lambda}(Q^T,P)\right)^{-1}k_{z,\Lambda}(Q^T,B)\right] \\ \end{aligned}\]def K_gen_DPLR(Lambda, P, Q, B, C, step, unmat=False):
aterm = (C.conj(), Q.conj())
bterm = (B, P)
def gen(o):
g = (2.0 / step) * ((1.0 - o) / (1.0 + o))
c = 2.0 / (1.0 + o)
def k(a):
return (a / (g - Lambda)).sum()
k00 = k(aterm[0] * bterm[0])
k01 = k(aterm[0] * bterm[1])
k10 = k(aterm[1] * bterm[0])
k11 = k(aterm[1] * bterm[1])
return c * (k00 - k01 * (1.0 / (1.0 + k11)) * k10)
return gen
此外,通过假设$A=\Lambda-PQ^T$能够简化离散形式的SSM的矩阵计算。离散形式的SSM矩阵表示为:
\[\begin{aligned} \overline{A} &= (I-\Delta A/2)^{-1}(I+\Delta A/2) \\ \overline{B}&= (I-\Delta A/2)^{-1}\Delta B \\ \end{aligned}\]注意到:
\[\begin{aligned} I+\Delta A/2 &= I+\frac{\Delta}{2}(\Lambda-PQ^T) \\ &= \frac{\Delta}{2}\left[\frac{2}{\Delta}I+(\Lambda-PQ^T)\right] \\ &:= \frac{\Delta}{2}A_0 \\ (I-\Delta A/2)^{-1} &= \left(I-\frac{\Delta}{2}(\Lambda-PQ^T)\right)^{-1} \\ &= \frac{2}{\Delta}\left[\frac{2}{\Delta}-\Lambda+PQ^T\right]^{-1} \\ &= \frac{2}{\Delta}\left[\left(\frac{2}{\Delta}-\Lambda\right)^{-1}-\left(\frac{2}{\Delta}-\Lambda\right)^{-1}P\left(I+Q^T\left(\frac{2}{\Delta}-\Lambda\right)^{-1}P\right)^{-1}Q^T\left(\frac{2}{\Delta}-\Lambda\right)^{-1}\right] \\ &:= \frac{2}{\Delta}A_1 \\ \end{aligned}\]此时离散形式的SSM矩阵表示为:
\[\begin{aligned} \overline{A} &= A_1A_0 \\ \overline{B}&= 2A_1 \\ \end{aligned}\](4)$A=\Lambda-PQ^T$
下面讨论$A=\Lambda-PQ^T$的假设是否成立。注意到:
\[\begin{aligned} A_{n,k} &= \begin{cases} -\sqrt{(2n+1)(2k+1)}, & k < n \\ -n-1, & k = n\\ 0, & k > n \end{cases} \\ \end{aligned}\]引入低秩矩阵$(PQ^T)_{n,k}=\sqrt{(2n+1)(2k+1)}$,有:
\[\begin{aligned} (A+\frac{1}{2}PQ^T)_{n,k} &= \begin{cases} -\frac{1}{2}\sqrt{(2n+1)(2k+1)}, & k < n \\ -\frac{1}{2}, & k = n\\ \frac{1}{2}\sqrt{(2n+1)(2k+1)}, & k > n \end{cases} \\ \end{aligned}\]进一步有:
\[\begin{aligned} (A+\frac{1}{2}PQ^T+\frac{1}{2}I)_{n,k} &= \begin{cases} -\frac{1}{2}\sqrt{(2n+1)(2k+1)}, & k < n \\ 0, & k = n\\ \frac{1}{2}\sqrt{(2n+1)(2k+1)}, & k > n \end{cases} \\ \end{aligned}\]此时上式是一个反对称矩阵,可以在复数域中被酉矩阵对角化。即存在对角矩阵$\Lambda$和酉矩阵$V$,满足:
\[\begin{aligned} A &=V^T\Lambda V - \frac{1}{2}I - \frac{1}{2}PQ^T \\ &=V^T(\Lambda - \frac{1}{2}I - \frac{1}{2}VP(VQ)^T)V \\ \end{aligned}\]其中对角矩阵$\Lambda\in\mathbb{R}^{N\times N}$、低秩矩阵$P,Q\in\mathbb{R}^{N\times r}$被设置为可学习参数,采用上述初始化。