AMMASurv:全切片图像与基因表达数据准确生存分析的非对称多模态注意力.
0. TL; DR
AMMASurv (Asymmetrical Multi-Modal Attention for Survival Analysis) 是一种非对称多模态方法。其核心创新在于一个asymmetrical multi-modal attention (AMMA)(非对称多模态注意力)机制,该机制能够:
- 有效利用模态内信息:通过模态内的信息传递,学习WSI patches之间的相关性和交互,从而捕捉WSI的整体信息;同时,通过避免在基因表达数据内部进行信息传递,来防止噪声的放大。
- 灵活适应不同重要性的模态:通过一种有向的跨模态信息传递(仅从重要模态到次要模态),利用更重要模态(如WSI)的信息来引导次要模态(如基因表达)的表示学习。
作者在两个广泛使用的数据集上进行了实验,结果表明,AMMASurv的性能显著优于其他state-of-the-art方法。
1. 背景介绍
利用多模态数据,如WSI和基因表达,进行生存分析可以提供更准确的预测。然而,现有的多模态方法,如DeepCorrSurv和MultiSurv,在处理这两种数据时存在局限性。
首先,它们未能有效挖掘每个模态内部的内在信息。对于WSI,这些方法通常只是从巨大的图像中随机选择一些patches来提取特征,而忽略了WSI作为一个整体所包含的完整信息。对于基因表达数据,它们直接编码高维的原始数据,其中包含了大量与任务不相关的信息,这会引入大量噪声。
其次,它们未能灵活地利用模态间的潜在联系。这些方法通常将来自不同模态的信息视为同等重要。但在很多情况下,不同模态对最终预测任务的贡献是不同的。如果一个模态本身含有大量噪声且信息量较少,将其与信息量丰富的模态同等对待,可能会“污染”最终的特征表示,从而损害模型性能。
为了应对这些挑战,作者提出了AMMASurv,一个基于Transformer的端到端模型。其核心是一个非对称的多模态注意力(AMMA)机制,旨在不均衡地融合来自不同重要性模态的信息。
2. AMMASurv 方法
AMMASurv是一个基于Transformer的端到端模型,由一个多模态编码器和一个MLP头组成。其核心是编码器中的Asymmetrical Multi-Modal Attention (AMMA) 机制。

2.1 图像特征 (Image Features)
从一张WSI的前景区域中随机选择$n$个patches。使用一个预训练的ResNet18作为骨干网络$B$,对patches序列$I$提取特征,得到特征序列$F_p \in \mathbb{R}^{n \times d_1}$。
为每个patch的特征附加一个位置嵌入(position embedding),以保留其空间信息。在特征序列的开头,添加一个可学习的embedding $Z_{token}^0$,其在Transformer编码器输出端的最终状态将作为整个WSI的表示。
2.2 基因表达特征 (Gene Expression Features)
对基因表达数据进行标准化,并将其分为$m$个组,每个组包含$d_2$个基因符号,形成一个基因特征序列$G \in \mathbb{R}^{m \times d_2}$。
将$G$的维度扩展至与图像特征相匹配,然后通过一个全连接网络$M$和ReLU激活函数进行映射,得到最终的基因表达特征$Z_2 \in \mathbb{R}^{m \times d}$。 $Z_2 = \text{ReLU}(M(F_g’))$
2.3 非对称多模态注意力 (AMMA)
AMMA是AMMASurv的核心,其设计基于对Transformer的图视角理解。
作者将多模态信息融合看作是在一个图中添加不同类型的节点和边。在这个图中,WSI patch特征和基因特征是两种异构的节点。作者认为,这两种模态的重要性是不对等的,基因表达数据通常含有更多噪声。因此,作者构建了一个非全连接的图:只允许从更重要的WSI节点到次重要的基因节点的有向信息流,而不允许反向流动。这意味着,基因表达节点的表示只能被WSI特征所“引导”和更新,而其自身的噪声表示不会影响到其他特征。
为了实现上述非对称的信息流,作者对标准的自注意力机制进行了修改。对于输入的图像特征$x^{img} \in \mathbb{R}^{n \times d}$和基因特征$x^{gene} \in \mathbb{R}^{m \times d}$,其输出$o^{img}$和$o^{gene}$的计算方式如下:
图像特征的更新只在图像特征内部进行自注意力计算。
\[\alpha_{ij}^{(1)} = \text{softmax}\left(\frac{(x_i^{img} W_Q)(x_j^{img} W_K)^T}{\sqrt{d}}\right)\\ o_i^{img} = \sum_{j=1}^n \alpha_{ij}^{(1)} (x_j^{img} W_V)\]基因特征的query来自于自身,但其key和value均来自于图像特征。这实现了WSI对基因表达的单向引导。
\[\alpha_{ij}^{(2)} = \text{softmax}\left(\frac{(x_i^{gene} W_Q)(x_j^{img} W_K)^T}{\sqrt{d}}\right)\\ o_i^{gene} = \sum_{j=1}^n \alpha_{ij}^{(2)} (x_j^{img} W_V)\]通过这种设计,AMMA实现了模态内的WSI patch间信息交互(用于学习WSI的整体信息)和模态间的单向引导(用于利用WSI信息对基因表达表示进行降噪和增强),同时避免了在噪声较多的基因数据内部进行信息传递,防止了噪声的放大。
2.4 多模态Transformer编码器
编码器由多个AMMA块和MLP块堆叠而成。每一层的编码过程如下:
\[Z'^{l}_1 = \text{AMMA}(\text{LN}(Z^{l-1}_1), \text{LN}(Z^{l-1}_2)) + Z^{l-1}_1\\ Z'^{l}_2 = \text{AMMA}(\text{LN}(Z^{l-1}_1), \text{LN}(Z^{l-1}_2)) + Z^{l-1}_2\\ Z^l_1 = \text{MLP}(\text{LN}(Z'^{l}_1)) + Z'^{l}_1\\ Z^l_2 = \text{MLP}(\text{LN}(Z'^{l}_2)) + Z'^{l}_2\]其中,$Z_1$和$Z_2$分别代表图像和基因的特征表示。最终,编码器输出端的WSI类别标记$Z_{token}^L$和基因特征序列的平均池化结果分别作为两种模态的最终表示$y_1$和$y_2$。
2.5 生存预测
将$y_1$和$y_2$拼接后,通过一个MLP Head模块,直接生成预测的风险评分$R$。损失函数采用negative Cox log partial likelihood。
3. 实验分析
3.1 实验设置
使用了TCGA中的两个公共数据集:LUSC(肺鳞癌)和OV(卵巢癌)。与八种SOTA模型进行比较,涵盖了仅使用WSI、仅使用基因表达和使用两种模态的三类方法。使用concordance index (C-index)。
3.2 与SOTA方法的比较
在LUSC和OV两个数据集上,AMMASurv的性能都显著优于所有比较方法。与表现最好的单模态WSI-only模型SeTranSurv相比,AMMASurv在LUSC和OV上的C-index分别提升了约5.8%和5.3%。与之前的多模态方法(DeepCorrSurv, MultiSurv)相比,性能提升更为巨大。这些结果证明了AMMASurv通过整合两种模态信息,取得了卓越的生存预测性能。

3.3 消融研究
消融实验有力地证明了AMMA中每一个设计(非对称性、单向引导、模态内WSI交互)的有效性和必要性。
- 替换为传统自注意力 (replace with self-attention):将AMMA替换为传统的多模态自注意力机制后,模型性能急剧下降,甚至不如WSI-only的模型。这表明,简单地将两种模态同等对待进行信息交互,会受到基因表达数据中噪声的严重干扰。AMMA的非对称设计成功地避免了这个问题。
- 替换为随机噪声 (replace with random noise):将输入的基因表达特征替换为随机向量后,模型性能下降到与WSI-only模型相似甚至更差的水平。这表明,AMMASurv确实从基因表达数据中挖掘并利用了对生存预测有益的信息。
- 无注意力引导 (without attention induce):如果直接将未经过引导的基因表示与编码后的WSI表示拼接起来,性能也会下降。这证明了通过WSI特征来引导基因表示学习的重要性。
