SM3:内存高效的自适应优化算法.

Adam等基于自适应梯度的优化器需要维护模型参数的二阶统计信息,带来大量内存开销,从而限制所使用模型的大小以及批量大小。本文提出了一种减少内存开销的自适应优化方法,允许训练更大的模型和使用更大的批量,实验表明该方法将训练大型翻译和语言模型的速度提高了两倍。

假设待优化的模型参数共有$d$个,自适应梯度方法计算并累积每一个参数截止到当前训练轮数$t$的梯度二阶矩:

\[\gamma_t(i) = \sum_{s=1}^{t} g_s^2(i), \quad \forall i \in [d]\]

在更新参数时使用二阶矩调整梯度的大小:

\[w_{t+1}(i) = w_t(i) - \eta \frac{g_t(i)}{\sqrt{\gamma_t(i)}}, \quad \forall i \in [d]\]

注意到在上述更新过程中,需要$O(d)$的内存空间存储梯度二阶矩,从而加重内存负担。

在作者提出的内存高效自适应算法中,引入了$k$个非空集合\(\{S_r\}_{r=1}^k\),每个参数索引$i \in [d]$存储在其中的若干个集合$S_r$中。在参数更新中,对于每个集合计算一个标量值,因此该算法需要$O(k)$的内存空间。通常$k «d$,从而达到节省内存的目的。

对于每一个集合$S_r$,算法存储一个移动求和项$\mu(r)$,用于累积每轮训练中该集合对应参数的最大梯度方差:

\[\mu_{t}(r) = \mu_{t-1}(r)+\mathop{\max}_{j \in S_r}g_t^2(j)\]

对于每一个参数$w(i)$,将其所存在集合的方差累计的最小值作为更新量$\nu_t(i)$:

\[\nu_t(i) = \mathop{\min}_{r: S_r\ni j} \mu_{t}(r)\] \[w_{t+1}(i) = w_t(i) - \eta \frac{g_t(i)}{\sqrt{\nu_t(i)}}, \quad \forall i \in [d]\]

由于该算法计算的是平方梯度(squared-gradient)的最大值(maxima)的和(sums)的最小值(minima)的平方根(square-root),因此称其为SM3算法。

当$k=d$,\(S_i=\{i\}\)时,该算法退化为标准的梯度自适应算法,计算复杂度最大。

SM3算法可以改写成如下更高效的形式。

对于尺寸为$d=m \times n$的模型参数,可以按照其行和列划分集合,共得到$m+n$个集合,从而把内存占用从$O(mn)$减小至$O(m+n)$。此时该算法与Adafactor类似。

下图展示了在机器翻译任务上不同优化器的性能表现。当批量从$384$提高到$768$时,SM3算法的表现更加突出。

作者也在图像分类任务上进行实验,结果表明SM3算法也能更快地收敛到极值。