通过自适应密度判别实现度量学习.

通常的深度度量学习方法大多只考虑标签的语义,并根据标签划分正负样本对,而忽略类内的区别。本文提出Magnet损失,既考虑类间相似度(inter-class similarty),又考虑类内方差(intra-class variation)。

Magnet损失在训练过程中检索距离最近的聚类簇的局部邻域,并惩罚它们的重叠区域。同一个聚类簇内样本可能具有不同的标签,但他们具有相近的语义信息。

在构造Magnet损失时,随机选择$M$个聚类簇,每个簇选取$D$个样本,则损失函数构造为:

\[\frac{1}{MD} \sum_{m=1}^M \sum_{d=1}^D -\log \frac{e^{-\frac{1}{2\sigma^2}||f_{\theta}(x_d^m)-\mu_m||_2^2-\alpha}}{\sum_{\mu: c(\mu) \neq c(f_{\theta}(x_d^m))}e^{-\frac{1}{2\sigma^2}||f_{\theta}(x_d^m)-\mu||_2^2}}\]

分子表示最小化每个聚类簇中的样本与对应聚类中心的距离,分母表示最大化每个聚类簇中的样本与其他簇的聚类中心的距离。其中每个聚类簇的均值$\mu_m$以及类内方差$\sigma^2$计算为:

\[\mu_m = \frac{1}{D} \sum_{d=1}^D f_{\theta}(x_d^m)\\ \sigma^2 = \frac{1}{MD-1} \sum_{m=1}^M \sum_{d=1}^D ||f_{\theta}(x_d^m)-\mu_m||_2^2\]
import torch
import torch.nn as nn

class MagnetLoss(nn.Module):

    def __init__(self, alpha=1.0):
        super(MagnetLoss, self).__init__()
        self.alpha = alpha

    def forward(self, data, classes, n_clusters):
        """
        Args:
            data: A batch of features.
            classes: Class labels for each example.
            n_clusters: Total number of clusters.
            alpha: The cluster separation gap hyperparameter.
        Returns:
            total_loss: The total magnet loss for the batch.
            losses: The loss for each example in the batch.
        """
        
        batch_size = data.shape[0]
        # 计算每个聚类的样本数
        d = batch_size // n_clusters
        # 构造每个样本的聚类标签
        clusters, _ = torch.sort(torch.arange(0, float(n_clusters)).repeat(d))
        # 构造聚类簇的类别标签
        cluster_classes = classes[0:n_clusters*d:d]
        # print(clusters.shape) # [batch_size,]
        # print(cluster_classes.shape) # [n_clusters,]

        # 计算聚类中心
        cluster_examples = torch.chunk(data, n_clusters)
        cluster_means = torch.stack([torch.mean(x, dim=0) for x in cluster_examples])
        # print(cluster_means.shape) # [n_clusters, num_features]
        # 计算每个样本到所有聚类中心的距离
        sample_costs = torch.sum((cluster_means - data.unsqueeze(1))**2, dim=2)
        # print(sample_costs.shape) # [batch_size, n_clusters]
        # 计算每个样本到自身聚类中心的距离
        n_clusters = torch.arange(0, n_clusters)
        intra_cluster_mask = torch.eq(clusters.unsqueeze(1), n_clusters.unsqueeze(0))
        # print(intra_cluster_mask.shape) # [batch_size, n_clusters]
        intra_cluster_costs = torch.sum(intra_cluster_mask * sample_costs, dim=1)
        # print(intra_cluster_costs.shape) # [batch_size,]


        variance = torch.sum(intra_cluster_costs) / float(batch_size - 1)
        var_normalizer = -1 / (2 * variance**2)
        # 计算损失函数的分子 numerator
        numerator = torch.exp(var_normalizer * intra_cluster_costs - self.alpha)
        # 计算损失函数的分母 denominator
        diff_class_mask = ~torch.eq(classes.unsqueeze(1), cluster_classes.unsqueeze(0))
        # print(diff_class_mask.shape) # [batch_size, n_clusters]
        denom_sample_costs = torch.exp(var_normalizer * sample_costs)
        denominator = torch.sum(diff_class_mask * denom_sample_costs, dim=1)

        # 计算Magnet损失
        epsilon = 1e-8
        losses = F.relu(-torch.log(numerator / (denominator + epsilon) + epsilon))
        total_loss = torch.mean(losses)

        return total_loss