Transformer结构中的层归一化.
0. TL; DR
本文研究了Transformer架构中层归一化(Layer Normalization)的位置对模型训练的影响。
- 传统的Transformer(Post-LN)将层归一化放在残差块之间,训练时需要精心设计的学习率预热阶段(Warm-Up),这会减慢优化速度并增加超参数调整的难度。
- 本文指出将层归一化放在残差块内部的Pre-LN Transformer在初始化时梯度表现良好,无需学习率预热阶段即可稳定训练。
1. 理论分析
TL; DR: Pre-LN结构通常更容易训练,但最终效果比Post-LN差
(1) Post-LN
Post-LN的形式如下:
\[x_{t+1} = \text{LayerNorm}(x_t + F_t(x_t))\]其中,$x_t$表示输入,$F_t(x_t)$表示自注意力机制,$\text{LayerNorm}$表示层归一化。
⚪ Post-LN的主要优点
① 稳定了前向传播的方差
如果$x$的方差为$σ^2_1$而$F(x)$的方差为$σ_2^2$,并且假设两者相互独立,则$x+F(x)$的方差为$σ^2_1+σ_2^2$,即残差会进一步放大方差。通过引入Post-LN能够稳定前向传播的数值,并且保持了每个模块的一致性。
② 微调性能更好
在微调阶段,通常希望优先调整靠近输出层的参数,不要过度调整靠近输入层的参数,以免严重破坏预训练效果。
由于Post-LN会带来一定的梯度消失问题(本质是削弱了残差,见下文),越靠近输入层的结果对最终输出的影响越弱,这正是微调时所希望的。所以预训练好的Post-LN会比Pre-LN有更好的微调性能。
⚪ Post-LN的主要缺点
① 削弱了残差的恒等分支
不失一般性地假设$σ_1=σ_2=1$,则$x+F(x)$的方差为$2$。LayerNorm将方差重新缩放到$1$,相当于在初始阶段进行操作:
\[\begin{aligned} x_{t+1} &= \frac{x_t+F(x_t)}{\sqrt{2}} \\ &= \frac{x_{t-1}+F(x_{t-1})}{(\sqrt{2})^2} + \frac{F(x_t)}{\sqrt{2}} \\ &= \cdots \\ &= \frac{x_0}{(\sqrt{2})^{t+1}} + \sum_{i=0}^t \frac{F(x_i)}{(\sqrt{2})^{t+1-i}} \\ \end{aligned}\]此时残差的恒等分支以幂函数的形式被削弱了,因此Post-LN失去了残差“易于训练”的优点,通常需要Warm-Up并设置足够小的学习率才能使训练过程收敛。
(2) Pre-LN
Pre-LN的形式如下:
\[x_{t+1} = x_t + F_t(\text{LayerNorm}(x_t))\]⚪ Pre-LN的主要优点
① 突出了残差路径的作用
Pre-LN可以展开为:
\[\begin{aligned} x_{t+1} &= x_t + F_t(\text{LayerNorm}(x_t)) \\ &= x_{t-1} + F_{t-1}(\text{LayerNorm}(x_{t-1})) + F_t(\text{LayerNorm}(x_t))\\ &= \cdots \\ &= x_0 + \sum_{i=0}^t F_i(\text{LayerNorm}(x_i)) \\ \end{aligned}\]残差的作用会更加明显,所以Pre-LN更好优化。但是最后的$x_t$方差将会很大,所以在接预测层之前$x_t$还要加个LayerNorm。
⚪ Pre-LN的主要缺点
① 实际等效层数减少
随着层数的加深,$x_t$和$x_{t+1}$之间的差异减小,因此近似有:
\[\begin{aligned} & F_{t+1}(\text{LayerNorm}(x_{t+1})) + F_t(\text{LayerNorm}(x_t)) \\ \approx & F_{t+1}(\text{LayerNorm}(x_t)) + F_t(\text{LayerNorm}(x_t)) \\ = &\begin{pmatrix} 1 & 1 \end{pmatrix}\begin{pmatrix} F_{t+1} \\ F_t \end{pmatrix} \begin{pmatrix}\text{LayerNorm}(x_t) \end{pmatrix} \end{aligned}\]当$t$比较大时,$x_t,x_{t+1}$相差较小,所以原本一个$t$层的模型与$t+1$层的模型的和,近似等效于一个更宽的$t$层模型。
Pre-LN增加了模型的宽度而降低了模型的深度,由于深度通常比宽度更重要,因此Pre-LN的最终训练效果通常不如Post-LN。