使用图神经网络进行多实例学习.
0. TL; DR
这篇论文将图神经网络(Graph Neural Network, GNN)应用于MIL,将每个“包”显式地构建成一个“图”(Graph),实例成为图的节点(nodes),实例间的关系(如距离)成为图的边(edges),把MIL问题转化为一个图分类(Graph Classification)问题。
论文利用GNN强大的信息传播和聚合能力来学习包级别(即图级别)的嵌入表示。它采用了可微分池化(Differentiable Pooling)技术,以一种层级化的方式对图进行聚类和压缩,从而得到一个固定长度、信息丰富的图表示,用于最终的分类。
在多个经典的MIL基准数据集(如Musk, Fox, Elephant等)上,基于GNN的方法全面超越了所有SOTA深度学习方法。该方法通过分析可微分池化学习到的节点分配矩阵,保留了模型的可解释性,能够有效定位对分类结果起决定性作用的关键实例。
1. 背景介绍
多实例学习(MIL)作为一种弱监督学习框架,其核心设定是处理由一组“实例”构成的“包”,并根据包级别的标签进行学习。然而长期以来被广泛采纳的假设是包内的实例是独立同分布的(i.i.d.),忽略这些实例间的内在结构和关系,丢失了关键的上下文和结构信息。
图神经网络(GNN)在处理图结构数据方面展现出了无与伦比的能力。GNN能够通过在图的边上传播和聚合信息(即消息传递),自动地学习到节点和图的强大表示。既然MIL的包天然具有图结构,可以直接用端到端的GNN来学习包的表示,把MIL问题被重新定义为一个图分类(Graph Classification)。
2. GNN-based MIL
GNN-MIL框架是一个端到端的流程,主要包括三个步骤:图的构建、图的嵌入学习,以及最终的分类。

2.1 图的构建
将一个由实例特征向量构成的“包”转换成一个图$G = (A, V)$,包中的每个实例$x_i$自然地成为图中的一个节点 (Nodes)。所有实例的特征构成了节点特征矩阵 $V$。
计算任意两个实例$x_m$和$x_n$之间的欧氏距离$dist(x_m, x_n)$。设定一个阈值$η$。如果两个实例的距离小于$η$,则在它们之间添加一条边 (Edges)。由此得到图的邻接矩阵 $A$:
\[A_{mn} = \begin{cases} 1 & \text{if } \text{dist}(x_m, x_n) < \eta \\ 0 & \text{otherwise} \end{cases}\]$η$是一个可调的超参数,允许根据任务控制图的稀疏度。$η = 0$: 图中没有任何边,等价于传统的i.i.d.假设。$η = +∞$: 图是一个全连接图,所有实例都相互关联。
2.2 图的嵌入学习
在得到图之后,需要为这个图学习一个固定长度的、能够代表整个包信息的嵌入向量。首先使用一个GNN模型(如GraphSAGE)在图上进行消息传递,来更新每个节点的嵌入,使每个节点的表示融合其邻域的结构信息。
\[Z_i = \text{GNN}_{\text{embd}}(A_i, V_i)\]其中$V_i$是原始的节点特征矩阵(来自实例特征),$A_i$是邻接矩阵,$\text{GNN}_{\text{embd}}$是一个GNN层,通过聚合邻居节点的信息来更新每个节点的表示,$Z_i$是更新后的节点嵌入矩阵。
对于一组信息更丰富的节点嵌入$Z_i$,需要将它们聚合成一个单一的图级别表示。作者采用了可微分池化(Differentiable Pooling, DiffPool)技术,以一种可学习的方式,对图进行层级化的软聚类。
用另一个GNN($GNN_{cluster}$)来为每个节点学习一个到$C$个簇的分配概率。$C$是预定义的簇的数量。
\[S_i = \text{softmax}(\text{GNN}_{\text{cluster}}(A_i, V_i))\]其中$S_i$是一个$K \times C$的矩阵,$S_{ij}$表示节点$i$属于簇$j$的概率。利用分配矩阵$S_i$,将原始图的节点“软分配”到$C$个簇中,形成一个更小的、粗化的新图。新图的节点特征矩阵$V^$和邻接矩阵$A^$通过以下方式计算:
\[V_i^* = S_i^T Z_i \quad (\text{维度为 } C \times D') \\ A_i^* = S_i^T A_i S_i \quad (\text{维度为 } C \times C)\]作者将簇的数量$C$设置为1或2。如果$C=1$,那么$V^{*}$就是一个$1 \times D’$的向量,它直接就是图嵌入。如果$C > 1$,可以对$V^{*}$的所有行向量(每个簇的表示)进行max-pooling或拼接,得到最终的图嵌入。
与简单的mean/max池化相比,DiffPool是一种数据驱动的、可学习的聚合方法。它能够自动发现图中的社群结构,并以一种更智能的方式汇总信息。
2.3 图的分类与深度监督
在得到固定长度的图(包)嵌入后,将其送入一个标准的多层感知机(MLP)分类器,进行最终的标签预测。
此外,为了稳定训练过程并提升性能,作者还引入了深度监督(Deep Supervision)技术。这意味着不仅在最终的输出层计算损失,还在模型的中间层(如$GNN_{embd}$的输出和DiffPool的输出)也引出分支进行预测并计算损失。最终的总损失是所有这些损失的加权和。
3. 实验分析
作者在三类数据集上对GNN-MIL进行了全面的评估。
3.1 经典MIL基准数据集
在数据集Musk1, Musk2, Fox, Tiger, Elephant上,GNN-MIL在五个数据集中有四个取得了SOTA性能,全面超越了包括mi-Graph(基于图核)、MI-Net和Attention-MIL在内的所有强基线。这个结果强有力地证明了显式地建模和利用实例间的结构信息是极其有效的。GNN端到端的表示学习能力,比基于人工设计的图核的mi-Graph和假设实例独立的Attention-MIL都要强大。

3.2 文本分类数据集
数据集为20个从20 Newsgroups语料库中构建的文本分类任务。GNN-MIL在20个任务中的11个上击败了当时性能最好的MI-Net with DS,并且平均准确率也略微胜出(81.6% vs 81.5%)。尽管性能提升不大,但这可能与数据集的构建方式有关(实例是随机抽取的,可能破坏了部分原始的句子顺序)。即便如此,GNN-MIL依然展现了其竞争力,证明了即使在结构信息不那么完美的场景下,它也能学习到有用的关系。

3.3 视网膜图像分类 (Messidor)
Messidor是一个真实的医疗影像任务,目标是根据视网膜图像诊断糖尿病。在该数据集上,GNN-MIL(基于DiffPool)取得了SOTA性能,准确率达到74.2%,F1分数为0.77,相比之前的最佳方法mi-Graph,错误率相对降低了超过6%。与不考虑图结构(即$η=0$,邻接矩阵为零矩阵)的模型相比,GNN-MIL的性能有明显提升(74.2% vs 72.4%)。

通过分析混淆矩阵发现,性能的提升主要来自于假阴性(False Negatives)的大幅减少。假阴性的减少意味着模型更少地漏诊了真正有病的样本。这可能是因为GNN通过聚合邻域信息,使得一些特征不那么典型的阳性实例也得到了增强,从而被模型成功识别。这恰恰是max-pooling等方法容易忽略的。

3.4 模型可解释性分析
作者可视化了可微分池化(DiffPool)学习到的簇分配矩阵 $S$。这个矩阵的每一行代表一个簇,每一列代表一个实例,矩阵中的值表示每个实例被分配到每个簇的概率。
当设置簇数量为2时,模型倾向于将正实例和负实例分配到不同的簇中。当设置簇数量为1时,模型学习到的分配概率就如同一个注意力分数:正实例(关键实例)被赋予了显著更高的概率值。这表明GNN-MIL保留了良好的可解释性。通过检查这个学到的分配矩阵,可以轻松地识别出哪些实例对最终的分类决策贡献最大,从而实现了关键实例的定位。
