生存分析的跨模态翻译与对齐.

0. TL; DR

CMTA (Cross-Modal Translation and Alignment) 框架的核心思想是为多模态数据构建两个并行的编码器-解码器结构,用于整合模态内信息并生成跨模态表示。设计一个跨模态注意力模块作为不同模态之间的信息桥梁,以探索内在的跨模态相关性,并传递互补信息。利用生成的跨模态表示来增强和校准模态内的表示,从而显著提高其在综合生存分析中的判别能力。

作者在五个公开的TCGA数据集上进行了广泛的实验。结果表明,CMTA框架的性能优于state-of-the-art方法,证明了其在探索和利用多模态数据互补信息方面的有效性。

1. 背景介绍

生存分析是临床预后研究中的一个关键课题,其目标是预测从某个已知起点到某个感兴趣事件(如死亡、疾病复发)发生所经过的时间。准确的生存预测对于医生评估疾病进展和治疗效果至关重要。

传统上,生存分析依赖于临床指标和长期随访报告,但这既耗时又在临床应用中不切实际。近年来,随着deep learning技术的成功,医学图像分析取得了显著进展,研究者们开始致力于建模影像特征与生存事件之间的联系。

尽管单模态的生存分析已取得可喜的成果,但结合来自不同视角的多模态数据可以提供互补信息,从而提高生存分析的敏感性。然而,现有的多模态方法存在一些问题:

基于这些观察,作者提出了CMTA框架,旨在探索内在的跨模态相关性,并传递潜在的互补信息。

2. CMTA 框架

CMTA框架的整体示意图如图所示。其核心是为两种模态(病理学和基因组学)构建并行的encoder-decoder结构,并通过一个cross-modal attention module进行信息交互。

2.1 问题定义与数据处理

目标是开发一个生存预测模型$F$,它整合病理图像$P$和基因组谱$G$,以估计风险函数$f_{hazard}(T=t|…)$。

病理图像处理:

  1. 使用CLAMWSI裁剪成一系列不重叠的512x512的patches
  2. 使用在ImageNet上预训练的ResNet-50提取每个patch的1024维特征。
  3. 通过一个全连接层将特征降维到$d$维。
  4. 最终,一个患者的病理图像被表示为$P = {p_1, …, p_M} \in \mathbb{R}^{M \times d}$。

基因组谱处理:

  1. 将基因组谱(RNA-seq, CNV, SNV)根据功能分为6组(如肿瘤抑制、致癌、蛋白激酶等)。
  2. 通过一个全连接层将每个基因组的特征降维到$d$维。
  3. 最终,一个患者的基因组谱被表示为$G = {g_1, …, g_K} \in \mathbb{R}^{K \times d}$。

2.2 病理学与基因组学编码器

作者引入自注意力机制来为每个模态构建编码器,以整合模态内的信息。

2.3 跨模态注意力模块 (Cross-Modal Attention Module)

跨模态注意力模块是CMTA框架的核心,旨在探索和传递跨模态的互补信息。

输入来自两个编码器的实例tokens,即$\mathcal{P} \in \mathbb{R}^{M \times d}$和$\mathcal{G} \in \mathbb{R}^{K \times d}$。该模块计算两个注意力图$H_p$和$H_g$。

\[H_p = \text{softmax}\left(\frac{(\mathcal{G}U) \times (\mathcal{P}V)^T}{\sqrt{d}}\right) \in \mathbb{R}^{K \times M} \\ H_g = \text{softmax}\left(\frac{(\mathcal{P}V) \times (\mathcal{G}U)^T}{\sqrt{d}}\right) \in \mathbb{R}^{M \times K}\]

其中,$U$和$V$是可学习的参数矩阵。$H_p$表示从基因组tokens到病理学tokens的关联状态,而$H_g$则表示从病理学tokens到基因组tokens的关联状态。

利用这两个注意力图,可以从一种模态的tokens中提取与另一种模态相关的信息。

\[P^* = H_p \times (\mathcal{P}W_p)\\ G^* = H_g \times (\mathcal{G}W_g)\]

其中,$P^$是病理学tokens中与基因组相关的信息,而$G^$是基因组tokens中与病理学相关的信息。

2.4 基因组学与病理学解码器

直接将提取出的跨模态信息$P^$和$G^$与模态内表示$p$和$g$叠加是不合理的,因为存在数据异质性。因此,作者构建了两个解码器,将这些“翻译”过来的信息转换为与原始模态对齐的跨模态表示。

2.5 特征对齐与融合

利用跨模态表示来增强和校准模态内表示。将增强后的两个模态的表示进行拼接,送入一个MLP进行最终的生存预测。

\[T_1, ..., T_t = \text{sigmoid}\left( \text{MLP}\left( \frac{p+\hat{p}}{2} \oplus \frac{g+\hat{g}}{2} \right) \right)\]

为了确保信息转换的质量,作者引入了一个对齐约束,即最小化跨模态表示与模态内表示之间的L1距离。

\[L_{sim} = \frac{1}{d} (||p - \hat{p}||_1 + ||g - \hat{g}||_1)\]

一个关键的技术细节是,在计算$L_{sim}$时,模态内表示$p$和$g$必须从计算图中分离(detach)。这确保了对齐是单向的(即$\hat{p} \to p$, $\hat{g} \to g$),避免模型陷入学习冗余共享信息的困境。

最终的损失函数是生存预测损失$L_{sur}$和对齐损失$L_{sim}$的加权和。

\[L_{total} = L_{sur} + \alpha L_{sim}\]

3. 实验分析

3.1 实验设置

使用了TCGA中的五个癌症数据集:BLCA, BRCA, GBMLGG, LUAD, UCEC。评估指标使用c-index(一致性指数)来评估模型的性能。

作者将CMTA与多种state-of-the-art的单模态和多模态生存预测方法进行了比较,包括SNN, TransMIL, MCAT, M3IF, GPDBN, Porpoise, HFBSurv等。

3.2 与SOTA方法的比较

3.3 消融研究

作者比较了L1损失、MSE损失、KL散度和余弦相似度作为对齐约束的度量。结果显示,L1损失在大多数数据集上都表现最好。

3.4 生存分析

作者使用CMTA预测的风险评分中位数,将患者分为高风险和低风险组,并绘制了Kaplan-Meier生存曲线。在所有五个数据集中,高风险组和低风险组的生存曲线都显示出统计学上的显著差异(p-value均远小于0.05)。这进一步验证了CMTA框架在生存分析中的有效性。