Analytic-DPM:扩散概率模型中最优反向方差的分析估计.
- paper:Analytic-DPM: an Analytic Estimate of the Optimal Reverse Variance in Diffusion Probabilistic Models
扩散模型 (Diffusion Model)是一类深度生成模型。这类模型首先定义前向扩散过程,向数据中逐渐地添加随机噪声;然后学习反向扩散过程,从噪声中构造所需的数据样本。
在扩散模型中,前向扩散过程可以定义为:
\[\begin{array}{rlr} \mathbf{x}_t & =\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon} \\ q\left(\mathbf{x}_t \mid \mathbf{x}_0\right) & =\mathcal{N}\left(\mathbf{x}_t ; \sqrt{\bar{\alpha}_t} \mathbf{x}_0,\left(1-\bar{\alpha}_t\right) \mathbf{I}\right) & \end{array}\]根据DDIM的假设,把\(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t},\mathbf{x}_0\right)\)建模为一个高斯分布,$\sigma_t$是可调的标准差参数:
\[\begin{aligned} q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)&=\mathcal{N}\left(\mathbf{x}_{t-1} ; \sqrt{\frac{1-\bar{\alpha}_{t-1}-\sigma_t^2}{1-\bar{\alpha}_{t}}}\mathbf{x}_{t}+\left( \sqrt{\bar{\alpha}_{t-1}} - \sqrt{\frac{\bar{\alpha}_{t}(1-\bar{\alpha}_{t-1}-\sigma_t^2)}{1-\bar{\alpha}_{t}}} \right) \mathbf{x}_0, \sigma_t^2 \mathbf{I}\right) \end{aligned}\]因此得到:
\[\begin{aligned} \mathbf{x}_{t-1} = \sqrt{\frac{1-\bar{\alpha}_{t-1}-\sigma_t^2}{1-\bar{\alpha}_{t}}}\mathbf{x}_{t}+\left( \sqrt{\bar{\alpha}_{t-1}} - \sqrt{\frac{\bar{\alpha}_{t}(1-\bar{\alpha}_{t-1}-\sigma_t^2)}{1-\bar{\alpha}_{t}}} \right) \mathbf{x}_0+ \sigma_t \boldsymbol{\epsilon}_t \end{aligned}\]在反向扩散过程中,我们希望求解\(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)\)。一种常见的求解方法是首先通过\(\mathbf{x}_{t}\)构造\(\mathbf{x}_{0}\),然后计算:
\[\begin{aligned} q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t},\mathbf{x}_0\right) \approx q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t},\mathbf{x}_0=\frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}}{\sqrt{\bar{\alpha}_t}}\right) \end{aligned}\]然而从\(\mathbf{x}_{t}\)构造\(\mathbf{x}_{0}\)并不是完全准确的,因此应该用概率分布而非确定性的函数来描述它。事实上,严格地有:
\[\begin{aligned} q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right) = \int q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t},\mathbf{x}_0\right) q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right) d \mathbf{x}_0 \end{aligned}\]注意到\(q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)\)是未知的,因此用正态分布\(\mathcal{N}\left(\mathbf{x}_0 ; \bar{\mu}(\mathbf{x}_{t}),\bar{\sigma}_t^2 \mathbf{I}\right)\)进行近似。
用正态分布\(\mathcal{N}\left(\mathbf{x}_0 ; \bar{\mu}(\mathbf{x}_{t}),\bar{\sigma}_t^2 \mathbf{I}\right)\)近似\(q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)\),落脚于分别近似\(q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)\)的均值和方差。
⚪ 近似均值
均值的近似为:
\[\begin{aligned} \bar{\mu}(\mathbf{x}_{t}) &= \mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \mathbf{x}_0 \right] \\ &= \mathop{\arg\min}_{\mu} \mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ ||\mathbf{x}_0 -\mu||^2 \right] \\ &= \mathop{\arg\min}_{\mu(\mathbf{x}_t)} \mathbb{E}_{\mathbf{x}_t \sim q\left(\mathbf{x}_{t}\right)}\mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ ||\mathbf{x}_0 -\boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right)||^2 \right] \\ &= \mathop{\arg\min}_{\mu(\mathbf{x}_t)} \mathbb{E}_{\mathbf{x}_t \sim q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)q\left(\mathbf{x}_{0}\right)} \left[ ||\mathbf{x}_0 -\boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right)||^2 \right] \end{aligned}\]不妨把\(\boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right)\)表示为\(\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\)的函数:
\[\boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right) = \frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)}{\sqrt{\bar{\alpha}_t}}\]则损失函数等价于标准扩散模型中的损失函数:
\[\begin{aligned} L_t & =\mathbb{E}_{\mathbf{x}_0, \boldsymbol{\epsilon}}\left[\left\| \frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}}{\sqrt{\bar{\alpha}_t}} - \frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)}{\sqrt{\bar{\alpha}_t}} \right\|^2\right] \\ & \propto \mathbb{E}_{\mathbf{x}_0, \boldsymbol{\epsilon}}\left[\left\|\boldsymbol{\epsilon}_t-\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right\|^2\right] \\ & =\mathbb{E}_{\mathbf{x}_0, \boldsymbol{\epsilon}}\left[\left\|\boldsymbol{\epsilon}_t-\boldsymbol{\epsilon}_\theta\left(\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_t, t\right)\right\|^2\right] \end{aligned}\]⚪ 近似方差
据定义,协方差矩阵计算为:
\[\begin{aligned} \Sigma(\mathbf{x}_{t}) = &\mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ (\mathbf{x}_0 -\bar{\mu}(\mathbf{x}_{t}))(\mathbf{x}_0 -\bar{\mu}(\mathbf{x}_{t}))^T \right] \\ =& \mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \left(\mathbf{x}_0 -\frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)}{\sqrt{\bar{\alpha}_t}}\right)\left(\mathbf{x}_0 -\frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)}{\sqrt{\bar{\alpha}_t}}\right)^T \right] \\ =& \mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \left(\mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}}+\frac{\sqrt{1-\bar{\alpha}_t} }{\sqrt{\bar{\alpha}_t}}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right)\left(\mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}}+\frac{\sqrt{1-\bar{\alpha}_t} }{\sqrt{\bar{\alpha}_t}}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right)^T \right] \\ =& \mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \left(\mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}}\right)\left(\mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}}\right)^T \right]+2\frac{\sqrt{1-\bar{\alpha}_t} }{\sqrt{\bar{\alpha}_t}}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}} \right]^T \\ & + \frac{1-\bar{\alpha}_t} {\bar{\alpha}_t}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) ^T \\ =& \mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \left(\mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}}\right)\left(\mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}}\right)^T \right]+2\frac{\sqrt{1-\bar{\alpha}_t} }{\sqrt{\bar{\alpha}_t}}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \left[\mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} [\mathbf{x}_0] -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}} \right]^T \\ & + \frac{1-\bar{\alpha}_t} {\bar{\alpha}_t}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) ^T \\ =& \mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \left(\mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}}\right)\left(\mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}}\right)^T \right]+2\frac{\sqrt{1-\bar{\alpha}_t} }{\sqrt{\bar{\alpha}_t}}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \left[\bar{\mu}(\mathbf{x}_{t}) -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}} \right]^T \\ & + \frac{1-\bar{\alpha}_t} {\bar{\alpha}_t}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) ^T \\ =& \mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \left(\mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}}\right)\left(\mathbf{x}_0 -\frac{\mathbf{x}_t}{\sqrt{\bar{\alpha}_t}}\right)^T \right]+2\frac{\sqrt{1-\bar{\alpha}_t} }{\sqrt{\bar{\alpha}_t}}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \left[\frac{-\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)}{\sqrt{\bar{\alpha}_t}} \right]^T \\ & + \frac{1-\bar{\alpha}_t} {\bar{\alpha}_t}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) ^T \\ =& \frac{1}{\bar{\alpha}_t}\mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t}\mathbf{x}_0 \right)\left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t}\mathbf{x}_0 \right)^T \right]- \frac{1-\bar{\alpha}_t} {\bar{\alpha}_t}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) ^T \end{aligned}\]上式两端取\(q\left(\mathbf{x}_{t}\right)\)的期望:
\[\begin{aligned} \Sigma(\mathbf{x}_{t}) = &\mathbb{E}_{\mathbf{x}_t \sim q\left(\mathbf{x}_{t}\right)} \left[\Sigma(\mathbf{x}_{t}) \right] \\ =& \mathbb{E}_{\mathbf{x}_t \sim q\left(\mathbf{x}_{t}\right)} \left[ \frac{1}{\bar{\alpha}_t}\mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t}\mathbf{x}_0 \right)\left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t}\mathbf{x}_0 \right)^T \right]- \frac{1-\bar{\alpha}_t} {\bar{\alpha}_t}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) ^T \right] \end{aligned}\]注意到:
\[\begin{aligned} &\frac{1}{\bar{\alpha}_t}\mathbb{E}_{\mathbf{x}_t \sim q\left(\mathbf{x}_{t}\right)} \left[ \mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)} \left[ \left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t}\mathbf{x}_0 \right)\left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t}\mathbf{x}_0 \right)^T \right] \right] \\ =& \frac{1}{\bar{\alpha}_t}\mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_{0}\right)} \left[ \underbrace{\mathbb{E}_{\mathbf{x}_t \sim q\left(\mathbf{x}_t \mid \mathbf{x}_{0}\right)} \left[ \left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t}\mathbf{x}_0 \right)\left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t}\mathbf{x}_0 \right)^T \right]}_{ q\left(\mathbf{x}_t \mid \mathbf{x}_{0}\right)\text{的协方差}} \right] \\ =& \frac{1}{\bar{\alpha}_t}\mathbb{E}_{\mathbf{x}_0 \sim q\left(\mathbf{x}_{0}\right)} \left[ \left(1-\bar{\alpha}_t\right) \mathbf{I} \right] =\frac{1-\bar{\alpha}_t}{\bar{\alpha}_t}\mathbf{I} \end{aligned}\]因此有:
\[\begin{aligned} \Sigma(\mathbf{x}_{t}) &= \frac{1-\bar{\alpha}_t}{\bar{\alpha}_t}\mathbf{I}-\mathbb{E}_{\mathbf{x}_t \sim q\left(\mathbf{x}_{t}\right)} \left[ \frac{1-\bar{\alpha}_t} {\bar{\alpha}_t}\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) ^T \right] \\ &= \frac{1-\bar{\alpha}_t}{\bar{\alpha}_t}\left(\mathbf{I}-\mathbb{E}_{\mathbf{x}_t \sim q\left(\mathbf{x}_{t}\right)} \left[ \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) ^T \right] \right) \end{aligned}\]两边取迹然后除以\(d=dim(\mathbf{x}_{t})\),得到$\sigma_t$的一个估计:
\[\hat{\sigma}_t^2 = \frac{1-\bar{\alpha}_t}{\bar{\alpha}_t} \left(1-\frac{1}{d}\mathbb{E}_{\mathbf{x}_t \sim q\left(\mathbf{x}_{t}\right)} \left[ ||\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) ||^2 \right] \right)\]⚪ 采样过程
至此我们得到\(q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right)\)的近似为:
\[\begin{aligned} q\left(\mathbf{x}_0 \mid \mathbf{x}_{t}\right) \approx \mathcal{N}\left(\mathbf{x}_0 ; \frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)}{\sqrt{\bar{\alpha}_t}},\hat{\sigma}_t^2 \mathbf{I}\right) \end{aligned}\]则反向扩散过程写作:
\[\begin{aligned} \mathbf{x}_{t-1} = \sqrt{\frac{1-\bar{\alpha}_{t-1}-\sigma_t^2}{1-\bar{\alpha}_{t}}}\mathbf{x}_{t}+\left( \sqrt{\bar{\alpha}_{t-1}} - \sqrt{\frac{\bar{\alpha}_{t}(1-\bar{\alpha}_{t-1}-\sigma_t^2)}{1-\bar{\alpha}_{t}}} \right) \mathbf{x}_0+ \sigma_t \boldsymbol{\epsilon}_t \end{aligned}\]其中:
\[\begin{aligned} &\left( \sqrt{\bar{\alpha}_{t-1}} - \sqrt{\frac{\bar{\alpha}_{t}(1-\bar{\alpha}_{t-1}-\sigma_t^2)}{1-\bar{\alpha}_{t}}} \right) \mathbf{x}_0 = \left( \sqrt{\bar{\alpha}_{t-1}} - \sqrt{\frac{\bar{\alpha}_{t}(1-\bar{\alpha}_{t-1}-\sigma_t^2)}{1-\bar{\alpha}_{t}}} \right)\left( \frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)}{\sqrt{\bar{\alpha}_t}}+\hat{\sigma}_t\boldsymbol{\epsilon} \right) \\ =&\frac{1}{\sqrt{\alpha_t}}\mathbf{x}_{t}-\sqrt{\frac{1-\bar{\alpha}_{t-1}-\sigma_t^2}{1-\bar{\alpha}_{t}}}\mathbf{x}_{t}+\left( \sqrt{1-\bar{\alpha}_{t-1}-\sigma_t^2}-\frac{\sqrt{1-\bar{\alpha}_t}}{\sqrt{\alpha_t}} \right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \\ &+ \left( \sqrt{\bar{\alpha}_{t-1}} - \sqrt{\frac{\bar{\alpha}_{t}(1-\bar{\alpha}_{t-1}-\sigma_t^2)}{1-\bar{\alpha}_{t}}} \right)\hat{\sigma}_t\boldsymbol{\epsilon} \end{aligned}\]代回原式得:
\[\begin{aligned} \mathbf{x}_{t-1} =& \frac{1}{\sqrt{\alpha_t}}\mathbf{x}_{t}+\left( \sqrt{1-\bar{\alpha}_{t-1}-\sigma_t^2}-\frac{\sqrt{1-\bar{\alpha}_t}}{\sqrt{\alpha_t}} \right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \\ & + \sigma_t \boldsymbol{\epsilon}_1 + \left( \sqrt{\bar{\alpha}_{t-1}} - \sqrt{\frac{\bar{\alpha}_{t}(1-\bar{\alpha}_{t-1}-\sigma_t^2)}{1-\bar{\alpha}_{t}}} \right)\hat{\sigma}_t\boldsymbol{\epsilon}_2 \\ =& \frac{1}{\sqrt{\alpha_t}}\mathbf{x}_{t}+\left( \sqrt{1-\bar{\alpha}_{t-1}-\sigma_t^2}-\frac{\sqrt{1-\bar{\alpha}_t}}{\sqrt{\alpha_t}} \right) \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right) \\ & + \sqrt{\sigma_t^2 + \left( \sqrt{\bar{\alpha}_{t-1}} - \sqrt{\frac{\bar{\alpha}_{t}(1-\bar{\alpha}_{t-1}-\sigma_t^2)}{1-\bar{\alpha}_{t}}} \right)^2\hat{\sigma}_t^2}\boldsymbol{\epsilon} \\ \end{aligned}\]⚪ Analytic-DPM的实现
原论文的实验结果显示,Analytic-DPM所做的方差修正,主要在生成扩散步数较少时会有比较明显的提升,所以它对扩散模型的加速比较有意义。
class GaussianDiffusion(nn.Module):
def __init__(
self,
model,
*,
image_size,
timesteps = 1000,
sampling_timesteps = None,
ddim_sampling_eta = 0.,
):
super().__init__()
self.model = model # 用于拟合\epsilon(x_t,t)的神经网络
self.channels = self.model.channels
self.image_size = image_size
betas = linear_beta_schedule(timesteps) # \beta_t
alphas = 1. - betas # \alpha_t
alphas_cumprod = torch.cumprod(alphas, dim=0) # \bar{\alpha}_t
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) # \bar{\alpha}_{t-1}
self.num_timesteps = int(timesteps)
self.alphas_cumprod = alphas_cumprod
# sampling related parameters
self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
assert self.sampling_timesteps <= timesteps
self.ddim_sampling_eta = ddim_sampling_eta
# helper function to register buffer from float64 to float32
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
register_buffer('betas', betas)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) # \sqrt{\bar{\alpha}_t}
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) # \sqrt{1-\bar{\alpha}_t}
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) # \log{1-\bar{\alpha}_t}
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) # \sqrt{1/\bar{\alpha}_t}
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) # \sqrt{1/(\bar{\alpha}_t-1)}
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) # \sigma_t
register_buffer('posterior_variance', posterior_variance)
# log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
"""
Training
"""
# 计算x_t=\sqrt{\bat{\alpha}_t}x_0+\sqrt{1-\bat{\alpha}_t}\epsilon
def q_sample(self, x_start, t, noise=None):
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
# 计算损失函数L_t=||\epsilon-\epsilon(x_t, t)||_1
def p_losses(self, x_start, t, noise = None):
b, c, h, w = x_start.shape
noise = torch.randn_like(x_start)
target = noise
# 计算 x_t
x = self.q_sample(x_start = x_start, t = t, noise = noise)
# 计算 \epsilon(x_t, t)
model_out = self.model(x, t)
loss = F.l1_loss(model_out, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b (...)', 'mean')
return loss.mean()
# 训练过程
def forward(self, img, *args, **kwargs):
b, c, h, w, device = *img.shape, img.device
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
img = img * 2 - 1 # data [0, 1] -> [-1, 1]
return self.p_losses(img, t, *args, **kwargs)
"""
Sampling
"""
def data_generator(self, batch, folder='./images/official-artwork'):
transform = T.Compose([
T.Resize(self.image_size),
T.RandomHorizontalFlip(),
T.ToTensor()
])
selected_imgs = random.sample(os.listdir(folder), k=batch)
batch_imgs = []
for img in selected_imgs:
img = Image.open(os.path.join(folder, img))
batch_imgs.append(transform(img).unsqueeze(0))
return torch.cat(batch_imgs, dim=0) * 2 - 1
# 计算 x_0 = \sqrt{1/\bat{\alpha}_t}x_t-\sqrt{1/(1-\bat{\alpha}_t)}\epsilon_t
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
@torch.no_grad()
def ddim_sample(self, shape, return_all_timesteps = False):
batch, device, total_timesteps, sampling_timesteps, eta = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta
# 用(batch_size * steps)个样本去估计方差修正项
T = torch.linspace(0, total_timesteps-1, steps=total_timesteps, device=device)
T = list(T.int().tolist())
factors = []
for t in tqdm(T, ncols=0, desc = 'computing unbiased term'):
batch_imgs = self.data_generator(batch).to(device)
t = torch.tensor(t, device=device).long().repeat(batch)
factor_t = torch.mean((self.model(self.q_sample(x_start = batch_imgs, t = t, noise = torch.randn(shape, device = device)), t))**2)
factors.append(factor_t)
times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
img = torch.randn(shape, device = device)
imgs = [img]
x_start = None
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
pred_noise = self.model(img, time_cond) # 计算 \epsilon(x_t, t)
x_start = self.predict_start_from_noise(img, time_cond, pred_noise) # x_0
if time_next < 0:
img = x_start
imgs.append(img)
continue
alpha = self.alphas_cumprod[time]
alpha_next = self.alphas_cumprod[time_next]
factor = factors[time]
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c = (1 - alpha_next - sigma ** 2).sqrt()
unbias = (alpha_next.sqrt() - c * (alpha / (1-alpha)).sqrt()) ** 2 * (1-alpha) / alpha * factor
noise = torch.randn_like(img)
img = x_start * alpha_next.sqrt() + \
c * pred_noise + \
(sigma**2 + unbias).sqrt() * noise
imgs.append(img)
ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
ret = (ret + 1) * 0.5
return ret
# 采样过程
@torch.no_grad()
def sample(self, batch_size = 16, img_channel = 3, return_all_timesteps = False):
image_size, channels = self.image_size, img_channel
return self.ddim_sample((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps)