生存分析的跨模态翻译与对齐.
0. TL; DR
CMTA (Cross-Modal Translation and Alignment) 框架的核心思想是为多模态数据构建两个并行的编码器-解码器结构,用于整合模态内信息并生成跨模态表示。设计一个跨模态注意力模块作为不同模态之间的信息桥梁,以探索内在的跨模态相关性,并传递互补信息。利用生成的跨模态表示来增强和校准模态内的表示,从而显著提高其在综合生存分析中的判别能力。
作者在五个公开的TCGA数据集上进行了广泛的实验。结果表明,CMTA框架的性能优于state-of-the-art方法,证明了其在探索和利用多模态数据互补信息方面的有效性。
1. 背景介绍
生存分析是临床预后研究中的一个关键课题,其目标是预测从某个已知起点到某个感兴趣事件(如死亡、疾病复发)发生所经过的时间。准确的生存预测对于医生评估疾病进展和治疗效果至关重要。
传统上,生存分析依赖于临床指标和长期随访报告,但这既耗时又在临床应用中不切实际。近年来,随着deep learning技术的成功,医学图像分析取得了显著进展,研究者们开始致力于建模影像特征与生存事件之间的联系。
- Radiology(放射学)图像(如CT, MRI)可以提供宏观信息,如病变位置、形态纹理等。
- Pathological images(病理图像,即WSI)则能提供关于肿瘤细胞及其微环境的微观信息。基于multiple-instance learning (MIL) 的病理学生存分析方法能够识别和高亮对生存事件有贡献的重要图像区域。
- Genomic profiles(基因组谱)则从分子层面为我们提供了理解生存事件的全新视角。
尽管单模态的生存分析已取得可喜的成果,但结合来自不同视角的多模态数据可以提供互补信息,从而提高生存分析的敏感性。然而,现有的多模态方法存在一些问题:
- 简单融合:直接拼接多模态特征会忽略模态之间潜在的相关性和交互作用。
- 单向指导:利用一种模态(如基因组谱)作为指导来关注另一种模态(如病理图像)的相关部分,虽然合理,但可能存在问题。如果作为指导的模态本身预测能力较弱,反而会“污染”信息量更丰富的模态。此外,这种方法会丢弃与指导模态无关但可能对预后很重要的信息,这违背了整合多模态互补信息的初衷。
基于这些观察,作者提出了CMTA框架,旨在探索内在的跨模态相关性,并传递潜在的互补信息。
2. CMTA 框架
CMTA框架的整体示意图如图所示。其核心是为两种模态(病理学和基因组学)构建并行的encoder-decoder结构,并通过一个cross-modal attention module进行信息交互。

2.1 问题定义与数据处理
目标是开发一个生存预测模型$F$,它整合病理图像$P$和基因组谱$G$,以估计风险函数$f_{hazard}(T=t|…)$。
病理图像处理:
- 使用CLAM将WSI裁剪成一系列不重叠的512x512的patches。
- 使用在ImageNet上预训练的ResNet-50提取每个patch的1024维特征。
- 通过一个全连接层将特征降维到$d$维。
- 最终,一个患者的病理图像被表示为$P = {p_1, …, p_M} \in \mathbb{R}^{M \times d}$。
基因组谱处理:
- 将基因组谱(RNA-seq, CNV, SNV)根据功能分为6组(如肿瘤抑制、致癌、蛋白激酶等)。
- 通过一个全连接层将每个基因组的特征降维到$d$维。
- 最终,一个患者的基因组谱被表示为$G = {g_1, …, g_K} \in \mathbb{R}^{K \times d}$。
2.2 病理学与基因组学编码器
作者引入自注意力机制来为每个模态构建编码器,以整合模态内的信息。
- Pathology Encoder:输入为patch集$P$。作者使用了一个可学习的class token $p^{(0)}$来聚合所有patch的信息。网络包含两层Multi-head Self-attention (MSA)和一个Pyramid Position Encoding Generator (PPEG) 模块(用于探索patches间的位置相关性)。最终,编码器输出的class token $p$被作为病理学的intra-modal representation(模态内表示)。
- Genomics Encoder:结构与Pathology Encoder类似,但没有PPEG模块。输入为基因组谱$G$,输出的class token $g$被作为基因组学的intra-modal representation。
2.3 跨模态注意力模块 (Cross-Modal Attention Module)

跨模态注意力模块是CMTA框架的核心,旨在探索和传递跨模态的互补信息。
输入来自两个编码器的实例tokens,即$\mathcal{P} \in \mathbb{R}^{M \times d}$和$\mathcal{G} \in \mathbb{R}^{K \times d}$。该模块计算两个注意力图$H_p$和$H_g$。
\[H_p = \text{softmax}\left(\frac{(\mathcal{G}U) \times (\mathcal{P}V)^T}{\sqrt{d}}\right) \in \mathbb{R}^{K \times M} \\ H_g = \text{softmax}\left(\frac{(\mathcal{P}V) \times (\mathcal{G}U)^T}{\sqrt{d}}\right) \in \mathbb{R}^{M \times K}\]其中,$U$和$V$是可学习的参数矩阵。$H_p$表示从基因组tokens到病理学tokens的关联状态,而$H_g$则表示从病理学tokens到基因组tokens的关联状态。
利用这两个注意力图,可以从一种模态的tokens中提取与另一种模态相关的信息。
\[P^* = H_p \times (\mathcal{P}W_p)\\ G^* = H_g \times (\mathcal{G}W_g)\]其中,$P^$是病理学tokens中与基因组相关的信息,而$G^$是基因组tokens中与病理学相关的信息。
2.4 基因组学与病理学解码器
直接将提取出的跨模态信息$P^$和$G^$与模态内表示$p$和$g$叠加是不合理的,因为存在数据异质性。因此,作者构建了两个解码器,将这些“翻译”过来的信息转换为与原始模态对齐的跨模态表示。
- Pathology Decoder:输入为$P^*$,输出一个跨模态表示$\hat{g}$。
- Genomics Decoder:输入为$G^*$,输出一个跨模态表示$\hat{p}$。
2.5 特征对齐与融合
利用跨模态表示来增强和校准模态内表示。将增强后的两个模态的表示进行拼接,送入一个MLP进行最终的生存预测。
\[T_1, ..., T_t = \text{sigmoid}\left( \text{MLP}\left( \frac{p+\hat{p}}{2} \oplus \frac{g+\hat{g}}{2} \right) \right)\]为了确保信息转换的质量,作者引入了一个对齐约束,即最小化跨模态表示与模态内表示之间的L1距离。
\[L_{sim} = \frac{1}{d} (||p - \hat{p}||_1 + ||g - \hat{g}||_1)\]一个关键的技术细节是,在计算$L_{sim}$时,模态内表示$p$和$g$必须从计算图中分离(detach)。这确保了对齐是单向的(即$\hat{p} \to p$, $\hat{g} \to g$),避免模型陷入学习冗余共享信息的困境。
最终的损失函数是生存预测损失$L_{sur}$和对齐损失$L_{sim}$的加权和。
\[L_{total} = L_{sur} + \alpha L_{sim}\]3. 实验分析
3.1 实验设置
使用了TCGA中的五个癌症数据集:BLCA, BRCA, GBMLGG, LUAD, UCEC。评估指标使用c-index(一致性指数)来评估模型的性能。
作者将CMTA与多种state-of-the-art的单模态和多模态生存预测方法进行了比较,包括SNN, TransMIL, MCAT, M3IF, GPDBN, Porpoise, HFBSurv等。
3.2 与SOTA方法的比较
- 与单模态模型比较:CMTA在所有五个TCGA数据集上都一致地取得了最佳性能。与表现最好的单模态模型相比,c-index的提升幅度在0.75%到5.29%之间。这证明了多模态数据在生存预测中的优势。
- 与多模态模型比较:与之前的SOTA多模态方法MCAT相比,CMTA在BLCA, BRCA, GBMLGG, LUAD, UCEC五个数据集上分别提升了1.83%, 0.89%, 1.81%, 2.67%, 6.39%。作者认为,这是因为MCAT丢弃了与基因表达无关的病理学信息,而CMTA则充分利用了两种模态的互补信息。
- 与基线模型比较:与一个简单的多模态基线DualTrans(直接拼接两个编码器的输出)相比,CMTA的性能在五个数据集上分别提升了3.03%, 0.42%, 1.38%, 1.58%, 2.51%。这证明了通过跨模态表示来增强和校准模态内表示的策略是有效的。

3.3 消融研究
- 移除跨模态注意力模块:性能在所有数据集上都出现了下降,特别是在LUAD上降低了4.93%。这表明,在进行信息转换时,高亮相关信息是必要的。
- 移除对齐约束:性能同样下降,表明无约束的信息转换会严重损害模态内表示的判别能力。
- 移除Tensor分离:这是影响最大的一个改动。如果不进行detach操作,模型会陷入学习冗余共享信息的困境,导致性能大幅下降,在BLCA上降低了9.08%。
- 移除PPEG模块:性能也有所下降,证明了在病理学编码器中探索patches间位置相关性的重要性。

作者比较了L1损失、MSE损失、KL散度和余弦相似度作为对齐约束的度量。结果显示,L1损失在大多数数据集上都表现最好。

3.4 生存分析
作者使用CMTA预测的风险评分中位数,将患者分为高风险和低风险组,并绘制了Kaplan-Meier生存曲线。在所有五个数据集中,高风险组和低风险组的生存曲线都显示出统计学上的显著差异(p-value均远小于0.05)。这进一步验证了CMTA框架在生存分析中的有效性。
