Amos:一种面向模型的自适应权重衰减的Adam风格优化器.

本文作者提出了Amos优化器,旨在自适应地调整梯度更新过程中的学习率和权重衰减系数。一般地,参数$\theta$的梯度下降更新公式为:

\[\theta_{t+1} = \theta_t - \alpha_tu_t\]

其中$u_t$代表$t$时刻的更新向量,标量$α_t>0$代表$t$时刻的学习率。额外引入权重衰减项$\rho_t>0$:

\[\theta_{t+1} = \theta_t - (\alpha_tu_t+\rho_t\theta_t)\]

因此需要选择合适的学习率$\alpha_t$和权重衰减项$\rho_t$。通常权重衰减项作为正则化目标,应该比主目标对参数更新的影响更小。因此引入约束:权重衰减带来的更新量始终要比目标相关的更新量高一阶:

\[\mathcal{O}(\alpha_t^2) = \mathcal{O}(\rho_t)\]

记最优参数$\theta^{*}$,则当前时刻$t$的参数误差为$\epsilon_t=\theta_t-\theta^{*}$,并且有以下关系:

\[\begin{aligned} ||\epsilon_{t+1}||^2 & = ||\theta_{t+1}-\theta^*||^2 \\ & = ||\theta_t - (\alpha_tu_t+\rho_t\theta_t)-\theta^*||^2 \\ & = ||\epsilon_t - (\alpha_tu_t+\rho_t\theta_t)||^2 \\ & = ||\epsilon_t||^2 - 2\epsilon_t(\alpha_tu_t+\rho_t\theta_t)+||\alpha_tu_t+\rho_t\theta_t||^2 \\& \approx ||\epsilon_t||^2 - 2\alpha_tu_t \cdot \epsilon_t +(\alpha_t^2||u_t||^2-2\rho_t\theta_t \cdot \epsilon_t) \end{aligned}\]

为使参数误差减小,不妨考察:

\[\begin{aligned} \alpha_t^2||u_t||^2&=2\rho_t\theta_t \cdot \epsilon_t \\ &= 2\rho_t(\epsilon_t+\theta^*) \\ &= 2\rho_t ||\epsilon_t||^2+2\rho_t\theta^* \cdot \epsilon_t \\ & \approx 2\rho_tq ||\epsilon_t||^2 \end{aligned}\]

通过引入参数$q$使得上式近似满足。此时参数误差的关系为:

\[\begin{aligned} ||\epsilon_{t+1}||^2 & \approx ||\epsilon_t||^2 - 2\alpha_tu_t \cdot \epsilon_t \\ & = ||\epsilon_t||^2 - 2\alpha_t||u_t|| \cdot || \epsilon_t || \cos(u_t,\epsilon_t) \\ & \approx ||\epsilon_t||^2 - 2\alpha_t p||u_t|| \cdot || \epsilon_t || \\ & = ||\epsilon_t||^2 - 2 p \sqrt{2\rho_tq} || \epsilon_t ||^2 \\ & = ||\epsilon_t||^2(1 - 2 p \sqrt{2\rho_tq} ) \\ & \approx ||\epsilon_t||^2 \exp(- 2 p \sqrt{2\rho_tq}) \end{aligned}\]

进一步递推得到:

\[||\epsilon_{t}||^2 \approx ||\epsilon_0||^2 \exp(- 2 \sum_{i=1}^{t-1} p \sqrt{2\rho_iq})\]

注意到上式满足的条件是\(\alpha_t^2\|u_t\|^2= 2\rho_tq \|\epsilon_t\|^2\),为使得$\alpha_t$和$\rho_t$具有同等程度地衰减,不妨设\(2\rho_tq = \lambda^2\|\epsilon\|^2\),联立解得:

\[\begin{aligned} \alpha_t &\approx \frac{\lambda ||\epsilon_t||^2}{||u_t||} \approx \frac{\lambda ||\epsilon_0||^2}{||u_t||} \exp(- 2 \sum_{i=1}^{t-1} p \sqrt{2\rho_iq}) \\ \rho_t &\approx \frac{\lambda^2 ||\epsilon_t||^2}{2q} \approx \frac{\lambda^2 ||\epsilon_0||^2}{2q} \exp(- 2 \sum_{i=1}^{t-1} p \sqrt{2\rho_iq}) \end{aligned}\]

此外由假设\(\alpha_t^2\|u_t\|^2= 2\rho_tq \|\epsilon_t\|^2\)和\(2\rho_tq = \lambda^2\|\epsilon\|^2\)可得\(\alpha_0\|u_0\|=\lambda \|\epsilon_0\|^2\),不妨假设初始时刻的更新向量$|u_0|=|\epsilon_0|$,则有$\lambda = \alpha_0 / |\epsilon_0|$。代入得:

\[\begin{aligned} \alpha_t &\approx \frac{\alpha_0 ||\epsilon_0||}{||u_t||} \exp(- 2 \sum_{i=1}^{t-1} p \sqrt{2\rho_iq}) \\ \rho_t &\approx \frac{\alpha_0^2}{2q} \exp(- 2 \sum_{i=1}^{t-1} p \sqrt{2\rho_iq}) \end{aligned}\]

因此,若自适应地设置$\alpha_t$和$\rho_t$,需要选择参数$\alpha_0, ||\epsilon_0||, p, q$。其中$\alpha_0$代表了每一步的相对更新幅度(全局学习率)一般取$1e−3$,任务简单也可以取到$1e−2$。$q=1$,其余参数的设置思路如下。

$||\epsilon_0||$定义为初始化参数与最优参数的距离,代表参数的变化尺度:

\[||\epsilon_0||=||\theta_0-\theta^*||\]

网络参数通常初始化为$0$均值, $\sigma^2$方差。因此对于参数\(\theta \Bbb{R}^k\),有$||\theta_0||^2\approx k\sigma^2$。在合理的初始化下,训练完成后参数的均值方差也不会有太大变化,因此有$||\theta^{*}||^2\approx k\sigma^2$。则有:

\[||\epsilon_0||^2=||\theta_0-\theta^*||^2 \approx k\sigma^2\]

此外对于全零初始化的参数(如偏置项或归一化层的reshift参数)或全一初始化的参数(如归一化层的rescale参数),不妨假设预测训练好的模型参数都在$±σ$附近,则也有$||\epsilon_0||^2 \approx k\sigma^2$。此时可设置$\sigma=0.5$。

至于参数$p$的取值,不妨分析参数$p$如何影响权重衰减函数,即求出$\rho_t$的解析近似。假设参数$p$是步数$t$的函数$p_t$,则有:

\[\begin{aligned} \sqrt{ 2\rho_t q} &\approx \alpha_0 \exp(- \sum_{i=1}^{t-1} p_i \sqrt{2\rho_iq}) \end{aligned}\]

把指数求和$\sum_{i=1}^{t-1}p_i\sqrt{2\rho_iq}$记为$S_t$,则上式对应一个差分方程:

\[S_{t+1}-S_t \approx \alpha_0 p_t \exp(-S_t)\]

近似于微分方程:

\[\frac{d S_t}{dt} \approx \alpha_0 p_t \exp(-S_t)\]

对上式进行调整并两端积分得:

\[\begin{aligned} \exp(S_t) \frac{d S_t}{dt} &= \alpha_0 p_t \\ \int_0^t \exp(S_t) \frac{d S_t}{dt} dt &= \int_0^t \alpha_0 p_t dt \\ \exp(S_t) - \exp(S_0) &= \alpha_0 \int_0^t p_t dt \end{aligned}\]

不妨取$p_t = p_0 \exp(-S_t)$,代入上式得:

\[\exp(S_t) - \exp(S_0) = \alpha_0 p_0 \int_0^t \exp(-S_t) dt\]

假设$S_0=0$,解上述微分方程得:

\[\exp(-2S_t) = \frac{1}{2 \alpha_0 p_0 t+1}\]

至此可以得到学习率$\alpha_t$和权重衰减项$\rho_t$的设置公式:

\[\begin{aligned} \alpha_t &\approx \frac{\alpha_0 ||\epsilon_0||}{||u_t||} \frac{1}{2 \alpha_0 p_0 t+1} \\ \rho_t &\approx \frac{\alpha_0^2}{2q} \frac{1}{2 \alpha_0 p_0 t+1} \end{aligned}\]

学习率和权重衰减都采用逆时间衰减(Inverse Time Decay)的形式。