Transformer结构中的层归一化.

0. TL; DR

本文研究了Transformer架构中层归一化(Layer Normalization)的位置对模型训练的影响。

1. 理论分析

TL; DR: Pre-LN结构通常更容易训练,但最终效果比Post-LN差

(1) Post-LN

Post-LN的形式如下:

\[x_{t+1} = \text{LayerNorm}(x_t + F_t(x_t))\]

其中,$x_t$表示输入,$F_t(x_t)$表示自注意力机制,$\text{LayerNorm}$表示层归一化。

⚪ Post-LN的主要优点

① 稳定了前向传播的方差

如果$x$的方差为$σ^2_1$而$F(x)$的方差为$σ_2^2$,并且假设两者相互独立,则$x+F(x)$的方差为$σ^2_1+σ_2^2$,即残差会进一步放大方差。通过引入Post-LN能够稳定前向传播的数值,并且保持了每个模块的一致性。

② 微调性能更好

在微调阶段,通常希望优先调整靠近输出层的参数,不要过度调整靠近输入层的参数,以免严重破坏预训练效果。

由于Post-LN会带来一定的梯度消失问题(本质是削弱了残差,见下文),越靠近输入层的结果对最终输出的影响越弱,这正是微调时所希望的。所以预训练好的Post-LN会比Pre-LN有更好的微调性能。

⚪ Post-LN的主要缺点

① 削弱了残差的恒等分支

不失一般性地假设$σ_1=σ_2=1$,则$x+F(x)$的方差为$2$。LayerNorm将方差重新缩放到$1$,相当于在初始阶段进行操作:

\[\begin{aligned} x_{t+1} &= \frac{x_t+F(x_t)}{\sqrt{2}} \\ &= \frac{x_{t-1}+F(x_{t-1})}{(\sqrt{2})^2} + \frac{F(x_t)}{\sqrt{2}} \\ &= \cdots \\ &= \frac{x_0}{(\sqrt{2})^{t+1}} + \sum_{i=0}^t \frac{F(x_i)}{(\sqrt{2})^{t+1-i}} \\ \end{aligned}\]

此时残差的恒等分支以幂函数的形式被削弱了,因此Post-LN失去了残差“易于训练”的优点,通常需要Warm-Up并设置足够小的学习率才能使训练过程收敛。

(2) Pre-LN

Pre-LN的形式如下:

\[x_{t+1} = x_t + F_t(\text{LayerNorm}(x_t))\]

⚪ Pre-LN的主要优点

① 突出了残差路径的作用

Pre-LN可以展开为:

\[\begin{aligned} x_{t+1} &= x_t + F_t(\text{LayerNorm}(x_t)) \\ &= x_{t-1} + F_{t-1}(\text{LayerNorm}(x_{t-1})) + F_t(\text{LayerNorm}(x_t))\\ &= \cdots \\ &= x_0 + \sum_{i=0}^t F_i(\text{LayerNorm}(x_i)) \\ \end{aligned}\]

残差的作用会更加明显,所以Pre-LN更好优化。但是最后的$x_t$方差将会很大,所以在接预测层之前$x_t$还要加个LayerNorm

⚪ Pre-LN的主要缺点

① 实际等效层数减少

随着层数的加深,$x_t$和$x_{t+1}$之间的差异减小,因此近似有:

\[\begin{aligned} & F_{t+1}(\text{LayerNorm}(x_{t+1})) + F_t(\text{LayerNorm}(x_t)) \\ \approx & F_{t+1}(\text{LayerNorm}(x_t)) + F_t(\text{LayerNorm}(x_t)) \\ = &\begin{pmatrix} 1 & 1 \end{pmatrix}\begin{pmatrix} F_{t+1} \\ F_t \end{pmatrix} \begin{pmatrix}\text{LayerNorm}(x_t) \end{pmatrix} \end{aligned}\]

当$t$比较大时,$x_t,x_{t+1}$相差较小,所以原本一个$t$层的模型与$t+1$层的模型的和,近似等效于一个更宽的$t$层模型。

Pre-LN增加了模型的宽度而降低了模型的深度,由于深度通常比宽度更重要,因此Pre-LN的最终训练效果通常不如Post-LN