不完整多模态生存预测的蒸馏提示学习.

0. TL; DR

DisPro (Distilled Prompt Learning) 框架旨在利用Large Language Models (LLMs) 对模态缺失的强大鲁棒性,通过一个两阶段的提示学习过程,为缺失的模态补偿全面的信息。

  1. 第一阶段:单模态提示 (Unimodal Prompting, UniPro):该阶段通过蒸馏学习,为每个单模态学习其知识分布,并将其压缩到一组可学习的unimodal prompts中。这为后续补充缺失模态的模态特有知识做准备。
  2. 第二阶段:多模态提示 (Multimodal Prompting, MultiPro):该阶段利用可用的模态作为LLMprompts,来推断缺失模态的表示,从而补偿模态共性信息。同时,第一阶段学习到的unimodal知识被注入到多模态推理过程中,以补偿缺失模态的modality-specific知识。

作者在涵盖各种缺失场景的广泛实验中证明了该方法的优越性。这项工作为利用LLM处理不完整多模态数据提供了新的思路和强大的框架。

1. 背景介绍

多模态生存预测,特别是整合病理学图像(提供定性的形态学信息)和基因组数据(提供定量的分子信息)的方法,已在精准预后分析中显示出巨大潜力。尽管现有模型取得了进展,但它们通常依赖于一个前提:所有模态的数据都是完整的。

然而,在真实的临床环境中,由于数据采集成本高(尤其是基因组数据)、隐私问题等多种因素,获取完整的多模态数据往往是不可行的。这极大地限制了现有模型在临床上的应用。因此,构建一个对不完整模态数据具有鲁棒性的多模态生存模型,是当前临床实践面临的一个紧迫问题。

目前处理不完整模态问题的主流方法可分为两类:

综上,现有方法在有效整合缺失模态的共性和特有知识方面仍存在不足。为此,作者提出了DisPro框架,旨在通过一个两阶段的提示学习过程,同时补偿缺失模态的共性和特有信息。

2. DisPro 框架

DisPro的核心是一个两阶段的提示学习框架,旨在利用LLM的强大推理能力来处理不完整的多模态数据。

2.1 问题定义

遵循multiple instance learning (MIL) 范式,一张WSI被表示为一个patch实例的包 \(X_n^p = \{x_{n,i}^p\}_{i=1}^{M_p}\) ,基因组数据被表示为一个生物学通路的包 \(X_n^g = \{x_{n,i}^g\}_{i=1}^{M_g}\) 。生存时间被离散化为 $I_t$ 个时间区间。模型的目标是预测在每个时间区间的风险概率 \(h_n^{(j)}\) ,并使用NLL loss进行优化。

2.2 第一阶段:单模态提示 (Unimodal Prompting, UniPro)

该阶段的目标是为每个单模态学习其知识分布,以便在后续阶段为缺失模态补充modality-specific知识。作者将CoOp的思想扩展到了MIL设置中。

使用预训练的编码器(如UNISNN)将WSI patches和基因组通路编码为特征嵌入。为每个风险等级(如“中等风险,死亡”)构建一个文本模板,在该模板中,嵌入一组可学习的context tokens $[P]_1…[P]_k$。通过一个LLMBioBERT)将这些带有可学习prompts的文本编码成每个风险等级的全局文本表示$t_p^{(j)}$。

计算每个patch(或通路)的特征与每个类别文本表示$t_p^{(j)}$的相似度。使用Top-K Max-pooling聚合patch级的相似度分数,得到slide级的预测。通过最小化生存损失$L_{surv}$来优化模态适配器和可学习的context tokens

经过第一阶段,模型为每个模态的每个风险等级都学习到了一组优化的prompts,这些prompts蒸馏了该模态的知识分布。

2.3 第二阶段:多模态提示 (Multimodal Prompting, MultiPro)

该阶段利用LLM的推理能力,从可用模态中推断缺失模态的信息。由于LLM的输入长度有限,需要从海量的WSI patches或通路中选择信息量最丰富的tokens。作者提出了UniPro Scoring模块,它重用第一阶段学习到的UniPro作为“评分器”。一个token的最终分数由三部分组成:

  1. 与该模态对应的UniPro的相似度分数(Uni)。
  2. 与另一模态对应的UniPro的相似度分数(Cross)。
  3. 一个可学习的自评分机制(Self-scoring)。
\[s_{n,\#}^{(i)} = s_{n,p}^{(i, \tau)} + s_{n,g}^{(i, \tau)} + a_{n,p}^{(i)}\]

选择得分最高的Top-Ktokens作为可用模态的输入。为缺失模态设置一组可学习的占位符tokens。将可用模态的tokens、缺失模态的占位符以及一个[CLS] token拼接起来,作为LLM的输入。

UniPro Distillation是补偿modality-specific知识的关键。将LLM输出中对应于缺失模态的部分提取出来,记为\([\tilde{g}_n]_{1...K_g}\)。使用这部分推断出的表示,通过与第一阶段学习到的缺失模态的UniPro进行交互(计算相似度、Top-K Pooling),计算一个蒸馏损失$L_{ud}$。通过最小化$L_{ud}$,强制模型推断出的表示能够与缺失模态自身的知识分布对齐。

MultiPro阶段的总损失由最终的生存预测损失(来自[CLS] token)和两个模态的蒸馏损失$L_{ud}$组成。

\[L = L_{surv}^{cls} + \alpha_1 L_{ud}^p + \alpha_2 L_{ud}^g\]

3. 实验分析

3.1 实验设置

在五个TCGA癌症数据集上进行评估:BLCA, BRCA, COADREAD, LUAD, UCEC。通过随机丢弃样本的模态来模拟不同的训练缺失率(总缺失率60%)。在推理时,分别在pathology-only, genomics-onlycomplete三种场景下进行评估。

与多种SOTA方法进行比较,包括单模态方法、完整多模态方法,以及专门为不完整模态设计的方法(如COM, M3Care, HGCN, MUSE, MAP)。

3.2 与SOTA方法的比较

3.3 消融研究

该图详细展示了在不同的训练缺失率组合下,DisPro与其他方法的性能对比。结果显示,DisPro在各种缺失设置下,都一致地、大幅度地优于其他方法。