使用Matilda进行多模态单细胞组学的多任务学习.
0. TL; DR
Matilda是一个用于多模态单细胞组学数据整合分析的多任务学习(multi-task learning)方法,利用其模块化的神经网络架构,将数据模拟、维度降低、细胞类型分类和特征选择四个关键任务整合到一个单一的、统一的框架中进行联合学习。
作者在多种来自不同技术平台(TEA-seq, CITE-seq, SHARE-seq)的多组学数据集上,对Matilda的各项功能进行了评估。结果表明,Matilda在数据模拟、细胞类型分类和特征选择等任务上的性能,均优于其他SOTA方法。
1. 背景介绍
单细胞多组学技术,如CITE-seq和TEA-seq,为我们提供了一扇前所未有的窗户,让我们能同时观察一个细胞的多个分子层面,例如基因表达(RNA)、表面蛋白(ADT)和染色质可及性(ATAC)。这为我们以一种更全面的视角来理解细胞系统创造了条件。
然而,尽管数据获取技术日新月异,我们的计算分析工具却相对滞后。现有的单细胞分析方法大多是被设计用来解决某个特定的任务,如数据模拟、细胞类型分类或特征选择。这种模式存在一个根本性的问题:它忽略了不同分析任务之间内在的、紧密的关联。例如:
- 一个好的数据模拟模型,能够生成逼真的数据,这些数据可以被用来增强训练集,从而提升细胞类型分类的准确性。
- 一个准确的细胞类型分类模型,其内部学习到的权重可以反过来帮助我们识别对分类最重要的特征。
因此,一个理想的分析框架,应该能够将这些相关的任务整合到一个统一的框架中,让它们相互促进、协同增益。
2. Matilda 方法
Matilda的核心是一个多任务神经网络,它巧妙地将一个变分自编码器和一个分类网络结合在一起,并通过一个统一的损失函数进行端到端的联合训练。

2.1 多任务学习架构
Matilda的架构主要由两部分构成:一个用于数据模拟和降维的VAE组件,以及一个用于细胞类型分类的分类器组件。
对于输入的N种模态的数据,Matilda为每一种模态都设计了一个独立的编码器,负责将对应模态的高维数据,编码到一个低维的特征空间中。来自所有模态的低维特征被拼接在一起,然后送入一个共享的全连接层进行联合学习。重参数化技巧从中采样得到最终的、融合了所有模态信息的低维潜变量(latent variable) $z$。
从潜变量 $z$ 开始,网络结构分为了两个并行的分支:
- 解码器分支(Decoder Branch): $z$被送入N个独立的解码器中,每个解码器负责重构其对应模态的原始数据。这部分构成了VAE的生成部分,用于数据模拟。
- 分类器分支(Classifier Branch): $z$同时也被送入一个简单的全连接分类网络中,该网络通过SoftMax激活函数,输出细胞属于每个类别的概率。
2.2 联合损失函数
Matilda的训练目标是最小化一个由两部分组成的联合损失函数:
\[L_{sum} = L_{sim} + \lambda \times L_{cla}\]数据模拟损失 ($L_{sim}$) 是VAE组件的损失,即标准的证据下界(ELBO)损失。它包含两项:
- 重构损失: 鼓励解码器生成的数据与原始数据尽可能相似。
- KL散度损失: 一个正则项,促使潜变量 $z$ 的后验分布 $q_\theta(z|X)$ 接近于标准正态分布先验 $p(z)$。
分类损失 ($L_{cla}$) 是分类器分支的损失,采用带标签平滑(label smoothing)的交叉熵损失。标签平滑是一种正则化技巧,可以防止模型对标签过于自信,从而提高泛化能力。
\[L_{cla} = - \sum_{i=1}^K y_i^{ls} \log y_i^{output}\]权重系数 $\lambda$ 是一个超参数,用于平衡数据模拟任务和细胞类型分类任务的重要性。
通过联合优化这个$L_{sum}$,Matilda实现了在一个模型内同时学习如何模拟数据和如何分类细胞。
2.3 模拟增强的训练策略
为了进一步提升性能,特别是在处理细胞类型不均衡的数据时,Matilda采用了一种模拟增强的训练策略。
在训练过程中,Matilda首先根据训练集中每个细胞类型的细胞数量进行排序,找到数量为中位数的细胞类型作为参考。对于那些细胞数量少于中位数的稀有类型,模型会利用其VAE组件生成模拟数据,将其增强到与参考类型相同的数量。对于那些细胞数量多于中位数的常见类型,则进行随机下采样。通过这种方式,模型在一个类别均衡的数据集上进行训练,这有助于模型更好地学习稀有细胞类型的分子特征。
2.4 跨模态联合特征选择
基于训练好的神经网络,Matilda实现了两种方法来同时从多个模态中识别对细胞类型分类最重要的特征。
⚪ 积分梯度 (Integrated Gradient, IG)
IG是一种用于解释深度学习模型预测结果的技术。它通过计算模型输出(分类概率)对每个输入特征的梯度积分,来量化该特征对预测的贡献度。
\[S_j = \int_{\tau=0}^1 X_j \times \frac{\partial F(\tau \times X)}{\partial X_j} d\tau\]通过对每个细胞类型反向传播梯度,Matilda可以为每个特征(无论是RNA、ADT还是ATAC)计算一个重要性分数。
⚪ 显著性图 (Saliency)
这是一种更简单的方法,直接使用模型输出对输入的梯度的大小,来作为特征的重要性分数。
\[S_j = \frac{\partial F(X)}{\partial X} \Big|_{X_j}\]通过这两种方法,Matilda能够为每个细胞类型,提供一个跨越所有模态的、排序的重要特征列表。
3. 实验分析
作者在五个来自TEA-seq、CITE-seq和SHARE-seq等主流平台的多组学数据集上,对Matilda的四项核心功能进行了全面的基准测试。
3.1 实验一:多模态数据模拟
UMAP可视化显示,Matilda能够生成与真实数据在潜在空间中分布高度一致的、细胞类型特异性的模拟数据。并且,模拟数据似乎还起到了去噪的效果,使得簇的边界更清晰。

作者比较了真实数据和模拟数据的特征相关性结构。热图显示,Matilda模拟的RNA数据,其基因-基因相关性矩阵与真实数据最为相似,显著优于scGAN, ACTIVA, SPARSim等专门为scRNA-seq设计的SOTA模拟方法(图A-C)。皮尔逊相关性的量化比较也证实了这一点(图D-F)。

3.2 实验二:多模态整合与降维
作者比较了Matilda与其他SOTA整合/降维方法(如Seurat, totalVI, Conos, MultiVI)生成的低维嵌入的质量。
Matilda生成的UMAP图具有最清晰的细胞簇边界和最少的簇间混杂(图A, B)。
作者使用k-means聚类,并通过ARI, NMI, FM, Jaccard等一系列指标来评估聚类结果与真实标签的一致性。结果显示,在所有数据集和所有指标上,Matilda的性能均显著优于所有其他方法(图C-E)。

3.3 实验三:多模态细胞类型分类
作者将Matilda与一系列专门为scRNA-seq设计的SOTA分类器(如CHETAH, scClassify, singleCellNet)以及一个多模态整合方法UMINT进行了比较。
结果显示,无论是在数据集内交叉验证还是跨批次测试中,Matilda的分类准确率都显著高于所有其他方法。

3.4 实验四:多模态联合特征选择
作者比较了Matilda与其他多种特征选择方法(如t-test, limma, MAST, PROPOSE)选出的top 100特征,在区分特定细胞类型时的性能。
结果显示,使用Matilda从多个模态中联合选出的特征,其分类能力平均而言优于所有那些只从RNA模态中选择特征的方法(图E)。
作者展示了Matilda为CD14单核细胞和B细胞选出的top特征,这些特征(如RNA模态的CD14基因,ADT模态的CD14蛋白)都与对应细胞类型的已知marker高度吻合,证明了其选择的特征具有明确的生物学意义(图A-D)。
