胃癌生存预测的全切片图像与基因表达多模态学习.
0. TL; DR
GC-SPLeM是一个用于预测患者风险评分的模型。该模型包含三个部分:WSI特征提取、模态融合网络和基于GNN的预测器。
作者在一个自建的GC数据集和一个公开数据集上进行了生存预测实验。在两个数据集上,GC-SPLeM的性能都显著优于state-of-the-art的单模态和多模态学习方法。与其他方法相比,GC-SPLeM不仅改善了生存预测结果,在处理不完整数据方面也显示出优势。
1. 背景介绍
胃癌(GC)是全球癌症相关死亡的主要原因之一。近年来,计算机辅助GC诊断方法,特别是基于deep learning和WSI的方法,取得了巨大成功。这些方法通常遵循一个框架:首先将WSI分割成patches,然后提取patch级别的特征,最后将这些特征聚合成WSI级别的表示用于下游任务。
然而,仅依赖WSI进行诊断存在局限性,原因有二:1) GC的发展机制特别复杂;2) WSI本身常常质量不佳、尺寸巨大且充满噪声,这会损害模型性能。
由于GC的发生是一个涉及遗传改变的多步骤、多因素过程,基因表达也被认为是GC诊断的重要数据源。多模态学习作为一种新兴技术,能够整合WSI和基因表达数据,为诊断过程提供补充信息,从而提升模型性能。已有研究如AMMA, PG-TFNet, MCAT等,通过使用类似transformer或cross-modal attention的机制融合多模态数据,取得了优于单模态方法的性能。
尽管如此,当前多模态学习方法的一个主要局限性在于,它们大多只适用于完整数据,即数据集中所有样本的所有模态都必须可用。但在真实的临床场景中,数据缺失非常普遍。如何解决数据缺失问题,并充分利用手头的数据,已成为多模态学习模型的主要挑战。
因此,作者提出了GC-SPLeM (Gastric Cancer Survival Prediction via Learning Multimodal data) 方法。该方法以一个multimodal attention模块来聚合多模态信息,并以一个基于graph neural network (GNN) 的模型来利用患者间的关联,从而缓解缺失数据的影响。
2. GC-SPLeM 方法
GC-SPLeM的总体流程如图所示,由三个主要部分组成:WSI特征提取器、模态融合网络和基于GNN的预测器。其目标是为每个患者预测一个指示生存风险的风险评分$r$。

2.1 WSI特征提取
如图(a)所示,WSI特征提取流程如下:
- 组织分割:基于饱和度通道的阈值分割,从WSI中提取前景组织区域。
- 图像块化 (Patching):在组织区域内裁剪不重叠的256x256的patches。
- 特征提取:将patches输入一个在ImageNet上预训练的ResNet50模型,将每个patch转换为一个1024维的病理学特征向量。
由于WSI尺寸巨大,作者采用非端到端的训练方案,即特征提取过程在模型训练前完成。
2.2 融合多模态特征
模态融合网络(图c)旨在将病理学特征矩阵和基因表达数据矩阵融合成一个统一的特征嵌入。
根据功能将基因分为$N$个基因集。每个基因集的表达数据通过一个两层的MLP,最终得到一个基因嵌入矩阵$G \in \mathbb{R}^{N \times d_k}$。
将预提取的WSI特征矩阵$H_o \in \mathbb{R}^{M \times 1024}$通过一个全连接层进行降维,得到新的WSI特征矩阵$H \in \mathbb{R}^{M \times d_k}$。
计算基因嵌入矩阵$G$和WSI嵌入矩阵$H$之间的协同注意力分数,以得到一个融合了基因表达信息的、新的WSI嵌入矩阵$\hat{H}$。
\[\text{co-attention}(G, H) = \text{softmax}\left(\frac{W_q G H^T W_k^T}{\sqrt{d_k}}\right) \\ \hat{H} = \text{co-attention}(G, H) \cdot (W_v H)\]其中$W_q, W_K, W_v \in \mathbb{R}^{d_k \times d_k}$是可学习的参数。$\hat{H}$本质上是一个经过注意力增强的WSI特征嵌入,其中与所选基因集表达强相关的特征被赋予了更大的权重。
将特征矩阵$\hat{H}$和$G$分别送入六个transformer编码器层和一个attention-based pooling层,聚合成患者级的特征向量$h$(病理学)和$g$(基因组学)。
\[a_{i_h} = \frac{\exp\{W_{\rho h}[\tanh(V_{\rho h}e_{i_h}^T) \cdot \text{sign}(U_{\rho h}e_{i_h}^T)]\}}{\sum_{i_h=1}^N \exp\{...\}} \\ h = \text{ReLU}(W_{\zeta h} \sum_{i_h=1}^N a_{i_h} e_{i_h})\]基因组学特征$g$的计算方式与此类似。将聚合后的特征嵌入$h$和$g$拼接起来,送入一个MLP,得到最终的64维特征向量$h_{final}$。
\[h_{final} = \text{MLP}(h \oplus g)\]将$h_{final}$送入一个全连接层以拟合患者的生存风险评分。损失函数采用离散生存模型的log-likelihood损失。
2.3 基于GNN的生存预测
在得到每个患者的64维特征向量后,作者构建了一个KNN亲和图(KNN-affinity graph)来表示患者间的关联。
图中的每个节点代表一个患者。节点间的距离由它们特征向量的每个节点与其$K$个最近的邻居相连(实验中$K=4$)。
将患者的特征向量(作为节点特征)和构建的图(作为邻接矩阵)输入一个双层graph convolution network。每个图卷积层包含一次邻接矩阵与特征矩阵的逐点乘法、一次线性计算和一次ReLU激活。最后,通过一个与模态融合网络中相同的生存预测层来得到最终的风险评分。
其背后的动机是,具有相似WSI和基因表达数据的患者很可能有相似的生存结局。从患者的邻居节点中学习信息,有助于增强其自身表示的鲁棒性,尤其是在数据有噪声甚至缺失的情况下。
3. 实验分析
3.1 数据集与实现细节
数据集:
- Ruijin dataset:一个自建的包含63名GC患者的数据集,包含WSI和来自9个肿瘤相关通路的231个基因的表达数据。
- LUAD dataset:一个来自TCGA的包含437名肺腺癌患者的公共数据集。
数据集按6:2:2划分为训练、验证和测试集,进行5次独立的划分并取平均结果。训练时使用了Adam优化器、梯度累积和早停等技术。
3.2 在完整数据上的性能比较
在两个数据集上,GC-SPLeM的C-index都比其他方法高出超过5%。单模态方法(CLAM (WSI-only))的性能远低于多模态方法,显示了仅使用WSI进行生存分析的局限性。与同样使用跨模态注意力的MCAT相比,GC-SPLeM的性能提升显著。GC-SPLeM与GC-SPLeM w/o GNN(即去掉了GNN预测器)之间的性能差异,突显了基于GNN的预测器对性能的巨大贡献。
| Method | Ruijin | LUAD |
|---|---|---|
| MCAT | 0.618 ± 0.088 | 0.573 ± 0.069 |
| GC-SPLeM w/o GNN | 0.618 ± 0.100 | 0.570 ± 0.071 |
| CLAM (WSI-only) | 0.467 ± 0.085 | 0.559 ± 0.060 |
| GC-SPLeM (Ours) | 0.678 ± 0.069 | 0.622 ± 0.053 |
3.3 在不完整数据上的性能比较
作者通过在训练集中随机将50%样本的基因表达数据替换为随机值,来模拟数据缺失场景。在不完整数据上,三种方法的预测准确率都有所下降。然而,GC-SPLeM仍然表现最好,其C-index比MCAT高出4.4%。
这证明了GC-SPLeM在处理缺失数据问题上的鲁棒性和实用价值。其性能优势主要归功于GNN模块,它通过从邻居节点学习信息,有效地弥补了部分节点自身信息不完整(或有噪声)的缺陷。
