Analytic-DPM:扩散概率模型中最优反向方差的分析估计.

扩散模型 (Diffusion Model)是一类深度生成模型。这类模型首先定义前向扩散过程,向数据中逐渐地添加随机噪声;然后学习反向扩散过程,从噪声中构造所需的数据样本。

在扩散模型中,前向扩散过程可以定义为:

\[\begin{array}{rlr} \mathbf{x}_t & =\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon} \\ q\left(\mathbf{x}_t \mid \mathbf{x}_0\right) & =\mathcal{N}\left(\mathbf{x}_t ; \sqrt{\bar{\alpha}_t} \mathbf{x}_0,\left(1-\bar{\alpha}_t\right) \mathbf{I}\right) & \end{array}\]

根据DDIM的假设,把\(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t},\mathbf{x}_0\right)\)建模为一个高斯分布,$\sigma_t$是可调的标准差参数:

\[\begin{aligned} q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)&=\mathcal{N}\left(\mathbf{x}_{t-1} ; \sqrt{\frac{1-\bar{\alpha}_{t-1}-\sigma_t^2}{1-\bar{\alpha}_{t}}}\mathbf{x}_{t}+\left( \sqrt{\bar{\alpha}_{t-1}} - \sqrt{\frac{\bar{\alpha}_{t}(1-\bar{\alpha}_{t-1}-\sigma_t^2)}{1-\bar{\alpha}_{t}}} \right) \mathbf{x}_0, \sigma_t^2 \mathbf{I}\right) \end{aligned}\]

因此得到:

\[\begin{aligned} \mathbf{x}_{t-1} = \sqrt{\frac{1-\bar{\alpha}_{t-1}-\sigma_t^2}{1-\bar{\alpha}_{t}}}\mathbf{x}_{t}+\left( \sqrt{\bar{\alpha}_{t-1}} - \sqrt{\frac{\bar{\alpha}_{t}(1-\bar{\alpha}_{t-1}-\sigma_t^2)}{1-\bar{\alpha}_{t}}} \right) \mathbf{x}_0+ \sigma_t \boldsymbol{\epsilon}_t \end{aligned}\]

在反向扩散过程中,我们希望求解\(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)\)。一种常见的求解方法是首先通过\(\mathbf{x}_{t}\)构造\(\mathbf{x}_{0}\),然后计算:

\[\begin{aligned} q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t},\mathbf{x}_0\right) \approx q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t},\mathbf{x}_0=\frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}}{\sqrt{\bar{\alpha}_t}}\right) \end{aligned}\]

然而从\(\mathbf{x}_{t}\)构造\(\mathbf{x}_{0}\)并不是完全准确的,因此应该用概率分布而非确定性的函数来描述它。事实上,严格地有:

\[\begin{aligned} q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right) = \int q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t},\mathbf{x}_0\right) q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right) d \mathbf{x}_0 \end{aligned}\]

注意到\(q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)\)是未知的,因此用正态分布\(\mathcal{N}\left(\mathbf{x}_0 ; \bar{\mu}(\mathbf{x}_{t}),\bar{\sigma}_t^2 \mathbf{I}\right)\)进行近似。

用正态分布\(\mathcal{N}\left(\mathbf{x}_0 ; \bar{\mu}(\mathbf{x}_{t}),\bar{\sigma}_t^2 \mathbf{I}\right)\)近似\(q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)\),落脚于分别近似\(q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)\)的均值和方差。

⚪ 近似均值

均值的近似为:

\[\begin{aligned} \bar{\mu}(\mathbf{x}_{t}) &= \mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \mathbf{x}_0 \right] \\ &= \mathop{\arg\min}_{\mu} \mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ ||\mathbf{x}_0 -\mu||^2 \right] \\ &= \mathop{\arg\min}_{\mu(\mathbf{x}_t)} \mathbb{E}_{\mathbf{x}_t \sim q\left(\mathbf{x}_{t}\right)}\mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ ||\mathbf{x}_0 -\boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right)||^2 \right] \\ &= \mathop{\arg\min}_{\mu(\mathbf{x}_t)} \mathbb{E}_{\mathbf{x}_t \sim q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)q\left(\mathbf{x}_{0}\right)} \left[ ||\mathbf{x}_0 -\boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right)||^2 \right] \end{aligned}\]

不妨把\(\boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right)\)表示为\(\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\)的函数:

\[\boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right) = \frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)}{\sqrt{\bar{\alpha}_t}}\]

则损失函数等价于标准扩散模型中的损失函数:

\[\begin{aligned} L_t & =\mathbb{E}_{\mathbf{x}_0, \boldsymbol{\epsilon}}\left[\left\| \frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}}{\sqrt{\bar{\alpha}_t}} - \frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)}{\sqrt{\bar{\alpha}_t}} \right\|^2\right] \\ & \propto \mathbb{E}_{\mathbf{x}_0, \boldsymbol{\epsilon}}\left[\left\|\boldsymbol{\epsilon}_t-\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right\|^2\right] \\ & =\mathbb{E}_{\mathbf{x}_0, \boldsymbol{\epsilon}}\left[\left\|\boldsymbol{\epsilon}_t-\boldsymbol{\epsilon}_\theta\left(\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_t, t\right)\right\|^2\right] \end{aligned}\]

⚪ 近似方差

据定义,协方差矩阵计算为:

\[\begin{aligned} \Sigma(\mathbf{x}_{t}) = &\mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ (\mathbf{x}_0 -\bar{\mu}(\mathbf{x}_{t}))(\mathbf{x}_0 -\bar{\mu}(\mathbf{x}_{t}))^T \right] \\ =& \mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \left(\mathbf{x}_0 -\frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)}{\sqrt{\bar{\alpha}_t}}\right)\left(\mathbf{x}_0 -\frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)}{\sqrt{\bar{\alpha}_t}}\right)^T \right] \\ =& \mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \left(\mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}}+\frac{\sqrt{1-\bar{\alpha}_t} }{\sqrt{\bar{\alpha}_t}}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right)\left(\mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}}+\frac{\sqrt{1-\bar{\alpha}_t} }{\sqrt{\bar{\alpha}_t}}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right)^T \right] \\ =& \mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \left(\mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}}\right)\left(\mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}}\right)^T \right]+2\frac{\sqrt{1-\bar{\alpha}_t} }{\sqrt{\bar{\alpha}_t}}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}} \right]^T \\ & + \frac{1-\bar{\alpha}_t} {\bar{\alpha}_t}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) ^T \\ =& \mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \left(\mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}}\right)\left(\mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}}\right)^T \right]+2\frac{\sqrt{1-\bar{\alpha}_t} }{\sqrt{\bar{\alpha}_t}}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \left[\mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} [\mathbf{x}_0] -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}} \right]^T \\ & + \frac{1-\bar{\alpha}_t} {\bar{\alpha}_t}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) ^T \\ =& \mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \left(\mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}}\right)\left(\mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}}\right)^T \right]+2\frac{\sqrt{1-\bar{\alpha}_t} }{\sqrt{\bar{\alpha}_t}}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \left[\bar{\mu}(\mathbf{x}_{t}) -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}} \right]^T \\ & + \frac{1-\bar{\alpha}_t} {\bar{\alpha}_t}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) ^T \\ =& \mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \left(\mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}}\right)\left(\mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}}\right)^T \right]+2\frac{\sqrt{1-\bar{\alpha}_t} }{\sqrt{\bar{\alpha}_t}}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \left[\frac{-\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)}{\sqrt{\bar{\alpha}_t}} \right]^T \\ & + \frac{1-\bar{\alpha}_t} {\bar{\alpha}_t}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) ^T \\ =& \frac{1}{\bar{\alpha}_t}\mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t}\mathbf{x}_0 \right)\left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t}\mathbf{x}_0 \right)^T \right]- \frac{1-\bar{\alpha}_t} {\bar{\alpha}_t}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) ^T \end{aligned}\]

上式两端取\(q\left(\mathbf{x}_{t}\right)\)的期望:

\[\begin{aligned} \Sigma(\mathbf{x}_{t}) = &\mathbb{E}_{\mathbf{x}_t \sim q\left(\mathbf{x}_{t}\right)} \left[\Sigma(\mathbf{x}_{t}) \right] \\ =& \mathbb{E}_{\mathbf{x}_t \sim q\left(\mathbf{x}_{t}\right)} \left[ \frac{1}{\bar{\alpha}_t}\mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t}\mathbf{x}_0 \right)\left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t}\mathbf{x}_0 \right)^T \right]- \frac{1-\bar{\alpha}_t} {\bar{\alpha}_t}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) ^T \right] \end{aligned}\]

注意到:

\[\begin{aligned} &\frac{1}{\bar{\alpha}_t}\mathbb{E}_{\mathbf{x}_t \sim q\left(\mathbf{x}_{t}\right)} \left[ \mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t}\mathbf{x}_0 \right)\left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t}\mathbf{x}_0 \right)^T \right] \right] \\ =& \frac{1}{\bar{\alpha}_t}\mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_{0}\right)} \left[ \underbrace{\mathbb{E}_{\mathbf{x}_t \sim q\left(\mathbf{x}_t \mid \mathbf{x}_{0}\right)} \left[ \left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t}\mathbf{x}_0 \right)\left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t}\mathbf{x}_0 \right)^T \right]}_{ q\left(\mathbf{x}_t \mid \mathbf{x}_{0}\right)\text{的协方差}} \right] \\ =& \frac{1}{\bar{\alpha}_t}\mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_{0}\right)} \left[ \left(1-\bar{\alpha}_t\right) \mathbf{I} \right] =\frac{1-\bar{\alpha}_t}{\bar{\alpha}_t}\mathbf{I} \end{aligned}\]

因此有:

\[\begin{aligned} \Sigma(\mathbf{x}_{t}) &= \frac{1-\bar{\alpha}_t}{\bar{\alpha}_t}\mathbf{I}-\mathbb{E}_{\mathbf{x}_t \sim q\left(\mathbf{x}_{t}\right)} \left[ \frac{1-\bar{\alpha}_t} {\bar{\alpha}_t}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) ^T \right] \\ &= \frac{1-\bar{\alpha}_t}{\bar{\alpha}_t}\left(\mathbf{I}-\mathbb{E}_{\mathbf{x}_t \sim q\left(\mathbf{x}_{t}\right)} \left[ \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) ^T \right] \right) \end{aligned}\]

两边取迹然后除以\(d=dim(\mathbf{x}_{t})\),得到$\sigma_t$的一个估计:

\[\hat{\sigma}_t^2 = \frac{1-\bar{\alpha}_t}{\bar{\alpha}_t} \left(1-\frac{1}{d}\mathbb{E}_{\mathbf{x}_t \sim q\left(\mathbf{x}_{t}\right)} \left[ ||\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) ||^2 \right] \right)\]

⚪ 采样过程

至此我们得到\(q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)\)的近似为:

\[\begin{aligned} q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right) \approx \mathcal{N}\left(\mathbf{x}_0 ; \frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)}{\sqrt{\bar{\alpha}_t}},\hat{\sigma}_t^2 \mathbf{I}\right) \end{aligned}\]

则反向扩散过程写作:

\[\begin{aligned} \mathbf{x}_{t-1} = \sqrt{\frac{1-\bar{\alpha}_{t-1}-\sigma_t^2}{1-\bar{\alpha}_{t}}}\mathbf{x}_{t}+\left( \sqrt{\bar{\alpha}_{t-1}} - \sqrt{\frac{\bar{\alpha}_{t}(1-\bar{\alpha}_{t-1}-\sigma_t^2)}{1-\bar{\alpha}_{t}}} \right) \mathbf{x}_0+ \sigma_t \boldsymbol{\epsilon}_t \end{aligned}\]

其中:

\[\begin{aligned} &\left( \sqrt{\bar{\alpha}_{t-1}} - \sqrt{\frac{\bar{\alpha}_{t}(1-\bar{\alpha}_{t-1}-\sigma_t^2)}{1-\bar{\alpha}_{t}}} \right) \mathbf{x}_0 = \left( \sqrt{\bar{\alpha}_{t-1}} - \sqrt{\frac{\bar{\alpha}_{t}(1-\bar{\alpha}_{t-1}-\sigma_t^2)}{1-\bar{\alpha}_{t}}} \right)\left( \frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)}{\sqrt{\bar{\alpha}_t}}+\hat{\sigma}_t\boldsymbol{\epsilon} \right) \\ =&\frac{1}{\sqrt{\alpha_t}}\mathbf{x}_{t}-\sqrt{\frac{1-\bar{\alpha}_{t-1}-\sigma_t^2}{1-\bar{\alpha}_{t}}}\mathbf{x}_{t}+\left( \sqrt{1-\bar{\alpha}_{t-1}-\sigma_t^2}-\frac{\sqrt{1-\bar{\alpha}_t}}{\sqrt{\alpha_t}} \right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \\ &+ \left( \sqrt{\bar{\alpha}_{t-1}} - \sqrt{\frac{\bar{\alpha}_{t}(1-\bar{\alpha}_{t-1}-\sigma_t^2)}{1-\bar{\alpha}_{t}}} \right)\hat{\sigma}_t\boldsymbol{\epsilon} \end{aligned}\]

代回原式得:

\[\begin{aligned} \mathbf{x}_{t-1} =& \frac{1}{\sqrt{\alpha_t}}\mathbf{x}_{t}+\left( \sqrt{1-\bar{\alpha}_{t-1}-\sigma_t^2}-\frac{\sqrt{1-\bar{\alpha}_t}}{\sqrt{\alpha_t}} \right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \\ & + \sigma_t \boldsymbol{\epsilon}_1 + \left( \sqrt{\bar{\alpha}_{t-1}} - \sqrt{\frac{\bar{\alpha}_{t}(1-\bar{\alpha}_{t-1}-\sigma_t^2)}{1-\bar{\alpha}_{t}}} \right)\hat{\sigma}_t\boldsymbol{\epsilon}_2 \\ =& \frac{1}{\sqrt{\alpha_t}}\mathbf{x}_{t}+\left( \sqrt{1-\bar{\alpha}_{t-1}-\sigma_t^2}-\frac{\sqrt{1-\bar{\alpha}_t}}{\sqrt{\alpha_t}} \right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \\ & + \sqrt{\sigma_t^2 + \left( \sqrt{\bar{\alpha}_{t-1}} - \sqrt{\frac{\bar{\alpha}_{t}(1-\bar{\alpha}_{t-1}-\sigma_t^2)}{1-\bar{\alpha}_{t}}} \right)^2\hat{\sigma}_t^2}\boldsymbol{\epsilon} \\ \end{aligned}\]

⚪ Analytic-DPM的实现

原论文的实验结果显示,Analytic-DPM所做的方差修正,主要在生成扩散步数较少时会有比较明显的提升,所以它对扩散模型的加速比较有意义。

class GaussianDiffusion(nn.Module):
    def __init__(
        self,
        model,
        *,
        image_size,
        timesteps = 1000,
        sampling_timesteps = None,
        ddim_sampling_eta = 0.,
    ):
        super().__init__()
        self.model = model # 用于拟合\epsilon(x_t,t)的神经网络
        self.channels = self.model.channels
        self.image_size = image_size

        betas = linear_beta_schedule(timesteps) # \beta_t
        alphas = 1. - betas # \alpha_t
        alphas_cumprod = torch.cumprod(alphas, dim=0) # \bar{\alpha}_t
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) # \bar{\alpha}_{t-1}

        self.num_timesteps = int(timesteps)
        self.alphas_cumprod = alphas_cumprod

        # sampling related parameters
        self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
        assert self.sampling_timesteps <= timesteps
        self.ddim_sampling_eta = ddim_sampling_eta
        
        # helper function to register buffer from float64 to float32
        register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
        register_buffer('betas', betas)
        
        # calculations for diffusion q(x_t | x_{t-1}) and others
        register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) # \sqrt{\bar{\alpha}_t}
        register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) # \sqrt{1-\bar{\alpha}_t}
        register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) # \log{1-\bar{\alpha}_t}
        register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) # \sqrt{1/\bar{\alpha}_t}
        register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) # \sqrt{1/(\bar{\alpha}_t-1)}

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) # \sigma_t
        register_buffer('posterior_variance', posterior_variance)
        # log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
        register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))

    """
    Training
    """

    # 计算x_t=\sqrt{\bat{\alpha}_t}x_0+\sqrt{1-\bat{\alpha}_t}\epsilon
    def q_sample(self, x_start, t, noise=None):
        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )
    
    # 计算损失函数L_t=||\epsilon-\epsilon(x_t, t)||_1
    def p_losses(self, x_start, t, noise = None):
        b, c, h, w = x_start.shape
        noise = torch.randn_like(x_start)
        target = noise

        # 计算 x_t
        x = self.q_sample(x_start = x_start, t = t, noise = noise)

        # 计算 \epsilon(x_t, t)
        model_out = self.model(x, t)

        loss = F.l1_loss(model_out, target, reduction = 'none')
        loss = reduce(loss, 'b ... -> b (...)', 'mean')
        return loss.mean()

    # 训练过程
    def forward(self, img, *args, **kwargs):
        b, c, h, w, device = *img.shape, img.device
        t = torch.randint(0, self.num_timesteps, (b,), device=device).long()

        img = img * 2 - 1 # data [0, 1] -> [-1, 1]
        return self.p_losses(img, t, *args, **kwargs)

    """
    Sampling
    """
    
    def data_generator(self, batch, folder='./images/official-artwork'):
        transform = T.Compose([
            T.Resize(self.image_size),
            T.RandomHorizontalFlip(),
            T.ToTensor()
        ])

        selected_imgs = random.sample(os.listdir(folder), k=batch)
        batch_imgs = []
        for img in selected_imgs:
            img = Image.open(os.path.join(folder, img))
            batch_imgs.append(transform(img).unsqueeze(0))
        return torch.cat(batch_imgs, dim=0) * 2 - 1
            
    # 计算 x_0 = \sqrt{1/\bat{\alpha}_t}x_t-\sqrt{1/(1-\bat{\alpha}_t)}\epsilon_t
    def predict_start_from_noise(self, x_t, t, noise):
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )

    @torch.no_grad()
    def ddim_sample(self, shape, return_all_timesteps = False):
        batch, device, total_timesteps, sampling_timesteps, eta = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta
        
        # 用(batch_size * steps)个样本去估计方差修正项
        T = torch.linspace(0, total_timesteps-1, steps=total_timesteps, device=device)
        T = list(T.int().tolist())
        factors = []
        for t in tqdm(T, ncols=0, desc = 'computing unbiased term'):
            batch_imgs = self.data_generator(batch).to(device)
            t = torch.tensor(t, device=device).long().repeat(batch)
            factor_t = torch.mean((self.model(self.q_sample(x_start = batch_imgs, t = t, noise = torch.randn(shape, device = device)), t))**2)
            factors.append(factor_t)
        
        times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1)   # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
        times = list(reversed(times.int().tolist()))
        time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]

        img = torch.randn(shape, device = device)
        imgs = [img]

        x_start = None

        for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
            time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
            pred_noise = self.model(img, time_cond) # 计算 \epsilon(x_t, t)
            x_start = self.predict_start_from_noise(img, time_cond, pred_noise) # x_0
            if time_next < 0:
                img = x_start
                imgs.append(img)
                continue

            alpha = self.alphas_cumprod[time]
            alpha_next = self.alphas_cumprod[time_next]
            factor = factors[time]

            sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
            c = (1 - alpha_next - sigma ** 2).sqrt()
            
            unbias = (alpha_next.sqrt() - c * (alpha / (1-alpha)).sqrt()) ** 2 * (1-alpha) / alpha * factor

            noise = torch.randn_like(img)
            img = x_start * alpha_next.sqrt() + \
                  c * pred_noise + \
                  (sigma**2 + unbias).sqrt() * noise
            imgs.append(img)

        ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
        ret = (ret + 1) * 0.5
        return ret
    
    # 采样过程
    @torch.no_grad()
    def sample(self, batch_size = 16, img_channel = 3, return_all_timesteps = False):
        image_size, channels = self.image_size, img_channel
        return self.ddim_sample((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps)