为什么GPT可以上下文学习?语言模型隐式地作为元优化器执行梯度下降.

0. TL; DR

本文探讨了大型预训练语言模型(如GPT)在上下文学习(ICL)中的工作机制。研究发现,GPT可以通过隐式优化(即不更新参数)来实现类似微调(finetuning)的效果。具体来说,GPT通过演示示例生成“元梯度”(meta-gradients),并通过注意力机制将这些元梯度应用于原始模型,从而构建一个上下文学习模型。实验结果表明,上下文学习的行为与显式微调非常相似。此外,本文还提出了一种基于动量的注意力机制,进一步验证了对上下文学习的理解,并展示了其在模型设计中的潜力。

1. 背景介绍

近年来,大型预训练语言模型(如GPT)在自然语言处理(NLP)领域取得了显著进展。这些模型通过上下文学习(ICL)在新任务上表现出色,即通过推理而非微调来完成任务。与微调需要额外的参数更新不同,ICL只需要几个演示示例,模型就可以预测未见过的输入的标签。尽管ICL在性能上取得了巨大成功,但其工作机制仍是一个未解之谜。

本文旨在解释GPT如何实现上下文学习,并将其视为隐式优化过程。具体来说,本文将ICL视为隐式微调,并通过理论分析和实验验证了这一观点。研究发现,Transformer模型中的注意力机制与梯度下降具有对偶形式,这为理解ICL提供了新的视角。

2. 方法介绍

本文探讨了Transformer模型中的注意力机制与梯度下降之间的对偶关系。具体来说,注意力机制可以被视为一种隐式的优化过程,其中注意力值被视为“元梯度”,用于更新模型的参数。

梯度下降在进行优化时可以表示为:

F(x)=(W0+ΔW)x

其中W0 是初始化的参数矩阵,ΔW 是通过梯度下降更新的参数矩阵,x 是输入表示。

在反向传播中,参数更新 ΔW 是通过历史输入表示 xi 和对应输出的误差信号 ei 的外积累加得到的:

ΔW=ieixi

结合上述两个公式,可以得到梯度下降的输出:

F(x)=W0x+ΔWx=W0x+ieixix=W0x+iei(xix)=W0x+LinearAttn(E,X,x)

其中E 是历史输出误差信号,作为值(values);X 是历史输入,作为键(keys);x 是当前输入,作为查询(query)。

在上下文学习(ICL)中,Transformer的注意力机制可以被视为一种隐式的优化过程。具体来说,给定一个查询输入 x,其注意力查询向量为 q=WQx,注意力结果可以表示为:

FICL(q)=Attn(V,K,q)=WV[X;X]softmax((WK[X;X])qd)

为了简化分析,作者将标准注意力近似为线性注意力,去除softmax操作和缩放因子:

FICL(q)WV[X;X](WK[X;X])q=WVX(WKX)q+WVX(WKX)q

定义 WZSL=WVX(WKX) 作为零样本学习(ZSL)中初始化的参数,因为 WZSLq 是没有演示示例时的注意力结果。根据线性层优化的对偶形式,可以推导出Transformer注意力的对偶形式:

FICL(q)=WZSLq+WVX(WKX)q=WZSLq+LinearAttn(WVX,WKX,q)=WZSLq+iWVxi(WKxi)q=WZSLq+ΔWICLq=(WZSL+ΔWICL)q

其中ΔWICL 是通过演示示例计算的参数更新,类似于梯度下降中的 ΔWWVX 被视为元梯度(meta-gradients),用于计算更新矩阵 ΔWICL

基于上述对Transformer注意力的分析,作者进一步比较了上下文学习(ICL)和显式微调(finetuning)之间的关系。显式微调的注意力结果可以表示为:

FFT(q)=(WV+ΔWV)XX(WK+ΔWK)q=(WZSL+ΔWFT)q

其中ΔWKΔWV 是通过反向传播从任务特定的训练目标中获得的参数更新;ΔWFT 是微调引入的对 WZSL 的更新。

通过比较上下文学习和显式微调,作者发现它们在以下方面具有相似性:

  1. 梯度下降:两者都通过隐式或显式的梯度下降更新 WZSL
  2. 相同的训练信息:ICL的元梯度和显式微调的梯度都来源于相同的训练示例。
  3. 相同的因果顺序:ICL和显式微调都遵循相同的训练示例顺序。
  4. 目标相同:两者都直接影响注意力键和值的计算。

3. 实验分析

本文使用两个预训练的GPT模型(1.3亿和2.7亿参数)进行实验。对于每个任务,使用相同的模板来格式化零样本学习(ZSL)、显式微调(FT)和上下文学习(ICL)的示例。实验中,ICL固定演示示例的数量为32,显式微调使用与ICL相同的演示示例作为训练示例,并使用SGD作为优化器。

3.1 ICL与FT的实验结果

⚪ ICL覆盖显式微调的正确预测

下表展示了六个分类数据集上的验证准确率。结果表明,ICL和显式微调都能显著提高性能,表明它们的优化对下游任务都有帮助。

为了比较ICL和显式微调的模型预测,本文定义了一个召回率指标(Rec2FTP),用于衡量ICL能够覆盖多少显式微调的正确预测:

Rec2FTP=N(FT>ZSL)(ICL>ZSL)N(FT>ZSL)

其中,N(FT>ZSL) 是显式微调能够正确预测但零样本学习(ZSL)不能的查询示例数量,N(FT>ZSL)(ICL>ZSL)ICL也能够正确预测的示例数量。

下表展示了两个GPT模型在六个数据集上的Rec2FTP分数。结果显示,ICL能够覆盖显式微调超过85%的正确预测。这表明从模型预测的角度来看,ICL可以覆盖显式微调的大部分正确行为。

⚪ ICL倾向于与显式微调相同方向更新注意力输出

为了比较ICL和显式微调对注意力输出的影响,本文定义了一个相似度指标(SimAOU),用于衡量ICL和显式微调对注意力输出的更新是否相似:

SimAOU(ΔFT)=cos(h(ICL)h(ZSL),h(FT)h(ZSL))

其中,h(ICL)h(FT) 分别是ICL和显式微调的注意力输出,h(ZSL) 是零样本学习的注意力输出。

下表展示了两个GPT模型在六个数据集上的SimAOU分数。结果显示,ICL的更新与显式微调的更新相似,而与随机更新的相似度接近零。这表明从表示的角度来看,ICL倾向于与显式微调相同方向更新注意力输出。

⚪ ICL倾向于生成与显式微调相似的注意力权重

下表展示了两个GPT模型在六个数据集上的SimAM分数。结果显示,与显式微调前的注意力权重相比,ICL更倾向于生成与显式微调后的注意力权重相似的权重。这表明从注意力行为的角度来看,ICL与显式微调相似。

⚪ ICL和显式微调倾向于对训练示例分配相似的注意力

为了比较ICL和显式微调对训练示例的注意力权重,本文使用Kendall秩相关系数来衡量它们的相似性:

Kendall(ICL,FT)=PcPdN(N1)/2

其中,Pc 是一致对的数量,Pd 是不一致对的数量,N 是训练示例的数量。

下表展示了两个GPT模型在六个数据集上的Kendall秩相关系数。结果显示,ICL和显式微调对训练示例的注意力权重的顺序相似,而与随机注意力权重的相似度接近零。这表明ICL和显式微调倾向于对训练示例分配相似的注意力。

3.2 动量注意力机制的实验验证

Transformer注意力与梯度下降的对偶形式启发,本文提出了一种基于动量的注意力机制。具体来说,动量注意力机制通过指数移动平均(EMA)来平均注意力值,从而引入动量机制:

MoAttn(V,K,qt)=Attn(V,K,qt)+EMA(V)

其中,V 是值,K 是键,qt 是查询,EMA(V) 是动量项。

下表展示了两个GPT模型在训练集和不同输入长度的验证集上的困惑度。结果表明,应用动量注意力的模型在所有验证集上都取得了比普通Transformer更低的困惑度。

下表展示了两个GPT模型在六个上下文学习数据集上的准确率。结果表明,应用动量注意力的模型在所有数据集上都取得了比普通Transformer更高的准确率。这表明引入动量机制可以提高Transformer注意力的性能。