不完整多模态生存预测的蒸馏提示学习.
0. TL; DR
DisPro (Distilled Prompt Learning) 框架旨在利用Large Language Models (LLMs) 对模态缺失的强大鲁棒性,通过一个两阶段的提示学习过程,为缺失的模态补偿全面的信息。
- 第一阶段:单模态提示 (Unimodal Prompting, UniPro):该阶段通过蒸馏学习,为每个单模态学习其知识分布,并将其压缩到一组可学习的unimodal prompts中。这为后续补充缺失模态的模态特有知识做准备。
- 第二阶段:多模态提示 (Multimodal Prompting, MultiPro):该阶段利用可用的模态作为LLM的prompts,来推断缺失模态的表示,从而补偿模态共性信息。同时,第一阶段学习到的unimodal知识被注入到多模态推理过程中,以补偿缺失模态的modality-specific知识。
作者在涵盖各种缺失场景的广泛实验中证明了该方法的优越性。这项工作为利用LLM处理不完整多模态数据提供了新的思路和强大的框架。
1. 背景介绍
多模态生存预测,特别是整合病理学图像(提供定性的形态学信息)和基因组数据(提供定量的分子信息)的方法,已在精准预后分析中显示出巨大潜力。尽管现有模型取得了进展,但它们通常依赖于一个前提:所有模态的数据都是完整的。
然而,在真实的临床环境中,由于数据采集成本高(尤其是基因组数据)、隐私问题等多种因素,获取完整的多模态数据往往是不可行的。这极大地限制了现有模型在临床上的应用。因此,构建一个对不完整模态数据具有鲁棒性的多模态生存模型,是当前临床实践面临的一个紧迫问题。
目前处理不完整模态问题的主流方法可分为两类:
- 基于补全的方法 (Imputation-based):
- 生成式 (Generation-based):使用生成模型从可用模态中合成缺失模态的特征或原始数据。但如图(a)所示,这种方法通常只能补全modality-common(模态共性)信息,因为缺失模态独有的modality-specific(模态特有)信息无法凭空生成。
- 检索式 (Retrieval-based):从训练集中检索最相似的样本来填补缺失的模态。但如图(b)所示,单个检索样本的随机性很大,难以完全捕捉缺失模态的独特知识。
- 无需补全的方法 (Imputation-free):
- 这类方法旨在通过学习对模态缺失不敏感的鲁棒多模态表示,来最小化性能下降。近期,有研究尝试利用Large Language Models (LLMs) 对模态缺失的鲁棒性,通过设计复杂的prompts来告知LLM不同的输入分布。然而,这些方法同样只关注了modality-common信息,忽略了缺失模态的特有知识。

综上,现有方法在有效整合缺失模态的共性和特有知识方面仍存在不足。为此,作者提出了DisPro框架,旨在通过一个两阶段的提示学习过程,同时补偿缺失模态的共性和特有信息。
2. DisPro 框架
DisPro的核心是一个两阶段的提示学习框架,旨在利用LLM的强大推理能力来处理不完整的多模态数据。

2.1 问题定义
遵循multiple instance learning (MIL) 范式,一张WSI被表示为一个patch实例的包 \(X_n^p = \{x_{n,i}^p\}_{i=1}^{M_p}\) ,基因组数据被表示为一个生物学通路的包 \(X_n^g = \{x_{n,i}^g\}_{i=1}^{M_g}\) 。生存时间被离散化为 $I_t$ 个时间区间。模型的目标是预测在每个时间区间的风险概率 \(h_n^{(j)}\) ,并使用NLL loss进行优化。
2.2 第一阶段:单模态提示 (Unimodal Prompting, UniPro)
该阶段的目标是为每个单模态学习其知识分布,以便在后续阶段为缺失模态补充modality-specific知识。作者将CoOp的思想扩展到了MIL设置中。
使用预训练的编码器(如UNI,SNN)将WSI patches和基因组通路编码为特征嵌入。为每个风险等级(如“中等风险,死亡”)构建一个文本模板,在该模板中,嵌入一组可学习的context tokens $[P]_1…[P]_k$。通过一个LLM(BioBERT)将这些带有可学习prompts的文本编码成每个风险等级的全局文本表示$t_p^{(j)}$。
计算每个patch(或通路)的特征与每个类别文本表示$t_p^{(j)}$的相似度。使用Top-K Max-pooling聚合patch级的相似度分数,得到slide级的预测。通过最小化生存损失$L_{surv}$来优化模态适配器和可学习的context tokens。
经过第一阶段,模型为每个模态的每个风险等级都学习到了一组优化的prompts,这些prompts蒸馏了该模态的知识分布。
2.3 第二阶段:多模态提示 (Multimodal Prompting, MultiPro)
该阶段利用LLM的推理能力,从可用模态中推断缺失模态的信息。由于LLM的输入长度有限,需要从海量的WSI patches或通路中选择信息量最丰富的tokens。作者提出了UniPro Scoring模块,它重用第一阶段学习到的UniPro作为“评分器”。一个token的最终分数由三部分组成:
- 与该模态对应的UniPro的相似度分数(Uni)。
- 与另一模态对应的UniPro的相似度分数(Cross)。
- 一个可学习的自评分机制(Self-scoring)。
选择得分最高的Top-K个tokens作为可用模态的输入。为缺失模态设置一组可学习的占位符tokens。将可用模态的tokens、缺失模态的占位符以及一个[CLS] token拼接起来,作为LLM的输入。
UniPro Distillation是补偿modality-specific知识的关键。将LLM输出中对应于缺失模态的部分提取出来,记为\([\tilde{g}_n]_{1...K_g}\)。使用这部分推断出的表示,通过与第一阶段学习到的缺失模态的UniPro进行交互(计算相似度、Top-K Pooling),计算一个蒸馏损失$L_{ud}$。通过最小化$L_{ud}$,强制模型推断出的表示能够与缺失模态自身的知识分布对齐。
MultiPro阶段的总损失由最终的生存预测损失(来自[CLS] token)和两个模态的蒸馏损失$L_{ud}$组成。
\[L = L_{surv}^{cls} + \alpha_1 L_{ud}^p + \alpha_2 L_{ud}^g\]3. 实验分析
3.1 实验设置
在五个TCGA癌症数据集上进行评估:BLCA, BRCA, COADREAD, LUAD, UCEC。通过随机丢弃样本的模态来模拟不同的训练缺失率(总缺失率60%)。在推理时,分别在pathology-only, genomics-only和complete三种场景下进行评估。
与多种SOTA方法进行比较,包括单模态方法、完整多模态方法,以及专门为不完整模态设计的方法(如COM, M3Care, HGCN, MUSE, MAP)。
3.2 与SOTA方法的比较
- 在60%训练缺失率下:DisPro在所有五个数据集、三种测试场景下,其平均性能都一致地、大幅度地优于其他为不完整模态设计的SOTA方法。例如,在完整数据测试时,其平均C-Index比第二名高出2.33%。
- 与完整多模态SOTA的比较:令人惊讶的是,即使在60%的训练数据缺失的情况下,DisPro在完整数据测试时的性能,在5个数据集中的3个上,甚至超过了在所有数据上训练的SOTA完整多模态模型(SurvPath)。这表明,利用LLM的强大能力是解决经典生存分析问题的一个非常有前途的方向。
- 在0%训练缺失率下(上限性能):当使用完整数据训练时,DisPro的性能进一步提升,一致地在所有测试场景下取得最佳表现,比SOTA完整多模态方法高出1.31%。

3.3 消融研究
- 移除所有Prompts (baseline):如果移除所有prompting机制,只保留一个简单的LLM分类头,性能会急剧下降。这证明了为LLM设计合适的prompts至关重要。
- 加入UniPro Distillation (+ UD):在baseline基础上仅加入UniPro Distillation(用于补充modality-specific知识),模型在所有推理场景下的性能都得到了一致的提升(平均约+1%)。
- 加入UniPro Scoring (+ US):在baseline基础上仅加入UniPro Scoring(用于选择判别性tokens以捕捉modality-common知识),性能也得到了约+1%的提升。
- 完整模型 (DisPro):当两个模块都加入时,模型性能达到最佳。这证明了两个模块的协同作用。

该图详细展示了在不同的训练缺失率组合下,DisPro与其他方法的性能对比。结果显示,DisPro在各种缺失设置下,都一致地、大幅度地优于其他方法。
