弱监督多实例学习图像分类中的核自注意力.
0. TL; DR
这篇论文为多实例学习(MIL)设计了自注意力-注意力池化(Self-Attention Attention-based MIL Pooling, SA-AbMILP)聚合方法。在进行AbMILP聚合之前,先通过一个自注意力(Self-Attention, SA)层来对所有实例的嵌入表示进行一次“预处理”,从而将实例间的依赖关系编码到新的表示中。
在一个专门设计的、需要捕捉多个实例共存和数量的MNIST数据集上,SA-AbMILP的性能显著优于基线AbMILP。在真实的乳腺癌、结肠癌组织学、微生物学和视网膜病变等多个医疗影像数据集上,SA-AbMILP及其核函数变体同样展现了全面超越现有方法的性能。
1. 背景介绍
多实例学习(MIL)的核心是处理只有“包”级别标签的数据。AbMILP通过一个小型神经网络为每个实例计算一个注意力权重,然后进行加权平均。这个过程存在一个根本性的局限:每个实例的注意力权重,是独立计算的,只依赖于该实例自身的特征
\[a_i = \frac{\exp(w^T \tanh(V h_i))}{\sum_j \exp(w^T \tanh(V h_j))}\]这意味着AbMILP本质上还是在假设实例是独立的。对于需要识别实例共现或计数的复杂MIL任务,需要在保留AbMILP强大聚合能力的同时,让模型能够捕捉到实例与实例之间的相互依赖关系。
2. SA-AbMILP
SA-AbMILP在AbMILP之前,插入一个自注意力(Self-Attention, SA)层。
2.1 SA-AbMILP的流程
SA-AbMILP的整个流程:
将包中的每个实例(如图像patch)通过一个CNN主干网络,得到它们的初始嵌入表示 ${h_i}_{i=1}^n$。自注意力层负责建模实例间的依赖关系。对于每个实例嵌入$h_i$,通过三个可学习的线性变换,分别得到其查询向量$q_i$、键向量$k_i$和值向量$v_i$。
\[q_i = W_q h_i, \quad k_i = W_k h_i, \quad v_i = W_v h_i\]计算任意两个实例$i$和$j$之间的“关联度”$\beta_{j,i}$。这是通过计算实例$j$的查询$q_j$与实例$i$的键$k_i$的相似度得到的。
\[\beta_{j,i} = \text{softmax}_i(\langle k_i, q_j \rangle)\]这里的$\beta_{j,i}$可以被理解为:“在生成实例$j$的新表示时,应该对实例$i$关注多少”。每个实例的新表示$\hat{h}_j$是通过对所有实例的值向量$v_i$进行加权求和得到的,并加入了残差连接。
\[\hat{h}_j = h_j + \gamma \sum_{i=1}^n \beta_{j,i} v_i\]这里的$\gamma$是一个可学习的标量,初始化为0。经过SA层后,每个实例的新表示$\hat{h}_j$不再仅仅是它自己的信息,而是融合了整个包内所有其他实例信息的上下文表示。
将经过SA层更新后的实例表示 ${\hat{h}i}{i=1}^n$ 送入一个标准的AbMILP层,进行最终的加权聚合,得到一个固定长度的包嵌入向量。将包嵌入向量送入一个全连接(FC)分类器,得到最终的包标签预测。
2.2 自注意力中的核函数
标准的自注意力使用点积(dot product)来衡量Query和Key的相似度。作者进一步探索了用其他核函数 (Kernel) 来代替点积的可能性,以期在特征空间更复杂或小样本场景下获得更好的效果:
- 径向基函数核 (RBF, GSA-AbMILP):
- 逆二次核 (Inverse Quadratic, IQSA-AbMILP):
- 拉普拉斯核 (Laplace, LSA-AbMILP):
- 模块核 (Module, MSA-AbMILP): $α$是一个可学习的参数
这些核函数提供了不同的相似度度量方式,为模型带来了更多的灵活性。
3. 实验分析
作者在涵盖多种MIL假设和应用场景的五个数据集上进行了详尽的实验。
3.1 MNIST数据集
为了检验模型在标准假设之外的、更复杂的MIL假设下的性能,作者构建了三种MNIST-bags:
- 标准假设: 包里有”9”即为正。
- 存在性假设: 包里同时有”9”和”7”才为正。
- 阈值假设: 包里至少有两个”9”才为正。
在简单的标准假设下,SA-AbMILP与基线性能相当。在需要识别实例共现(存在性假设)和计数(阈值假设)的复杂场景下,SA-AbMILP及其核变体的性能显著优于基线。在小样本时,标准点积的SA-AbMILP表现最好;随着样本量增加,使用核函数的版本(如LSA、IQSA)开始展现优势。
在存在性假设下,SA-AbMILP能够同时给“9”和“7”都分配较高的注意力权重,而AbMILP往往只能关注其中一个。这直观地证明了SA层确实帮助模型理解了实例间的“合作”关系。
3.2 组织病理学数据集(乳腺癌 & 结肠癌)
验证模型在真实、复杂的医疗影像任务上的性能:SA-AbMILP及其核变体在两个数据集上的AUC和召回率(Recall)等关键指标上,均优于现有方法。特别是在召回率上的提升,对于减少医疗诊断中的漏诊(假阴性)具有重要临床意义。
可视化结果显示,SA-AbMILP高亮的阳性区域(patches)比AbMILP更少、更集中。这有助于病理学家更快地定位到最关键的区域。进一步分析SA层本身的注意力图谱发现,当模型关注一个关键的细胞核时,它的注意力也会自然地扩散到其组织学上相关的邻近细胞核,这证明模型确实在学习有意义的生物学结构。
3.3 微生物学数据集 (DIFaS) & 视网膜病变数据集 (Messidor)
在更多样的医疗影像任务上验证模型的泛化能力。
- 在DIFaS真菌分类任务中,SA-AbMILP在两种不同的主干网络(ResNet-18, AlexNet)下,都取得了最好的分类性能。
- 在Messidor糖尿病视网膜病变筛查任务中,SA-AbMILP的核变体版本(LSA-AbMILP)取得了76.3%的准确率,同样超越了包括GNN-MIL在内的所有对比方法。
跨多个不同类型的数据集取得一致的性能提升,充分证明了SA-AbMILP的鲁棒性和泛化能力。实验还发现,核函数的选择是数据依赖的。在DIFaS上,RBF核和IQ核表现更好;而在Messidor上,拉普拉斯核表现最好。这表明,为特定任务选择合适的核函数是一个值得探索的调优方向。