多模态单细胞数据整合的图神经网络.

0. TL; DR

作者提出了 scMoGNN,一个通用的图神经网络(GNN)框架,旨在一站式解决单细胞数据的模态预测(用一种数据预测另一种)、模态匹配(为不同批次的细胞“配对”)和联合嵌入(学习一个融合多模态信息的细胞表征)。

scMoGNN将单细胞数据矩阵转换为一个细胞-特征二分图(细胞和特征都是节点,细胞与它所拥有的特征之间有边相连)。这使得GNN可以通过消息传递机制让细胞从相似的细胞(通过共享特征连接)和相关的特征中学习。GNN的卷积层以及任务特定的输出头可以被无缝适配到不同任务中。

scMoGNN不仅是 NeurIPS 2021 多模态单细胞数据整合竞赛“模态预测”赛道的官方总分第一名,在后续的扩展实验中,它在“模态匹配”和“联合嵌入”两个任务上也全面超越了当时的SOTA方法,成为了一个新的强大基线。

1. 背景介绍

单细胞多组学技术描绘了一幅前所未有的细胞高清分子图谱。然而,如何从这些高维、稀疏且充满噪声的数据中淘金,是一个巨大的挑战。学术界和工业界共同定义了三个核心的计算任务,以期充分利用这些宝贵数据:

  1. 模态预测 (Modality Prediction): 类似于“翻译”,目标是利用一种模态(如基因表达)来预测另一种模态(如蛋白丰度)的完整谱图。
  2. 模态匹配 (Modality Matching): 类似于“配对”,目标是为来自两个不同单组学实验的细胞找到它们原本的一一对应关系。
  3. 联合嵌入 (Joint Embedding): 类似于“融合”,目标是学习一个统一的低维空间,能够同时捕捉多种模态的信息,以便进行更精准的细胞类型鉴定或轨迹分析。

现有的方法,无论是基于矩阵分解(如MOFA+)还是深度学习(如Autoencoder变体),大多存在一个共同的盲点:它们将每个细胞视为一个独立的样本。这种处理方式忽略了一个重要的事实:在生物系统中,细胞并非孤立存在,功能相似的细胞在分子特征上具有相似性,不同的分子特征之间也存在复杂的调控网络。这些高阶的、网络状的结构信息,在传统的机器学习方法中被很大程度上忽略了。

图神经网络 (Graph Neural Networks, GNNs) 核心思想是在图结构上进行学习,通过迭代式地聚合邻居节点的信息来更新自身表征。其天然的优势在于:

作者设计了 scMoGNN,旨在构建一个通用的GNN框架,系统性地解决单细胞多组学的三大核心任务。

2. scMoGNN 框架

scMoGNN 框架可以被分解为三个阶段:图构建、细胞-特征图卷积 和 任务特异性头。

2.1 阶段一:图构建 (Graph Construction)

作者提出了一种细胞-特征二分图 (cell-feature bipartite graph) 的构建方法。

在这个图中,有两种类型的节点:

边只存在于细胞节点和特征节点之间。如果细胞 $i$ 的特征 $j$ 的表达值不为零,那么就在细胞节点 $u_i$ 和特征节点 $v_j$ 之间连接一条边。边的权重就是该特征的表达值。

这个二分图可以用一个邻接矩阵 $A$ 来表示:

\[A = \begin{pmatrix} O & M \\ M^T & O \end{pmatrix} \in \mathbb{R}^{(N+k) \times (N+k)}\]

其中 $M$ 是 $N \times k$ 的原始细胞-特征矩阵(N个细胞,k个特征),$O$ 是零矩阵。

这种构图方式的妙处在于,它建立了一个信息流动的桥梁:细胞的信息可以传递给它表达的特征,特征的信息也可以回传给表达它的细胞。通过GNN的多层传播,一个细胞最终可以聚合到与它共享相似特征谱的其他细胞的信息。

在特定任务中(如模态预测),还可以进一步丰富图结构。例如,如果有基因通路(pathway)的先验知识,可以在相关的特征节点之间也加上边,这样邻接矩阵就变为:

\[A = \begin{pmatrix} O & M \\ M^T & P \end{pmatrix}\]

其中 $P$ 是 $k \times k$ 的特征-特征邻接矩阵,代表了基因间的已知关联。

2.2 阶段二:细胞-特征图卷积 (Cell-Feature Graph Convolution)

在构建好的图上使用特制的GNN层进行信息传播和节点表征学习。由于图中有细胞和特征两种不同类型的节点,作者采用了类似异构图神经网络的思想,对不同类型的节点和边使用不同的参数进行处理。

对于GNN的每一层 $l$,一个节点 $i$ 的新表征 $h_i^{l+1}$ 是通过聚合其邻居信息并结合自身旧表征 $h_i^l$ 来更新的。以一个特征节点 $v_i$ 为例,它的邻居都是细胞节点。其收到的聚合“消息”可以表示为:

\[m_{i, l}^{U \to V} = \sigma \left( b_{l}^{U \to V} + \sum_{j \in \mathcal{N}_i, u_j \in U} \frac{e_{ji}}{c_{ji}} h_j^l W_{l}^{U \to V} \right)\]

其中$h_j^l$ 是邻居细胞节点 $j$ 在第 $l$ 层的表征。$W_l^{U \to V}$ 是从细胞到特征的可学习变换矩阵。$e_{ji}$ 是边权重,$c_{ji}$ 是用于图卷积的归一化系数。

同理也可以计算从特征节点到细胞节点的消息 $m_{j, l}^{V \to U}$。对于包含特征-特征边的图,还会有 $m_{i, l}^{V \to V}$。

最终的节点更新方式可以是一个简单的残差连接:

\[h_{i}^{l+1} = h_{i}^{l} + \alpha \cdot m_{i, l}^{V \to V} + (1-\alpha) \cdot m_{i, l}^{U \to V}\]

其中 $\alpha$ 是一个超参数,用于平衡来自同类邻居和异类邻居的信息。通过堆叠 $L$ 个这样的卷积层,每个节点的最终表征就编码了其 $L$ 跳邻域内的丰富结构信息。

2.3 阶段三:任务特异性头 (Task-specific Head)

在经过多层GNN卷积后,得到了每个细胞节点在每一层的表征 $H_U^1, \dots, H_U^L$。最后一步是根据具体任务,设计不同的“输出头”来处理这些表征并计算损失。

⚪ 模态预测

将所有层的细胞表征通过一个可学习的加权平均进行聚合,然后通过一个简单的全连接层映射到目标模态的维度。

\[\hat{H} = \sum_{i=1}^L w_i \cdot H_U^i \quad \rightarrow \quad \hat{Y} = \hat{H}W + b\]

损失函数采用标准的均方根误差(RMSE)损失。

\[\mathcal{L} = \sqrt{\frac{1}{N}\sum_{i=1}^N (Y_i - \hat{Y}_i)^2}\]

⚪ 模态匹配

对两种待匹配的模态分别构建图并运行scMoGNN,得到两组细胞嵌入 $H_1$ 和 $H_2$。然后计算两组嵌入两两之间的余弦相似度,得到一个 $N x N$ 的相似度矩阵 $S$。

\[S = \hat{H}'_{m1} \cdot \hat{H}'^T_{m2}\]

损失函数采用一个对称的交叉熵损失,目标是让相似度矩阵 $S$ 在真实配对的位置上概率最大。

\[\mathcal{L}_{match} = - \sum_{c_1, c_2} Y_{c_1, c_2} \log(P^r_{c_1, c_2}) + Y_{c_1, c_2} \log(P^c_{c_1, c_2})\]

其中 $P^r$ 和 $P^c$ 分别是行和列的Softmax概率。作者还引入了辅助的重构和预测损失来增强模型的鲁棒性。

在推理阶段,将相似度矩阵 $S$ 视为一个二分图,并使用匈牙利算法求解最大权重匹配,以得到最终的硬匹配结果。

⚪ 联合嵌入

将两种模态的特征拼接起来,构建一个统一的细胞-特征图。运行scMoGNN后得到最终的细胞嵌入 $H$。

损失函数采用一个多任务损失,包括:

3. 实验分析

作者在NeurIPS 2021竞赛的官方数据集上,对scMoGNN在三大任务上的表现进行了全面评估。

5.1 实验一:模态预测

在与众多强手的激烈竞争中,scMoGNN的集成版本(Ensemble)取得了最低的总体RMSE(0.2780),荣获该赛道总分第一名。

在四个子任务中,scMoGNN(无论是单模型还是集成模型)的性能都非常稳定且名列前茅。特别是在GEX-to-ADT这个最困难的子任务上,最终版本取得了0.3809的最低loss,超越了竞赛中的所有对手。这充分证明了GNN框架在处理高维、异构数据上的强大能力。

5.2 实验二:模态匹配

scMoGNN在所有四个子任务上均显著超越了竞赛的官方冠军GLUE。例如,在GEX-ADT子任务上,scMoGNN的匹配得分(0.0810)比GLUE(0.0495)高出了约64%。最终,scMoGNN的总匹配分(0.0720)比冠军GLUE(0.0539)高出约33.6%,树立了该任务新的SOTA

消融实验发现如果去掉GNN的图传播层(即不利用图结构信息),模型的性能会发生显著下降。这直接证明了GNN所捕捉到的高阶结构信息是提升匹配性能的关键。

5.3 实验三:联合嵌入

GEX-ADT的联合嵌入任务中,scMoGNN在综合了生物学保留和批次校正的6个指标后,取得了最高的平均分(0.8168),再次超越了竞赛冠军Amateur(JAE)和强大的对手GLUE。这表明GNN学习到的联合表征质量非常高。