递归序列的深度符号回归.
本文讨论符号回归(Symbolic Regression)问题,即给定一个数字序列$[u_1,u_2,…,u_n]$,寻找它们满足的函数或递归关系$u_n=f(n,u_{i<n})$,并预测下一个项$u_{n+1}$。
作者使用Transformer来推断递归关系,并在随机生成的序列数据集上训练模型。结果表明,该符号回归模型可以预测复杂的递归关系,并以较高的精度外推序列的下一项。为了测试模型的域外泛化能力,作者在在线整数序列百科全书(Online Encyclopedia of Integer Sequences, OEIS)上评估模型,其表现优于内置的Mathematica函数。
此外,该模型还能够预测常见常数和特殊函数的近似表达式:
1. 问题建模
给定长度为$n_{input}$的输入序列$[u_0,…,u_{n_{input}-1}]$,寻找函数$f$使得$u_i=f(i,u_{i-1},…,u_{i-d})$,其中$d$为递归度(recursion degree),表示当前元素只和前$d$个元素相关。这类问题通常是欠定的,给定有限项$n\in \Bbb{N}$,可能存在无限多个递归关系。在本文中,如果给定序列的前$n$项,函数$f$能够正确地预测下一项$n_{pred}$,则认为能解决问题。
序列中的数字采用两种设置:
- 整数(Integer):整数序列是代数学中非常感兴趣的一个领域,递归公式仅使用在$\Bbb{Z}$中闭合的运算符(例如加法、乘法、绝对值、模和整数除法)。
- 浮点数(Float):浮点数运算可以使用更大的操作符集(如实数除法、指数和三角函数),从而使问题更具挑战性。
本文设置两个任务格式:
- 符号(Symbolic)回归:模型在训练过程中使用序列的前$n_{input}$个输入项预测递归关系。在测试时通过模型预测的递归关系与实际序列中下一项$n_{pred}$的近似程度来评估性能。
- 数值(Numeric)回归:模型也在训练过程中使用序列的前$n_{input}$项,但直接预测后续项$n_{pred}$。在测试时将模型预测与序列的真实值进行比较。
2. 数据生成
本文使用的训练数据是通过随机生成递归关系来构造的。具体地,随机抽样一个初始项,然后使用递归关系创建后续项,步骤如下:
- 采样运算符的数量$1 \leq o \leq o_{max}$,并构建具有$o$个结点的一元二叉树。运算符的数量决定了表达式的难度;
- 从下面的运算符列表中为二叉树的每个非叶结点采样一个运算。注意到浮点数情况比整数情况使用更多运算符,通过扩展问题空间使任务更具挑战性;
- 采样递归深度$1\leq d \leq d_{max}$,深度$d$表明每个当前元素只和前$d$个元素相关;
- 为二叉树的每个叶节点采样,以$p_{const}$的概率取常数,以$p_n$的概率取当前索引$n$,以$p_{var}$的概率取前面的项$u_{n-i},i\in [1,d]$;
- 根据上一项采样得到的最深的叶结点$u_{n-i}$计算实际的递归深度$d_{eff}$,然后从随机分布$\mathcal{P}$中采样$d_{eff}$个初始化项;
- 采样序列长度$l_{min}\leq l \leq_{max}$,根据初始化项和递归关系计算后续$l$项。序列总长度为$n_{input}=d_{eff}+l$。
在最后一步中,如果遇到超出$10^{100}$的数值,或超出运算符的范围(如除以零或负平方根),则中断计算。上述过程中的超参数取值如下:
3. 序列嵌入
模型的输入是整数序列或浮点数序列,输出是递归关系(符号回归)或数值序列(数值回归)。由于Transformer的输入和输出应为来自固定词表(vocabulary)的token序列,因此需要对进行序列嵌入。
⚪ 整数
整数表示为以$b$为基的序列,序列长度为$\lfloor \log_b |x| \rfloor +2$,包括$1$位符号位和$\lfloor \log_b |x| \rfloor +1$位范围为$[0,b-1]$的数字。
比如$x=-325$以$b=10$为基表示为$[-,3,2,5]$,以$b=30$为基表示为$[-,10,25]$。
基数$b$的选择需要权衡序列长度和词汇表大小。实验选择$b=10000$,序列中的整数绝对值限制在$10^{100}$以内,因此每个整数最多有$26$个token,词汇表大小为$10^4$。
⚪ 浮点数
作者将浮点数保留$4$位有效数字,并用$3$个token表示,第一位表示符号位,第二位表示有效数字,第三位表示指数($E-100$到$E100$)。此时词汇表大小为$10^4$。
如$1/3$表示为$[+,3333,E-4]$。
⚪ 递归关系
递归关系采用直接波兰表示法(direct Polish notation),即按前缀顺序枚举树。
如$\cos(3x)$表示为$[\cos, mul, 3, x]$。
4. 实验分析
作者使用了一个简单的Transformer架构,包含$8$个隐藏层、$8$个注意力head和$512$的嵌入维度。
预测精度定义为预测表达式计算序列\(\{\hat{u}_i\}\)的后续$n_{pred}$项与正确序列\(\{u_i\}\)的后续项之间的比较关系:
\[acc(n_{pred}, \tau) = \Bbb{P}(\mathop{\max}_{1 \leq i \leq n_{pred}} |\frac{\hat{u}_i-u_i}{u_i}| \leq \tau )\]通过选择足够小的$\tau$和足够大的$n_{pred}$来保证预测公式与真实公示的匹配性。实验中$\tau$不能设置为$0$,因为浮点数运算中计算机精度有限,不同表达式的等效解的计算结果可能不同。实验选择$\tau=10^{-10}$,$n_{pred}=10$。
(1)In-domain generalization
下表展示了在生成数据集中的平均分布内精度。虽然浮点数设置比整数设置困难得多,但符号模型在这两种情况下都达到了很好的精度;而数值模型得到的结果稍差。
下图展示了不同超参数的消融实验。
- 图1表示不同公差水平$\tau$下的平均精度变化。符号模型在较低公差下的性能比数值模型好得多。在较高公差下,符号模型在整数设置中仍保持优势,在浮点数设置中的性能类似。这表明符号方法对于高精度预测更具优势。
- 图2表示不同预测数量$n_{pred}$下的平均精度变化。随着$n_{pred}$增加,两种模型的准确性都会下降。符号模型的精度下降不明显,尤其是浮点数设置基本平坦。这表明符号方法一旦找到正确的公式,就可以预测整个序列;而数值模型的精度随着外推而下降。
- 图3表示不同操作符数量$o$下的平均精度变化。随着操作符数量的增加,精度会迅速下降,尤其是在操作符更加多样化的浮点数设置中。
- 图4表示不同真实递归度$d_{eff}$下的平均精度变化。增加递归度具有类似操作符数量$o$但更温和的效果。
- 图5表示不同输入序列长度$l$下的平均精度变化。较短的序列更难预测,因为它们提供的关于底层表达式的信息较少。
作者对不同的操作符进行消融实验,结果表明主要困难在于除法和三角函数算子。
作者对整数模型中整数的token嵌入和浮点数模型中数值指数的token嵌入进行t-SNE可视化。两者都显示为序列结构,表现为顺序组织,其中指数嵌入在$0$左右对称。
(2)Out-of-domain generalization
作者评估了模型的域外推广能力。由于符号回归中的递归预测是一个完全未经探索的分支。因此没有官方基准。对于整数序列,作者使用OEIS的子集作为域外基准;对于浮点数序列,作者使用带有词汇表之外的常量和运算符的生成器。
⚪整数序列
整数序列使用在线整数序列百科全书(Online Encyclopedia of Integer Sequences, OEIS)作为评估基准,该数据库包含$30$万个整数序列,然而其中有一些序列没有解析递推式。作者选择标记为“容易”中至少包含$35$项的前$10000$个序列作为测试集,通过给出前$15$或$25$项,并预测后面的$10$项。
下表给出了模型的预测结果。只给出前$15$项时,数值模型预测后$1$项的准确率为$53\%$,预测后$10$项的准确率为$27\%$;而符号模型的精度分别是$33\%$和$19\%$。这是因为测试集中包含大量的非解析公式,即使如此该模型也能为五分之一的序列构造正确的分析公式。当给出前$25$项时性能略微提高。作者还与两个内置的Mathematica函数:FindSequenceFunction(查找非递归表达式)和FindLinearRecurrence(查找线性递归关系)进行比较,这些函数的性能都比符号模型差。
⚪浮点数序列
在处理浮点数序列时,使用带有词汇表之外的常量和运算符的生成器生成测试数据。其中符号模型的主要困难是处理词汇表之外的常量和运算符,模型被迫使用已知的词汇表来近似它们。
对于浮点数序列,数字模型的性能不会受到影响,因为它不会遇到词汇表外的问题。而符号模型的近似精度下降比较明显。
作者给出了对于未知运算符的近似情况。结果表明数值和符号模型都能很好地处理这些新运算符;符号模型的结果更为多样,尤其是多项式等可以通过词汇表轻松构建的函数。
作者还测试了在序列中增加噪声$\xi_n$~$\mathcal{N}(0,\sigma u_n)$对结果的影响。结果表明当训练过程没有噪声时,在测试集中引入噪声会导致精度骤降。如果训练过程引入噪声,模型能够对噪声具有一定的鲁棒性。