多模态知识增强全切片病理基础模型.

0. TL; DR

mSTAR (Multimodal Self-TAught PRetraining)是一个病理学基础模型,它在一个统一的框架内整合了三种模态:病理切片、专家撰写的报告和基因表达数据,使用了来自32个癌种的26,169个切片级模态对(包含超过1.16亿个图像块)进行开发。

mSTAR采用全切片预训练范式:

  1. 第一阶段:通过切片级对比学习预训练一个切片聚合器,旨在将多模态知识注入该聚合器。
  2. 第二阶段:利用这个预训练好的聚合器作为“教师模型”,通过自教式训练(Self-Taught Training)将全切片上下文信息教给图像块特征提取器。

这种方法将建模从单模态扩展到多模态,从patch-level分析扩展到slide-level分析。在一个涵盖97个任务的肿瘤学基准测试中,mSTAR的性能优于先前的state-of-the-art模型,特别是在分子预测和多模态任务中。研究结果还揭示,多模态整合比简单地扩大纯视觉数据集更能带来性能上的提升。

1. 背景介绍

Foundation models (FMs) 在计算病理学领域取得了巨大成功,推动了癌症诊断、治疗和预后等临床任务的进步。然而,尽管取得了令人鼓舞的性能,现有的病理学FMs仍面临几个未解决的挑战:

  1. 多模态数据利用不足:临床实践中常用的大量多模态数据,如pathology reportsgene expression profiles,在预训练中未得到充分利用。现有的病理学FMs要么只关注vision-only数据,要么使用image-caption数据,但简单的字幕信息不足以提供真实肿瘤学任务所需的whole-slide(全切片)上下文。而病理报告和基因表达谱分别能提供最相关的临床信息和定量的分子动力学信息,整合这些数据能够建立更全面、更整体的视角。
  2. 局限于图像块级建模:现有的病理学FMs主要在patch/ROI-level数据上进行建模。它们通常预训练一个patch extractor,然后在下游任务中,使用multiple instance learning (MIL) 来聚合patch特征以进行slide-level的建模。这种两阶段方法存在一个固有局限:patch特征的质量上限决定了最终模型的性能上限。由于slide-level的多模态自监督信号未能指导patch-level的特征提取,这两个独立阶段的预训练目标不一致,不可避免地导致次优性能。

为了应对这些挑战,作者提出了mSTAR (Multimodal Self-TAught PRetraining)。这是一个whole-slide pretraining paradigm,旨在将多模态知识注入到病理学基础模型中,从而将上下文理解从patch-level扩展到slide-level,从单模态扩展到多模态。

2. mSTAR 框架

mSTAR的预训练范式包含两个阶段,旨在将多模态知识无缝地嵌入到基础模型中。

2.1 第一阶段:预训练切片聚合器 (Pretrain slide aggregator)

此阶段的目标是通过slide-level contrastive learning,将多模态知识注入到一个slide aggregator中。

WSI使用一个预训练好的patch extractorUNI)将每个patch编码为特征,然后将这些patch特征输入到一个slide aggregator(一个2层的TransMIL)中,整合成一个slide-level的表示$P_i$。

Pathology Reports使用一个Bert-like的文本编码器(BioBERT)来编码报告,得到报告的[CLS] 标记嵌入$T_i$。

RNA-Seq基因名通过Gene2Vec生成embedding,基因表达值离散化后也生成embedding,两者相加后输入一个Performer编码器,得到基因的[CLS] 标记嵌入$G_i$。

模态间对比学习 (Inter-modality contrastive learning)遵循CLIP的思想,对齐来自同一个样本的不同模态的表示。具体来说,对于一个mini-batch中的$N$个样本,最大化匹配的模态对(如$P_i, T_i$)之间的相似度,同时最小化不匹配的模-态对(如$P_i, T_j$)之间的相似度。该损失函数对三种模态对(WSI-报告、WSI-基因、报告-基因)分别计算并相加。

癌症间对比学习 (Inter-cancer contrastive learning):为了缓解不同癌种的异质性,作者利用TCGA中固有的癌症类型标签。将一个样本的所有可用模态的[CLS] 标记拼接成一个锚点表示$a_i$,然后在mini-batch内通过hard sample mining技术,使用triplet loss来拉近同种癌症样本的表示,推远不同癌症样本的表示。

\[L_{triplet} = \frac{1}{N} \sum_{i=1}^N \max(d(a_i, a^+) - d(a_i, a^-) + \epsilon, 0)\]

经过第一阶段,slide aggregator就学习到了多模态知识,并将在下一阶段扮演“教师”的角色。

2.2 第二阶段:预训练图像块提取器 (Pretrain patch extractor)

此阶段的目标是利用在第一阶段预训练好的slide aggregator,通过一种名为Self-Taught training的方法,将slide-level的多模-态知识无缝地传播到patch extractor中。

教师网络是第一阶段预训练好的slide aggregator。学生网络是一个需要预训练的patch extractor(如ViT-L)。

对于一张WSI,将其所有patch特征$P_i$输入“教师”聚合器,得到被重新嵌入的、包含了多模态知识的新特征$\hat{P}_i$。训练“学生”提取器,使其提取的patch特征$p_i^m$尽可能地接近“教师”给出的目标特征$\hat{p}_i^m$。这通过最小化两者之间的L1损失来实现。

为了避免灾难性遗忘(catastrophic forgetting),作者采用了siamese结构。patch extractor包含两个分支,一个分支通过梯度下降更新,另一个分支的参数则通过前一个分支参数的Exponential Moving Average (EMA) 进行更新。通过在两个分支的输出之间施加相似性约束,来保持模型的稳定性。

最终损失函数是上述两个目标的加权和。

\[\min \sum_{i}^M \sum_{m} \lambda ||f(p_i^m) - \hat{p}_i^m||_1 + (1-\lambda)||p_i^m - \tilde{p}_i^m||_1\]

通过这两个阶段的预训练,mSTARpatch extractor被注入了whole-slide上下文的多模态知识,使其能够更好地理解patches和整个WSI,从而在各种下游任务中表现更佳。

3. 实验分析

作者在一系列广泛的肿瘤学基准测试中对mSTAR进行了评估,涵盖了7大类、15种类型、共97个任务。

3.1 病理学诊断 (Pathological diagnosis)

mSTAR在各种病理学诊断任务中均表现出卓越的性能和泛化性。

3.2 分子预测 (Molecular prediction)

通过在病理图像和基因表达数据上的联合预训练,mSTAR在仅使用病理图像进行分子预测的任务上,表现出卓越的性能和强大的泛化能力。

3.3 视觉-语言评估 (Vision-language evaluation)

由于在预训练中引入了病理报告,mSTAR获得了强大的语言相关能力,使其在零样本分类、检索和报告生成等任务中表现出色。

3.4 生存预测与多模态融合

在16个生存预测任务中,mSTARheld-out和外部队列上都表现出了一致的优越性和强大的泛化能力。

作者将mSTAR提取的病理特征与其他FMs提取的特征,分别输入到4个现有的多模态融合模型中进行比较。结果显示,使用mSTAR特征的模型性能全面领先。其平均排名为1.47,远超第二名UNI(2.68)。在平均C-Index上,mSTAR也取得了1.8%的显著提升。

mSTAR的多模态预训练使其能够学习到与其他模态更对齐的病理特征,从而极大地促进了下游的多模态融合任务。

3.5 消融研究

多模态整合比单纯的单模态数据量扩展,能带来更高的效率和性能回报,为资源受限的医学AI开发提供了一条更实用的路径。