基于注意力的深度多实例学习.

0. TL; DR

这篇论文提出了一种基于注意力机制(Attention-based)的、可学习的、置换不变的MIL池化算子,为每个实例学习一个注意力权重(attention weight),然后对实例的特征表示进行加权平均,从而得到整个包的表示。这个权重的大小直观地反映了该实例对最终包标签的贡献度。

在多个MIL基准数据集、合成的MNIST-bags以及真实的乳腺癌和结肠癌病理图像数据集上,基于注意力的模型均取得了与SOTA方法相当或更优的性能,尤其是在小样本和真实世界复杂场景下。通过可视化注意力权重,可以清晰地看到模型在做决策时“关注”了哪些关键实例。

1. 背景介绍

多实例学习(MIL)是一种弱监督学习范式,它的设定是有一堆“包”(Bags),每个包里装着若干“实例”(Instances),只知道整个包的标签(例如,这张病理切片是“癌变”的),却不知道里面每个实例(细胞图块)的具体标签。

MIL的目标主要有两个:包级别预测: 准确地判断一个新包的标签。;实例级别定位: 找出是哪些“关键实例”(导致了包的标签(例如,定位出癌变的细胞)。

为了实现这些目标,现有的MIL方法大致可以分为两类:

  1. 实例空间方法 (Instance-level approach): 先训练一个分类器来预测每个实例的分数,然后通过一个聚合函数(如maxmean)将所有实例的分数合并成包的分数。这类方法的优点是天然具有可解释性,因为实例的分数可以直接用来定位关键实例。缺点是由于缺乏实例级别的真值标签,实例分类器本身可能训练不充分,从而引入误差,影响最终的包预测性能。
  2. 嵌入空间方法 (Embedding-level approach): 先将所有实例的特征向量通过一个聚合函数(如summax)融合成一个统一的、固定长度的“包表示”,然后再用这个包表示去训练一个包级别的分类器。这类方法的优点是直接面向包级别的分类任务,通常能获得更好的包预测性能。缺点是实例级别的原始信息在聚合过程中丢失了,模型丧失了可解释性。

无论是哪种方法,都离不开一个核心组件——MIL池化(MIL Pooling),负责聚合实例信息的函数。传统方法(如summax)最大的问题在于它们是预先定义好的、不可学习的。本文的出发点是能否设计一种可学习的、自适应的池化方法,让模型自己决定如何聚合信息,并同时兼顾“性能”与“可解释性”。

2. Attention-based MIL

作者首先从理论上阐述了MIL建模的通用框架。由于包内的实例没有顺序关系,任何对包进行打分的函数$S(X)$必须是一个对称函数(即对实例的顺序不敏感,permutation-invariant)。根据对称函数基本定理,任何对称函数都可以被分解为如下形式:

\[S(X) = g\left( \sum_{x \in X} f(x) \right)\]

或者近似分解为:

\[S(X) \approx g\left( \max_{x \in X} f(x) \right)\]

这个分解给出了一个通用的三步MIL流程:

  1. 转换 (f): 对每个实例进行特征转换。
  2. 聚合 (σ): 用一个对称函数(如summax)聚合所有转换后的实例。
  3. 再转换 (g): 对聚合后的结果进行最终的变换,得到包的分数。

在深度学习框架下可以用神经网络来参数化fg,使模型变得极其灵活。而聚合函数σ用一个可学习的加权平均来代替固定的summax

\[z = \sum_{k=1}^{K} a_k h_k\]

其中,$h_k$是第$k$个实例经过神经网络f转换后的嵌入表示,而$a_k$是模型为该实例学习到的注意力权重。权重$a_k$本身也是通过一个小型神经网络(即注意力网络)计算得出的,并经过softmax归一化,以确保所有权重之和为1。

\[a_k = \frac{\exp\{ \mathbf{w}^T \tanh(\mathbf{V} \mathbf{h}_k^T) \}}{\sum_{j=1}^{K} \exp\{ \mathbf{w}^T \tanh(\mathbf{V} \mathbf{h}_j^T) \}}\]

其中$\mathbf{V}$ 和 $\mathbf{w}$是注意力网络中可学习的参数矩阵。$\tanh$是激活函数,引入非线性。此时聚合方式不再是固定的,而是通过反向传播,根据数据和任务自适应地学习。模型可以学习到从近似mean(所有权重相似)到近似max(只有一个权重接近1)之间的任何聚合策略。

作者进一步指出,单独使用$\tanh$激活函数可能在某些情况下(输入接近0时)表现为线性,限制了模型的表达能力。因此进一步引入了门控机制(Gating Mechanism)来增强非线性,提出了Gated-Attention

\[a_k = \frac{\exp\{\mathbf{w}^T (\tanh(\mathbf{V} \mathbf{h}_k^T) \odot \text{sigm}(\mathbf{U} \mathbf{h}_k^T))\}}{\sum_{j=1}^{K} \exp\{\mathbf{w}^T (\tanh(\mathbf{V} \mathbf{h}_j^T) \odot \text{sigm}(\mathbf{U} \mathbf{h}_j^T))\}}\]

其中$\mathbf{U}$是门控机制引入的新的可学习参数。$\text{sigm}$是 Sigmoid 函数,输出一个0到1之间的“门”,用来控制$\tanh$输出的信息流。

Attention-based MIL遵循嵌入空间的范式,通过加权平均得到一个信息丰富的包表示,有利于提升包级别预测性能;同时,学习到的注意力权重$a_k$本身就是一种实例级别的分数,直接反映了每个实例的重要性。通过可视化这些权重,可以实现可解释性。

3. 实验分析

作者在三大类数据集上对提出的方法进行了验证。

3.1 经典MIL基准数据集

数据集 Musk1, Musk2, Fox, Tiger, Elephant 是小规模、预计算特征的数据集。在这些数据集上Attention-based方法与当时最好的传统MIL方法(如mi-Graph, miFV)性能相当。尽管深度模型的优势并不明显。但能取得相当的性能,已经证明了该方法的基本有效性。

3.2 合成MNIST-bags数据集

作者在一个更可控、更像真实图像任务的场景下,检验模型的性能,特别是处理原始像素和不同样本规模的能力。具体地,将MNIST数字图片构造成“包”,如果包里包含数字“9”,则为正包。通过调整训练包的数量和每个包平均包含的实例数,来测试模型在不同难度下的表现。

在训练数据较少时(如只有50-150个包),Attention-based方法的性能远超其他所有方法,实验再次验证了Embedding-based模型通常优于Instance-based模型,并且Max-pooling的表现始终优于Mean-pooling

论文展示了一个包含多个数字的包,Attention模型准确地为所有“9”分配了最高的注意力权重,而其他数字的权重几乎为0。这直观地证明了其定位关键实例的能力。

3.3 真实世界组织病理学数据集

作者使用了两个数据集:

  1. Breast Cancer: 58张乳腺癌H&E染色全切片图像。
  2. Colon Cancer: 100张结肠癌H&E染色全切片图像。

结果表明,Attention-based方法(尤其是Gated-Attention)在所有评估指标(准确率、精确率、召回率、F1分数、AUC)上,均取得了最佳性能。在医疗诊断中,高召回率(Recall)至关重要,因为它意味着更少的漏诊(False Negatives)。Attention-based方法在召回率上表现出色,这对于临床应用非常有价值。

论文通过将注意力权重可视化为热力图,清晰地展示了模型在仅使用图像级标签训练后,能够准确地高亮出与病理学家标注高度重合的癌细胞区域。这强有力地证明了该方法在提供可解释的诊断依据方面的巨大潜力。