HiPPO:具有最优多项式投影的循环记忆.
状态空间模型通常表示为线性常微分方程的形式:
\[\mathbf{x}^\prime(t) = A \mathbf{x}(t) + B \mathbf{u}(t) \\ \mathbf{y}(t) = C \mathbf{x}(t) + D \mathbf{u}(t)\]其中$\mathbf{x}(t)$是状态向量,$\mathbf{u}(t)$是输入向量,$\mathbf{y}(t)$是输出向量,$A,B,C,D$是系统矩阵。
状态空间模型的目标是用一个向量$\mathbf{x}(t)$存储输入$\mathbf{u}(t)$的信息。为了使该目标可行,把向量$\mathbf{x}(t)$拆解成一系列基函数\(\{g_n(t)\}_{n=0}^N\)的线性组合,目标函数写作两者的L2距离最小化:
\[\mathop{\arg \min}_{c_0,...,c_N} \int_a^b \left[ \mathbf{u}(t) - \sum_{n=0}^N c_n g_n(t) \right]^2 dt\]该目标可以展开为:
\[\int_a^b \mathbf{u}^2(t)dt - 2\sum_{n=0}^N c_n\int_a^b \mathbf{u}(t)g_n(t)dt + \sum_{m=0}^N \sum_{n=0}^N c_mc_n \int_a^bg_m(t)g_n(t)dt\]指定基函数\(\{g_n(t)\}_{n=0}^N\)为标准正交基函数,即有:
\[\int_a^bg_m(t)g_n(t)dt =\begin{cases} 0 & m \neq n \\ 1 & m = n \end{cases}\]则目标函数简化为:
\[\int_a^b \mathbf{u}^2(t)dt - 2\sum_{n=0}^N c_n\int_a^b \mathbf{u}(t)g_n(t)dt + \sum_{n=0}^N c_n ^2\]上述目标函数的最优解是:
\[c_n = \int_a^b \mathbf{u}(t)g_n(t)dt\]上述目标函数的最优解与随机给定的静态区间$[a,b]$挂钩,而输入$\mathbf{u}(t)$代表持续采集的信号,不妨设$t\in[0,T]$,取$s \to t_{\leq T}(s)$是$[a,b]$到$[0,T]$的一个映射,则将上述最优解表示为随着参数$T$而变化的形式:
\[c_n(T) = \int_a^b \mathbf{u}(t_{\leq T}(s))g_n(s)ds\]上式两侧对$T$求导为:
\[\begin{aligned} \frac{d}{dT}c_n(T) =& \int_a^b \mathbf{u}^\prime(t_{\leq T}(s))\frac{\partial t_{\leq T}(s)}{\partial T}g_n(s)ds \\ =& \int_a^b \left( \frac{\partial t_{\leq T}(s)}{\partial T} / \frac{\partial t_{\leq T}(s)}{\partial s} \right) g_n(s)d\mathbf{u}(t_{\leq T}(s)) \\ =& \left[\mathbf{u}(t_{\leq T}(s))\left( \frac{\partial t_{\leq T}(s)}{\partial T} / \frac{\partial t_{\leq T}(s)}{\partial s} \right) g_n(s) \right]_{s=a}^{s=b} \\ & - \int_a^b\mathbf{u}(t_{\leq T}(s)) d \left[ \left( \frac{\partial t_{\leq T}(s)}{\partial T} / \frac{\partial t_{\leq T}(s)}{\partial s} \right) g_n(s) \right] \end{aligned}\]对于基函数\(\{g_n(t)\}_{n=0}^N\),指定为勒让德多项式构造的标准正交基:
\[g_n(t) = \sqrt{\frac{2n+1}{2}}L_n(t)\]而对于映射函数$t_{\leq T}(s)$,原文给出了两种情况:
⚪ HiPPO-LegT(Translated Legendre)
设置:
\[t_{\leq T}(s) = (s+1)w/2+T-w\]映射函数此时将$[-1,1]$映射到$[T-w,T]$,即只保留最邻近窗口的输入信息。结合$L_n(1)=1,L_n(-1)=(-1)^n$,代入前式得:
\[\begin{aligned} \frac{d}{dT}c_n(T) =& \left[\mathbf{u}((s+1)w/2+T-w)\frac{2}{w} \sqrt{\frac{2n+1}{2}}L_n(s) \right]_{s=-1}^{s=1} \\ & - \int_{-1}^1\mathbf{u}((s+1)w/2+T-w) d \left[ \frac{2}{w} \sqrt{\frac{2n+1}{2}}L_n(s) \right] \\ =& \frac{\sqrt{2(2n+1)}}{w}\left[\mathbf{u}(T) -(-1)^n\mathbf{u}(T-w) \right] \\ & - \frac{2}{w}\int_{-1}^1\mathbf{u}((s+1)w/2+T-w) \sqrt{\frac{2n+1}{2}}L_n^\prime(s) ds \\ \end{aligned}\]下面对上式做一些化简。首先对$u(t_{\leq T}(s))$的近似做$n\leq N$截断:
\[\mathbf{u}((s+1)w/2+T-w) \approx \sum_{k=0}^N c_k(T) g_k(s)\]此时有:
\[\mathbf{u}(T-w) \approx \sum_{k=0}^N (-1)^kc_k(T) \sqrt{\frac{2k+1}{2}}\]其次根据勒让德多项式的性质:
\[L_{n+1}^\prime(s) = \sum_{k=0}^n (2k+1)[1-(n-k)\%2]L_k(s)\]此时有:
\[\begin{aligned} &\int_{-1}^1\mathbf{u}((s+1)w/2+T-w) \sqrt{\frac{2n+1}{2}}L_n^\prime(s) ds \\ = & \int_{-1}^1\mathbf{u}((s+1)w/2+T-w) \sqrt{\frac{2n+1}{2}}\left[\sum_{k=0}^{n-1} (2k+1)[1-(n-1-k)\%2]L_k(s)\right] ds \\ = & \int_{-1}^1\mathbf{u}((s+1)w/2+T-w) \sqrt{\frac{2n+1}{2}}\left[\sum_{k=0}^{n-1} \sqrt{2(2k+1)}[1-(n-1-k)\%2]g_k(s)\right] ds \\ = & \sqrt{2n+1} \sum_{k=0}^{n-1} \sqrt{2k+1}[1-(n-1-k)\%2] \int_{-1}^1\mathbf{u}((s+1)w/2+T-w) g_k(s)ds \\ = & \sqrt{2n+1} \sum_{k=0}^{n-1} \sqrt{2k+1}[1-(n-1-k)\%2] c_k(T) \\ \end{aligned}\]整合上述结果,得到:
\[\begin{aligned} \frac{d}{dT}c_n(T) =& \frac{\sqrt{2(2n+1)}}{w}\left[\mathbf{u}(T) -(-1)^n\mathbf{u}(T-w) \right] \\ & - \frac{2}{w}\int_{-1}^1\mathbf{u}((s+1)w/2+T-w) \sqrt{\frac{2n+1}{2}}L_n^\prime(s) ds \\ =& \frac{\sqrt{2(2n+1)}}{w}\left[\mathbf{u}(T) -(-1)^n\sum_{k=0}^N (-1)^kc_k(T) \sqrt{\frac{2k+1}{2}} \right] \\ & - \frac{2}{w}\sqrt{2n+1} \sum_{k=0}^{n-1} \sqrt{2k+1}[1-(n-1-k)\%2] c_k(T) \\ =& \frac{\sqrt{2(2n+1)}}{w}\mathbf{u}(T)-\frac{\sqrt{2n+1}}{w}\sum_{k=0}^N (-1)^{n-k}c_k(T) \sqrt{2k+1} \\ & - \frac{\sqrt{2n+1}}{w} \sum_{k=0}^{n-1} \sqrt{2k+1}2[1-(n-1-k)\%2] c_k(T) \\ =& \frac{\sqrt{2(2n+1)}}{w}\mathbf{u}(T)-\frac{\sqrt{2n+1}}{w}\sum_{k=n}^N (-1)^{n-k}c_k(T) \sqrt{2k+1} \\ & - \frac{\sqrt{2n+1}}{w} \sum_{k=0}^{n-1} \sqrt{2k+1}\left(2[1-(n-1-k)\%2]+(-1)^{n-k}\right) c_k(T) \\ =& \frac{\sqrt{2(2n+1)}}{w}\mathbf{u}(T)-\frac{\sqrt{2n+1}}{w}\sum_{k=n}^N (-1)^{n-k}c_k(T) \sqrt{2k+1} \\ & - \frac{\sqrt{2n+1}}{w} \sum_{k=0}^{n-1} \sqrt{2k+1} c_k(T) \\ \end{aligned}\]根据上式可以得到一组状态空间模型:
\[\begin{aligned} \frac{d}{dt}c_n(t) &= A_{n,k}c_n(t) + B_n \mathbf{u}(t) \\ A_{n,k} &= -\frac{1}{w}\begin{cases} \sqrt{(2n+1)(2k+1)}, & k < n \\ (-1)^{n-k} \sqrt{(2n+1)(2k+1)}, & k \geq n \end{cases} \\ B_n &= \frac{\sqrt{2(2n+1)}}{w} \\ \end{aligned}\]可以通过为上式引入缩放因子$\lambda_n$使得结果更加一般化。
在实际场景中,输入$\mathbf{u}(t)$通常是离散的:$u_0,u_1,…,u_t,…$,比如文本输入,若希望用上述状态空间模型来实时记忆这些离散点,需要对常微分方程离散化。设定离散化步长$\epsilon$,将输入$\mathbf{u}(t)$表示为$u(t)=u_k,t\in[k\epsilon,(k+1)\epsilon)$,即通过阶梯函数把离散输入$\mathbf{u}(t)$连续化。基于此,对上述ODE两端积分:
\[\begin{aligned} c_n(t+\epsilon)-c_n(t) &= A_{n,k}\int_t^{t+\epsilon}c_n(s)ds + B_n \int_t^{t+\epsilon}\mathbf{u}(s)ds \\ &= A_{n,k}\int_t^{t+\epsilon}c_n(s)ds + \epsilon B_n u_k \\ \end{aligned}\]其中$t=k\epsilon$。
若假设在$[t,t+\epsilon)$区间内$c_n(s)$近似等于$c_n(t)$,则得到前向欧拉格式:
\[\begin{aligned} c_n(t+\epsilon)-c_n(t) &= \epsilon A_{n,k}c_n(t) + \epsilon B_n u_k \\ &\downarrow \\ c_n(t+\epsilon) &= (I+\epsilon A_{n,k})c_n(t) + \epsilon B_n u_k \\ \end{aligned}\]若假设在$[t,t+\epsilon)$区间内$c_n(s)$近似等于$c_n(t+\epsilon)$,则得到后向欧拉格式:
\[\begin{aligned} c_n(t+\epsilon)-c_n(t) &= \epsilon A_{n,k}c_n(t+\epsilon) + \epsilon B_n u_k \\ &\downarrow \\ c_n(t+\epsilon) &= (I-\epsilon A_{n,k})^{-1}\left(c_n(t) + \epsilon B_n u_k\right) \\ \end{aligned}\]前后向欧拉都具有相同的理论精度,但后向通常会有更好的数值稳定性。若假设在$[t,t+\epsilon)$区间内$c_n(s)$近似等于$(c_n(t)+c_n(t+\epsilon))/2$,则得到双线性格式:
\[\begin{aligned} c_n(t+\epsilon)-c_n(t) &= \frac{1}{2}\epsilon A_{n,k}(c_n(t)+c_n(t+\epsilon)) + \epsilon B_n u_k \\ &\downarrow \\ c_n(t+\epsilon) &= (I-\epsilon A_{n,k}/2)^{-1}\left((I+\epsilon A_{n,k}/2)c_n(t) + \epsilon B_n u_k\right) \\ \end{aligned}\]⚪ LegS(Scaled Legendre)
设置:
\[t_{\leq T}(s) = (s+1)T/2\]映射函数此时将$[-1,1]$均匀映射到$[0,T]$,即没有舍弃任何历史输入信息,并且平等地对待所有输入。代入前式得:
\[\begin{aligned} \frac{d}{dT}c_n(T) =& \left[\mathbf{u}((s+1)T/2)\left( \frac{s+1}{T} \right) \sqrt{\frac{2n+1}{2}}L_n(s) \right]_{s=-1}^{s=1} \\ & - \int_{-1}^1\mathbf{u}((s+1)T/2) d \left[ \left( \frac{s+1}{T} \right) \sqrt{\frac{2n+1}{2}}L_n(s) \right] \\ =& \frac{\sqrt{2(2n+1)}}{T}\mathbf{u}(T) \\ & - \frac{1}{T}\sqrt{\frac{2n+1}{2}}\int_{-1}^1\mathbf{u}((s+1)T/2) \left[ L_n(s)+( s+1 ) L_n^\prime(s) \right]ds \\ \end{aligned}\]根据勒让德多项式的性质:
\[(s+1)L_n^\prime(s) = -(n+1)L_n(s)+\sum_{k=0}^n (2k+1)L_k(s)\]此时有:
\[\begin{aligned} &\int_{-1}^1\mathbf{u}((s+1)T/2) \left[ L_n(s)+( s+1 ) L_n^\prime(s) \right]ds \\ =&\int_{-1}^1\mathbf{u}((s+1)T/2) \left[ L_n(s)-(n+1)L_n(s)+\sum_{k=0}^n (2k+1)L_k(s) \right]ds \\ =&\int_{-1}^1\mathbf{u}((s+1)T/2) \left[ -nL_n(s)+\sum_{k=0}^n (2k+1)L_k(s) \right]ds \\ =&\int_{-1}^1\mathbf{u}((s+1)T/2) \left[ -n\sqrt{\frac{2}{2n+1}}g_n(s)+\sum_{k=0}^n (2k+1)\sqrt{\frac{2}{2k+1}}g_k(s) \right]ds \\ =&-n\sqrt{\frac{2}{2n+1}}\int_{-1}^1\mathbf{u}((s+1)T/2) g_n(s)ds +\sum_{k=0}^n \sqrt{2(2k+1)}\int_{-1}^1\mathbf{u}((s+1)T/2)g_k(s) ds \\ =&-n\sqrt{\frac{2}{2n+1}}c_n(T) +\sum_{k=0}^n \sqrt{2(2k+1)}c_k(T) \\ \end{aligned}\]整合上述结果,得到:
\[\begin{aligned} \frac{d}{dT}c_n(T) =& \frac{\sqrt{2(2n+1)}}{T}\mathbf{u}(T) \\ & - \frac{1}{T}\sqrt{\frac{2n+1}{2}}\int_{-1}^1\mathbf{u}((s+1)T/2) \left[ L_n(s)+( s+1 ) L_n^\prime(s) \right]ds \\ =& \frac{\sqrt{2(2n+1)}}{T}\mathbf{u}(T) \\ & - \frac{1}{T}\sqrt{\frac{2n+1}{2}}\left[ -n\sqrt{\frac{2}{2n+1}}c_n(T) +\sum_{k=0}^n \sqrt{2(2k+1)}c_k(T) \right] \\ =& \frac{\sqrt{2(2n+1)}}{T}\mathbf{u}(T) - \frac{1}{T}\left[ -nc_n(T) +\sum_{k=0}^n \sqrt{(2n+1)(2k+1)}c_k(T) \right] \\ \end{aligned}\]根据上式可以得到一组状态空间模型:
\[\begin{aligned} \frac{d}{dt}c_n(t) &= \frac{A_{n,k}}{t}c_n(t) + \frac{B_n}{t} \mathbf{u}(t) \\ A_{n,k} &= -\begin{cases} \sqrt{(2n+1)(2k+1)}, & k < n \\ n+1, & k = n\\ 0, & k > n \end{cases} \\ B_n &= \sqrt{2(2n+1)} \\ \end{aligned}\]def make_HiPPO(N):
P = np.sqrt(1 + 2 * np.arange(N))
A = P[:, np.newaxis] * P[np.newaxis, :]
A = np.tril(A) - np.diag(np.arange(N))
return -A
下面对上式离散化。将输入$\mathbf{u}(t)$表示为$u(t)=u_k,t\in[k\epsilon,(k+1)\epsilon)$,对上述ODE两端积分:
\[\begin{aligned} c_n(t+\epsilon)-c_n(t) &= A_{n,k}\int_t^{t+\epsilon}\frac{c_n(s)}{s}ds + B_n\int_t^{t+\epsilon}\frac{\mathbf{u}(s)}{s}ds \\ &= A_{n,k}\int_t^{t+\epsilon}\frac{c_n(s)}{s}ds + B_nu_k\ln(\frac{t+\epsilon}{t}) \\ \end{aligned}\]假设在$[t,t+\epsilon)$区间内$c_n(s)$近似等于$(c_n(t)+c_n(t+\epsilon))/2$,并且取近似$\ln(1+x)\approx x$,则有:
\[\begin{aligned} c_n(t+\epsilon)-c_n(t) &= \frac{1}{2}\epsilon A_{n,k}\left(\frac{c_n(t)}{t}+\frac{c_n(t+\epsilon)}{t+\epsilon}\right) + \frac{\epsilon}{t} B_nu_k \\ &\downarrow \\ c_n(t+\epsilon) &= \left(I-\frac{\epsilon A_{n,k}}{2(t+\epsilon)}\right)\left[\left(I+\frac{\epsilon A_{n,k}}{2t}\right)c_n(t) + \frac{\epsilon}{t} B_nu_k\right] \\ \end{aligned}\]HiPPO-LegS具有一系列优良性质:
- 时间尺度等变性(Timescale equivariance):将$t=k\epsilon$代入上式并化简,发现更新公式与步长$\epsilon$无关。
- 多项式衰减(Polynomial decay):HiPPO-LegS对历史信号的记忆满足多项式衰减,衰减函数是$1/n$的$1,2,⋯,d$次函数的线性组合,比RNN的指数衰减更缓慢,从而理论上能记忆更长的历史,更不容易梯度消失。
- 计算高效(Computational efficiency):HiPPO-LegS的$A$矩阵执行矩阵与向量的乘法时是计算高效的,因为$A$是一个下三角矩阵,可以利用前缀和的相关算法并行计算,计算复杂度为$O(d)$。