通过自适应密度判别实现度量学习.
通常的深度度量学习方法大多只考虑标签的语义,并根据标签划分正负样本对,而忽略类内的区别。本文提出Magnet损失,既考虑类间相似度(inter-class similarty),又考虑类内方差(intra-class variation)。
Magnet损失在训练过程中检索距离最近的聚类簇的局部邻域,并惩罚它们的重叠区域。同一个聚类簇内样本可能具有不同的标签,但他们具有相近的语义信息。
在构造Magnet损失时,随机选择$M$个聚类簇,每个簇选取$D$个样本,则损失函数构造为:
\[\frac{1}{MD} \sum_{m=1}^M \sum_{d=1}^D -\log \frac{e^{-\frac{1}{2\sigma^2}||f_{\theta}(x_d^m)-\mu_m||_2^2-\alpha}}{\sum_{\mu: c(\mu) \neq c(f_{\theta}(x_d^m))}e^{-\frac{1}{2\sigma^2}||f_{\theta}(x_d^m)-\mu||_2^2}}\]分子表示最小化每个聚类簇中的样本与对应聚类中心的距离,分母表示最大化每个聚类簇中的样本与其他簇的聚类中心的距离。其中每个聚类簇的均值$\mu_m$以及类内方差$\sigma^2$计算为:
\[\mu_m = \frac{1}{D} \sum_{d=1}^D f_{\theta}(x_d^m)\\ \sigma^2 = \frac{1}{MD-1} \sum_{m=1}^M \sum_{d=1}^D ||f_{\theta}(x_d^m)-\mu_m||_2^2\]import torch
import torch.nn as nn
class MagnetLoss(nn.Module):
def __init__(self, alpha=1.0):
super(MagnetLoss, self).__init__()
self.alpha = alpha
def forward(self, data, classes, n_clusters):
"""
Args:
data: A batch of features.
classes: Class labels for each example.
n_clusters: Total number of clusters.
alpha: The cluster separation gap hyperparameter.
Returns:
total_loss: The total magnet loss for the batch.
losses: The loss for each example in the batch.
"""
batch_size = data.shape[0]
# 计算每个聚类的样本数
d = batch_size // n_clusters
# 构造每个样本的聚类标签
clusters, _ = torch.sort(torch.arange(0, float(n_clusters)).repeat(d))
# 构造聚类簇的类别标签
cluster_classes = classes[0:n_clusters*d:d]
# print(clusters.shape) # [batch_size,]
# print(cluster_classes.shape) # [n_clusters,]
# 计算聚类中心
cluster_examples = torch.chunk(data, n_clusters)
cluster_means = torch.stack([torch.mean(x, dim=0) for x in cluster_examples])
# print(cluster_means.shape) # [n_clusters, num_features]
# 计算每个样本到所有聚类中心的距离
sample_costs = torch.sum((cluster_means - data.unsqueeze(1))**2, dim=2)
# print(sample_costs.shape) # [batch_size, n_clusters]
# 计算每个样本到自身聚类中心的距离
n_clusters = torch.arange(0, n_clusters)
intra_cluster_mask = torch.eq(clusters.unsqueeze(1), n_clusters.unsqueeze(0))
# print(intra_cluster_mask.shape) # [batch_size, n_clusters]
intra_cluster_costs = torch.sum(intra_cluster_mask * sample_costs, dim=1)
# print(intra_cluster_costs.shape) # [batch_size,]
variance = torch.sum(intra_cluster_costs) / float(batch_size - 1)
var_normalizer = -1 / (2 * variance**2)
# 计算损失函数的分子 numerator
numerator = torch.exp(var_normalizer * intra_cluster_costs - self.alpha)
# 计算损失函数的分母 denominator
diff_class_mask = ~torch.eq(classes.unsqueeze(1), cluster_classes.unsqueeze(0))
# print(diff_class_mask.shape) # [batch_size, n_clusters]
denom_sample_costs = torch.exp(var_normalizer * sample_costs)
denominator = torch.sum(diff_class_mask * denom_sample_costs, dim=1)
# 计算Magnet损失
epsilon = 1e-8
losses = F.relu(-torch.log(numerator / (denominator + epsilon) + epsilon))
total_loss = torch.mean(losses)
return total_loss