使用代理的无融合距离度量学习.

本文设计了一种基于代理(proxy)的深度度量学习损失,为每一个类别赋予了一个proxy,将一个批次的数据和所有的proxy之间计算距离,并拉近每个类别的数据和该类别对应的proxy之间的距离,且拉远与其他类别的proxy之间的距离。

记样本$(x,y)$的proxy为$p_y$,$p_x$既可以初始化为随机向量;也可以通过随机采样构造:随机采样该类别中的一小部分数据,选择数据集中距离这部分数据最近的一个样本作为$p_x$。

Proxy-NCA采用Neighborhood Component Analysis (NCA)中的损失函数,即通过指数加权使得样本$x$更接近$y$而不是集合$Z$中的任意元素:

\[-\log (\frac{\exp(-d(x,y))}{\sum_{z \in Z}\exp(-d(x,z))})\]

对于样本$(x,y)$,Proxy-NCA损失定义为:

\[-\log (\frac{\exp(-d(x,p_y))}{\sum_{z \neq y} \exp(-d(x,p_z))})\]

Proxy-NCA不需要采样三元组,只需要采样anchor样本,能够加速收敛,且可以较好地缓解噪声标签和离群点的负面影响。

import torch
import torch.nn as nn
import torch.nn.functional as F

class ProxyNCA(nn.Module):
    def __init__(self, 
        nb_classes,
        sz_embedding,
    ):
        torch.nn.Module.__init__(self)
        self.proxies = nn.Parameter(torch.randn(nb_classes, sz_embedding))
        self.nb_classes = nb_classes
 
    def forward(self, X, Y):
        # 计算余弦相似度
        P = F.normalize(self.proxies, p = 2, dim = -1)
        X = F.normalize(X, p = 2, dim = -1)
        D = 1-torch.cdist(X, P, p = 2) ** 2 # [batch_size, num_proxy]

        # 生成one-hot标签
        labels = torch.FloatTensor(Y.shape[0], self.nb_classes).zero_()
        pos_onehot = labels.scatter_(1, Y.data, 1)
        neg_onehot = 1 - pos_onehot

        exp = torch.exp(-D)
        pos_exp = torch.sum(pos_onehot*exp, -1)
        neg_exp = torch.sum(neg_onehot*exp, -1)
        loss = torch.mean(-torch.log(pos_exp/neg_exp), -1)
        return loss
    
if __name__ == '__main__':
    nb_classes = 100
    sz_batch = 32
    sz_embedding = 64
    X = torch.randn(sz_batch, sz_embedding)
    Y = torch.randint(low=0, high=nb_classes, size=[sz_batch])
    pnca = ProxyNCA(nb_classes, sz_embedding)
    print(pnca(X, Y.view(sz_batch, 1)))