scMoE:使用稀疏混合专家进行单细胞多模态多任务学习.

0. TL; DR

scMoE (single-cell Sparse Mixture-of-Experts)是一个单细胞多组学分析模型,其核心是将稀疏混合专家(SMoE)层嵌入到Transformer模块中。SMoE通过一个路由(Router)为每个输入动态选择一小部分专家(Experts)进行处理。在模拟数据和多个真实的单细胞多组学数据集上,scMoE在联合细胞群识别和跨模态预测两大任务中均显著优于现有SOTA方法。

1. 背景介绍

单细胞多组学技术能够在同一个细胞中同时观察转录组、蛋白质组、表观组等多个分子层面的信息。UnitedNet这样的工作开始探索使用一个统一的模型同时完成联合细胞群识别(即跨模态的细胞聚类)和跨模态预测(即用一种数据预测另一种数据)等多种任务。它们通常采用“编码器-融合器-解码器”的架构,将所有模态的信息压缩到一个共享的潜在空间中。

作者指出这类框架在单细胞领域的一个致命缺陷:优化冲突。在 UnitedNet 的实验中观察到一个反直觉的现象:当使用全部四种模态(蛋白、mRNApre-mRNADNA)时,其细胞聚类性能(ARI指标)是所有组合中最差的,甚至不如只用两种模态。这个问题的根源在于共享参数空间中的梯度冲突:在训练过程中,来自不同模态(如DNAProtein)的梯度方向可能存在显著差异甚至相反。当这些相互冲突的梯度作用于同一个共享参数(如融合器)时,就会导致模型无法有效学习,最终导致性能下降。

除了优化冲突,可解释性也是一大障碍。在生物医学领域,模型的可解释性至关重要,它能帮助研究人员发现新的生物学机制。但现有方法依赖的SHAP等工具,计算复杂度极高,不适用于需要快速迭代探索的单细胞研究场景。

2. scMoE 模型

scMoE框架在Transformer架构的基础上引入了稀疏混合专家(Sparse Mixture-of-Experts, SMoE),遵循“编码器-Transformer-解码器”模式。

  1. 模态特异性编码器 (Encoder): 对于每一种输入的组学数据(如RNA、ATAC等),使用一个独立的编码器将其处理成token嵌入。将每个细胞的特征向量切分成若干个patch,每个patch作为一个token。这使得不同维度、不同类型的原始数据都能被转换成统一格式的token序列 $h(ν) ∈ R^{(B×P×D)}$。
  2. 带SMoE的Transformer: 将所有模态的token序列在patch维度上拼接起来,作为MHA的输入,注意力机制可以同时学习模态内的依赖关系(自注意力)和模态间的交互关系(交叉注意力)。MHA的输出被送入SMoE层,通过其稀疏激活的专家们来处理信息,不同的专家可以特化于处理不同模态、不同细胞类型或不同任务的特征,从而根本上解决了优化冲突问题。
  3. 模态特异性解码器 (Decoder): Transformer层的输出嵌入 $h̃$ 会被送入多个解码器。每个解码器负责从这个富含多模态信息的嵌入中重构出原始的某一种组学数据。

稀疏混合专家 (SMoE)表达如下:

\[y = \sum_{i=1}^{k} R(x)_i \cdot f_i(x) \\ R(x) = \text{Top-K}(\text{softmax}(g(x)), k)\]

其中 $f_i(x)$ 是第 $i$ 个专家的输出。$g(x)$ 是可训练的路由器网络,它为每个专家生成一个权重。Top-K 操作是实现“稀疏”的关键:它只保留权重最高的 $k$ 个专家(通常k很小,如2),其余专家的权重被置为0。$y$ 是SMoE层的最终输出。

scMoE同时优化两个核心任务的损失:

  1. 联合细胞群识别损失 $L_{\text{DDC}}$: 直接在Transformer输出的共享嵌入 $h̃$ 上计算一个聚类损失,采用深度散度聚类 (Deep Divergence-based Clustering, DDC) 损失。在有监督场景下,可替换为交叉熵损失。
  2. 跨模态重构损失 $L_{\text{Recon}}$: 所有解码器重构误差的总和,对应跨模态预测任务。

scMoE 具有双重可解释性:

  1. 机理可解释性 (Mechanistic Interpretability): 可以在模型推理时,直接观察路由器的决策(哪个专家被激活了)和注意力图谱。路由器的决策可以统计不同模态、不同细胞类型激活各个专家的频率,从而理解每个专家的“专长”。注意力图谱的注意力分数直接揭示了不同模态特征之间的关联强度。
  2. 轻量级后验可解释性 (Post-hoc Interpretability via TCAV): 为了提供更符合生物学直觉的解释,引入了概念激活向量(TCAV)TCAV通过训练一个简单的线性分类器,来量化一个高层“概念”(如“某个基因是marker gene”、“输入数据来自电生理模态”、“某个细胞缺失了80%的RNA信息”)对模型最终预测的贡献度。TCAV计算效率极高,且能回答更具生物学意义的问题。

3. 实验分析

作者在一个模拟数据集(Dyngen)和三个真实的、涵盖多种技术和生物场景的多组学数据集(DBiT-seq, Patch-seq, ATAC+gene)上,对scMoE进行了全面的评估。

在包含四种模态的Dyngen数据集上,对于scMoE,整合全部四种模态(Pre, Pro, D, m)后的聚类性能(ARI=0.72)优于大部分双模态组合,而 UnitedNet 四模态时性能骤降至0.56,证明了scMoE通过SMoE机制成功缓解了梯度冲突。在跨模态预测任务中scMoE的性能也超越了所有基线方法。

在多个真实数据集上scMoE展示了强大实力:

一系列消融实验证明了scMoE设计的合理性:

作者通过机理和后验两种方式直观地展示了scMoE的可解释性: