AMMASurv:全切片图像与基因表达数据准确生存分析的非对称多模态注意力.

0. TL; DR

AMMASurv (Asymmetrical Multi-Modal Attention for Survival Analysis) 是一种非对称多模态方法。其核心创新在于一个asymmetrical multi-modal attention (AMMA)(非对称多模态注意力)机制,该机制能够:

作者在两个广泛使用的数据集上进行了实验,结果表明,AMMASurv的性能显著优于其他state-of-the-art方法。

1. 背景介绍

利用多模态数据,如WSI和基因表达,进行生存分析可以提供更准确的预测。然而,现有的多模态方法,如DeepCorrSurvMultiSurv,在处理这两种数据时存在局限性。

首先,它们未能有效挖掘每个模态内部的内在信息。对于WSI,这些方法通常只是从巨大的图像中随机选择一些patches来提取特征,而忽略了WSI作为一个整体所包含的完整信息。对于基因表达数据,它们直接编码高维的原始数据,其中包含了大量与任务不相关的信息,这会引入大量噪声。

其次,它们未能灵活地利用模态间的潜在联系。这些方法通常将来自不同模态的信息视为同等重要。但在很多情况下,不同模态对最终预测任务的贡献是不同的。如果一个模态本身含有大量噪声且信息量较少,将其与信息量丰富的模态同等对待,可能会“污染”最终的特征表示,从而损害模型性能。

为了应对这些挑战,作者提出了AMMASurv,一个基于Transformer的端到端模型。其核心是一个非对称的多模态注意力(AMMA)机制,旨在不均衡地融合来自不同重要性模态的信息。

2. AMMASurv 方法

AMMASurv是一个基于Transformer的端到端模型,由一个多模态编码器和一个MLP头组成。其核心是编码器中的Asymmetrical Multi-Modal Attention (AMMA) 机制。

2.1 图像特征 (Image Features)

从一张WSI的前景区域中随机选择$n$个patches。使用一个预训练的ResNet18作为骨干网络$B$,对patches序列$I$提取特征,得到特征序列$F_p \in \mathbb{R}^{n \times d_1}$。

为每个patch的特征附加一个位置嵌入(position embedding),以保留其空间信息。在特征序列的开头,添加一个可学习的embedding $Z_{token}^0$,其在Transformer编码器输出端的最终状态将作为整个WSI的表示。

2.2 基因表达特征 (Gene Expression Features)

对基因表达数据进行标准化,并将其分为$m$个组,每个组包含$d_2$个基因符号,形成一个基因特征序列$G \in \mathbb{R}^{m \times d_2}$。

将$G$的维度扩展至与图像特征相匹配,然后通过一个全连接网络$M$和ReLU激活函数进行映射,得到最终的基因表达特征$Z_2 \in \mathbb{R}^{m \times d}$。 $Z_2 = \text{ReLU}(M(F_g’))$

2.3 非对称多模态注意力 (AMMA)

AMMAAMMASurv的核心,其设计基于对Transformer的图视角理解。

作者将多模态信息融合看作是在一个图中添加不同类型的节点和边。在这个图中,WSI patch特征和基因特征是两种异构的节点。作者认为,这两种模态的重要性是不对等的,基因表达数据通常含有更多噪声。因此,作者构建了一个非全连接的图:只允许从更重要的WSI节点到次重要的基因节点的有向信息流,而不允许反向流动。这意味着,基因表达节点的表示只能被WSI特征所“引导”和更新,而其自身的噪声表示不会影响到其他特征。

为了实现上述非对称的信息流,作者对标准的自注意力机制进行了修改。对于输入的图像特征$x^{img} \in \mathbb{R}^{n \times d}$和基因特征$x^{gene} \in \mathbb{R}^{m \times d}$,其输出$o^{img}$和$o^{gene}$的计算方式如下:

图像特征的更新只在图像特征内部进行自注意力计算。

\[\alpha_{ij}^{(1)} = \text{softmax}\left(\frac{(x_i^{img} W_Q)(x_j^{img} W_K)^T}{\sqrt{d}}\right)\\ o_i^{img} = \sum_{j=1}^n \alpha_{ij}^{(1)} (x_j^{img} W_V)\]

基因特征的query来自于自身,但其keyvalue均来自于图像特征。这实现了WSI对基因表达的单向引导。

\[\alpha_{ij}^{(2)} = \text{softmax}\left(\frac{(x_i^{gene} W_Q)(x_j^{img} W_K)^T}{\sqrt{d}}\right)\\ o_i^{gene} = \sum_{j=1}^n \alpha_{ij}^{(2)} (x_j^{img} W_V)\]

通过这种设计,AMMA实现了模态内的WSI patch间信息交互(用于学习WSI的整体信息)和模态间的单向引导(用于利用WSI信息对基因表达表示进行降噪和增强),同时避免了在噪声较多的基因数据内部进行信息传递,防止了噪声的放大。

2.4 多模态Transformer编码器

编码器由多个AMMA块和MLP块堆叠而成。每一层的编码过程如下:

\[Z'^{l}_1 = \text{AMMA}(\text{LN}(Z^{l-1}_1), \text{LN}(Z^{l-1}_2)) + Z^{l-1}_1\\ Z'^{l}_2 = \text{AMMA}(\text{LN}(Z^{l-1}_1), \text{LN}(Z^{l-1}_2)) + Z^{l-1}_2\\ Z^l_1 = \text{MLP}(\text{LN}(Z'^{l}_1)) + Z'^{l}_1\\ Z^l_2 = \text{MLP}(\text{LN}(Z'^{l}_2)) + Z'^{l}_2\]

其中,$Z_1$和$Z_2$分别代表图像和基因的特征表示。最终,编码器输出端的WSI类别标记$Z_{token}^L$和基因特征序列的平均池化结果分别作为两种模态的最终表示$y_1$和$y_2$。

2.5 生存预测

将$y_1$和$y_2$拼接后,通过一个MLP Head模块,直接生成预测的风险评分$R$。损失函数采用negative Cox log partial likelihood

3. 实验分析

3.1 实验设置

使用了TCGA中的两个公共数据集:LUSC(肺鳞癌)和OV(卵巢癌)。与八种SOTA模型进行比较,涵盖了仅使用WSI、仅使用基因表达和使用两种模态的三类方法。使用concordance index (C-index)

3.2 与SOTA方法的比较

LUSCOV两个数据集上,AMMASurv的性能都显著优于所有比较方法。与表现最好的单模态WSI-only模型SeTranSurv相比,AMMASurvLUSCOV上的C-index分别提升了约5.8%和5.3%。与之前的多模态方法(DeepCorrSurv, MultiSurv)相比,性能提升更为巨大。这些结果证明了AMMASurv通过整合两种模态信息,取得了卓越的生存预测性能。

3.3 消融研究

消融实验有力地证明了AMMA中每一个设计(非对称性、单向引导、模态内WSI交互)的有效性和必要性。