Fixup初始化: 没有归一化的残差学习.
0. TL; DR
本文提出了一种名为Fixup的初始化方法,用于训练深度残差网络,而无需依赖于归一化层(如批量归一化)。通过重新调整标准初始化的尺度,Fixup能够解决训练初期的梯度爆炸和消失问题。实验表明,使用Fixup的残差网络在训练稳定性、收敛速度和泛化能力上与使用归一化的网络相当,甚至在某些任务上表现更好。该方法在图像分类和机器翻译任务上均取得了优异的性能。
1. 背景介绍
近年来,深度神经网络在人工智能领域取得了显著进展,尤其是在图像识别、机器翻译等任务上。残差网络(ResNet)作为一种重要的网络架构,通过引入残差连接解决了深层网络训练中的梯度消失问题。然而,残差网络通常依赖于归一化层(如批量归一化、层归一化等)来稳定训练、提高学习率、加速收敛并改善泛化能力。尽管归一化技术在实践中取得了巨大成功,但其背后的原理仍不完全清楚。
本文的核心目标是挑战“归一化是训练深度残差网络不可或缺的”这一普遍观点。作者通过理论分析和实验验证,提出了一种新的初始化方法——Fixup,它能够在不使用任何归一化层的情况下,实现与归一化网络相当的训练效果和性能。
2. Fixup初始化方法
在标准初始化方法下,残差网络的输出方差会随着网络深度呈指数增长,导致梯度爆炸。作者通过分析残差网络的梯度范数,发现其在初始化时的梯度范数下界与交叉熵损失相关,这表明在没有归一化的情况下,标准初始化会导致梯度爆炸。具体来说,对于一个残差网络,其输出可以表示为:
\[x_l = x_0 + \sum_{i=0}^{l-1} F_i(x_i)\]其中,$F_i$ 是第 $i$ 个残差分支。假设每个残差分支的输出方差与输入方差相近,那么在没有归一化的情况下,输出方差会随着深度呈指数增长,从而导致梯度爆炸。
Fixup的核心思想是通过重新调整残差分支的权重初始化,使得每个残差分支对网络输出的更新幅度与网络深度无关。具体步骤如下:
- 初始化分类层和残差分支的最后一层权重为$0$:这有助于稳定训练初期的输出。
- 对残差分支内的权重层进行重新缩放:具体来说,将残差分支内的权重层按 $L^{-\frac{1}{2(m-2)}}$ 缩放,其中 $L$ 是网络深度,$m$ 是残差分支内的层数。这种缩放方式可以确保每个残差分支对网络输出的更新幅度为 $Θ(η/L)$,从而使得整个网络的更新幅度为 $Θ(η)$。
- 添加标量乘数和偏置:在每个残差分支中添加一个标量乘数(初始化为$1$),并在每个卷积层、线性层和激活层前添加一个标量偏置(初始化为$0$)。这些参数有助于进一步调整网络的表示能力。
3. 实验分析
作者通过实验验证了Fixup在训练深度残差网络时的有效性。在CIFAR-10数据集上,使用宽残差网络(WRN)架构,Fixup能够在高达10,000层的网络中稳定训练,并且在第一个epoch后的测试精度与批量归一化相当。这表明Fixup能够在不使用归一化的情况下,实现与归一化网络相同的训练速度和稳定性。
在图像分类任务中,作者在CIFAR-10和ImageNet数据集上进行了实验。在CIFAR-10上,使用ResNet-110架构,Fixup在标准初始化的基础上取得了7%的相对改进。然而,与批量归一化相比,Fixup在训练初期存在过拟合问题。通过引入更强的正则化方法(如Mixup),Fixup能够取得与批量归一化相当的性能。在ImageNet上,使用ResNet-50和ResNet-101架构,Fixup在加入Mixup正则化后,测试误差与批量归一化相当,甚至在某些情况下表现更好。
作者还将Fixup应用于机器翻译任务,使用Transformer模型在IWSLT德英翻译和WMT英德翻译数据集上进行实验。实验结果表明,Fixup在不使用层归一化的情况下,能够稳定训练,并且在IWSLT德英翻译任务上取得了34.5的BLEU分数,优于层归一化的34.2;在WMT英德翻译任务上,Fixup取得了29.3的BLEU分数,与层归一化的29.3相当。这表明Fixup在机器翻译任务上也具有良好的泛化能力。