FlatNCE: 避免浮点数误差的小批量对比学习损失函数.

对比学习是无监督学习中常用的方法。标准的对比学习的一个缺点是依赖于较大的batch_size(如SimCLRbatch_size=4096时效果最佳),当减小batch_size时效果会明显下降。本文讨论了标准的对比学习在小batch_size时效果差的原因,作者认为其原因是计算损失和梯度时的浮点误差,并提出了一个改进的损失函数FlatNCE,提高了对比学习在较小batch_size时的学习效果。

标准的对比学习(contrastive learning)的过程如下。对于某个样本$x$,构造$K$个配对样本$y_1,y_2,…,y_K$,其中$K$即为训练batch_size。$y_t$是正样本,其余样本都是负样本。对每一个样本对$(x,y_i)$进行打分,得到$s_1,s_2,…,s_K$。对比学习的原理是尽可能增大正负样本对之间的得分差距,通常用负交叉熵作为损失函数:

\[-\log\frac{e^{s_t}}{\sum_{i}^{}e^{s_i}} = \log(\sum_{i}^{}e^{s_i})-s_t = \log(1+\sum_{i≠t}^{}e^{s_i-s_t})\]

通常正样本$y_t$是对原样本$x$进行数据增强得到的相似度较高的样本,而$K-1$个负样本是随机选择的其他样本。由于正负样本的差距比较明显,因此通常有$s_t»s_i(i≠t)$,即$e^{s_i-s_t}≈0$。当batch_size较小时$K$比较小,上式会相当接近$0$,使得原损失函数也非常接近$0$。

由于$e^{s_i-s_t}$本身较小,计算会存在浮点误差,甚至浮点误差比数值本身还大。上述误差导致损失函数计算也含有浮点误差,进而梯度计算也含有浮点误差,最终使得反向传播梯度被随机噪声掩盖,使得学习效果变差。

本文提出了一种改进的损失函数,能够减轻小batch_size对学习过程的影响。由于$e^{s_i-s_t}$较小,对损失函数做一阶泰勒展开:

\[\log(1+\sum_{i≠t}^{}e^{s_i-s_t}) ≈ \sum_{i≠t}^{}e^{s_i-s_t}\]

注意到$\log(1+x) < x$,因此$\sum_{i≠t}^{}e^{s_i-s_t}$是原损失函数的一个上界。为了减小浮点误差的影响,可以对损失函数乘以一个常数,进行等比例放大,也不会对优化过程产生影响。作者提出的FlatNCE是将该损失函数乘以它的倒数:

\[\frac{\sum_{i≠t}^{}e^{s_i-s_t}}{\text{detach}(\sum_{i≠t}^{}e^{s_i-s_t})}\]

其中$\text{detach}(\cdot)$表示不计算梯度,而是将括号内的项作为常数。注意到:

\[\nabla_{\theta} \frac{\sum_{i≠t}^{}e^{s_i-s_t}}{\text{detach}(\sum_{i≠t}^{}e^{s_i-s_t})} = \frac{\nabla_{\theta}\sum_{i≠t}^{}e^{s_i-s_t}}{\sum_{i≠t}^{}e^{s_i-s_t}} = \nabla_{\theta}\log(\sum_{i≠t}^{}e^{s_i-s_t})\]

因此上式左端和右端提供的梯度是相同的,故可以把损失函数设置为:

\[\log(\sum_{i≠t}^{}e^{s_i-s_t}) = \log(\sum_{i≠t}^{}e^{s_i})-s_t\]

注意到该损失函数$\log(\sum_{i≠t}^{}e^{s_i})-s_t$与原交叉熵损失$\log(\sum_{i}^{}e^{s_i})-s_t$相比,仅仅是logsumexp运算中去掉了正样本对的得分$s_t$。

作者用FlatNCE损失替换了SimCLR中的损失函数,将对应的结果称为FlatCLR。实验表面该损失在小的batch_size下也具有较好的学习能力:

下表表面FlatCLR相较于SimCLR获得性能的提高: