PolyLoss:一种分类损失函数的多项式展开视角.

本文作者设计了Poly Loss,这是一种受多项式泰勒展开启发的分类损失函数。Poly Loss将各种损失函数视为一系列多项式函数的线性组合。其中低阶部分倾向于得到正确的预测结果,高阶部分倾向于防止预测结果出错(如缓解类别不平衡问题)。

Poly Loss在统一多种损失函数的同时,引入了新的参数来为更精确的任务目标进行更精确的调整。让损失倾向于高阶项会使得模型尽可能减少出错,但不容易预测高置信度结果;让损失倾向于低阶项会使得模型预测出高置信度结果,但可能出现错误。

对于分类任务常用的交叉熵损失(Cross-Entropy loss)。给定网络输出的预测类别概率分布$p=(p_1,…,p_K)$和标签类别$t$,则损失计算为:

\[\begin{aligned} \mathcal{L}_{CE}(p,t) &= -\log p_t \\ \end{aligned}\]

对其进行泰勒展开:

\[\begin{aligned} \mathcal{L}_{CE}(p,t) &= -\log p_t = \sum_{j=1}^\infty \frac{1}{j} (1-p_t)^j \\ &= (1-p_t)+ \frac{1}{2} (1-p_t)^2 +\frac{1}{3} (1-p_t)^3 + \cdots \\ \end{aligned}\]

计算损失函数相对于预测概率$p_t$的负梯度(对应负梯度越大,则预测概率倾向于越大):

\[\begin{aligned} -\frac{\partial \mathcal{L}_{CE}(p,t)}{\partial p_t} &= \sum_{j=1}^\infty (1-p_t)^{j-1} \\ &= 1+(1-p_t)+ (1-p_t)^2 + \cdots \\ \end{aligned}\]

观察到损失函数的低阶项贡献了更大的负梯度,有助于预测正确的结果;而高阶项的梯度趋近于$0$。

Focal Loss是一种擅长处理类别不平衡的分类损失,显式地引入了权重因子$(1-p_t)^{\gamma},\gamma \geq 0$,使得$p_t$(目标类别的预测置信度)越大时权重越小,即对容易分类的样本减少权重。

\[\mathcal{L}_{\text{focal}}(p,t) = -(1-p_t)^\gamma \log p_t\]

同样地,对Focal Loss进行泰勒展开与负梯度计算:

\[\begin{aligned} \mathcal{L}_{\text{focal}}(p,t) &= -(1-p_t)^\gamma \log p_t = \sum_{j=1}^\infty \frac{1}{j} (1-p_t)^{j+\gamma} \\ &= (1-p_t)^{1+\gamma}+ \frac{1}{2} (1-p_t)^{2+\gamma} +\frac{1}{3} (1-p_t)^{3+\gamma} + \cdots \\ -\frac{\partial \mathcal{L}_{\text{focal}}(p,t)}{\partial p_t} &= \sum_{j=1}^\infty (1+\frac{\gamma}{j})(1-p_t)^{j+\gamma-1} \\ &= (1+\gamma)(1-p_t)^{\gamma}+ (1+\frac{\gamma}{2})(1-p_t)^{1+\gamma} + \cdots \\ \end{aligned}\]

对比Focal Loss与交叉熵损失,发现前者相当于在标准的分类损失的每一个多项式上乘以$(1-p_t)^\gamma$,相当于调整了每一项的系数,从而改善了分类损失对于类别不平衡问题的适应性。

基于上述分析,作者尝试为交叉熵损失的每一个多项式项引入一个扰动,用于更精确的任务目标进行更精确的调整。考虑到可实现性,为前$N$项引入$\epsilon_1,…,\epsilon_N$:

\[\begin{aligned} \mathcal{L}_{\text{Poly-N}}(p,t) &= (\epsilon_1+1)(1-p_t)+ \cdots +\left(\epsilon_N+\frac{1}{N}\right) (1-p_t)^N + \sum_{j=N+1}^\infty \frac{1}{j} (1-p_t)^j \\ &= -\log p_t + \sum_{j=1}^N \epsilon_j (1-p_t)^j \end{aligned}\]

特别地,只为第一项引入$\epsilon_1$:

\[\begin{aligned} \mathcal{L}_{\text{Poly-1}}(p,t) &= (\epsilon_1+1)(1-p_t)+ \sum_{j=2}^\infty \frac{1}{j} (1-p_t)^j \\ &= -\log p_t + \epsilon_1 (1-p_t) \end{aligned}\]

通过实验发现仅仅对多项式第一项做出扰动就可以提升绝大多数任务表现。并且模型的平均精度随着扰动增加而增加,这说明增大一阶多项式鼓励模型去给出置信度更高的结果。

coco数据集上,模型却因为扰动减小而增大。这是因为coco数据集相比于人工整理的ImageNet,存在更严重的类别不平衡等问题。通过施加负扰动让模型不再给出过度自信的预测,具有一定的正则化效果。

通过灵活地为不同多项式项设置不同的扰动程度,能够实现更精确的任务表现提升: