Noether网络:通过元学习学习有用的守恒量.

机器学习中许多方法通过归纳偏置(inductive bias)利用所解决问题中的对称性(symmetry),比如图像分类中的卷积神经网络具有平移不变性(translation invariance),用于药物设计的图神经网络具有置换不变性(permutation invariance),用于蛋白质结构预测的Transformer具有旋转平移等变性(roto-translational equivariance)。

挖掘任务中的这种对称性归纳偏置能够提高机器学习系统的性能,然而对于感兴趣的数据分布可能存在未知或难以利用的对称性;本质上这是因为对称性描述了数据扰动的反事实效应,而数据扰动是无法直接观察到的。本文作者受Noether定理启发,通过元学习间接地利用序列预测问题中的对称性,为发现序列问题中的归纳偏差提供了一个通用框架。

1. Noether定理

Noether定理指出:

比如对于一个通过重力相互作用的行星系统:该系统在空间中三个方向上都是平移不变的(即在$x$,$y$或$z$轴上平移整个系统会保持运动定律),在这些方向上系统具有的守恒量是线性动量(linear momentum)。同地样,系统具有时间不变性(即运动定律在任意时刻相同),相应的守恒量是系统的总能量。

通常数据中的对称性是很难发现的,因为它们是与数据中无法观测到的扰动相关联的全局属性;而守恒量可以在真实数据中直接观察到,并可以通过机器学习算法捕捉。因此作者指出,可以通过元学习来近似问题中有用的归纳偏差。

直接从数据中学习守恒量需要解决两个问题:

  1. 真实数据经常含有噪音,从而违背精确的守恒定律。许多守恒定律仅在给出系统部分信息的情况下近似满足。例如,在真实耗散系统中能量守恒并不完全满足;在遮挡情况下从像素中估计的质量守恒必然是不精确的。
  2. 仅对守恒量进行优化可能会产生没有价值的结果,例如预测一个与输入无关的常数$C$。

作者指出寻找有用的守恒量并不需要在训练过程中精确保存它的数值,而只需要利用其守恒性来提供数据流形的信息。这使得在具有噪声的环境中也能捕捉有用的守恒量,并改善模型任务性能。

2. Noether网络

以序列预测问题中的视频预测模型为例,作者设计了如下Noether网络结构。

视频预测任务是指给定初始的状态帧$x_0$,预测之后$T$个状态帧\(\hat{x}_1,...,\hat{x}_T\)。上图所示系统为一个斜面上的物体下滑后撞击平面上的物体,根据物理学可知撞击过程会发生动量交换,然而由于系统缺乏相应的归纳偏置,故直接通过神经网络$f_{\theta}$预测的结果\(\tilde{x}_T\)不符合实际情况。

Noether网络采用一种裁剪(tailoring)框架,该框架通过微调神经网络$f_{\theta}$来减少预测偏差,该过程是通过构造一个无监督的守恒损失实现的。具体地,使用一个参数化的神经网络$g_{\phi}$来学习每个数据帧\(\tilde{x}_t\)中的守恒量\(g_{\phi}(\tilde{x}_t)\),并构造一个守恒的Noether损失:

\[\mathcal{L}_{\text{Noether}} (x_0,\tilde{x}_{1:T};g_{\phi}) = \sum_{t=1}^{T}|g_{\phi}(x_0)-g_{\phi}(\tilde{x}_t)|^2 ≈ \sum_{t=1}^{T}|g_{\phi}(x_{t-1})-g_{\phi}(\tilde{x}_t)|^2\]

该守恒损失编码了一种元归纳偏差(即针对归纳偏差的归纳偏差),它以指数方式缩小了搜索空间,并简化了参数化过程。

通过Noether损失对神经网络$f_{\theta}$的参数进行裁剪:

\[\theta(x_0;\phi) = \theta - \lambda_{in} \nabla_{\theta} \mathcal{L}_{\text{Noether}} (x_0,\tilde{x}_{1:T}(\theta);g_{\phi})\]

完成参数裁剪后,通过神经网络$f_{\theta}$给出最后的预测结果\(\hat{x}_t = f_{\theta(x_0;\phi)}(\hat{x}_{t-1})\),使用该结果计算任务损失\(\mathcal{L}_{\text{task}} (x_{1:T},\hat{x}_{1:T})\),并通过反向传播算法进一步更新神经网络$f_{\theta}$和神经网络$g_{\phi}$的参数。

完整的训练过程如下:

3. 实验分析

(1) Noether网络能否恢复数据中已知的守恒定律?

作者选择理想弹簧和理想单摆的状态坐标数据集,其中具有的守恒量为能量。对于输入为$x=(p,q)$的单摆,其中$q$是角度,$p$是动量,则能量守恒量为$3(1-\cos q) + p^2 \sim p^2-3\cos q$。对于输入为$x=(p,q)$的弹簧,其中$q$是位移,$p$是动量,则能量守恒量为$1/2(q^2 + p^2)\sim q^2 + p^2$。

实验表明,Noether网络能够正确地挑出与真实守恒量具有几乎精确参数的相同形式的方程。

(2) Noether网络适用于受控动力学的设置吗?

动作条件视频预测是指根据之前的帧和代理动作预测未来的帧。作者收集了一个受控单摆像素数据集,在OpenAI健身房环境中录制了单摆摆动的视频片段。模型从当前时间步开始接收$4$个历史帧和$26$个策略操作序列,并预测$26$个未来帧。

作者使用一个随机视频生成(SVG)模型预测受控单摆实验的像素帧,Noether网络改善了生成图像的总体均方误差MSE和结构相似性SSIM

(3) Noether网络能从原始像素中学习有用的守恒量吗?

作者进一步收集了一个随机视频生成(SVG)数据集,包含各种真实对象滑下斜坡并与第二个对象碰撞的视频。模型以前两帧为条件,预测随后的$20$帧;斜坡的倾斜角度不超过$20$度。在这种设置下,普通的SVG模型难以泛化,因为运动中的对象与目标对象作用后会变形。Noether网络改善了生成图像的学习感知图像块相似性LPIPS、均方误差MSE、峰值信噪比PSNR和结构相似性SSIM这表明Noether网络能够从原始视频数据中学习有用的归纳偏置。

作者展示了一些预测序列的显著性热图,并按照网络$g_{\phi}$输出的PCA维度排序。图中红色表示高重要性,蓝色表示低重要性。第一个维度解释了不同帧之间的绝大多数差异,在这两个示例中主要关注滑动的对象,也关注静止的对象和坡道的边缘。第四维关注的是滑动对象以及人手和桌子。第六维度跟踪蓝色滑动对象。

(4) 守恒程度如何影响预测任务的性能?

Noether网络使用单次内部梯度更新对网络参数进行裁剪,通过增加内部更新的次数能够增强模型学习到守恒量的程度。作者给出了增加内部更新次数对内部Noether损失和外部任务损失的影响。结果表明,内部更新次数为$150$附近时外部损失取得最小值,这表明Noether网络学习到的近似守恒可以推广到更精确的守恒设置。