基于多模态最优传输的协同注意力Transformer模型及其全局结构一致性用于生存预测.
0. TL; DR
MOTCat (Multimodal Optimal Transport-based Co-Attention Transformer) 框架的核心创新在于引入了最优传输理论,提出一种基于OT的协同注意力机制,从全局视角匹配病理图像patches和基因embeddings。这不仅能够筛选出信息最丰富的patches来代表WSI,更重要的是,它强制模型在匹配时考虑每个模态的内在结构,从而有效捕捉对生存预测至关重要的TME交互。
为了解决OT在处理海量WSI patches时面临的巨大计算复杂度问题,作者提出了一种鲁棒且高效的实现方式,即通过对WSI patches进行micro-batch处理,并使用非平衡mini-batch OT来近似原始OT问题。
作者在五个基准数据集上进行了广泛实验,结果表明,MOTCat的性能显著优于state-of-the-art方法,证明了其在多模态生存分析中的优越性。
1. 背景介绍
生存预测是癌症预后研究中的一个复杂序数回归任务,旨在估计患者的相对死亡风险。在临床实践中,整合病理学的定性形态学信息和基因组学的定量分子谱,对于精准预测至关重要。
然而,现有的多模态学习方法在整合组织学和基因组学数据时面临一些悬而未决的问题:
- WSI的有效表示:WSI的尺寸巨大,如何从中提取关键信息而不丢失重要细节是一个巨大挑战。
- TME交互的建模:肿瘤微环境(TME)内的交互,如肿瘤细胞与肿瘤浸润淋巴细胞(TILs)的共现,是重要的预后指标。现有的co-attention方法通过计算病理patch和基因instance之间的密集局部相似性来指导信息选择。但是,这种“局部视角”忽略了每个模态内部的“全局结构”,例如WSI中空间上分散但功能上协作的TME组分,以及基因组数据中固有的共表达模式。
作者认为,TME的形态学结构与基因的共表达网络之间可能存在内在的结构一致性。利用这种全局结构一致性来指导模态间的匹配,更有可能识别出与TME相关的、对生存预测有价值的病理patches。
基于此,作者引入了Optimal Transport (OT) 理论。OT是一种结构匹配方法,旨在寻找两个分布之间的最优“传输方案”,使得总体传输成本最小。作者将其应用于多模态协同注意力,提出MOTCat框架,希望通过OT的全局视角来解决上述挑战。

2. MOTCat 框架
MOTCat的整体框架如图所示。其核心是将两种模态(WSI和基因组)都表述为“bag”(包)结构,然后通过OT-based Co-Attention来识别信息最丰富的instances(实例),最后将筛选后的特征输入Transformer编码器进行融合和预测。

2.1 多模态包的构建 (Multimodal Bags Formulation)
遵循multiple instance learning (MIL) 的范式,一张WSI被视为一个包含多个patch实例的“包”。
\[B_n^p = \{f_p(x_{n,i}^p) : x_{n,i}^p \in X_n^p\} = \{b_{n,i}^p\}_{i=1}^{M_p}\]其中,$X_n^p$是第$n$个患者的WSI,$x_{n,i}^p$是其中的第$i$个patch,$f_p(\cdot)$是一个CNN编码器(ResNet-50)用于提取patch特征$b_{n,i}^p$。
基因组数据也根据生物学功能影响被组织成一个“包”。
\[B_n^g = \{f_g^j(x_{n,j}^g) : x_{n,j}^g \in X_n^g\} = \{b_{n,j}^g\}_{j=1}^{M_g}\]其中,$X_n^g$是第$n$个患者的基因组数据,它被分为$M_g$个功能类别(如肿瘤抑制、致癌等),$f_g^j(\cdot)$是用于编码第$j$个功能类别的网络。
2.2 基于最优传输的协同注意力 (Optimal Transport-based Co-Attention)
这是MOTCat的核心。它旨在通过寻找病理学bag $B_n^p$和基因组学bag $B_n^g$之间的最优匹配方案,来识别信息最丰富的病理patches。
寻找一个传输方案(optimal matching flow)$P_n$,使得在满足边际约束的条件下,总的传输成本最小。其离散的Kantorovich形式为:
\[W(B_n^p, B_n^g) = \min_{P_n \in \Pi(\mu_p, \mu_g)} <P_n, C_n>_F\]其中成本矩阵$C_n$其元素$C_{n}^{u,v} = c(b_{n,u}^p, b_{n,v}^g)$是病理patch $u$和基因组实例$v$之间的局部距离(如L2距离)。传输方案的集合$\Pi(\mu_p, \mu_g)$要求$P_n$的行和与列和分别等于病理学和基因组学的边际分布$\mu_p$和$\mu_g$。
这个OT公式强制模型在寻找最优匹配时进行全局权衡,而不仅仅是考虑局部相似性。这使得模型能够捕捉病理学的空间交互结构和基因组学的共表达结构之间的一致性。
得到最优传输方案 $P_n^*$ 后,信息量最丰富的WSI实例(patches)可以通过 $B_n^p$ 与 \(P_n^*\) 的乘积来识别,即 \(\hat{B}_n^p = P_n^{*\top} B_n^p\) 。
2.3 基于微批次的优化 (Optimization over Micro-Batch)
由于一张WSI的patch数量巨大($M_p$可能大于50,000),直接求解上述OT问题的计算复杂度极高。作者为此提出了一种高效的近似方法。
将一张WSI的patch bag分割成多个子集,称为Micro-Batch,每次只在一个Micro-Batch上进行计算。作者采用UMBOT公式来近似原始的OT问题。UMBOT通过引入熵正则化和边际惩罚项,使得在大规模数据集上的计算更为高效和稳定。其优化问题变为:
\[W_m(B_{n,m}^p, B_n^g) = \min_{P_n^m \in \Pi} <P_n^m, C_n^m>_F + \epsilon KL(P_n^m | \mu_p^m \otimes \mu_g) \\+ \tau(D_\phi(P_{n,p}^m || \mu_p^m)+P_{n,g}^m || \mu_g)\]其中,$B_{n,m}^p$是一个大小为$m$的micro-batch。这种方法将计算复杂度从$O(M^3\log(M))$降低到$O(M \times m)$。
2.4 生存预测
对于每个模态(病理学和基因组学),使用一个Transformer编码器来聚合筛选出的实例特征,得到bag-level的表示$H_n^p$和$H_n^g$。
将两个模态的bag-level表示拼接起来,送入一个最终的分类器,以预测离散化时间区间上的风险概率。
使用NLL-loss(负对数似然损失)作为生存预测的损失函数,该损失函数能够处理删失数据。
3. 实验分析
3.1 实验设置
使用了来自TCGA的五个公共癌症数据集:BLCA, BRCA, UCEC, GBMLGG, LUAD。使用c-index(一致性指数)作为主要性能指标。
与多种单模态(SNN, TransMIL等)和多模态(MCAT, Porpoise等)的SOTA方法进行比较。
3.2 实验结果
MOTCat在5个数据集中的4个上取得了最高性能,证明了其有效融合了多模态数据。值得注意的是,在UCEC数据集上,大多数多模态方法的性能甚至不如单模态的基因组学模型,这突显了多模态融合的挑战性。然而,MOTCat在该数据集上仍然取得了与基因组模型相当的性能。
在一对一的比较中,MOTCat在所有基准数据集上都取得了优于其他多模态SOTA方法的性能,c-index提升了1.0%-2.6%。特别是与最相似的MCAT相比,MOTCat在所有数据集上都表现更佳,这证明了从全局视角学习模态间结构一致性的有效性。

3.3 消融研究
作者比较了三个变体:(a) 原始的MCAT(无OT,无MB),(b) 在MCAT基础上加入Micro-Batch策略(有MB,无OT),(c) 完整的MOTCat(有OT,有MB)。
从(a)到(b)的性能提升表明,Micro-Batch策略本身就能带来好处,这可能是因为它增加了训练中“bag”的数量。从(b)到(c)的性能提升则证明了OT-based co-attention的有效性。
作者比较了三种方法在不同micro-batch大小(128, 256, 512)下的性能和鲁棒性。MOTCat(蓝色)在各种大小下都取得了最高的平均性能,特别是在样本量最大的BRCA数据集上。在UCEC和LUAD数据集上,MOTCat的结果最鲁棒(方差最小)。MOTCat在性能和鲁棒性之间取得了最佳的权衡。

3.4 统计分析与可视化
作者使用MOTCat预测的风险评分中位数将患者分为高风险(红色)和低风险(绿色)组。Kaplan-Meier曲线显示,在所有五个数据集中,两组患者的生存曲线都有显著的统计学差异(p-value均<0.05)。这直观地证明了MOTCat在患者分层方面的有效性。
