用于轨迹推断的流形插值最优传输流.

0. TL; DR

作者们提出了一种名为MIOFlowManifold Interpolating Optimal-Transport Flow)的新方法,用于从在零散时间点采集的静态“快照”样本中,学习随机、连续的群体动态。MIOFlow通过训练一个神经普通微分方程(Neural ODE),使其在静态群体快照之间进行插值,这种插值过程受到基于流形距离的最优传输的惩罚。

为了确保流动(flow)遵循数据的内在几何结构,作者们在一个自编码器的隐空间中进行操作,这个自编码器被称为“测地线自编码器”(geodesic autoencoder, GAE)。在GAE中,隐空间中点与点之间的距离被正则化,以匹配作者们定义的一种新的多尺度流形距离。

作者们证明,在群体间的插值任务上,MIOFlow优于归一化流、薛定谔桥以及其他设计用于从噪声生成数据的生成模型。理论上,作者们将这些轨迹与动态最优传输联系起来。作者们在具有分叉和合并的模拟数据,以及在胚状体分化和急性髓系白血病治疗的单细胞RNA测序(scRNA-seq)数据上评估了该方法。

1. 背景介绍

这项工作要解决的问题是,如何从在离散时间点采集的静态、横截面样本中,学习到一个概率分布的连续动态。在自然系统中,高维数据通常被认为起源于一个嵌入在高维测量空间中的低维流形(manifold),这被称为“流形假设”(manifold hypothesis)。这一假设在生物、化学和物理系统中催生了许多成功的模型。

在单细胞生物学中,尽管测量的维度极高(例如,每个基因是一个维度),但由于基因之间的信息冗余,数据的内在维度是低维的。黎曼流形(Riemannian manifold)是模拟这类系统的一个很好的数学工具。

然而,现有的技术(如scRNA-seq)在测量时会破坏细胞,使得我们无法追踪单个细胞随时间的变化。因此,作者们只能得到不同时间点的细胞群体“快照”。如何从这些快照中学习连续的群体动态,特别是在未测量的中间时间点进行插值,以及推断遵循流形结构的单个轨迹,是一个核心挑战。

最近,一些基于神经网络的“流”(flows)或“传输”(transports)模型被提出,但它们大多专注于生成式建模,即从一个简单的噪声分布(如高斯分布)流动到一个复杂的数据分布,以生成新数据。例如,基于分数的生成式匹配、扩散模型(diffusion models)、薛定谔桥(Schrödinger bridges)和连续归一化流(continuous normalizing flows, CNF)。

与这些方法不同,本文的目标是学习系统的连续动态,并进行插值。为了在具有流形结构的数据上连续地插值群体,作者们提出了MIOFlow,一个基于动态最优传输和流形嵌入的新框架。MIOFlow使用一个神经ODE来在时间点之间传输数据点,并满足以下三个关键特性:

  1. 传输发生在由样本定义的流形上。
  2. 传输过程通过Wasserstein距离来惩罚,以确保与观测到的时间点数据一致。
  3. 传输过程是内在随机的,以模拟生物过程的随机性。

与之前的工作(如TrajectoryNet)相比,MIOFlow有几个关键优势:它不要求从高斯分布开始,而是直接在观测到的数据分布之间流动;它通过引入扩散项来自然地建模随机性;并且,它通过一个新颖的测地线自编码器(Geodesic Autoencoder, GAE)和基于流形距离的最优传输惩罚,显式地强制流动发生在数据流形上,而不是在环境空间中。

2. MIOFlow 方法

为了从多个横截面的细胞群体快照中学习单个轨迹,作者们提出了MIOFlow。该方法包含两个主要步骤:首先,通过一个测地线自编码器(Geodesic Autoencoder)学习一个能够保持扩散测地线距离的嵌入空间;然后,在这个嵌入空间中,通过一个基于最优传输损失的神经ODE来学习连续的轨迹。

2.1 测地线自编码器嵌入 (Geodesic Autoencoder Embedding)

动态最优传输的公式通常假设底层的距离是欧几里得距离。然而,高维单细胞数据通常位于一个低维流形上,直接在环境空间中使用欧几里得距离会忽略数据的内在几何结构。为了解决这个问题,作者们的目标是学习一个嵌入空间$Z$,使得该空间中的欧几里得距离能够近似原始数据流形上的测地线距离(geodesic distance)。

扩散测地线距离 (Diffusion Geodesic Distance)

作者们首先定义了一种新的多尺度流形距离,称为“扩散测地线距离”(diffusion geodesic distance)。理论上,在一个闭合的黎曼流形$(M, d_M)$上,一种称为“扩散地距离”(diffusion ground distance)$D_\alpha(x,y)$被证明与测地线距离$d_M^{2\alpha}(x,y)$是等价的。这个距离是通过在不同时间尺度上比较两个点$x$和$y$的热核(heat kernel)分布来定义的。

在实践中,作者们无法直接计算热核,但可以使用扩散矩阵(diffusion matrix)$P_\epsilon$来近似。扩散矩阵是通过对数据构建一个k-NN图,然后进行密度归一化和行归一化得到的,它描述了一个在数据图上的马尔可夫随机游走过程。基于这个扩散矩阵,作者们定义了扩散测地线距离$G_\alpha(x_i, x_j)$:

\[G_\alpha(x_i, x_j) := \sum_{k=0}^{K} 2^{-(K-k)\alpha} \|(P_\epsilon)^{2^k}_{i:} - (P_\epsilon)^{2^k}_{j:} \|_1 + 2^{-(K+1)/2} \|\pi_i - \pi_j \|_1\]

这个距离通过在指数级增长的时间尺度($2^k$)上比较从点$x_i$和$x_j$开始的随机游走概率分布,来捕捉它们在流形上的距离。作者们证明,当数据点足够多时,这个离散的$G_\alpha$会收敛到连续的$D_\alpha$,从而近似于测地线距离。

测地线自编码器 (Geodesic Autoencoder, GAE) 的训练

GAE是一个标准的自编码器,但其训练过程包含一个特殊的正则化项,即强制其隐空间$Z$中的欧几里得距离与上述计算出的扩散测地线距离$G_\alpha$相匹配。其损失函数$L(\phi)$为:

\[L(\phi) := \frac{2}{N}\sum_{i=1}^N \sum_{j>i} (\|\phi(x_i) - \phi(x_j)\|_2 - G_\alpha(x_i, x_j))^2\]

其中$\phi$是编码器。同时,解码器$\phi^{-1}$通过一个标准的重构损失$L_r = \sum_x ||\phi^{-1} \circ \phi(x) - x||_2$来训练。通过这种方式,GAE学习到了一个“等距嵌入”,使得在隐空间中可以直接使用欧几里得距离进行动态最优传输,而这等价于在原始流形上使用测地线距离。

2.2 推断轨迹 (Inferring Trajectories)

在获得了保持测地线距离的隐空间$Z$后,作者们使用一个神经ODE来建模细胞轨迹。目标是学习一个参数化的函数$f_\theta(x,t)$,使得由它定义的轨迹$X_t = X_0 + \int_0^t f_\theta(X_u, u)du$能够满足在各个观测时间点$i$的分布约束$X_i \sim \mu_i$。

基于最优传输的损失函数

作者们利用了动态最优传输理论,将这个问题转化为一个带正则化的优化问题。其损失函数包含三个部分:

  1. 边际匹配损失 ($L_m$):在每个观测时间点$i$,计算由ODE模型预测出的分布$\hat{\mu}_i$与真实观测分布$\mu_i$之间的Wasserstein距离,并将其加和。
\[L_m := \sum_{i=1}^{T-1} W_2(\hat{\mu}_i, \mu_i)\]
  1. 能量损失 ($L_e$):惩罚速度场$f_\theta$的$L_2$范数,即路径的“动能”,以鼓励模型学习更短、更“经济”的路径。
\[L_e := \lambda_e \sum_{i=0}^{T-2} \int_i^{i+1} \|f_\theta(x_t, t)\|_2^2 dt\]
  1. 密度损失 ($L_d$):惩罚那些偏离数据流形的轨迹点,即鼓励插值点靠近观测到的数据点。
\[L_d := \lambda_d \sum_{t=1}^{T-1} \sum_{x\in\hat{X}_t} \sum_{i=1}^k \max(0, \text{min-k}(\{\|x-y\| : y \in X_t\}) - h)\]

建模随机性 (Modeling Diffusion)

为了模拟生物过程的内在随机性,作者们在ODE中加入了一个扩散项,使其变为一个随机微分方程(SDE):$dXt = f(Xt, t)dt + \sqrt{\sigma_t} dB_t$。在实践中,作者们通过学习一个随时间变化的噪声尺度$\sigma_t$来实现。这使得从同一个初始细胞出发,可以生成多条不同的随机轨迹。

整个MIOFlow的训练过程分为“局部训练”和“全局训练”两个阶段,以确保模型既能准确匹配相邻时间点的分布,又能学习到全局一致的连续轨迹。

3. 实验分析

作者们在模拟数据和真实的scRNA-seq数据集上,对MIOFlow的性能进行了评估,并与TrajectoryNetDiffusion Schrödinger’s Bridge (DSB)等方法进行了比较。

3.1 人工数据集

Petal数据集

这个数据集模拟了一个从中心开始,分叉成多个“花瓣”,最后又合并的动态过程,包含了分叉和合并两种复杂的拓扑结构。

在隐藏中间时间点进行插值的任务中,MIOFlowW1距离和MMD距离两个指标上都取得了最好的结果,并且训练时间远少于TrajectoryNetDSB

Dyngen数据集

这是一个更具挑战性的模拟scRNA-seq数据集,包含一个不对称的分叉,并且不同分支的细胞数量不均衡。

3.2 单细胞数据

胚状体 (Embryoid Body, EB) 数据

这是一个人类胚状体分化的时间序列数据集,展示了从干细胞到多个谱系的分化过程。

作者展示了MIOFlow的一项强大能力——将学习到的轨迹解码回原始的基因空间,从而观察单个基因的表达动态。与TrajectoryNet(下图)相比,MIOFlow(上图)生成的基因表达轨迹更加平滑且符合生物学预期。例如,对于已知在神经前体细胞中表达先升后降的HAND2基因,MIOFlow清晰地展示了这一非单调趋势,而TrajectoryNet的轨迹则显得混乱。对于神经谱系的标志物ONECUT2MIOFlow也正确地显示了其在神经谱系轨迹中的特异性高表达。

作者比较了使用GAE和不使用GAE时的模型性能。结果显示,无论使用哪种核函数(高斯核或$\alpha$-decay核),使用了GAEMIOFlow在几乎所有指标上都表现得更好。这证明了GAE学习到的测地线嵌入对于准确建模流形上的动态至关重要。

急性髓系白血病 (AML) 数据

这是一个AML小鼠模型在接受化疗过程中的数据集。作者用MIOFlow来研究癌细胞的耐药性机制。