状态空间模型的无需状态推理: 转移函数方法。
S4提出了使用生成函数建模状态空间模型的高效计算和训练方式。离散形式的SSM系统可以建模为:
\[\begin{aligned} x_{k+1} &= \overline{A}x_k + \overline{B} u_k \\ y_{k+1} &= \overline{C} x_{k+1} \end{aligned}\]从0时刻开始向后推导几个时刻后的输出,可得到如下的形式:
\[\begin{aligned} x_0 &= B u_0 \\ y_0 &= C x_0= CB x_0 \\ & \cdots \\ x_k &= A^k B u_0 +A^{k-1} B u_1 + \cdots +A B u_{k-1} + B u_k \\ y_k &= C A^k B u_0 +CA^{k-1} B u_1 + \cdots +CA B u_{k-1} + CB u_k \\ \end{aligned}\]形式上,构造下列形式的卷积核$\overline{K}$,即可将序列运算转化为卷积运算:
\[\begin{aligned} \overline{K} &= (C\overline{B}, C\overline{A} \overline{B},...,C \overline{A}^k \overline{B}, ...) \\ y &= \overline{K} * u \end{aligned}\]可以使用快速傅里叶变换(FFT)的卷积定理来快速计算该卷积运算:首先将输入序列的 FFT 相乘,然后应用逆 FFT 来有效地计算卷积的输出。
S4把$B,C$视为可训练参数,因此需要高效计算$\overline{A}^k$。S4把矩阵$\overline{A}$假设为复数空间下的对角低秩矩阵(Diagonal Plus Low-Rank, DPLR):$\overline{A}=\Lambda-PQ^T$,其中$P,Q\in \mathbb{R}^{N\times 1}$。进而不直接计算卷积核$\overline{K}$,而是计算卷积核$\overline{K}$的生成函数,把矩阵幂运算转化为矩阵逆运算。
\[\hat{K}_L(z;\overline{A},\overline{B},\overline{C}) = \sum_{i=0}^{\infty} \overline{C}\overline{A}^i\overline{B} z^i = \overline{C} (1- z\overline{A})^{-1} \overline{B}\]作者观察到$\overline{K}$的生成函数实际上是有理函数,基于此提出了有理转移函数(Rational Transfer Function,RTF)把生成函数表示为:
\[\hat{K}_L(z;\overline{A},\overline{B},\overline{C}) = \overline{C} (1- z\overline{A})^{-1} \overline{B} = \frac{b_1+b_2z+b_3z^2+\cdots + b_dz^{d-1}}{1+a_1z+a_2z^2+\cdots +a_dz^d}\]注意到:
\[\begin{aligned} \hat{K}_L(z;\overline{A},\overline{B},\overline{C}) &= \overline{C} (1- z\overline{A})^{-1} \overline{B} \\ &= z^{-1}\left[ 1+ \overline{C} (z^{-1}I- \overline{A})^{-1} \overline{B} -1\right] \\ (\text{according to}\det&(I+UV)=\det(I+VU)) \\ &= z^{-1}\left[ \det \left(I+ (z^{-1}I- \overline{A})^{-1} \overline{B}\overline{C}\right) -1\right] \\ &= z^{-1}\left[ \det \left(\left(z^{-1}I- \overline{A}\right)^{-1}\left( z^{-1}I- \overline{A}+ \overline{B}\overline{C}\right)\right) -1\right] \\ &= z^{-1}\left[ \frac{\det \left( z^{-1}I- \overline{A}+ \overline{B}\overline{C}\right)}{\det \left(z^{-1}I- \overline{A}\right)} -1\right] \\ &= \frac{z^{d-1}\left[\det \left( z^{-1}I- \overline{A}+ \overline{B}\overline{C}\right)- \det \left(z^{-1}I- \overline{A}\right)\right] }{z^d\det \left(z^{-1}I- \overline{A}\right)} \end{aligned}\]上式分子为z的$d-1$阶多项式,其系数对应$(b_1,b_2,…,b_d)$;分母为z的$d$阶多项式,其系数对应$(a_1,a_2,…,a_d)$。
注意到满足:
\[\det(\lambda I-\overline{A}) = \lambda^d + a_1\lambda^{d-1} + \cdots + a_d\]的矩阵$\overline{A}$可以构造为:
\[\overline{A} = \begin{pmatrix} -a_1 & -a_2 & \cdots & -a_{d-1} & -a_d \\ 1 & 0 & \cdots & 0 & 0 \\ 0 & 1 & \cdots & 0 & 0 \\ \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & \cdots & 1 & 0 \end{pmatrix}\]进一步有:
\[\det \left( \lambda I- \overline{A}+ \overline{B}\overline{C}\right)- \det \left(\lambda I- \overline{A}\right) = b_1\lambda^{d-1}+b_2\lambda^{d-2} + \cdots + b_d \\ \downarrow \\ \det \left( \lambda I- \overline{A}+ \overline{B}\overline{C}\right) = \lambda^d + (a_1+b_1)\lambda^{d-1} + \cdots + (a_d+b_d)\]满足条件的矩阵$\overline{A}-\overline{B}\overline{C}$可以构造为:
\[\overline{A}-\overline{B}\overline{C} = \begin{pmatrix} -a_1-b_1 & -a_2-b_2 & \cdots & -a_{d-1}-b_{d-1} & -a_d-b_d \\ 1 & 0 & \cdots & 0 & 0 \\ 0 & 1 & \cdots & 0 & 0 \\ \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & \cdots & 1 & 0 \end{pmatrix}\]从而得到$\overline{B}$和$\overline{C}$的一组解:
\[\overline{B}\overline{C} = \begin{pmatrix} b_1 & b_2 & \cdots & b_{d-1} & b_d \\ 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cdots & 0 & 0 \\ \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & \cdots & 0 & 0 \end{pmatrix} = \begin{pmatrix} 1 \\ 0 \\ \vdots \\ 0 \\ 0 \end{pmatrix}\begin{pmatrix} b_1 & b_2 & \cdots b_{d-1} & b_d \end{pmatrix}\]注意到:
\[\begin{aligned} DFT(\hat{K}_L) &= \frac{DFT(b_1+b_2z+b_3z^2+\cdots + b_dz^{d-1})}{DFT(1+a_1z+a_2z^2+\cdots +a_dz^d)}\\ &= \frac{DFT(b_1,b_2,b_3,\cdots, b_d,0,0,\cdots,0)}{DFT(1,a_1,a_2,\cdots,a_d,0,\cdots,0)} \end{aligned}\]上述DFT的计算复杂度为$O(L\log L)$,与隐状态维度$d$无关,因此可以通过增大$d$来提高模的性能: