勒让德记忆单元:循环神经网络中的连续时间表示.
0. TL;DR
本文提出了一种名为Legendre Memory Unit(LMU)的新型循环神经网络(RNN)记忆单元,旨在使用较少的资源在长时间窗口内动态维持信息。LMU通过解决一组耦合的常微分方程(ODE)来实现连续时间历史的正交化,这些方程的相空间通过勒让德多项式线性映射到时间滑动窗口。
LMU是首个能够处理跨越100,000个时间步长的时间依赖性的循环架构。在置换序列MNIST等任务上,LMU的性能超过了现有的RNN。这些结果归因于网络能够独立于步长学习尺度不变特征的能力,通过ODE求解器的反向传播允许每一层调整其内部时间步长。
1. 背景介绍
循环神经网络(RNN)已被广泛用于需要学习长距离时间依赖性的任务,如机器翻译、图像字幕生成和语音识别。尽管长短期记忆网络(LSTM)等架构在建模复杂时间关系方面取得了显著成功,但它们在处理连续时间数据和长距离依赖方面仍存在限制。特别是在内存受限的情况下,实时运行并利用连续时间数据流中的长距离依赖关系对于模型来说是一个挑战。
生物神经系统则自然地配备了解决连续时间信息处理问题的机制。神经元通过突触连接连续地过滤随时间变化的尖峰信号。受这些机制的启发,本文提出了一种新的RNN架构——Legendre Memory Unit(LMU),它能够理论上保证学习长距离依赖性,即使离散时间步长趋近于零。
2. Legendre Memory Unit
LMU的主要组件是一个记忆单元,它使用勒让德多项式正交化输入信号的连续时间历史。记忆单元基于连续时间延迟的线性传递函数$F(s) = e^{-θs}$导出,其中$θ$是时间窗口的长度。该传递函数可以由常微分方程表示:
\[\theta \dot{\mathbf{m}}(t) = A\mathbf{m}(t) + B\mathbf{u}(t)\]这个函数可以通过勒让德多项式\(\{g_n(t)\}_{n=0}^N\)的线性组合来最佳近似。目标函数写作两者的L2距离最小化:
\[\mathop{\arg \min}_{c_0,...,c_N} \int_a^b \left[ \mathbf{u}(t) - \sum_{n=0}^N c_n g_n(t) \right]^2 dt\]其中基函数\(\{g_n(t)\}_{n=0}^N\)为勒让德多项式构造的标准正交基:
\[g_n(t) = \sqrt{\frac{2n+1}{2}}L_n(t)\]该目标可以展开为:
\[\int_a^b \mathbf{u}^2(t)dt - 2\sum_{n=0}^N c_n\int_a^b \mathbf{u}(t)g_n(t)dt + \sum_{m=0}^N \sum_{n=0}^N c_mc_n \int_a^bg_m(t)g_n(t)dt \\ \downarrow\\ \int_a^b \mathbf{u}^2(t)dt - 2\sum_{n=0}^N c_n\int_a^b \mathbf{u}(t)g_n(t)dt + \sum_{n=0}^N c_n ^2\]上述目标函数的最优解是:
\[c_n = \int_a^b \mathbf{u}(t)g_n(t)dt\]上述目标函数的最优解与随机给定的静态区间$[a,b]$挂钩,而输入$\mathbf{u}(t)$代表持续采集的信号,不妨设$t\in[0,T]$,取$s \to t_{\leq T}(s)$是$[a,b]$到$[0,T]$的一个映射,则将上述最优解表示为随着参数$T$而变化的形式:
\[c_n(T) = \int_a^b \mathbf{u}(t_{\leq T}(s))g_n(s)ds\]上式两侧对$T$求导为:
\[\begin{aligned} \frac{d}{dT}c_n(T) =& \int_a^b \mathbf{u}^\prime(t_{\leq T}(s))\frac{\partial t_{\leq T}(s)}{\partial T}g_n(s)ds \\ =& \int_a^b \left( \frac{\partial t_{\leq T}(s)}{\partial T} / \frac{\partial t_{\leq T}(s)}{\partial s} \right) g_n(s)d\mathbf{u}(t_{\leq T}(s)) \\ =& \left[\mathbf{u}(t_{\leq T}(s))\left( \frac{\partial t_{\leq T}(s)}{\partial T} / \frac{\partial t_{\leq T}(s)}{\partial s} \right) g_n(s) \right]_{s=a}^{s=b} \\ & - \int_a^b\mathbf{u}(t_{\leq T}(s)) d \left[ \left( \frac{\partial t_{\leq T}(s)}{\partial T} / \frac{\partial t_{\leq T}(s)}{\partial s} \right) g_n(s) \right] \end{aligned}\]而对于映射函数$t_{\leq T}(s)$,设置:
\[t_{\leq T}(s) = (s+1)\theta /2+T-\theta\]映射函数此时将$[-1,1]$映射到$[T-\theta ,T]$,即只保留时间窗口$\theta$的输入信息。结合$L_n(1)=1,L_n(-1)=(-1)^n$,代入前式得:
\[\begin{aligned} \frac{d}{dT}c_n(T) =& \left[\mathbf{u}((s+1)\theta /2+T-\theta )\frac{2}{\theta } \sqrt{\frac{2n+1}{2}}L_n(s) \right]_{s=-1}^{s=1} \\ & - \int_{-1}^1\mathbf{u}((s+1)\theta /2+T-\theta ) d \left[ \frac{2}{\theta } \sqrt{\frac{2n+1}{2}}L_n(s) \right] \\ =& \frac{\sqrt{2(2n+1)}}{\theta }\left[\mathbf{u}(T) -(-1)^n\mathbf{u}(T-\theta ) \right] \\ & - \frac{2}{\theta }\int_{-1}^1\mathbf{u}((s+1)\theta /2+T-\theta ) \sqrt{\frac{2n+1}{2}}L_n^\prime(s) ds \\ \end{aligned}\]下面对上式做一些化简。首先对$u(t_{\leq T}(s))$的近似做$n\leq N$截断:
\[\mathbf{u}((s+1)\theta /2+T-\theta ) \approx \sum_{k=0}^N c_k(T) g_k(s)\]此时有:
\[\mathbf{u}(T-\theta ) \approx \sum_{k=0}^N (-1)^kc_k(T) \sqrt{\frac{2k+1}{2}}\]其次根据勒让德多项式的性质:
\[L_{n+1}^\prime(s) = \sum_{k=0}^n (2k+1)[1-(n-k)\%2]L_k(s)\]此时有:
\[\begin{aligned} &\int_{-1}^1\mathbf{u}((s+1)\theta /2+T-\theta ) \sqrt{\frac{2n+1}{2}}L_n^\prime(s) ds \\ = & \int_{-1}^1\mathbf{u}((s+1)\theta /2+T-\theta ) \sqrt{\frac{2n+1}{2}}\left[\sum_{k=0}^{n-1} (2k+1)[1-(n-1-k)\%2]L_k(s)\right] ds \\ = & \int_{-1}^1\mathbf{u}((s+1)\theta /2+T-\theta ) \sqrt{\frac{2n+1}{2}}\left[\sum_{k=0}^{n-1} \sqrt{2(2k+1)}[1-(n-1-k)\%2]g_k(s)\right] ds \\ = & \sqrt{2n+1} \sum_{k=0}^{n-1} \sqrt{2k+1}[1-(n-1-k)\%2] \int_{-1}^1\mathbf{u}((s+1)\theta /2+T-\theta ) g_k(s)ds \\ = & \sqrt{2n+1} \sum_{k=0}^{n-1} \sqrt{2k+1}[1-(n-1-k)\%2] c_k(T) \\ \end{aligned}\]整合上述结果,得到:
\[\begin{aligned} \frac{d}{dT}c_n(T) =& \frac{\sqrt{2(2n+1)}}{\theta }\left[\mathbf{u}(T) -(-1)^n\mathbf{u}(T-\theta ) \right] \\ & - \frac{2}{\theta }\int_{-1}^1\mathbf{u}((s+1)\theta /2+T-\theta ) \sqrt{\frac{2n+1}{2}}L_n^\prime(s) ds \\ =& \frac{\sqrt{2(2n+1)}}{\theta }\left[\mathbf{u}(T) -(-1)^n\sum_{k=0}^N (-1)^kc_k(T) \sqrt{\frac{2k+1}{2}} \right] \\ & - \frac{2}{\theta }\sqrt{2n+1} \sum_{k=0}^{n-1} \sqrt{2k+1}[1-(n-1-k)\%2] c_k(T) \\ =& \frac{\sqrt{2(2n+1)}}{\theta }\mathbf{u}(T)-\frac{\sqrt{2n+1}}{\theta }\sum_{k=0}^N (-1)^{n-k}c_k(T) \sqrt{2k+1} \\ & - \frac{\sqrt{2n+1}}{\theta } \sum_{k=0}^{n-1} \sqrt{2k+1}2[1-(n-1-k)\%2] c_k(T) \\ =& \frac{\sqrt{2(2n+1)}}{\theta }\mathbf{u}(T)-\frac{\sqrt{2n+1}}{\theta }\sum_{k=n}^N (-1)^{n-k}c_k(T) \sqrt{2k+1} \\ & - \frac{\sqrt{2n+1}}{\theta } \sum_{k=0}^{n-1} \sqrt{2k+1}\left(2[1-(n-1-k)\%2]+(-1)^{n-k}\right) c_k(T) \\ =& \frac{\sqrt{2(2n+1)}}{\theta }\mathbf{u}(T)-\frac{\sqrt{2n+1}}{\theta }\sum_{k=n}^N (-1)^{n-k}c_k(T) \sqrt{2k+1} \\ & - \frac{\sqrt{2n+1}}{\theta } \sum_{k=0}^{n-1} \sqrt{2k+1} c_k(T) \\ \end{aligned}\]根据上式可以得到:
\[\begin{aligned} \theta \frac{d}{dt}c_n(t) &= A_{n,k}c_n(t) + B_n \mathbf{u}(t) \\ A_{n,k} &= -\begin{cases} \sqrt{(2n+1)(2k+1)}, & k < n \\ (-1)^{n-k} \sqrt{(2n+1)(2k+1)}, & k \geq n \end{cases} \\ B_n &= \sqrt{2(2n+1)} \\ \end{aligned}\]可以通过为上式引入缩放因子$\lambda_n$使得结果更加一般化。若指定$\lambda_n=(-1)^{n}\sqrt{\frac{2n+1}{2}}$,则近似导出原文的矩阵系数:
\[\begin{aligned} \mathbf{A}=[A_{n,k}]\in\mathbb{R}^{d\times d},& A_{n,k} = (2n+1)\begin{cases} -1, & n<k \\ (-1)^{n-k+1}, & n\geq k\end{cases} \\ \mathbf{B}=[B_{n}]\in\mathbb{R}^{d\times 1},&B_n = (2n+1)(-1)^n \\ \end{aligned}\]3. LMU模型
LMU的记忆表示支持存储相对于时间尺度的更高频率输入($\theta$时间窗口)。通过将上述方程映射到RNN的内存上,给定离散时间点的输入,可以计算出内存状态。
LMU层的设计包括一个$n$维状态向量$h_t$和一个$d$维内存向量$m_t$,它们通过线性和非线性变换动态耦合:
\[h_t = f(W_xx_t+W_hh_{t-1}+W_mm_t)\]其中$W_x,W_h,W_m$是可学习参数,内存向量$m_t$构造为:
\[\begin{aligned} m_t &= Am_{t-1}+Bu_t \\ \mathbf{A}&=[A_{n,k}]\in\mathbb{R}^{d\times d}, A_{n,k} = (2n+1)\begin{cases} -1, & n<k \\ (-1)^{n-k+1}, & n\geq k\end{cases} \\ \mathbf{B}&=[B_{n}]\in\mathbb{R}^{d\times 1},B_n = (2n+1)(-1)^n \\ \end{aligned}\]内存向量$m_t$表示输入$u_t$在勒让德多项式上投影的滑动窗口。输入$u_t$构造为:
\[u_t = e_x^Tx_t+e_h^Th_{t-1}+e_m^Tm_{t-1}\]其中$e_x,e_h,e_m$是可学习编码向量。通过勒让德多项式将连续时间历史映射到正交基函数上,从而实现信息的压缩和高效存储。
4. 实验分析
实验表明,LMU成功地在不使用任何训练的情况下,仅使用$105$个内部状态变量持久地保持信息($10^5$时间步长)。相比之下,LSTM则需要更多的参数和状态变量。
此外,在Mackey-Glass时间序列预测任务上,LMU在归一化均方根误差(NRMSE)方面优于LSTM,训练时间更短,预测精度更高。在置换序列MNIST任务上,LMU在没有依赖高级正则化技术的情况下,性能超过了当前最先进的RNN,达到了$97.15\%$的准确率。
这些结果表明,LMU在处理长时间依赖性和连续时间数据方面具有显著优势。