通过多类别N-pair损失改进深度度量学习.

度量学习旨在学习特征的嵌入空间,使得相似数据点靠的近,不相似数据离得远。深度度量学习通过深度网络学习到一种非线性的嵌入表达,但学习过程会遇到收敛缓慢甚至陷入局部最优解的问题,这是因为在每次更新网络权重时,损失函数通常仅仅考虑了一个负样本,并没有将其他类的负样本距离考虑进来。

Triplet Loss想办法拉近同类样本距离,拉远异类样本距离。痛点在于每次只看一个负类的距离,没有考虑其他所有负类的情况,这就导致了在随机产生的数据对中,每一个数据对并不能有效的保证当前优化的方向能够拉远所有负类样本的距离,这就导致了训练过程中的收敛不稳定或者陷入局部最优。本文提出了N-pair损失,把Triplet损失扩展到比较所有负类样本的距离。

对于每一个样本$x$,选择一个正样本$x^+$和所有其他类别的负样本$x_1^-,…,x_{N-1}^-$,构造$(N+1)$元组\(\{x,x^+,x_1^-,...,x_{N-1}^-\}\),则N-pair损失定义为:

\[\log (1+\sum_{i=1}^{N-1} \exp(f(x)^Tf(x_i^-)-f(x)^Tf(x^+))) \\= - \log\frac{\exp(f(x)^Tf(x^+))}{\exp(f(x)^Tf(x^+))+ \sum_{i=1}^{N-1} \exp(f(x)^Tf(x_i^-))}\]

如果每个类别的负样本只采样一个,则N-pair损失等价于多类别交叉熵损失。直接构造N-pair损失需要存储$(B,N+1)$的数据矩阵,其中$B$是批量大小,需要占用较大的显存。作者提出一种高效的批量构造(Batch construction)策略,每次输入$(B,2)$的成对数据,每对数据具有相同的类别,然后通过深度网络映射为anchor向量$A \in \Bbb{R}^{B \times d}$和positive向量$P \in \Bbb{R}^{B \times d}$,通过矩阵乘法$AP^T \in \Bbb{R}^{B \times B}$可以得到任意两个数据之间的特征内积$f(x_i)^Tf(x_j)$,正负样本对的区分可以通过传入标签$y \in \Bbb{R}^{B}$实现。

import torch
import torch.nn as nn
import torch.nn.functional as F


def cross_entropy(logits, target, size_average=True):
    if size_average:
        return torch.mean(torch.sum(- target * F.log_softmax(logits, -1), -1))
    else:
        return torch.sum(torch.sum(- target * F.log_softmax(logits, -1), -1))


class NpairLoss(nn.Module):
    """the multi-class n-pair loss"""
    def __init__(self):
        super(NpairLoss, self).__init__()

    def forward(self, anchor, positive, target):
        '''  
        构造训练数据集:anchors,positives,targets
        其中anchors和positives代表着成对的数据,每一行(一对数据)取自同一个类,target代表对应成对数据的类别,
        一个标准的batch是有N(类别数)对的样本(当然也可以不是,比如N太大了)
        '''
        batch_size = anchor.size(0)
        target = target.view(batch_size, 1)

        # 根据整数标签构造正负样本索引
        target = (target == torch.transpose(target, 0, 1)).float()
        target = target / torch.sum(target, dim=1, keepdim=True).float()

        logit = torch.matmul(anchor, torch.transpose(positive, 0, 1))
        loss = cross_entropy(logit, target)

        return loss