隐式梯度正则化.

梯度下降法是优化神经网络时最常用的数值方法。本文作者发现梯度下降过程会在损失函数中隐式地引入损失梯度的梯度惩罚项,从而具有正则化模型的效果,作者称之为隐式梯度正则化(implicit gradient regularization)。可以通过后向误差分析来计算这种正则化的大小,也可以将其调整为显式正则化项。隐式正则化能够使得梯度下降过程倾向于平坦的极小值,并对噪声扰动具有鲁棒性。

1. 梯度下降的后向误差分析

梯度下降法是指沿着损失函数$L(\theta)$的梯度$g$的负方向(数值下降最快的方向)更新参数$\theta$,可以写作如下常微分方程ODE

\[\dot{\theta} = -\nabla_{\theta}L(\theta) = -g(\theta)\]

上式可通过数值方法求解,如使用一阶显式欧拉方法将其转化为差分方程:

\[\theta_{t+\gamma} = \theta_t - \gamma g(\theta_t)\]

通过迭代求解上式可以得到一组解$\theta_{\gamma},\theta_{2\gamma},\theta_{3\gamma},…$。注意到这组解与微分方程的精确解$\theta_{\gamma}^{*},\theta_{2\gamma}^{*},\theta_{3\gamma}^{*},…$存在一定的误差。直接衡量这两组解之间的误差比较困难,因此采用后向误差分析(backward error analysis)

解$\theta_{\gamma},\theta_{2\gamma},\theta_{3\gamma},…$理论上应该是另一个微分方程$\dot{\theta} = -\tilde{g}(\theta)$的一组精确解,因此可以分析$g(\theta)$与$\tilde{g}(\theta)$的误差。

对$\theta_{t+\gamma}$进行Taylor展开:

\[\theta_{t+\gamma} = \theta_t+\gamma \nabla_t \theta_{t}+\frac{1}{2}\gamma^2 \nabla^2_t \theta_{t} + \frac{1}{6}\gamma^3 \nabla^3_t \theta_{t} + \cdots \\ = (1+\gamma \nabla_t+\frac{1}{2}\gamma^2 \nabla^2_t +\frac{1}{6}\gamma^3 \nabla^3_t +\cdots)\theta_{t} = e^{\gamma \nabla_t}\theta_t\]

因此解$\theta_{\gamma},\theta_{2\gamma},\theta_{3\gamma},…$对应的差分方程可写作:

\[(e^{\gamma \nabla_t}-1)\theta_t = - \gamma g(\theta_t)\]

上式可以通过算符运算调整为:

\[\theta_t = - \gamma (\frac{1}{e^{\gamma \nabla_t}-1}) g(\theta_t) \\ \nabla_t \theta_t = - (\frac{\gamma \nabla_t}{e^{\gamma \nabla_t}-1}) g(\theta_t) \\ \dot{\theta}_t = - (1-\frac{\gamma \nabla_t}{2}+\frac{\gamma^2 \nabla_t^2}{12}-\frac{\gamma^4 \nabla_t^4}{720}+\cdots) g(\theta_t)\]

对上式保留一阶项:

\[\dot{\theta}_t = - (1-\frac{\gamma \nabla_t}{2}) g(\theta_t) = -g(\theta_t)+\frac{1}{2}\gamma \nabla_tg(\theta_t) \\ = -g(\theta_t)+\frac{1}{2}\gamma \nabla_{\theta}\dot{\theta}_tg(\theta_t) \\ = -g(\theta_t)+\frac{1}{2}\gamma \nabla_{\theta}[-g(\theta_t)+\frac{1}{2}\gamma \nabla_{\theta}\dot{\theta}_tg(\theta_t)]g(\theta_t) \\ ≈ -g(\theta_t)+\frac{1}{2}\gamma \nabla_{\theta}[-g(\theta_t)]g(\theta_t) = -g(\theta_t)-\frac{1}{2}\gamma \nabla_{\theta}g(\theta_t)g(\theta_t) \\ =-g(\theta_t)-\frac{1}{4}\gamma \nabla_{\theta}||g(\theta_t)||^2\]

因此解$\theta_{\gamma},\theta_{2\gamma},\theta_{3\gamma},…$是微分方程\(\dot{\theta}_t=-g(\theta_t)-\frac{1}{4}\gamma \nabla_{\theta}\|g(\theta_t)\|^2\)的近似精确解。对照$\dot{\theta} = -\tilde{g}(\theta)$,则有:

\[\tilde{g}(\theta_t) = g(\theta_t)+\frac{1}{4}\gamma \nabla_{\theta}||g(\theta_t)||^2 \\ = \nabla_{\theta}L(\theta_t)+\frac{1}{4}\gamma \nabla_{\theta}||\nabla_{\theta}L(\theta_t)||^2 \\ = \nabla_{\theta}[L(\theta_t)+\frac{1}{4}\gamma ||\nabla_{\theta}L(\theta_t)||^2]\]

2. 梯度下降引入的损失修正

观察上式发现,梯度下降并不是直接沿梯度$g$的负方向更新参数,而是沿着修正梯度$\tilde{g}$的负方向更新参数。相当于修正的损失函数$\tilde{L}$:

\[\tilde{L}(\theta) = L(\theta)+\frac{1}{4}\gamma ||\nabla_{\theta}L(\theta)||^2\]

其中\(\frac{1}{4}\gamma\|\nabla_{\theta}L(\theta)\|^2\)相当于正则化项,其值取决于学习率$\gamma$和模型大小。该正则化项对具有较大梯度值的损失平面进行惩罚,有助于模型到达更加平缓的区域,有利于提高泛化性能。

作者通过分析和实验,预测隐式梯度正则化应具有如下性质:

  1. 隐式梯度正则化鼓励比损失函数更小的梯度惩罚项,更大的学习率和模型尺寸将使得梯度惩罚更小;
  2. 隐式梯度正则化鼓励更平稳的优化;
  3. 隐式梯度正则化鼓励更高的测试精度;
  4. 隐式梯度正则化鼓励寻找对参数的噪声扰动更鲁棒的最优解。

当学习率过小时,会弱化这种隐式的梯度正则化,从而导致模型的泛化性能不佳。然而学习率过大将会导致训练的不稳定,因此可以将该梯度正则化项显式地添加到损失函数中,即构造显式梯度正则化(explicit gradient regularization)

\[\tilde{L}(\theta) = L(\theta)+ \mu ||\nabla_{\theta}L(\theta)||^2\]

3. 实验分析

作者发现,学习率和模型参数越大,隐式正则化越强,测试精度也越高(性质1,3):

作者发现,学习率越大(对应正则化程度越强),损失曲面的斜率越小(性质2),对参数扰动的鲁棒性越强(性质4)。