TrajectoryNet:建模细胞动态的动态最优传输网络.

0. TL; DR

在生物医学领域,我们经常遇到这样一种数据:它们是在不同时间点对动态过程进行的静态、横截面测量。如何从这些数据中推断出个体的连续轨迹是一个核心挑战。最近的一些方法尝试使用最优传输(optimal transport)在时间点之间进行成对匹配,但它们无法模拟连续的动态和系统中个体可能遵循的非线性路径。

为了解决这个问题,作者建立了连续归一化流(continuous normalizing flows, CNF)和动态最优传输(dynamic optimal transport)之间的联系,从而能够对个体随时间变化的预期路径进行建模。标准的CNF通常是欠约束的,因为它允许从源分布到目标分布的路径是任意的。为此,作者提出了TrajectoryNet,一个通过控制分布之间的连续路径来实现动态最优传输的模型。

作者展示了TrajectoryNet在研究单细胞RNA测序(scRNA-seq)数据中的细胞动力学方面的特别适用性,并证明其性能优于近期提出的、基于静态最优传输的细胞分布插值模型。

1. 背景介绍

在数据科学中,尤其是生物医学领域,我们经常面对的是对随时间变化的现象进行的横截面采样数据。例如,不同年龄群体的健康指标,或疾病进展不同阶段的测量数据。在这些测量中,我们虽然在多个时间点进行了采样,但在每个时间点,我们只能观察到该时刻群体的分布(即一个横截面),而无法追踪单个实体(如一个细胞或一个个体)的连续变化,这导致了点对点的对应关系缺失。

从这些静态的“快照”测量中提取纵向的动态信息是一个巨大的挑战。现有的插值方法有限,而问题因“非配对”的采样而变得更加复杂。作者将此问题构建为一个非平衡动态传输(unbalanced dynamic transport)问题,目标是使用高效、平滑的路径将一个横截面测量中的实体“运输”到下一个。

近年来,单细胞RNA测序(scRNA-seq)技术的发展使得我们能够以前所未有的分辨率了解细胞的身份和行为。一个特别重要的应用是研究细胞如何从一种状态分化到另一种状态。然而,scRNA-seq技术在测量时会破坏细胞,因此我们只能获得静态的快照数据,无法监测单个细胞随时间的变化。此外,由于实验成本高昂,通常只能在少数几个离散的时间点进行采样。

现有的计算方法各有局限:一些方法试图在一个时间点内推断轨迹(即“伪时间”),但这并不能反映真实的实验时间。另一些方法在两个时间点之间进行线性插值,但这无法捕捉复杂的非线性路径。

TrajectoryNet旨在克服这些限制。它不仅能在观测到的时间点之间沿着数据流形进行非线性插值,还能为单个实体创建连续时间的轨迹,并构建一个能够通过in silico扰动来理解动态驱动因素的深度表征模型。

2. TrajectoryNet 方法

TrajectoryNet的核心思想是,通过对连续归一化流(Continuous Normalizing Flows, CNFs)施加特定的正则化,来近似求解高维空间中的动态最优传输问题。

2.1 从静态到动态最优传输

静态最优传输 (Static Optimal Transport, OT)也称为蒙日-坎托罗维奇问题(Monge-Kantorovich Problem),其目标是找到一个“运输方案”$\pi$,以最小的成本将一个概率分布$\mu$变换为另一个概率分布$\nu$。其数学形式是求解一个在所有联合分布$\Pi(\mu, \nu)$上的最小化问题。尽管在离散情况下有如Sinkhorn算法等快速近似解,但在高维连续空间中,直接求解OT非常困难。

动态最优传输 (Dynamic Optimal Transport, DOT)OT的一种动态形式,由Benamou & Brenier提出。它将OT问题与流体动力学联系起来。假设一个随时间$t$变化的密度场$P(x,t)$和一个速度场$f(x,t)$满足连续性方程(continuity equation):

\[\partial_t P + \nabla \cdot (Pf) = 0\]

该方程保证了质量在运输过程中是守恒的。那么,两个分布$\mu$和$\nu$之间的平方$L_2$ Wasserstein距离可以表示为寻找一个最优流$(P,f)$,使得在将$\mu$运输到$\nu$的过程中,总的动能最小:

\[W_2^2(\mu, \nu) = \inf_{(P,f)} (t_1-t_0) \int_{\mathbb{R}^d} \int_{t_0}^{t_1} P(x,t)|f(x,t)|^2 dt dx\]

这个公式为我们提供了一个通过最小化路径能量来构建动态轨迹的理论基础。然而,传统的数值解法需要对时空进行网格化,其计算复杂度随维度呈指数增长,不适用于高维单细胞数据。

2.2 连续归一化流 (CNFs)

CNF是一种生成模型,它通过一个由神经网络$f_\theta(x(t), t)$参数化的常微分方程(ODE)来将一个简单的基础分布(如高斯分布)连续地变换为一个复杂的目标分布。从$t_0$到$t_1$的变换可以表示为:

\[x(t_1) = x(t_0) + \int_{t_0}^{t_1} f_\theta(x(t),t) dt, \quad x(t_0) \sim P_{t_0}(x)\]

其概率密度的变化遵循瞬时变量替换公式(instantaneous change of variables formula),涉及到$f_\theta$的迹(trace),这比计算雅可比行列式要高效得多:

\[\log P_{t_1}(x(t_1)) = \log P_{t_0}(x(t_0)) - \int_{t_0}^{t_1} \text{Tr}\left(\frac{\partial f_\theta(x(t),t)}{\partial x(t)}\right) dt\]

CNF通过最大化数据的对数似然来进行训练。

2.3 TrajectoryNet:通过正则化CNF近似动态OT

作者的核心洞察是,CNFDOT之间存在深刻的联系。DOT要求在起始和终止时间点精确匹配分布,而CNF则通过KL散度来“软性”地匹配目标分布。作者证明通过在CNF的损失函数中加入一个对速度场$f$的$L_2$范数的惩罚项(即能量正则化),当惩罚系数$\lambda$足够大时,CNF的解会收敛到DOT的解。

因此,TrajectoryNet在标准的CNF对数似然损失的基础上,增加了一个能量损失项$L_{\text{energy}}$:

\[L_{\text{energy}}(x) = \lambda_e \int_t \|f(\tilde{x}, t)\|^2 + \lambda_j \int_t \|J_f(\tilde{x})\|_F^2\]

第一项惩罚速度场$f$的$L_2$范数,旨在最小化路径的“动能”,鼓励模型学习更“直”的路径,这与DOT的目标一致。第二项惩罚$f$的雅可比矩阵$J_f$的Frobenius范数,这相当于惩罚路径的二阶导数(加速度),鼓励路径更加平滑,避免剧烈的转弯。

通过这种正则化,TrajectoryNet不仅学习如何从一个分布变换到另一个,还学习如何以一种“能量最优”且“平滑”的方式进行变换。

2.4 针对单细胞数据的进一步适配

为了更好地模拟细胞系统的特性,TrajectoryNet引入了三个针对生物学先验的正则化项:

  1. 生长率正则化 ($L_{\text{growth}}$)

为了处理细胞增殖/死亡导致的非平衡传输问题,作者训练了一个独立的神经网络$G(x,t)$来预测在状态$x$和时间$t$的细胞生长率。这个网络的训练目标来自于离散非平衡OT的解。然后,在CNF的密度积分中,引入这个生长率项来调整质量的变化:

\[\log M_{t_i}(x) = \log M_{t_{i-1}}(x) - \int_{t_i}^{t_{i-1}} \text{Tr}(\dots) dt + \log G(x_{t_{i-1}}, t_{i-1})\]
  1. 密度罚项 ($L_{\text{density}}$)

由于细胞通常生活在一个低维流形上,作者希望插值的路径也能停留在这个流形上。为此,作者设计了一个密度罚项,惩罚那些远离任何观测数据点的预测点。

  1. 速度正则化 ($L_{\text{velocity}}$)

对于单细胞数据,可以利用RNA-velocity技术为每个观测到的细胞估计一个局部的、瞬时的速度向量$\hat{d}x/dt$。作者利用这个信息来正则化TrajectoryNet学习到的全局速度场$f(x,t)$,通过最小化它们方向上的余弦相似度(cosine-similarity):

\[L_{\text{velocity}}(x, t, \hat{d}x/dt) = \text{cosine-similarity}(f(x,t), \hat{d}x/dt)\]

这使得全局轨迹在经过观测点时,其方向能与局部的生物学动态保持一致。

最终,TrajectoryNet的总损失函数$L_T$是这些项的加权和: \(L_T = \sum_{i=1}^k -\log P_{t_i}(x_{t_i}) + L_{\text{energy}} + L_{\text{density}} + L_{\text{velocity}} + L_{\text{growth}}\)

3. 实验分析

作者在一系列人工数据集和真实的单细胞数据集上对TrajectoryNet进行了评估。

3.1 人工数据集

作者设计了三个具有已知路径的2D人工数据集:一个“拱形”(Arch)、一个“树形”(Tree)和一个“圆形”(Cycle),并在训练时隐藏中间时间点,以测试模型的插值能力。

作者比较了不同版本的TrajectoryNet与静态OT以及其他基线(如前一个/后一个时间点)的性能,使用了Wasserstein距离(EMD)和均方误差(MSE)作为评估指标。在ArchTree数据集上,基础的TrajectoryNetBase)和静态OT倾向于走直线“捷径”,而忽略了数据的流形结构。加入了密度正则化(+ D)或速度正则化(+ V)的TrajectoryNet能够更好地贴合数据流形,取得了更低的EMDMSE

可视化结果清晰地展示了密度或速度正则化如何引导TrajectoryNet的路径(橙色)沿着弯曲的流形前进,而不是像静态OT那样直接“抄近路”(蓝色)。Cycle数据集模拟了细胞周期,其分布不随时间改变,但细胞在圆周上运动。只有加入了速度正则化(+ V)的TrajectoryNet能够成功捕捉到这种纯粹的动态,而其他所有方法都因为看不到分布变化而失败。

3.2 单细胞数据

作者在两个真实的scRNA-seq数据集上评估了TrajectoryNet

小鼠皮层数据

这个数据集包含4个时间点的小鼠胚胎皮层细胞,展示了一个从神经干细胞到成熟神经元的分化过程,其数据流形呈现出弯曲的结构。

在留一法交叉验证中,TrajectoryNet的所有版本都显著优于静态OT和其他基线方法。这表明对于具有弯曲流形结构的真实数据,TrajectoryNet的连续、非线性插值能力至关重要。

可视化结果与静态OT的插值结果(图e)相比,TrajectoryNet的插值结果(图f)更好地保持了数据在低维空间的分布结构。插值得到的细胞在关键标志物基因(Pax6, Eomes, Tbr1)的表达上也符合预期的发育模式。

胚状体数据

这个数据集包含5个时间点的人类胚状体发育过程,从一个均一的干细胞群分化出四个不同的谱系,呈现出复杂的分叉结构。

在这个更复杂的数据集上,TrajectoryNet再次展示了其优越性。特别是加入了密度正则化(+ D)的模型取得了最好的平均性能。作者推测,速度正则化(+ V)表现不佳可能是由于该数据集中未剪接RNA计数较低,导致RNA-velocity估计不准。

在通过PHATE降维可视化的空间中,TrajectoryNet生成的轨迹清晰地描绘了从中心干细胞群到四个末端谱系的分叉过程。

作者展示了TrajectoryNet的一个强大应用:将学习到的轨迹投影回原始的基因表达空间。通过对四个终末谱系的细胞进行“反向积分”,可以追溯它们在分化早期的基因表达动态。例如,作者发现心脏谱系的标志基因HAND1在分化早期(第6天)就已经开始高表达,比其他谱系的标志物出现得更早。这证明了TrajectoryNet不仅能插值,还能为理解基因调控的时序性提供深刻洞见。