Radam:修正Adam算法中自适应学习率的早期方差.

对于Adam等自适应随机优化算法,学习率warmup(即在最初几次训练中使用较小的学习率)能够稳定训练、加速收敛、提高泛化能力。作者认为在这些自适应优化算法中,自适应学习率的方差在早期阶段比较大,使用warmup能够减少方差。作者进一步提出了一种修正的Adam算法(rectified Adam, Radam),显式地修正自适应学习率的方差。

1. 发现问题:Adam与warmup

下图展示了使用TransformerDE-EN WSLT’14数据集上训练神经机器翻译模型时的损失曲线,当移除学习率warmup时,训练损失从$3$增加到$10$,这说明自适应优化算法在没有warmup时可能会收敛到不好的局部最优解。

作者进一步绘制了每轮训练中梯度绝对值的直方图,并沿Y轴将所有轮数的梯度直方图堆叠起来。图中结果显示,当不采用warmup时在训练早期的梯度分布会发生比较明显的扭曲,这意味着经过几轮更新后可能陷入了糟糕的局部最优。当应用warmup后,梯度绝对值的分布相对稳定。

2. 分析问题:自适应学习率的方差

Adam算法使用了偏差修正过的指数滑动平均的动量$m_t$和自适应学习率$l_t$:

\[m_t = \frac{\beta_1m_{t-1}+(1-\beta_1)g_t}{1-\beta_1^t} = \frac{(1-\beta_1)\sum_{i=1}^{t}\beta_1^{t-i}g_{i}}{1-\beta_1^t}\] \[l_t = \sqrt{\frac{1-\beta_2^t}{\beta_2m_{t-1}+(1-\beta_2)g_t^2}} = \sqrt{\frac{1-\beta_2^t}{(1-\beta_2)\sum_{i=1}^{t}\beta_2^{t-i}g_{i}^2}}\] \[\theta_t = \theta_{t-1} - \alpha_t m_t l_t\]

作者认为,由于训练早期阶段处理的样本量较少,导致自适应学习率的方差过大。以$t=1$为例,对应的自适应学习率$l_1 = \sqrt{\frac{1}{g_{1}^2}}$。假设$g_1$服从高斯分布$\mathcal{N}(0,\sigma^2)$,则$\text{Var}[\sqrt{\frac{1}{g_{1}^2}}]$是发散的。作者假设Adam等自适应随机优化算法的不收敛问题来自更新早期阶段自适应学习率的无界方差。

为了进一步验证上述假设,作者设计了两种Adam的变体,均可以降低早期自适应学习率的方差。

上述两种方法均减小了训练过程中梯度分布的失真,这进一步证明了通过减小自适应学习率的方差可以缓解收敛问题。

3. 自适应学习率$l_t$的方差公式$\text{Var}[l_t]$

下面寻找自适应学习率$l_t$的方差公式$\text{Var}[l_t]$。为简化讨论,用简单平均SMA代替自适应学习率计算中的指数滑动平均EMA

\[l_t = \sqrt{\frac{1-\beta_2^t}{(1-\beta_2)\sum_{i=1}^{t}\beta_2^{t-i}g_{i}^2}} ≈\sqrt{\frac{t}{\sum_{i=1}^{t}g_{i}^2}}\]

仍然假设$g_t$服从高斯分布$\mathcal{N}(0,\sigma^2)$,则变量$\frac{t}{\sum_{i=1}^{t}g_{i}^2}$服从scaled inverse chi-square分布$\chi^2(t,\frac{1}{\sigma^2})$,作者假设变量$\frac{1-\beta_2^t}{(1-\beta_2)\sum_{i=1}^{t}\beta_2^{t-i}g_{i}^2}$也服从类似的分布$\chi^2(\rho_t,\frac{1}{\sigma^2})$,下面求解$\rho_t$。

对于$\chi^2(\rho_t,\frac{1}{\sigma^2})$,构造一个服从该分布的随机变量$\frac{\rho_t}{\sum_{i=1}^{\rho_t}g_{t+1-i}^2}$。若两个随机变量$\frac{1-\beta_2^t}{(1-\beta_2)\sum_{i=1}^{t}\beta_2^{t-i}g_{i}^2}$和$\frac{\rho_t}{\sum_{i=1}^{\rho_t}g_{t+1-i}^2}$具有相同的分布,则对于$g_i^2=i$,有:

\[\frac{1-\beta_2^t}{(1-\beta_2)\sum_{i=1}^{t}\beta_2^{t-i}{i}} = \frac{\rho_t}{\sum_{i=1}^{\rho_t}{t+1-i}}\]

解上式得:

\[\rho_t = \frac{2}{1-\beta_2} - 1 - \frac{2t\beta_2^t}{1-\beta_2^t}\]

特别地,有:

\[\rho_∞ = \frac{2}{1-\beta_2} - 1\]

对于分布$l_t^2$~$\chi^2(\rho_t,\frac{1}{\sigma^2})$,当$\rho_t>4$时,有如下结论:

\[\text{Var}[l_t] = \frac{1}{\sigma^2}(\frac{\rho_t}{\rho_t-2}-\frac{\rho_t 2^{2\rho_t-5}}{\pi}\mathcal{B}(\frac{\rho_t-1}{x},\frac{\rho_t-1}{x})^2)\]

其中$\mathcal{B}(\cdot)$是beta函数。注意到$\text{Var}[l_t]$随$\rho_t$增加而单调减少。当$\rho = \rho_∞$时取得最小方差:

\[\text{Var}[l_t]_{min} =\text{Var}[l_t]|_{\rho = \rho_∞}\]

4. 修正自适应学习率的方差

根据上述分析,自适应学习率$l_t$的方差具有最小值\(\text{Var}[l_t]\|_{\rho = \rho_∞}\)。在每轮更新中,为自适应学习率引入修正系数$r_t$,从而控制每轮的自适应学习率$r_tl_t$的方差均为最小方差:

\[\text{Var}[r_tl_t] = \text{Var}[l_t]|_{\rho = \rho_∞}\]

因此修正系数$r_t$计算为:

\[r_t = \sqrt{\frac{\text{Var}[l_t]|_{\rho = \rho_∞}}{\text{Var}[l_t]}}\]

方差$\text{Var}[l_t]$的计算使用一阶近似:

\[\text{Var}[l_t] = \frac{1}{\sigma^2}(\frac{\rho_t}{\rho_t-2}-\frac{\rho_t 2^{2\rho_t-5}}{\pi}\mathcal{B}(\frac{\rho_t-1}{x},\frac{\rho_t-1}{x})^2) \\ ≈ \frac{\rho_t}{2(\rho_t-2)(\rho_t-4)\sigma^2}\]

下图展示了解析形式和一阶近似的曲线,两者的差异远小于它们本身的值,这说明一阶近似是足够精确的。

则修正系数$r_t$计算为:

\[r_t = \sqrt{\frac{(\rho_t-4)(\rho_t-2)\rho_∞}{(\rho_∞-4)(\rho_∞-2)\rho_t}}\]

注意到上述推导仅在$\rho_t>4$时成立。

作者进一步进行了仿真实验,从$\mathcal{N}(\mu,1)$中采样$g_t$,绘制方差$\text{Var}[l_t]$和$\text{Var}[r_tl_t]$随更新轮数的变化曲线。仿真结果显示自适应学习率在早期阶段具有较大的方差,而校正后的自适应学习率具有相对一致的方差。

5. Radam

综上所述,Radam算法的流程如下。

首先计算自适应学习率的平方$l_t^2$的最大自由度:

\[\rho_∞ = \frac{2}{1-\beta_2} - 1\]

对于第$t$轮更新,计算自适应学习率的平方$l_t^2$的自由度:

\[\rho_t = \rho_∞ - \frac{2t\beta_2^t}{1-\beta_2^t}\]

当$\rho_t>4$时,对自适应学习率进行修正。计算修正系数$r_t$:

\[r_t = \sqrt{\frac{(\rho_t-4)(\rho_t-2)\rho_∞}{(\rho_∞-4)(\rho_∞-2)\rho_t}}\]

则参数更新为:

\[\theta_t = \theta_{t-1} - \alpha_t m_t r_tl_t\]

当$\rho_t\leq 4$时,只使用动量$m_t$进行更新:

\[\theta_t = \theta_{t-1} - \alpha_t m_t\]

5. 实验分析

作者在语言模型和图像分类任务上分别测试了Radam算法的性能。在语言模型上,虽然修正项使Radam在最初的几次更新中比Adam慢,但它允许Radam在之后更快地收敛,并获得更好的性能。在图像分类上,尽管Radam的测试精度并没有没有优于SGD,但它会带来更好的训练精度。

作者测试了在不同学习率设置下Radam的表现。实验结果表明,通过校正自适应学习率的方差,Radam提高了模型在较大范围内的学习率设置下训练的鲁棒性,实现了一致的模型性能;而AdamSGD对学习率更敏感。

作者也比较了Radamwarmup的效果。warmup具有更多的超参数,对预热轮数和学习率的选择比较敏感。例如,将学习率设置为$0.1$时,$100$轮预热的Adam精度为$90.13$,而Radam的精度为$91.06$,且具有更少的超参数。