重新回顾多实例神经网络.
0. TL; DR
这篇论文为多实例学习(MIL)提出了两种端到端神经网络架构:mi-Net(实例空间)和 MI-Net(嵌入空间)。作者发现,直接学习一个统一的“包表示”(Bag Representation)的 MI-Net 架构,在分类性能上通常优于先预测单个实例标签再聚合的 mi-Net 架构。
通过将深度监督(Deep Supervision)和残差连接(Residual Connections)引入到 MI-Net 中,性能得到了进一步的显著提升,在多个经典的MIL基准数据集(涵盖药物发现、图像和文本分类)上取得了极具竞争力的结果,在普通CPU上处理一个“包”仅需亚毫秒级时间。
1. 背景介绍
在多实例学习(Multiple Instance Learning, MIL)中,数据被组织成包(Bag),每个包包含若干个实例(Instances)。标签是在包级别给出的。
虽然神经网络在监督学习中取得了巨大成功,但将其直接应用于MIL面临两大挑战:
- 输入可变性:不同的包包含的实例数量不同,而标准NN通常要求输入是固定尺寸的。
- 监督信号模糊:只有包级别的标签,缺乏实例级别的标签来直接指导网络学习。
本文旨在设计出能够端到端处理MIL任务、充分利用现代深度学习优势的神经网络架构。
2. mi-Net vs. MI-Net
作者提出了两种核心的神经网络范式来解决MIL问题,它们在如何处理实例和包的关系上有着根本的不同。
2.1 mi-Net:实例空间范式 (Instance-Space)
mi-Net 先判断实例,再决定包。其工作流程如下:
- 实例特征学习:包中的每个实例(Instance)独立地通过一系列全连接层(FC layers)进行特征变换。
- 实例得分预测:在网络的较深层,为每个实例预测一个“为正”的概率得分(Instance Score),这个得分是一个0到1之间的标量。
- MIL池化聚合:一个 MIL池化层 (MIL Pooling Layer) 接收所有实例的得分,并将其聚合成一个包得分 (Bag Score)的标量。
- 分类与训练:使用包得分和真实的包标签计算交叉熵损失,并进行反向传播训练。
mi-Net 的数学形式可以表达为:
\[\begin{cases} \mathbf{x}_{ij}^{\ell} = H^\ell(\mathbf{x}_{ij}^{\ell-1}) \\ P_i^L = M^L(p_{ij| j=1...m_i}^{L-1} ) \end{cases}\]其中,$H^\ell$ 是第 $\ell$ 层的非线性变换,$p_{ij}^{L-1}$ 是第 $j$ 个实例的预测得分,$M^L$ 是MIL池化操作,$P_i^L$ 是最终的包得分。
2.2 MI-Net:嵌入空间范式 (Embedded-Space)
与 mi-Net 不同,MI-Net 先为整个包学习一个统一的表示。网络应该直接聚焦于学习对包分类最有用的信息,而不是被预测单个实例标签的中间任务所束缚。其工作流程如下:
- 实例特征学习:与 mi-Net 相同,每个实例首先通过若干FC层。
- MIL池化聚合:MIL池化层在这里作用于实例的特征向量,而不是实例的得分。它将一个包内所有实例的特征向量聚合成一个单一的、固定长度的包表示向量 (Bag Representation)。
- 包级别分类:这个包表示向量随后被送入FC层进行最终的包分类,得到包得分。
MI-Net 的数学形式为:
\[\begin{cases} \mathbf{x}_{ij}^{\ell} = H^\ell(\mathbf{x}_{ij}^{\ell-1}) \\ \mathbf{X}_i^{\ell} = M^\ell(\mathbf{x}_{ij| j=1...m_i}^{\ell-1} ) \end{cases}\]其中,$\mathbf{X}_i^{\ell}$ 是在第 $\ell$ 层聚合后得到的包表示。
2.3 让MI-Net更强大
为了进一步提升MI-Net的性能,作者引入了两种强大的深度学习技术。
MI-Net with Deep Supervision (DS)
在深度网络中,来自最终损失的梯度信号在反向传播时可能会变得很弱,导致浅层网络训练不充分。在网络的每个中间层都引出一个分支,通过MIL池化和分类器直接产生一个包预测。在训练时,总损失是所有这些分支损失的总和。在测试时,最终结果是所有分支预测的平均值。这相当于在训练过程中为网络的每一层都提供了监督信号,确保了各层都能学到有用的特征。
MI-Net with Residual Connections (RC)
借鉴ResNet的思想,帮助训练更深的网络,避免梯度消失。将深层学到的包表示,看作是浅层包表示的残差(residual)。即,将一个FC层+池化层的输出与前一层的输出相加。该方法使得信息流在网络中传递更顺畅,有助于学习更复杂的包表示。
2.4 MIL池化的选择
MIL池化是所有这些架构的核心,它必须是可微的,以便进行端到端训练。本文探讨了三种池化方法:
- Max Pooling: 符合MIL的“正包中至少有一个正实例”的直觉。
- Mean Pooling: 考虑了包中所有实例的贡献。
- Log-Sum-Exp (LSE) Pooling: 一种
max
函数的光滑、可微近似。参数r
控制其平滑程度,r
越大越接近max
,越小越接近mean
。
3. 实验分析
作者在多个经典的MIL基准数据集上进行了详尽的实验,涵盖了药物活性预测(MUSK)、图像分类(Elephant, Fox, Tiger)和文本分类(20 Newsgroups)。
在大多数数据集上,直接学习包表示的MI-Net性能都优于先预测实例得分的mi-Net。加入了深度监督的MI-Net with DS
在几乎所有数据集上都取得了最佳或接近最佳的成绩,尤其是在复杂的文本分类任务上,平均准确率(81.5%)显著高于其他方法。这证明了深度监督对于MIL任务的有效性。
为了深入理解各个组件的作用,作者进行了一系列消融实验。
- 池化方法的影响:大多数情况下,Max Pooling 的表现最好或非常接近最好。这个结果非常直观,因为它最直接地模拟了MIL的标准假设。
- 深度监督的有效性:对比有无深度监督的MI-Net,结果显示,在所有五个基准数据集上,加入深度监督都带来了性能提升。这再次证明了DS对于在弱监督下学习良好特征的重要性。
- 残差连接的有效性:引入残差连接在大部分情况下都有助于提高性能,并且有助于学习一个好的包表示。
一个有趣的发现是,与传统监督学习中“越深越好”的认知不同,在这些MIL任务上,盲目增加网络的深度和宽度并不能带来明显的性能提升,有时甚至会略微下降。作者推测这可能与MIL数据集的规模有限以及池化操作本身相对简单有关。