通过直方图损失学习深度嵌入.

直方图损失(Histogram Loss)是一种深度度量学习的损失函数,它首先估计正样本对和负样本对所对应的两个特征距离分布,然后计算正样本对之间的相似度比负样本对之间的相似度还要小的概率。

对于正样本对集合PP和负样本对集合NN,分别计算它们对应的特征距离集合:

S+={sij=<fθ(xi),fθ(xj)>|(i,j)P}S={sij=<fθ(xi),fθ(xj)>|(i,j)N}S+={sij=<fθ(xi),fθ(xj)>|(i,j)P}S={sij=<fθ(xi),fθ(xj)>|(i,j)N}

其中距离函数选用余弦相似度,取值范围是[1,1][1,1]。根据距离集合构造RRbins的直方图H+H+HH,则直方图的间隔为Δ=2/RΔ=2/R

直方图H+H+上每个位置rr处对应的值h+h+计算为:

h+=1|S+|(i,j)Pδi,j,rh+=1|S+|(i,j)Pδi,j,r

其中权重δi,j,rδi,j,r将一个样本对之间的距离以线性插值的方式分配到相邻的两个节点上,如果这个距离sijsij离节点rr越近,则这个权重越大:

δi,j,r={(sijtr1)/Δ,sij[tr1;tr](tr+1sij)/Δ,sij[tr;tr+1]0,otherwiseδi,j,r=(sijtr1)/Δ,sij[tr1;tr](tr+1sij)/Δ,sij[tr;tr+1]0,otherwise

直方图H+H+HH近似作为正样本对和负样本的特征距离分布,下面计算一个随机负样本对的距离比一个随机正样本对的距离小的概率:

p=11p(x)[x1p+(y)dy]dx=11p(x)Φ+(x)dx=Ex~p(x)[Φ+(x)]

上式可以通过直方图表示为离散形式:

pRr=1(hrrq=1h+q)=Rr=1hrϕ+r
import torch
from numpy.testing import assert_almost_equal

class HistogramLoss(torch.nn.Module):
    def __init__(self, num_bins, cuda=False):
        super(HistogramLoss, self).__init__()
        self.step = 2 / num_bins
        self.eps = 1 / num_bins
        self.cuda = cuda
        self.t = torch.arange(-1+self.eps, 1, self.step).view(-1, 1)
        self.tsize = self.t.size()[0]
        if self.cuda:
            self.t = self.t.cuda()
        
    def forward(self, features, classes):
        
        # 构造直方图分布
        def histogram(inds, size):
            s_repeat_ = s_repeat.clone()
            # 把样本对距离s分配给直方图的bins [t_{r-1}, t_r]
            indsa = (s_repeat_ - self.t >= 0) & (s_repeat_ - self.t < self.step) & inds
            assert indsa.nonzero().size()[0] == size, ('Another number of bins should be used')
            
            # 把样本对距离s分配给直方图的bins [t_r, t_{r+1}]
            zeros = torch.zeros((1, indsa.size()[1])).byte()
            if self.cuda:
                zeros = zeros.cuda()
            indsb = torch.cat((zeros, indsa))[:-1, :].to(dtype=torch.bool)
            
            # 根据权重delta计算直方图的取值
            s_repeat_[~(indsb|indsa)] = 0
            # indsa corresponds to the first condition of the second equation of the paper
            s_repeat_[indsa] = (s_repeat_ - self.t)[indsa] / self.step
            # indsb corresponds to the second condition of the second equation of the paper
            s_repeat_[indsb] =  (-s_repeat_ + self.t)[indsb] / self.step
            
            return s_repeat_.sum(1) / size
        
        batch_size = classes.size()[0]
        # 计算特征之间的余弦相似度
        features = 1. * features / (torch.norm(features, 2, dim=-1, keepdim=True).expand_as(features) + 1e-12)
        dists = torch.mm(features, features.transpose(0, 1))  # [batch_size, batch_size]
        
        # 构造全1上三角阵(用于mask掉重复的样本对和自身的样本对)
        s_inds = torch.triu(torch.ones(batch_size, batch_size), 1).byte() 
        if self.cuda:
            s_inds= s_inds.cuda()
            
        # 取出所有有效样本对的距离
        s = dists[s_inds].view(1, -1) # num_pairs = batch_size * (batch_size-1) / 2
        s_repeat = s.repeat(self.tsize, 1) # [num_bins, num_pairs]
        s_repeat_floor = (torch.floor(s_repeat.data / self.step) * self.step).float()
        # print(s_repeat_floor)
        
        # 匹配正负样本对
        classes_eq = (classes.repeat(batch_size, 1)  == classes.view(-1, 1).repeat(1, batch_size)).data
        pos_inds = classes_eq[s_inds].repeat(self.tsize, 1)
        neg_inds = ~classes_eq[s_inds].repeat(self.tsize, 1)
        pos_size = classes_eq[s_inds].sum().item()
        neg_size = (~classes_eq[s_inds]).sum().item()

        # 计算样本对的直方图
        histogram_pos = histogram(pos_inds, pos_size)
        assert_almost_equal(histogram_pos.sum().item(), 1, decimal=1, 
                            err_msg='Not good positive histogram', verbose=True)
        histogram_neg = histogram(neg_inds, neg_size)
        assert_almost_equal(histogram_neg.sum().item(), 1, decimal=1, 
                            err_msg='Not good negative histogram', verbose=True)
        
        # 计算正样本对直方图的累计密度函数
        histogram_pos_repeat = histogram_pos.view(-1, 1).repeat(1, histogram_pos.size()[0]) # [num_bins, num_bins]
        histogram_pos_inds = torch.tril(torch.ones(histogram_pos_repeat.size()), -1).byte()
        if self.cuda:
            histogram_pos_inds = histogram_pos_inds.cuda()
        histogram_pos_repeat[histogram_pos_inds] = 0
        histogram_pos_cdf = histogram_pos_repeat.sum(0)
        
        # 计算直方图损失
        loss = torch.sum(histogram_neg * histogram_pos_cdf)
        
        return loss