使用距离感知自注意力的深度多实例学习.

0. TL; DR

这篇论文提出了距离感知自注意力多实例学习(Distance-Aware Self-Attention MIL, DAS-MIL)模型。该模型在计算任意两个实例之间的注意力权重时,加入了一个与它们欧氏距离相关的、可学习的偏置项。通过一个可学习的Sigmoid函数,将连续的距离映射为一个插值系数,用于在两个可学习的嵌入向量之间进行插值。

在一个需要距离感知才能解决的MNIST-COLLAGE数据集上,DAS-MIL取得了95.8%的测试准确率,超越了所有不考虑相对位置或使用绝对位置编码的MIL模型。在公开的CAMELYON16乳腺癌转移检测任务上,DAS-MIL也超越了强基线模型,证明了其在真实世界复杂场景下的有效性。

1. 背景介绍

多实例学习(MIL)的经典假设是,包内的实例是独立同分布(i.i.d.)的。然而,在许多现实应用中,这个假设并不成立。尤其是在计算病理学中,WSI被切分成的大量patches之间存在着复杂的空间依赖关系。例如肿瘤细胞与免疫细胞的空间邻近关系。

为了打破i.i.d.假设,捕捉实例间的关联,近期的许多研究开始将自注意力(Self-Attention)机制引入MIL的聚合阶段。自注意力允许模型在计算每个实例的权重时,综合考虑包内所有其他实例的信息,从而动态地建模实例间的依赖关系。

现有的MIL模型在计算注意力时,只考虑了实例的特征相似度,而完全忽略了它们的相对空间位置。绝对位置编码可以打破置换不变性,但是无法有效表示相对关系,并且缺乏旋转不变性。本文的出发点是能否设计一种自注意力机制,使其在计算注意力权重时,能够直接、显式地考虑任意两个实例之间的相对距离。

2. DAS-MIL

DAS-MIL的整体框架是一个标准的嵌入式MIL模型,并引入了距离感知自注意力(Distance-Aware Self-Attention)DAS-MIL整体流程:

  1. 特征提取: 用CNN为每个patch提取特征向量。
  2. 距离感知自注意力: 将所有patch的特征向量和它们之间的成对欧氏距离矩阵,送入DAS-Attn层,得到一组经过空间关系加权的、新的特征向量。
  3. 聚合: 对新的特征向量进行max-pooling,得到最终的包(WSI)嵌入。
  4. 分类: 将包嵌入送入一个线性分类器,得到最终的分类结果。

对于一个实例$i$和一个实例$j$,它们之间的注意力兼容性(compatibility)$e_{ij}$由它们的查询(Query)、键(Key)向量的点积决定:

\[e_{ij} = \frac{(x_i W^Q)(x_j W^K)^T}{\sqrt{d_z}}\]

其中,$x_i, x_j$是实例的输入特征,$W^Q, W^K$是可学习的权重矩阵。这个$e_{ij}$决定了在生成实例$i$的输出时,应该对实例$j$“关注”多少。

DAS-MIL的核心思想是,在计算$e_{ij}$和最终的输出时,加入与实例$i$和$j$之间相对距离$δ_ij$相关的可学习偏置项 :

\[e_{ij} = \frac{(x_i W^Q + b_{ij}^Q)(x_j W^K + b_{ij}^K)^T - (b_{ij}^Q)(b_{ij}^K)^T}{\sqrt{d_z}}\]

这里的 $b_{ij}^Q$ 和 $b_{ij}^K$ 就是距离相关的偏置向量。它们被加到原始的QueryKey向量上,从而在点积计算中引入了距离信息。

同样,在最终的加权求和中,也加入了一个距离相关的偏置项 $b_{ij}^V$:

\[z_i = \sum_{j=1}^n \alpha_{ij} (x_j W^V + b_{ij}^V)\]

上述偏置项$b_{ij}$必须是距离$\delta_{ij}$的函数。但$\delta_{ij}$是一个连续值,而以往的相对位置表示大多处理的是离散的词间距。作者提出了一种插值方案来对连续距离进行编码。以Key偏置项$b_{ij}^K$为例:

\[b_{ij}^K = \phi(\delta_{ij}) u^K + (1 - \phi(\delta_{ij})) v^K\]

其中$u^K, v^K$是两个可学习的、固定大小的嵌入向量,它们分别代表了“近距离关系”和“远距离关系”的编码原型。$\phi(\delta_{ij})$是一个可学习的插值函数,它将连续的距离$\delta_{ij}$映射到一个$[0, 1]$区间的标量,这个标量决定了最终的偏置项$b_{ij}^K$在$u^K$和$v^K$之间的插值比例。 $\phi(\cdot)$参数化为一个经过缩放和平移的Sigmoid函数,其中的缩放$β$和平移$θ$参数都是可学习的。

\[\phi(\delta) = \text{sigmoid}(\beta \cdot \delta + \theta)\]

3. 实验分析

为了验证模型是否真的能理解和利用相对距离,作者构建了一个名为MNIST-COLLAGE的新数据集。每个“包”是一张大画布,上面随机散布着MNIST数字。一个包被标记为“正”,当且仅当它同时包含一个“0”和一个“1”,并且它俩的距离小于某个阈值。还有一个反向版本MNIST-COLLAGE-INV,要求距离大于阈值。

MNIST-COLLAGE上,DAS-MIL取得了95.8%的测试准确率,在MNIST-COLLAGE-INV上也达到了90.6%。这两个结果都显著优于所有对比方法。不考虑任何位置信息的模型(如AB-MIL)的性能上限被卡在88%左右,因为它们无法区分那些包含了“0”和“1”但距离不达标的负包。使用绝对位置编码的模型同样表现不佳,证明了这种编码方式难以有效推断相对距离。与DAS-MIL最接近的是离散相对位置编码,它取得了第二好的成绩。但这进一步凸显了DAS-MIL处理连续距离的优势。

可视化注意力热力图显示,DAS-MIL确实对那些满足距离条件的“0”和“1”对产生了强烈的注意力响应,证明了其机制的有效性。

为了测试真实世界的病理学挑战,使用了公开的乳腺癌淋巴结转移检测数据集CAMELYON16。为了减少训练时间,作者使用了在大量病理图像上预训练好的Transformer模型来提取patch特征。

在测试集上,DAS-MIL的AUROC达到了0.914,平衡准确率为86.4%,超越了所有对比的MIL模型。虽然TransMIL在训练集上表现略好,但DAS-MIL在测试集上实现了反超,说明DAS-MIL的泛化能力更强。这可能是因为它学习到的空间关系是一种更本质、更不易过拟合的模式。结果证明在病理图像分析中,patch间的相对空间关系确实是一种至关重要的、不可或缺的信息。