深度多实例学习中的基于损失的注意力.
0. TL; DR
这篇论文提出了一种基于损失的注意力(Loss-based Attention)机制,确保多实例学习(Attention MIL)中学习到的注意力权重,真正反映了实例对最终分类的贡献。该机制不再使用独立的网络来学习权重,而是让注意力权重直接从最终的分类损失函数(如Softmax+交叉熵)中计算出来。
在多个经典的MIL基准数据集(如MUSK, FOX等)上,Loss-Attention全面超越了包括MI-Net和Gated-Attention在内的所有SOTA方法,在包级别分类和实例级别检索(精度和召回率)任务上,均取得了最佳性能。当被应用于标准的图像分类任务时(将特征图上的点视为实例),Loss-Attention同样在分类和目标定位任务上超越了现有方法。
1. 背景介绍
多实例学习(MIL)旨在只有包级别标签的弱监督下,既能准确地预测包的类别,又能定位出是哪些关键实例导致了这个预测。以Attention-based MIL (ADMIL)为代表的方法通过引入注意力机制来为每个实例学习一个权重,然后进行加权聚合,这些权重可以被用作可解释性的依据。ADMIL的标准做法是:
- 用一个独立的辅助神经网络来为每个实例$h_i$计算注意力权重$α_i$。
- 用这些权重对实例表示进行加权求和,得到包表示。
- 再用另一个独立的分类网络对包表示进行分类。
这种设计权重网络和分类网络是相互独立的。权重网络的目标是给重要的实例高权重,但它并不知道分类网络最终是如何利用这些信息的,这可能导致一种权责分离的局面:一个实例可能因为某些特征(如纹理鲜明)而被权重网络赋予高权重,但它在分类网络看来可能是一个错误的、会产生误导的实例。最终模型可能会给错误的实例高权重,不仅降低了最终的分类性能,也使得基于注意力的可解释性变得不可靠。
2. Loss-Attention
标准的分类损失(Softmax+交叉熵)。对于一个$K$分类问题,一个样本的预测logits为$z$,其属于类别$k$的概率为:
\[q_k = \frac{\exp(z_k)}{\sum_c \exp(z_c)}\]其损失为$L = -\log(q_k)$。
Loss-Attention直接利用了计算最终分类logits的全连接层来同时定义实例的注意力权重。假设一个包$i$有$n_i$个实例,每个实例的特征为$h_{i,j}$。最终的分类层参数为权重矩阵$W$和偏置$b$。
- 实例级别的logits: $z_{i,j} = h_{i,j}W + b$
- 包级别的logits: $z_i = (\sum_j h_{i,j})W + b = \sum_j z_{i,j}$
Loss-Attention的权重公式计算为:
\[\alpha_{i,j} = \frac{\sum_{c=0}^{K-1} \exp(z_{i,j,c})}{\sum_{t=1}^{n_i} \sum_{c=0}^{K-1} \exp(z_{i,t,c})}\]其中$z_{i,j,c} = h_{i,j}w_c + b_c$ 是实例$j$对于类别$c$的logit。分子是实例$j$所有类别logits的指数和,可以理解为这个实例的“总能量”或“激活强度”。分母是整个包内所有实例、所有类别logits的指数总和。
Loss-Attention计算权重$α$和计算最终分类logits $z$,使用的是同一套参数$W$和$b$。一个实例要想获得高权重$α_{i,j}$,它的“总能量”$\sum_c \exp(z_{i,j,c})$就必须大。而这个“总能量”正是由决定最终分类的参数$W, b$计算得出的。这意味着一个实例必须对最终的分类决策做出强有力的贡献,才有可能获得高权重。权重学习和分类任务被内在地绑定了。

作者在理论上证明,如果只使用包级别的分类损失$L1$,模型可能会“偷懒”,只找到一个最显著的实例来满足分类任务,导致其他阳性实例被忽略(即低召回率)。为了解决这个问题,作者设计了一个包含三个部分的联合损失函数:
\[L_p = L_1 + L_2 + L_3\]- 包损失 ($L1$) 是标准的包级别交叉熵损失,用于监督包的分类。
- 加权实例损失 ($L2$) 将实例的权重$α$与实例自身的分类损失结合起来。要求所有被赋予了高权重$α_i,t$的实例,其自身的分类损失也必须很小。通过调整$λ$的大小,可以直接控制被正确分类的高权重实例的数量,从而有效提升实例级别的召回率。
- 一致性损失 ($L3$) 借鉴了半监督学习中的Temporal Ensembling思想,引入了一个一致性成本来平滑实例权重$α$的学习过程。
其中$\tilde{\alpha}_{i,t}$是$α_i,t$在历史训练迭代中的一个指数移动平均(EMA)值。这个损失项鼓励当前迭代的权重与历史的“共识”保持一致,避免权重在训练过程中剧烈波动,从而提升模型的泛化能力。
3. 实验分析
3.1 MIL数据集
Loss-Attention在所有五个经典MIL基准数据集(Musk1, Musk2, Fox, Tiger, Elephant)上,都取得了最佳的分类准确率,全面超越了包括所有强基线。
在这些小规模、特征固定的数据集上取得全面领先,充分证明了Loss-Attention机制本身在聚合信息和优化模型方面的根本优势。

3.2 MNIST-bags
对于基于MNIST构建的二分类和多分类MIL任务,Loss-Attention在二分类(AUC)和多分类(Accuracy)任务上均取得了与SOTA方法相当或更优的性能。

通过评估不同权重阈值$α$下,模型检索出的实例的精度(Precision)、召回率(Recall)和F-score,当只关心那些模型最“自信”的实例(即权重$α$较大时,如$α > 0.5$),Loss-Attention的精度、召回率和F-score都显著高。这意味着Loss-Attention赋予高权重的实例,更有可能是真正的阳性实例。

3.3 消融研究
消融验证“加权实例损失”$L2$是否真的能提升实例召回率。通过改变$L2$的系数$λ$(从0到100),观察实例检索性能的变化。$λ=0$意味着完全不使用$L2$。
随着$λ$的增大,高权重实例的召回率显著提升,而精度略有下降。在综合指标F-score上,任何$λ > 0$的设置都优于$λ=0$。实验结果表明,$L2$正则项是提升模型定位多个关键实例能力的关键所在,实现高实例召回率。

3.4 图像分类与定位
将标准的图像分类任务看作一个特殊的MIL问题。一张图片可以被看作一个“包”,而其最后一层卷积输出的特征图(feature map)上的每一个空间位置,都可以被看作一个实例。
将ResNet18的最后一个全局平均池化层,替换为Loss-Attention层。在CIFAR-10和Tiny ImageNet的图像分类任务上,Loss-Attention版本的ResNet18取得了更高的分类准确率。在目标定位任务上(评估权重最高的patch是否落在物体边界框内),Loss-Attention的平均精度(AP)同样是最高的。
这个实验展示了Loss-Attention的巨大潜力。它不仅仅是一个MIL池化算子,更可以被看作是一种通用的、优于全局池化的特征聚合与定位模块,可以被广泛应用于各种需要空间感知的视觉任务中。
