动态池化的深度多实例学习.
0. TL; DR
这篇论文提出了一种名为动态池化(Dynamic Pooling)的多实例学习MIL聚合范式。不再一次性地决定实例的权重,而是根据当前生成的“包”嵌入,反过来评估每个实例与这个“包”嵌入的相似度,来动态地调整每个实例对“包”表示的贡献度。
搭载了动态池化的多实例神经网络(DP-MINN)在涵盖药物预测、图像检索、文本分类和医疗诊断的多个MIL基准任务上,性能全面超越了包括MI-Net和Attention-Net在内的所有SOTA方法。
1. 背景介绍
在深度多实例学习(Deep MIL)中,一个标准的流程是:将包中的每个实例通过一个神经网络得到一个高维的嵌入向量,用一个置换不变(permutation-invariant)的池化函数将一个包内所有可变数量的实例嵌入聚合成一个单一的、固定长度的“包嵌入”,将包嵌入送入一个分类器,得到最终的包标签。
这个流程中的MIL池化决定了信息如何从“多”汇聚到“一”。现有的主流池化方法的计算过程是静态的、前馈的。这种“静态”的聚合方式,使得模型很难捕捉实例之间复杂的上下文关系。
本文的出发点是能否设计一种池化机制,使其能够动态地、迭代地学习实例的贡献度,从而在聚合过程中融入实例间的上下文信息。
2. 动态池化 (Dynamic Pooling)
动态池化 (Dynamic Pooling)的灵感来源于胶囊网络中的“动态路由”,其核心是“协议路由”(routing-by-agreement)。作者将其思想迁移到MIL领域,提出了“协议池化”(pooling-by-agreement)。
动态池化的本质是一个迭代的权重更新过程。它不再一次性计算出所有实例的最终权重,而是在$T$次迭代中,不断地优化这些权重。DP-MINN的整体架构是一个标准的端到端网络。输入实例先经过一个实例嵌入网络(如MLP),然后送入动态池化层得到包嵌入,最后计算损失进行训练。
假设一个包的实例嵌入为${f(x_1), f(x_2), \dots, f(x_K)}$。在第1次迭代($t=1$)时,初始化所有实例的临时权重$b_i$为0。这意味着初始时,所有实例被认为是同等重要的。
\[b_i^1 \leftarrow 0 \quad \text{for all } i \in [1, K]\]将临时权重$b_i$通过一个$\text{softmax}$函数,得到归一化的贡献度权重$c_i$。
\[c_i^t = \text{softmax}(b_i^t) = \frac{\exp(b_i^t)}{\sum_j \exp(b_j^t)}\]使用当前的贡献度权重$c_i$,对所有实例嵌入进行加权求和,得到当前的包嵌入$σ^t(X)$。
\[\sigma^t(X) = \sum_i c_i^t f(x_i)\]将包嵌入$σ$通过一个“挤压”函数进行非线性变换,得到$s^t(X)$。这个函数的作用是:将短向量(模长小)几乎压缩到0,将长向量的模长压缩到接近1。这使得包嵌入的模长可以被用作概率。
\[s^t(X) = \frac{\|\sigma^t(X)\|^2}{1 + \|\sigma^t(X)\|^2} \frac{\sigma^t(X)}{\|\sigma^t(X)\|}\]计算每个实例嵌入$f(x_i)$与当前整个包的嵌入$s^t(X)$的点积(相似度)。将这个相似度累加到临时权重$b_i$上,用于下一次迭代。
\[b_i^{t+1} \leftarrow b_i^t + f(x_i) \cdot s^t(X)\]经过$T$次迭代后,返回最终的包嵌入$s^T(X)$。
最终得到的包嵌入$s^T(X)$的L2范数(模长)被用作包为正的概率$|s|$。作者发现,使用间隔损失(Margin Loss)比传统的交叉熵损失效果更好。
\[L(X) = Y \cdot \max(0, m^+ - \|s\|)^2 + (1-Y) \cdot \max(0, \|s\| - m^-)^2\]其中$m^+$和$m^-$是预设的间隔,如0.9和0.1。
3. 实验分析
作者在涵盖四大类任务的多个数据集上对DP-MINN进行了广泛的验证。
3.1 经典MIL基准(药物预测与图像检索)
在数据集Musk1, Musk2, Fox, Tiger, Elephant上,DP-MINN取得了全面的SOTA性能。这些数据集包含了不同的MIL假设场景。DP-MINN的普适性成功证明了其动态建模上下文关系的能力,使其能够适应比标准max-pooling或attention更复杂的任务。
3.2 文本分类 (20 Newsgroups)
在20个文本分类任务上,DP-MINN在绝大多数任务上都击败了竞争对手。文本数据是典型的上下文依赖场景,句子的重要性往往取决于它与前后文的关系。DP-MINN的迭代更新机制非常适合捕捉这种依赖,因此取得了优异的性能。
3.3 医疗诊断 (UCSB breast, Messidor)
在乳腺癌和糖尿病视网膜病变两个医疗影像数据集上,DP-MINN均取得了最佳性能。医疗诊断的决策往往需要综合多个病灶区域的信息。DP-MINN能够有效地建模这种复杂的、非“关键实例”的场景,展现了其在真实世界复杂应用中的巨大潜力。
3.4 消融研究
- 迭代次数$T$的影响: 实验表明,当$T=1$时,动态池化退化为平均池化,性能一般。随着$T$的增加,性能普遍提升,在$T=3$时达到最佳平衡点。过多的迭代次数并没有带来进一步的显著提升。
- 损失函数的影响: 对比间隔损失和交叉熵损失,实验发现在大多数数据集上,间隔损失的效果更好。