GradNorm: 使用梯度标准化调整多任务损失权重.

在多任务学习的损失优化过程中,可能会遇到两个问题:

  1. 不同任务的损失量级(magnitude)不同,损失量级较大的任务在梯度反向传播中占主导地位,导致模型忽视其他任务;
  2. 不同任务的学习难度不同,导致收敛速度不同。

多任务损失通常表示为所有任务损失的加权和:

\[L(t) = \sum_{i}^{}w_i(t)L_i(t)\]

作者将上述损失$L(t)$称为label loss,用于更新模型参数$W$;额外引入gradient loss,即将损失权重$w_i$也看做优化参数,在每轮更新中构造损失权重$w_i$的损失$L_{\text{grad}}$,并进行梯度更新。因此第$t$轮中参数更新如下:

\[W(t+1) \gets W(t)-\alpha \nabla_{W}L(t)\] \[w_i(t+1) \gets w_i(t)-\lambda \nabla_{w_i}L_{\text{grad}}\]

Gradient Normalization的出发点如上图所示。不同任务反向传播的梯度量级不同,因此额外构造损失权重的损失$L_{\text{grad}}$,用于调整梯度量级。$L_{\text{grad}}$同时考虑了不同任务的梯度量级和训练速度。

定义$G_W^{(i)}(t)$为第$t$轮训练中第$i$个任务上梯度标准化的值,用于衡量该任务损失对应梯度的量级,计算为该任务的加权损失梯度的L2范数;$G_W^{(i)}(t)$越大,表示该任务的梯度占主导地位。

\[G_W^{(i)}(t) = || \nabla_W w_i(t)L_i(t) ||_2\]

定义$\overline{G}_W(t)$为第$t$轮训练中所有任务全局梯度标准化的值,计算为所有任务梯度标准化的均值:

\[\overline{G}_W(t) = E_{\text{task}}[G_W^{(i)}(t)]\]

定义$\tilde{L}_i(t)$为第$t$轮训练中第$i$个任务的训练速度,计算为当前损失$L_i(t)$与初始损失$L_i(0)$之比;$\tilde{L}_i(t)$越大,表示该任务的训练速度越慢。

\[\tilde{L}_i(t) = \frac{L_i(t)}{L_i(0)}\]

定义$r_i(t)$为第$t$轮训练中第$i$个任务的相对训练速度,计算为该任务的训练速度\(\tilde{L}_i(t)\)与平均训练速度(所有任务训练速度的均值\(E_{\text{task}}[\tilde{L}_i(t)]\))之比;$r_i(t)$越大,表示该任务在所有任务中训练速度较慢。

\[r_i(t) = \frac{\tilde{L}_i(t)}{E_{\text{task}}[\tilde{L}_i(t)]}\]

$L_{\text{grad}}$可以表示为:

\[L_{\text{grad}}(t;w_i(t)) = \sum_{i}^{} | G_W^{(i)}(t)-\overline{G}_W(t) \times [r_i(t)]^{\alpha} |_1\]

直观上,第$i$个任务的梯度量级$G_W^{(i)}(t)$与平均梯度量级\(\overline{G}_W(t)\)差距越大,说明该任务占据主导地位,导致\(L_{\text{grad}}\)增大,则应该降低该任务的损失权重。

$r_i(t)$衡量该任务的学习难度。任务学习难度越小,对应$r_i(t)$越小,则该任务的梯度量级应该减小,对应该任务的损失权重减小。