Grokking:在小规模算术数据集上的过拟合外泛化.
0. TL; DR
本文研究了神经网络在小规模算法生成数据集上的泛化能力,揭示了一种名为“Grokking”的现象,即神经网络在严重过拟合之后,验证准确率会突然从随机水平提升到完美泛化。这种现象表明,即使在过参数化的神经网络中,泛化能力也可能在训练数据完全记忆之后很久才出现。此外,作者还发现数据集大小对泛化的影响,以及优化细节(如权重衰减)对数据效率的显著提升作用。
1. 背景介绍
深度学习中,过参数化的神经网络(即参数数量远超训练样本数量的网络)的泛化能力一直是研究热点。根据经典学习理论,这类网络本应过拟合训练数据而无法泛化到新的数据,但实际中它们却常常表现出良好的泛化性能。这种现象在自然数据集上已经得到了广泛研究,但在小规模算法生成数据集上,泛化行为可能更加明显且易于分析。
本文作者通过在小规模算法生成数据集上训练神经网络,研究了数据效率、记忆、泛化和学习速度等问题。这些数据集由二元运算表构成,例如加法、乘法或置换群的组合运算。作者发现,在某些情况下,神经网络在严重过拟合训练数据之后,验证准确率会突然提升,这种现象被命名为“Grokking”。
2. 方法介绍
本文使用的数据集是二元运算表,形式为 $a \circ b = c$,其中 $a, b, c$ 是离散符号,$\circ$ 是二元运算。这些运算包括模加法、模乘法、置换群的组合运算等。训练神经网络时,仅使用部分可能的等式作为训练集,其余作为验证集。这种设置类似于解数独谜题,网络需要“填补空白”以完成整个二元运算表。
作者使用了一个小型Transformer模型进行实验。该模型是一个两层的解码器结构,宽度为128,包含4个注意力头,总参数量约为 $4 \times 10^5$ 个非嵌入参数。模型的输入是等式的各个符号,输出是等式的结果符号。
训练过程中,作者使用了AdamW优化器,学习率为 $10^{-3}$,权重衰减系数为1。训练时采用因果注意力掩码,并仅在等式的答案部分计算损失和准确率。为了研究优化细节对泛化的影响,作者还尝试了不同的优化方法,包括全批量梯度下降、随机梯度下降、不同的学习率、残差dropout和权重衰减等。
3. 实验分析
实验中,作者观察到在某些二元运算任务上,神经网络在训练准确率达到接近完美的水平后,验证准确率会在经过大量优化步骤后突然提升。例如,在模除法任务中,训练准确率在不到 $10^3$ 步时接近完美,但验证准确率直到接近 $10^6$ 步时才达到高水平。这种现象表明,即使在严重过拟合之后,网络仍然能够学习到数据中的模式并实现泛化。
作者发现,随着训练数据集大小的减小,达到99%验证准确率所需的优化步骤数量急剧增加。例如,在使用抽象群 $S_5$ 的乘法任务中,当训练数据比例从30%降至25%时,达到99%验证准确率所需的中位优化步骤数增加了40-50%。这表明,对于小规模数据集,优化难度随着数据量的减少而迅速增加。
作者在多种二元运算任务上测试了Grokking现象。结果显示,对称运算(如 $x + y$ 和 $x \cdot y$)通常比非对称运算(如 $x - y$ 和 $x / y$)更容易泛化。此外,一些复杂的混合运算(如 $[x/y \mod p \text{ if } y \text{ is odd, otherwise } x - y \mod p]$)也表现出Grokking现象,表明网络能够学习到多种简单运算的组合。
作者尝试了多种正则化方法,包括权重衰减、残差dropout和梯度噪声等。实验结果表明,权重衰减对数据效率的提升最为显著,能够将所需样本量减少一半以上。此外,添加噪声(如梯度噪声)也有助于泛化,这与噪声可能诱导优化过程找到更平坦的最小值有关。
为了深入了解网络的泛化能力,作者对训练后的网络输出层权重进行了可视化。例如,在模加法任务中,网络的嵌入空间呈现出“数轴”的结构,反映了模加法的循环拓扑。在 $S_5$ 任务中,网络的嵌入空间形成了置换群的陪集结构。这些可视化结果表明,网络能够学习到数据中隐藏的数学结构。