多模态癌症生存预测的原型信息瓶颈与解耦.

0. TL; DR

PIBD (Prototypical Information Bottlenecking and Disentangling) 框架包含两个核心模块:

作者在五个癌症基准数据集上进行了广泛的实验,结果表明,与现有state-of-the-art方法相比,PIBD取得了优越的性能。

1. 背景介绍

癌症生存分析旨在评估患者的死亡风险,而整合组织学信息和基因组分子谱的多模态学习已成为该领域的一个重要方向。组织学图像提供了肿瘤微环境的视觉表型信息,而基因组数据则提供了全局的分子亚型信息。

然而,多模态数据中的大量冗余信息给有效融合带来了巨大挑战。

受信息论中用于缓解冗余思想的启发,作者提出了PIBD框架,旨在从信息论的角度解决多模态数据中的“模态内”和“模态间”冗余问题。

2. 方法介绍

PIBD的整体框架如图所示。它首先使用PIB模块为每个模态选择有判别力的实例,然后通过PID模块将这些实例解耦为模态共性和模态特有的表示,最后将这些紧凑的表示用于生存预测。

2.1 原型信息瓶颈 (Prototypical Information Bottleneck, PIB)

为了解决“模态内冗余”,作者提出了PIB,这是对经典信息瓶颈(Information Bottleneck, IB)理论的一种新颖变体。

IB的目标是学习一个既能最大程度压缩输入$X$的信息,又能最大程度保留关于目标$Y$信息的中间表示$Z$。其目标函数是$R_{IB} = I(Z, Y) - \beta I(Z, X)$,其中,$I(\cdot, \cdot)$是互信息。

直接在WSI的海量patches上应用IB会面临巨大的计算挑战。因此,PIB不再为每个实例单独建模后验分布$p(z|x)$,而是直接用一组可学习的“prototypes”(原型)$P = {\mathcal{N}(\hat{z}; \mu_y, \Sigma_y)}_{y=1}^{2N_t}$来近似整个bag的后验分布$p(z|x)$。每个原型代表一个特定风险等级$y$的条件概率分布$p(\hat{z}|y)$。

对于一个bag中的所有实例,计算它们与每个原型的相似度。然后,只保留与每个原型最相似的一部分实例(由超参数Irr, 即information retention rate控制),其余的被视为冗余信息并丢弃。通过一个对比学习式的损失函数$L_{pro}$来优化原型。该损失函数旨在拉近被保留的实例与它们对应标签的“正原型”之间的距离,同时推远它们与“负原型”之间的距离。最终,PIB的损失函数$L_{PIB}$结合了任务损失(生存预测损失)、KL散度(正则化项,约束原型分布接近先验分布)和上述的原型学习损失$L_{pro}$。

\[L_{PIB} = \frac{1}{2N_t} \sum_{n=1}^{2N_t} \{\alpha L_{surv}(\hat{z}^{(n)}, t^{(n)}, c^{(n)}) + \beta KL[\mathcal{N}(\hat{z};\mu_n, \Sigma_n), r(z)]\} + \gamma L_{pro}\]

通过PIB,模型能够为每个模态筛选出与任务最相关的判别性特征,同时去除大量冗余信息。

2.2 原型信息解耦 (Prototypical Information Disentanglement, PID)

在消除了模态内冗余之后,PID模块旨在解决“模态间冗余”。它将纠缠的多模态特征解耦为理想中相互独立的modality-common(模态共性)特征$C$和modality-specific(模态特有)特征$S_h, S_g$。

通过最小化共性特征与特有特征之间、以及不同模态的特有特征之间的互信息(Mutual Information, MI),来强制它们解耦和独立。其损失函数为:

\[L_{PID} = I(S, C) + I(S_h, S_g)\]

作者设计了一个disentangled transformer来实现这一目标。作者重用了PIB中学习到的原型分布。通过Product-of-Experts (PoE) 技术,将病理学和基因组学的正原型分布相乘,得到一个联合原型分布$p(z|x_h, x_g)$。从这个联合分布中采样一个token,作为指导共性信息提取的query,通过cross-attention机制从两个模态中提取共性特征$C$。

transformer中,通过self-attention机制建模每个模态内部的交互(如pathway-to-pathwaypatch-to-patch),得到各自的特有特征$S_h$和$S_g$。使用CLUB(一个互信息上界估计器)来近似和最小化$L_{PID}$。

2.3 总体损失函数与推理

PIBD的总损失函数是生存预测损失、两个模态的PIB损失以及PID损失的加权和。

\[L = L_{surv} + L_{PIB}^h + L_{PIB}^g + \lambda L_{PID}\]

在推理阶段,由于真实标签未知,需要先从原型集中确定“正原型”。作者通过一个简单的策略实现:对于一个待测样本,选择那个与之相似度最高的实例比例最大的原型作为正原型。

3. 实验分析

3.1 实验设置

使用了来自TCGA的五个公共癌症数据集:BRCA, BLCA, COADREAD, STAD, HNSC

与三类SOTA方法进行比较:(1) 单模态方法;(2) 多模态方法;(3) 基于信息论的方法。使用C-indexKaplan-Meier (KM) 分析。

3.2 与SOTA方法的比较

PIBD取得了最佳的总体性能(平均C-index=0.699),在5个基准数据集中的4个上表现最优。与单模态方法相比,大多数多模态方法(包括PIBD)都显示出更高的总体C-index,证明了多模态信息的价值。与多模态SOTA方法(如MOTCat, SurvPath)相比,PIBD的总体性能高出1.6%,显示了其在处理模态内和模态间冗余方面的优势。与基于信息论的方法(如DeepIMV)相比,PIBD的性能提升了0.5%-4.9%,证明了PIBD框架设计的优越性。

PIBD能够将患者显著地分层为高风险和低风险组,其p-value在所有数据集上都非常低。与第二好的方法SurvPath相比,PIBD的分层效果在BRCA, COADREADHNSC数据集上尤为突出。

3.3 消融研究

3.4 PIB模块的可解释性与干预实验

作者在推理阶段进行了干预实验。当移除正确的“正原型”时,模型的C-Index急剧下降到0.5以下,意味着预测能力完全丧失。而当随机移除一个“负原型”时,性能仅有轻微下降。这进一步证实了PIB能够有效为不同风险等级建模判别性分布。

作者对学习到的原型进行采样,并使用t-SNE降维可视化。结果显示,代表不同风险等级的原型在二维平面上形成了清晰可分的簇。这证明了PIB学习到的原型确实为不同的风险等级建模了有判别力的潜在分布。