批归一化使深度网络中的残差块偏向于恒等函数.
0. TL; DR
本文探讨了批归一化(Batch Normalization)在深度残差网络中的作用,指出其关键优势在于初始化时将残差分支相对于跳跃连接(skip connection)缩小,使得网络函数接近恒等函数,从而便于训练。基于此,作者提出了一种名为“SkipInit”的简单初始化方法,无需归一化即可训练深度残差网络。此外,文章还对批归一化在不同批量大小下的学习率优势进行了详细分析,指出其在小批量训练中并无显著优势。
1. 背景介绍
深度残差网络(ResNet)在图像识别等任务中取得了巨大成功,而批归一化(Batch Normalization)是其不可或缺的组成部分。批归一化通过归一化隐藏层的激活值,显著提高了可训练的最大网络深度,从而推动了深度学习在多个领域的进步。然而,关于批归一化为何能如此有效地提升训练效果,其背后的原理尚未完全清晰。
本文的核心目标是揭示批归一化在深度残差网络中的关键作用,并探索在不使用归一化的情况下如何训练深度残差网络。作者通过理论分析和实验验证,提出了一个简单而有效的初始化方案——SkipInit,该方案能够在不使用归一化的情况下实现与批归一化网络相当的训练效果。
2. SkipInit 方法
在残差网络中,每个残差块包含一个残差分支和一个跳跃连接。残差分支由多个卷积层、归一化层和非线性激活函数组成,而跳跃连接通常是恒等函数。作者通过分析隐藏层激活值的方差,揭示了批量归一化在初始化时对残差分支的影响。
\[x_l = x_0 + \sum_{i=0}^{l-1} F_i(BN(x_i))\]对于未归一化的网络,残差分支的输出方差与输入方差相近,导致残差块的输出方差呈指数增长,这使得网络难以训练。而对于批量归一化的网络,由于归一化操作,残差分支的输出方差被抑制到接近1,从而使得残差块的输出主要由跳跃连接决定,即网络函数接近恒等函数。这种特性确保了网络在初始化时具有良好的梯度传播,便于训练。
基于上述分析,作者提出了SkipInit初始化方法。该方法的核心思想是在每个残差分支的末尾引入一个可学习的标量乘数$α$,并在初始化时将其设置为$0$或一个较小的常数$1/\sqrt{d}$($d$是残差块的数量)。这样,在初始化时,残差分支的贡献被显著缩小,使得残差块的输出接近跳跃连接,从而实现了与批归一化类似的效果。
\[x_{t+1} = x_t + \alpha \cdot F_t(x_t)\]SkipInit的具体实现非常简单,只需在残差分支的末尾添加一个标量乘数,这一改动可以在不使用归一化的情况下训练深度残差网络,并且在实验中取得了与批量归一化网络相当的性能。
3. 实验分析
作者在CIFAR-10数据集上对不同深度的残差网络进行了实验,验证了SkipInit的有效性。实验结果表明,使用SkipInit的网络能够在高达1000层的深度下稳定训练,并且在测试精度上与批归一化网络相当。这表明SkipInit能够在不使用归一化的情况下,实现与批归一化网络相同的训练效果。
作者进一步研究了批归一化在不同批量大小下的性能。实验结果表明,当批量大小较小时,批归一化和SkipInit的最优学习率相似,且均远小于最大稳定学习率。这表明在小批量训练中,使用大学习率并无显著优势。而当批量大小较大时,批归一化网络能够使用更大的学习率,从而提高训练效率。这表明批归一化的主要优势在于其能够提高最大稳定学习率,从而在大批量训练中发挥优势。
作者还研究了批归一化的正则化效果。实验结果表明,批归一化在小批量训练中具有一定的正则化效果,能够提高测试精度。通过引入额外的正则化方法(如Dropout),SkipInit也能够在小批量训练中取得与批归一化相当的性能。
在ImageNet数据集上,作者对批归一化、SkipInit和Fixup进行了比较。实验结果表明,在标准批量大小下,SkipInit和Fixup能够与批归一化取得相当的性能。然而,当批量大小较大时,批归一化仍然表现最佳。这进一步验证了批归一化在大批量训练中的优势。