深度度量学习的代理锚点损失.

本文提出了Proxy-Anchor损失,为每一个类别赋予了一个proxy,将一个批次的数据和所有的proxy之间计算距离,并拉近每个类别的数据和该类别对应的proxy之间的距离,且拉远与其他类别的proxy之间的距离。

Proxy-Anchor损失和Proxy-NCA损失的主要区别在于,Proxy-NCA遍历每一个样本,减少该样本和对应类别的proxy之间的距离,增大和其他类别的proxy之间的距离;而Proxy-Anchor损失遍历每一个proxy,减少该类别的所有样本与该proxy的距离,增大其他类别的样本与该proxy的距离。

\[\frac{1}{|P^+|} \sum_{p \in P^+} \log (1+\sum_{x \in X_p^+}e^{\alpha(D[f_{\theta}(x),p]+\delta)}) \\+ \frac{1}{|P|} \sum_{p \in P} \log (1+\sum_{x \in X_p^-}e^{-\alpha(D[f_{\theta}(x),p]-\delta)})\]

其中$P$是所有proxy集合,$P^+$是数据集中出现的有效proxy集合。Proxy-NCA没有利用数据-数据之间的相互关系,关联每个数据点的只有proxyProxy-Anchor通过同时考虑所有数据点改善了这一点。

import torch
import torch.nn as nn
import torch.nn.functional as F

class Proxy_Anchor(torch.nn.Module):
    def __init__(self, nb_classes, sz_embed, delta = 0.1, alpha = 32):
        torch.nn.Module.__init__(self)
        # Proxy Anchor Initialization
        self.proxies = torch.nn.Parameter(torch.randn(nb_classes, sz_embed))
        nn.init.kaiming_normal_(self.proxies, mode='fan_out')

        self.nb_classes = nb_classes
        self.sz_embed = sz_embed
        self.delta = delta
        self.alpha = alpha
        
    def forward(self, X, Y):
        P = self.proxies
        
        # 计算余弦相似度
        def norm(x, axis=-1):
            x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
            return x
        cos = 1-torch.mm(norm(X),norm(P).permute(1,0))

        # 生成one-hot标签
        labels = torch.FloatTensor(Y.shape[0], self.nb_classes).zero_()
        P_one_hot = labels.scatter_(1, Y.data, 1)
        N_one_hot = 1 - P_one_hot
 
        # 统计有效proxy数量
        with_pos_proxies = torch.nonzero(P_one_hot.sum(dim = 0) != 0).squeeze(dim = 1)   # The set of positive proxies of data in the batch
        num_valid_proxies = len(with_pos_proxies)   # The number of positive proxies
    
        # 计算损失函数
        pos_exp = torch.exp(self.alpha * (cos + self.delta))
        neg_exp = torch.exp(-self.alpha * (cos - self.delta))

        P_sim_sum = torch.mul(P_one_hot, pos_exp).sum(dim=0) 
        N_sim_sum = torch.mul(N_one_hot, neg_exp).sum(dim=0) 
        
        pos_term = torch.log(1 + P_sim_sum).sum() / num_valid_proxies
        neg_term = torch.log(1 + N_sim_sum).sum() / self.nb_classes
        loss = pos_term + neg_term     
        
        return loss
    
if __name__ == '__main__':
    nb_classes = 100
    sz_batch = 32
    sz_embedding = 64
    X = torch.randn(sz_batch, sz_embedding)
    Y = torch.randint(low=0, high=nb_classes, size=[sz_batch])
    pnca = Proxy_Anchor(nb_classes, sz_embedding)
    print(pnca(X, Y.view(sz_batch, 1)))