掩码难实例挖掘的全切片图像分类多实例学习.
0. TL; DR
这篇论文指出,在多实例学习(MIL)中,过度关注“易分类”样本会导致模型学习到的决策边界是次优的,而真正有助于模型提升泛化能力的是那些“难分类”的样本(hard instances)。
为了在缺乏实例标签的MIL框架下挖掘这些难样本,作者提出了掩码式难样本挖掘多实例学习(Masked Hard Instance Mining MIL, MHIM-MIL)框架。采用一个巧妙的教师-学生模型,教师模型稳定地为所有实例生成注意力分数,并执行掩码操作;学生模型使用未被掩码的难样本进行学习;通过一个一致性损失强制学生模型和教师模型的输出(包表示)保持一致,从而在迭代中相互促进、共同进步。
MHIM-MIL作为一个通用的训练框架,可以被应用在多种先进的MIL模型上,并一致性地、显著地提升了它们的性能。在CAMELYON-16和TCGA肺癌两大公开数据集上,搭载了MHIM框架的模型性能全面超越所有最新方法,相比于其他复杂的MIL框架,MHIM-MIL的教师模型无需梯度更新,因此几乎不增加额外的计算开销。
1. 背景介绍
机器学习的一个基本原理是,模型的决策边界是由那些位于类别边界附近的、最容易被混淆的样本所决定的。这些样本被称为难样本(hard examples)。利用这些难样本来学习一个更鲁棒、更精准的分类边界。
近年来几乎所有的先进多实例学习(MIL)方法都是识别并聚焦于最显著的实例。现有MIL方法对“显著实例”的过度关注,使模型反复学习易于分类的样本,导致模型学习到的决策边界可能只是局部最优的,泛化能力受到限制。
本文的出发点是如何在只有包级别标签、无法直接识别难样本的MIL框架下,有效地挖掘并利用这些宝贵的难实例。
2. MHIM-MIL
MHIM-MIL是一个教师-学生(Teacher-Student)训练框架,它通过“掩码”操作来间接地实现难样本挖掘。MHIM-MIL框架在训练阶段是一个孪生(Siamese)结构。
- 学生模型 (Student Model, $S$):可以是任何基于注意力的MIL模型,它的目标是学习对WSI进行分类。它接收的是经过教师模型“筛选”后的难样本,通过标准的反向传播和梯度下降进行更新。
- 教师模型 (Teacher Model, $T$):用来评估所有实例的重要性,并决定哪些实例是“容易的”,应该被过滤掉。它的参数是通过对学生模型的参数进行指数移动平均(Exponential Moving Average, EMA)来平滑更新的。
其中,$\theta_t, \theta_s$分别是教师和学生的参数,$λ$是动量系数(如0.9999)。这种方式使得教师模型比学生模型更“稳重”,其输出的注意力分数也更稳定,不易受到单次迭代中噪声的影响。
教师模型$T$接收完整的实例集$Z$,并为每个实例计算一个注意力分数$[a_1, \dots, a_N] = T(Z)$。将所有实例根据其注意力分数$a_i$从高到低排序。将注意力分数排在前$β_h\%$的实例标记为“掩盖”。这些被认为是模型最自信、最容易分类的样本。将所有被标记为“掩盖”的实例从原始实例集中移除,剩下的实例集$Ẑ$就构成了挖掘出的“难样本集”$\hat{Z} = \text{Mask}(Z, \hat{M})$。将这个只包含难样本的$Ẑ$作为输入,送入学生模型$S$进行训练$\hat{Y} = S(\hat{Z})$。
为了进一步提升效率和鲁棒性,作者还提出了几种混合掩盖策略。
- L-HAM (Low-Attention): 同时掩盖掉注意力分数最低的$β_l\%$的实例。这些通常是完全不相关的背景区域,移除它们可以减少计算量,提高训练效率。
- R-HAM (Random-Attention): 随机掩盖掉$β_r\%$的实例。这引入了随机性,类似于Dropout,可以有效降低过拟合的风险。
- LR-HAM: 将上述三种策略结合起来,同时掩盖最高、最低和随机的一部分实例。
为了让教师和学生能够相互促进,形成一个良性循环,MHIM-MIL引入了一致性损失(Consistency Loss)。学生模型的总损失:
\[\mathcal{L} = \mathcal{L}_{cls} + \alpha \mathcal{L}_{con}\]其中 $L_cls$ 是标准的分类损失(如交叉熵),计算学生模型对难样本集的预测$Ŷ$与真实包标签$Y$之间的误差。$L_con$是一致性损失,它要求学生模型在处理难样本集后得到的包表示$F_s$,与教师模型在处理完整实例集后得到的包表示$F_t$保持一致。
\[\mathcal{L}_{con} = -\text{softmax}(F_t / \tau) \log(F_s)\]$τ$是一个温度系数。
3. 实验分析
作者在CAMELYON-16和TCGA肺癌这两个极具挑战性的WSI数据集上,将MHIM-MIL框架应用到了AB-MIL, TransMIL, DSMIL等多个先进的MIL模型上,为所有基线模型都带来了显著的性能提升。即使是对于DTFD-MIL这样已经非常复杂的、专门挖掘显著实例的框架,MHIM-MIL依然能通过挖掘难样本带来性能提升。
MHIM-MIL的即插即用特性,证明了它是一个通用且有效的训练框架,而非一个特定的模型。教师模型通过EMA更新,没有额外的梯度计算,因此参数量和计算开销几乎不变。相比于其他复杂的MIL框架,MHIM-MIL几乎不增加额外开销,甚至在应用于TransMIL等计算复杂度与输入序列长度的平方相关的模型,掩盖掉部分实例意味着输入序列变短了,还能显著降低训练时间和内存占用。
消融研究:
- MHIM策略的重要性: 仅使用掩盖策略(没有教师-学生架构),就能为基线模型带来约2%的AUC提升,证明了难样本挖掘的核心思想是有效的。
- 教师-学生架构的重要性: 引入动量更新的教师模型,相比于让学生自己指导自己,带来了更稳定、更有效的性能提升。
- 一致性损失的重要性: 加入一致性损失后,性能得到进一步提升,证明了在教师和学生之间建立约束,能引导模型走向更好的优化方向。
- 混合掩盖策略的影响: 实验表明,不同的混合掩盖策略(如R-HAM, L-HAM)在不同数据集和模型上表现各异。例如,在实例冗余度高的TCGA数据集上,加入随机掩盖的R-HAM效果更好;而在关键信息更稀疏的CAMELYON-16上,同时去除低分实例的L-HAM表现更佳。这说明可以根据任务特性灵活选择掩盖策略。
通过可视化注意力热图,可以直观地看到MHIM-MIL的作用。基线模型往往会给非肿瘤区域(假阳性)分配较高的肿瘤概率,或者只关注最典型的肿瘤区域而忽略了其他同样是肿瘤但形态不典型的区域(假阴性)。MHIM-MIL模型对非肿瘤区域的误判显著减少,泛化能力更强。对于那些细微的、不典型的肿瘤区域,MHIM-MIL能够更准确地识别出来。
MHIM-MIL训练后的模型,其注意力分数(亮斑)看起来可能更“分散”,似乎关注了更多“不相关”的区域。但实际上,正是通过强迫模型去看这些难样本,它才学会了更全面、更鲁棒的判别规则。