通过Diff Pruning实现参数高效的迁移学习.
针对特定任务对预训练的模型进行微调是当代NLP的主流模式,在一系列自然语言理解任务中取得了最先进的结果。虽然这种方法简单明了,在经验上也很有效,但很难扩展到多任务、内存受限的情况下,因为它需要为每个任务存储一整套模型参数。
本文提出Diff pruning,通过一个特定任务的 diff 向量扩展基础模型。只需要微调$0.5\%$的预训练参数。为了学习这个向量,将特定任务的模型参数重参数化为$\theta_{task}=\theta_{pretrained}+\delta_{task}$,其中预训练的参数向量$\theta_{pretrained}$是固定的,特定任务的 diff 向量$\delta_{task}$是需要微调的,进而构造以下经验风险最小化:
\[\min_{\delta} L(\theta+\delta) + \lambda R(\theta+\delta)\]微调模型的成本是diff 向量$\delta$。如果能将$\delta$正则化,使其稀疏,从而使\(\|\delta_0\|\leq \|\theta\|\),那么随着任务数量的增加,这种方法可以变得更具有参数效率。可以用$\delta$的L0-norm惩罚来指定这一目标:
\[R(\theta+\delta) = ||\delta||_0\]考虑到L0-norm的计算是不可微的,为了近似这个L0目标,作者采用了一种基于梯度的学习方法,即使用一个宽松的掩码向量进行L0稀疏度学习。这种方法将binary vector放宽到连续空间,然后与密集的权重向量相乘,以确定在训练中应用多少权重向量。为了应用这种方法,将$\delta$分解成一个二进制掩码向量$z$乘以一个密集向量$w$:
\[\delta = z \odot w, z\in \{0,1\},w\in R^d\]把$z$初始化为参数$\alpha$控制的伯努利分布$p(z;\alpha)$,则目标函数可以写作$z$的期望形式:
\[\min_{\alpha,w} \mathbb{E}_{z \sim p(z;\alpha)}\left[ L(\theta+\delta) + \lambda ||\delta||_0 \right]\]在上述目标中,$z$仍然是离散的。将$z$放宽到连续空间$[0,1]^d$,并采用拉伸的Hard-Concrete分布,这样就可以使用路径梯度估计器。具体来说,$z$被定义为来自均匀分布的样本$u$的一个确定性和可微函数。
\[\begin{aligned} u & \sim U[0,1] \\ s &= \sigma(\log(u) - \log(1-u)+\alpha) \\ \hat{s} &= s\times(r-l) + l \\ z &= \min(1, \max(0, \hat{s})) \end{aligned}\]此处$l<0,r>1$是两个常数,用来将$s$拉伸到区间$(l,r)^d$,然后用min-max操作将它夹在$[0,1]^d$中。此时得到一个L0-norm的可微闭式表达:
\[\mathbb{E}\left[ ||\delta||_0 \right] = \sum_{i=1}^d \sigma\left( \alpha_{i}-\log\frac{-l}{r} \right)\]最终的优化问题可以表示为:
\[\min_{\alpha,w} \mathbb{E}_{u \sim U[0,1]}\left[ L(\theta+z \odot w)\right] + \lambda \sum_{i=1}^d \sigma\left( \alpha_{i}-\log\frac{-l}{r} \right)\]作者对不同任务设置了不同的diff 向量进行微调。下图显示了每个任务中不同层的非零 diff 参数的百分比。结果表明不同任务的微调确实修改了网络的不同部分,尽管有些任务之间存在一些质量上的相似性,例如QNLI和QQP(都必须对问题进行编码),以及MRPC和STS-B(都必须预测句子间的相似性)。嵌入层(最上面一层)在所有任务中的修改都很稀疏。
为了实现设置一个精确的稀疏率,作者对 diff 向量$\delta$使用幅度修剪(magnitude pruning),通过在$\delta$中只保留前$t\% \times d$的值来达到稀疏率$t\%$。结果表明,应用幅度修剪来投影到L0-ball上能够实现精确的稀疏目标(稀疏率设置为$0.5\%$),并且在性能上几乎没有损失。
为了使diff pruning能够适应模型结构,考虑对其进行结构化的扩展。structured diff pruning可以让模型学会在局部区域修改参数,而不是独立处理每个参数。首先将参数索引分为$G$组$g(1),…,g(G)$,然后为每个组$g(j)$引入二进制掩码向量$z^j$,则把增量参数分解为:
\[\delta^j_i = z_i \cdot z^j \cdot w\]对应的L0-norm的可微闭式表达:
\[\mathbb{E}\left[ ||\delta||_0 \right] = \sum_{j=1}^G\sum_{i=1}^d\mathbb{E}[1\{z_i\cdot z^j > 0\}] \\ = \sum_{j=1}^G\sum_{i=1}^d \sigma\left( \alpha_{i}-\log\frac{-l}{r} \right) \cdot \sigma\left( \alpha^j-\log\frac{-l}{r} \right)\]结构化Diff Pruning 为每个组引入了一个额外的掩码,这鼓励了对整个组进行 pruning。结果发现结构化Diff Pruning导致的微调模型更有可能使整个组与它们的预训练值(零差异)没有变化。