层归一化和动态激活函数之间的数学关系.

语言模型中常用的归一化函数是LayerNormRMSNormLayerNorm针对每个训练样本计算所有特征的均值和方差,对输入归一化后进行re-scalingre-shiftingRMSNorm去掉了LayerNorm中的均值计算和re-shifting操作,进一步减少了计算负担。本文从梯度近似的角度设计了RMSNorm归一化的替代函数DyISRU,从而在网络中去掉了归一化层。

1. 分析RMSNorm的梯度

RMSNorm的公式如下:

\[\mathbf{y} = \frac{\mathbf{x}}{||\mathbf{x}||/\sqrt{d}} = \sqrt{d}\frac{\mathbf{x}}{||\mathbf{x}||}\]

其中$\mathbf{x} \in R^d$,$|\mathbf{x}|$表示$\mathbf{x}$的L2范数:

\[||\mathbf{x}|| = \sqrt{\sum_{i=1}^d x_i^2}\]

L2范数的导数可以通过链式法则求解:

\[\begin{aligned} \nabla_{\mathbf{x}} \|\mathbf{x}\| &= \nabla_{\mathbf{x}} \left(\sqrt{\mathbf{x}^\top \mathbf{x}}\right) \\ &= \frac{1}{2\sqrt{\mathbf{x}^\top \mathbf{x}}} \nabla_{\mathbf{x}} \left(\mathbf{x}^\top \mathbf{x}\right) \\ &= \frac{\mathbf{x}}{\|\mathbf{x}\|} \end{aligned}\]

进而有:

\[\begin{aligned} \nabla_{\mathbf{x}} \frac{\mathbf{x}}{||\mathbf{x}||} &= \frac{||\mathbf{x}|| - \mathbf{x} \cdot \frac{\mathbf{x}}{\|\mathbf{x}\|}}{||\mathbf{x}||^2} \\ &= \frac{I}{||\mathbf{x}||} - \frac{\mathbf{x} \mathbf{x}^\top}{||\mathbf{x}||^3} \end{aligned}\]

RMSNorm的梯度可以表示为:

\[\begin{aligned} \nabla_{\mathbf{x}} \mathbf{y} &= \sqrt{d}\nabla_{\mathbf{x}} \frac{\mathbf{x}}{||\mathbf{x}||} \\ &= \sqrt{d}\left( \frac{I}{||\mathbf{x}||} - \frac{\mathbf{x} \mathbf{x}^\top}{||\mathbf{x}||^3} \right) \\ &= \frac{\sqrt{d}}{||\mathbf{x}||}\left( I - \frac{\mathbf{y} \mathbf{y}^\top}{d} \right) \end{aligned}\]

2. 寻找RMSNorm的替代

现在的目标的寻找一个函数$\mathbf{y}=f(\mathbf{x})$,若该函数能够近似RMSNorm的梯度,则$f$能够替代归一化层的使用,从而实现在网络中去掉归一化层的目标。

假设$\mathbf{y}=f(\mathbf{x})$是逐元素操作,即$y_i=f(x_i)$,则$f$的梯度需满足:

\[\frac{d y_i}{d x_i} = \rho\left( 1 - \frac{y_i^2}{d} \right)\]

其中$\rho=\sqrt{d} / ||\mathbf{x}||$。若假设$\rho$为常数,直接求解上述微分方程得到:

\[y_i = \sqrt{d} \tanh \left( \frac{x_i}{\rho \sqrt{d}} \right)\]

因此取$f(x) \sim \tanh(x)$,可以实现RMSNorm的替代,这便是Dynamic Tanh的结论。

3. DyISRU函数

注意到:

\[\rho=\frac{\sqrt{d}}{||\mathbf{x}||}= \frac{\mathbf{y}}{\mathbf{x}} = \frac{y_i}{x_i}\]

则$f$的梯度需满足:

\[\frac{d y_i}{d x_i} = \frac{y_i}{x_i} \left( 1 - \frac{y_i^2}{d} \right)\]

求解上述微分方程得到:

\[y_i = \frac{\sqrt{d} x_i}{\sqrt{x_i^2+C}}\]

其中$C$为常数。上式形如Inverse Square Root Unit (ISRU),因此被称为DyISRU