用流匹配联合建模单细胞的速度与生长.

0. TL; DR

从静态的“快照”数据中学习单细胞的潜在动态,已成为科学研究和机器学习领域日益关注的焦点。然而,破坏性的测量技术导致了快照之间的数据是“非配对”的,而细胞的增殖/死亡则导致了数据是“非平衡”的,这给学习底层动态带来了巨大挑战。

本文提出了一种名为VGFMjoint Velocity-Growth Flow Matching)的新范式,它通过流匹配(flow matching)来联合学习单细胞群体的状态转变(state transition)和质量增长(mass growth)。VGFM的核心思想源于对静态半松弛最优传输(static semi-relaxed optimal transport)的一种新的动态理解。作者将这种数学工具(用于寻找非配对、非平衡数据间的耦合)重新诠释为一个两阶段的动态过程,从而构建了一个包含状态速度(velocity)和质量增长(growth)的理想单细胞动力学模型。

为了在实际中应用,作者使用神经网络来逼近这个理想动力学,从而形成了联合的速度和生长流匹配框架。此外,VGFM还采用了一个分布拟合损失(distribution fitting loss)来进一步提升对快照数据的拟合性能。在合成和真实数据集上的大量实验表明,VGFM能够准确捕捉考虑到质量和状态随时间变化的底层生物学动态,其性能优于现有的单细胞动力学建模方法。

1. 背景介绍

从稀疏和含噪的数据中推断复杂系统的潜在动态,是科学和工程领域的一个根本性挑战。在许多领域,如金融市场、气候系统和生物过程,我们很少能完整地观察到连续的轨迹。相反,我们通常只能获得在离散时间点上收集的横断面“快照”数据。

这个问题在单细胞RNA测序(single-cell RNA sequencing)领域尤为突出。由于测序过程会破坏细胞,我们得到的是一系列在时间上非配对的群体水平快照,无法追踪单个细胞的命运。更复杂的是,在细胞发育或响应过程中,细胞会经历增殖和死亡,导致不同时间点的细胞总数发生变化,即数据是“非平衡”(unbalanced)的,违背了质量守恒。因此,如何从有限的样本中重构随时间演化的、非标准化的密度函数,已成为一个重要的研究问题。

基于深度学习的动力学推断模型已显示出巨大潜力。这些模型通常使用由神经网络参数化的常微分方程(ordinary or stochastic differential equations, ODEs or SDEs)来近似控制密度演化的速度场。现有方法大致可分为两类:

  1. 基于模拟的方法(Simulation-based methods):这类方法通过将初始数据输入神经网络并数值求解ODE/SDE来生成合成轨迹,然后将模拟结果与观测数据进行比较来计算损失。然而,这种方法在训练过程中严重依赖数值求解器,计算成本高昂。在高维情况下,巨大的搜索空间进一步加剧了训练的不稳定性。
  2. 无模拟方法(Simulation-free approaches):以流匹配(flow matching)为代表,这类方法通过构建条件概率路径来高效地训练速度场,而无需模拟轨迹。这使得训练过程更高效、更稳定。

然而,上述大多数无模拟方法只考虑了速度场v,忽略了观测数据的非平衡性,这可能导致动力学重构的错误。细胞增殖和死亡是单细胞动态中的普遍现象,因此必须引入一个生长项(growth termg来描述质量的变化。一些研究尝试联合学习vg,但它们或者施加了缺乏生物学依据的数学约束(如$v = ∇g$),或者严重依赖计算成本高昂的模拟过程,限制了它们在高维数据上的应用。

为了解决这些问题,作者提出了VGFM,一种新颖的单细胞动力学建模方法。VGFM旨在通过流匹配联合学习状态转变和质量增长。该方法的核心是基于半松弛最优传输(semi-relaxed optimal transport)理论,并对其进行动态的重新诠释,从而构建一个理想的、解耦的、同时包含速度和生长的动力学模型。然后,通过神经网络来匹配这个理想模型,并辅以一个分布拟合损失,以实现对单细胞快照数据的精确建模。

2. VGFM 方法

VGFM框架的核心是基于半松弛最优传输理论,构建一个理想的、包含速度和生长的非平衡动力学模型,然后利用流匹配的思想让神经网络去学习这个模型。整个流程如图所示,可以分解为三个步骤:动态诠释、构建联合动力学、以及速度-生长流匹配。

2.1 单细胞的非平衡动力学

首先,作者们用一个包含速度项v和生长项g的系统来描述单细胞的非平衡动态。对于一个细胞在时间$t$的状态$x_t$,其演化由以下方程组控制:

\[\begin{cases} \frac{dx_t}{dt} = v_t(x_t) \\ \frac{d \log w_t(x_t)}{dt} = g_t(x_t) \end{cases}\]

其中,$v_t$控制细胞状态(基因表达)的转变,$g_t$控制与细胞状态$x_t$相关的权重$w_t$的变化,从而模拟细胞的增殖或死亡。这个动态过程对应的群体密度$p_t$的演化由以下连续性方程描述:

\[\partial_t p_t = -\nabla \cdot (p_t v_t) + g_t p_t\]

2.2 基于半松弛最优传输构建非平衡动力学

上述方程虽然给出了非平衡动力学的形式,但并未指定$v_t$和$g_t$的具体形式。为此,作者们求助于半松弛最优传输(semi-relaxed optimal transport)。

动态诠释半松弛最优传输

经典的半松弛最优传输问题定义如下:

\[\min_{\pi \ge 0} J_{sot}(\pi) \triangleq \int_{\Omega^2} c(x_0, x_1) d\pi(x_0, x_1) + \text{KL}(P_{0\#}\pi \| p_0) \\ \text{subject to } P_{1\#}\pi = p_1\]

这个公式在寻找从分布$p_0$到$p_1$的耦合$\pi$时,通过KL散度项允许了源边际 \(P_{0 \#}\pi\) 与原始分布$p_0$不一致,从而实现了质量的变化。

作者们创新性地将这个静态的优化问题诠释为一个两阶段的动态过程:

  1. 质量增长阶段 (t ∈ [0, λ]):在第一个时间段,只发生质量变化。细胞状态不变,但其权重(或密度)根据一个生长函数$g_t$演化,即$\partial_t p_t = g_t p_t$。
  2. 状态转换阶段 (t ∈ (λ, 1]):在第二个时间段,只发生状态变化。细胞质量守恒,状态根据一个速度场$v_t$演化,即$\partial_t p_t = -\nabla \cdot (p_t v_t)$。

作者证明了,这个两阶段动态模型的总“成本”(包含运输成本和生长成本)与原始的半松弛最优传输问题的最优值是相等的。这意味着,任何一个半松弛最优传输解,都可以被等价地分解为这样一个先“生长”后“平移”的过程。

构建联合动力学

虽然两阶段模型为解耦速度和生长提供了理论基础,但它与生物系统中两者同时发生的现实不符。因此,作者基于这个解耦的模型,构建了一个等价的联合动力学模型。作者定义了一个新的速度场$\tilde{v}_t$和生长场$\tilde{g}_t$,使得它们在整个时间区间$[0, 1]$上共同作用:

\[\partial_t \tilde{p}_t = -\nabla \cdot (\tilde{p}_t \tilde{v}_t) + \tilde{g}_t \tilde{p}_t\]

作者证明了,在相同的初始分布$p_0$下,这个联合动力学模型在$t=1$时刻得到的最终分布$\tilde{p}_1$,与两阶段模型得到的最终分布$p_1$是完全相同的。这个联合动力学模型为作者提供了一个理想的、可用于流匹配的目标。

2.3 速度与生长流匹配

有了理想的联合动力学模型,作者就可以推导出理想的速度场 \(\tilde{v}_t\) 和生长场 \(\tilde{g}_t\) 的解析形式。给定一个从$p_0$采样的初始点$x_0$,其沿着流线 \(\psi_{\tilde{v},t}(x_0)\) 演化时,理想的速度和生长函数可以表示为:

\[\tilde{v}_t(\psi_{\tilde{v},t}(x_0)) = T^*(x_0) - x_0 \\ \tilde{g}_t(\psi_{\tilde{v},t}(x_0)) = \log P_{0\#}\pi^*(x_0) - \log p_0(x_0)\]

其中,$T^$ 是Monge map,$\pi^$ 是最优的半松弛传输方案。

现在,作者的目标就是训练两个神经网络$v_\theta(x,t)$和$g_\omega(x,t)$来分别逼近这两个理想的函数。这就是VGFM的核心——联合速度-生长流匹配。其损失函数$L_{VGFM}$定义为:

\[L_{VGFM}(\theta, \omega) = \sum_{i=1}^n \sum_{j=1}^m \pi_{ij}^{0\to1} E_t \left[ \|v_\theta(x_t, t) - (x_j^1 - x_i^0)\|^2 + |g_\omega(x_t, t) - \log([\pi^{0\to1}\mathbf{1}_m]_i)|^2 \right]\]

其中,$x_t = x_i^0 + t(x_j^1 - x_i^0)$是沿直线路径的插值点,$\pi^{0\to1}$是通过Sinkhorn算法从数据样本中计算出的半松弛最优传输方案。这个损失函数可以直接通过小批量采样进行高效优化,避免了任何ODE的数值求解。

2.4 训练过程与分布拟合损失

为了进一步提高模型对真实快照数据的拟合能力,作者在流匹配损失的基础上,额外引入了一个分布拟合损失$L_{OT}$。

在训练过程中,作者从初始数据$X_0$出发,使用当前学习到的$v_\theta$和$g_\omega$通过ODE求解器模拟出在后续观测时间点的预测细胞分布$\hat{X}_t$及其权重$\hat{w}(\hat{x})$。然后,计算预测分布与真实观测分布$X_t$之间的Wasserstein距离,作为额外的损失项:

\[L_{OT}(\theta, \omega) = \sum_{t=1}^{T-1} W_1\left(\frac{1}{N_t}p(X_t), \frac{1}{\sum_{\hat{x}\in\hat{X}_t}\hat{w}(\hat{x})} p_{\hat{w}}(\hat{X}_t)\right)\]

最终的总损失是流匹配损失和分布拟合损失的加权和:$L(\theta, \omega) = L_{VGFM}(\theta, \omega) + L_{OT}(\theta, \omega)$。训练采用了一个预热(warm-up)策略:先只用$L_{VGFM}$进行训练,为$v_\theta$和$g_\omega$提供一个良好的初始化;然后再加入$L_{OT}$进行联合训练,以进一步优化对数据分布的拟合。

通过这种结合流匹配和分布拟合的策略,VGFM既享受了流匹配带来的训练稳定性和高效率,又通过直接拟合数据分布确保了生成结果的保真度。

3. 实验分析

作者在一系列合成数据集和真实的单细胞数据集上对VGFM进行了广泛的评估,并与多种先进方法进行了比较。

3.1 在合成数据集上的表现

作者使用了三个合成数据集:一个经典的Simulation Gene三基因调控网络,一个通过Dyngen模拟的具有复杂分叉和不平衡性的scRNA-seq数据,以及一个更具挑战性的1000维高斯混合模型。

在所有三个数据集上,VGFMW1距离(衡量分布拟合精度)和相对质量误差(Relative Mass Error, RME,衡量生长预测精度)两个指标上均取得了最优性能,显著优于其他所有基线方法,包括同样考虑非平衡性的UDSB, TIGON, 和 DeepRUOT,以及不考虑非平衡性的OT-CFMOT-MFM

Simulation Gene数据上,VGFM不仅重构了准确的细胞动态轨迹(图a),其预测的生长率(图b)也与真实生长率(图c)高度吻合。比较了不同方法预测的相对细胞总数随时间的变化。结果清晰地显示,VGFM的预测曲线(红色)与真实的观测曲线(黑色)几乎完美重合,而其他方法则存在不同程度的偏差。特别是在EB (50D)数据上,TIGON的预测出现了显著偏离,反映了其在高维数据处理上的局限性,而VGFM依然表现稳健。

在1000维高斯数据集上,几种基于模拟的非平衡方法(UDSB, TIGON, DeepRUOT)都因为计算挑战而无法收敛。相比之下,VGFM得益于流匹配的框架,能够有效地扩展到高维空间,并取得了优异的结果。

3.2 在真实世界数据集上的表现

作者在三个真实的scRNA-seq数据集(EB, CITE-seq, Pancreas)上进行了评估。

作者采用了“留一法”策略,即在训练时隐藏一个中间时间点,然后评估模型对该时间点的预测能力。在EB (5D), CITE (5D)CITE (50D) 三个设置下,VGFM的平均W1距离均优于或持平于其他基线方法。这表明VGFM能够准确地插值生成未观测时间点的细胞分布。

EB (5D)数据集上,作者通过可视化比较了VGFM(完整模型)和VGFM (w/o LOT)(即只使用流匹配损失,不使用分布拟合损失)的预测结果。可以清楚地看到,完整的VGFM模型(图b)生成的轨迹和分布更接近真实情况,证明了分布拟合损失$L_{OT}$对于提升最终预测精度的重要性。

EB (50D)数据集上进行了更详细的消融研究。