扩散模型DDPM开源代码的剖析对应公式与作者给的开源项目,diffusion model

Posted 旋转的油纸伞

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了扩散模型DDPM开源代码的剖析对应公式与作者给的开源项目,diffusion model相关的知识,希望对你有一定的参考价值。

扩散模型DDPM开源代码的剖析【对应公式与作者给的开源项目,diffusion model】

一、简介

论文地址:https://proceedings.neurips.cc/paper/2020/hash/4c5bcfec8584af0d967f1ab10179ca4b-Abstract.html
项目地址:https://github.com/hojonathanho/diffusion
公式推导参考这篇博客:https://blog.csdn.net/qq_45934285/article/details/129107994?spm=1001.2014.3001.5502
本文主要对扩散模型的关键公式给出原代码帮助理解和学习。有pytorch和TensorFlow版。
原作者给的代码不太好理解,给出了pytorch的好理解一些。

二、扩散过程:输入是x_0和时刻num_steps,输出是x_t

首先值得注意的是:x_0是一个二维数组,例如这里给的是一个10000行2列的数组,即每一行代表一个点。
这里取了s_curve的x轴和z轴的坐标,用点表示看起来就像一个s型

s_curve,_ = make_s_curve(10**4,noise=0.1)
s_curve = s_curve[:,[0,2]]/10.0
dataset = torch.Tensor(s_curve).float()

扩散过程其实就是一个不断加噪的过程,其不含参。可以给出最终公式。
x t = α ‾ t x 0 x_t=\\sqrt \\overline\\alpha_ tx_ 0 xt=αt x0 + 1 − α ‾ t \\sqrt 1-\\overline \\alpha _ t 1α t z ‾ t \\overline z_t zt

在t不断变大的时候 β t \\beta_t βt越来越大, α t = 1 − β t \\alpha_t=1-\\beta_t αt=1βt越来越小。即t增大的时候上面公式的前一项系数越来越小,后一项系数越来越大不断接近一个 z ‾ t \\overline z_t zt的高斯分布。

代码来自diffusion_tf/diffusion_utils_2.py

  def q_sample(self, x_start, t, noise=None):
    """
    Diffuse the data (t == 0 means diffused for 1 step)
    """
    if noise is None:
      noise = tf.random_normal(shape=x_start.shape)
    assert noise.shape == x_start.shape
    return (
        self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
        self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
    )

pytorch:

#计算任意时刻的x采样值,基于x_0和重参数化
def q_x(x_0,t):
    """可以基于x[0]得到任意时刻t的x[t]"""
    noise = torch.randn_like(x_0)# 创建了一个与 x_0 张量具有相同形状的名为 noise 的张量,并且该张量的值是从标准正态分布中随机采样得到的。
    alphas_t = alphas_bar_sqrt[t]
    alphas_1_m_t = one_minus_alphas_bar_sqrt[t]
    return (alphas_t * x_0 + alphas_1_m_t * noise)#在x[0]的基础上添加噪声

可见对于求 x t x_t xt的公式最难理解的就是代码如何实现 z ‾ t \\overline z_t zt代码中是创建了一个与 x_0 张量具有**相同形状**的名为 noise 的张量,并且该张量的值是从标准正态分布中随机采样得到的。这个noise其元素的值是从均值为0、标准差为1的正态分布中随机采样得到的。这个张量可以被用于实现噪声注入,数据增强等操作,也可以被用于一些随机化的算法中。

值得一提的是原项目中用num_diffusion_timesteps=1000来表示t,假如num_steps=100,那么很多需要用到的参数都可以提前算出来。

三、逆扩散过程:输入x_t,不断采样最终输出x_0

最终公式是:

q ( X t − 1 ∣ X t X 0 ) = N ( X t − 1 ; 1 α t ( X t − β t ( 1 − α ˉ t ) Z ) , 1 − α ˉ t − 1 1 − α ˉ t β t ) , Z ∼ N ( 0 , I ) q\\left(X_t-1 \\mid X_t X_0\\right)=N\\left(X_t-1 ; \\frac1\\sqrt\\alpha_t (X_t-\\frac\\beta_t\\sqrt\\left(1-\\bar\\alpha_t\\right) Z), \\frac1-\\bar\\alpha_t-11-\\bar\\alpha_t \\beta_t\\right), Z \\sim N(0, I) q(Xt1XtX0)=N(Xt1;αt 1(Xt(1αˉt) βtZ),1αˉt1αˉt1βt),ZN(0,I)
在论文中方差设置为一个常数 β t \\beta _t βt β ~ t \\tilde\\beta _t β~t其中:
β ~ t = 1 − α ˉ t − 1 1 − α ˉ t β t \\tilde\\beta _t=\\frac1-\\bar\\alpha_t-11-\\bar\\alpha_t \\beta_t β~t=1αˉt1αˉt1βt
因此可训练的参数只存在与其均值之中。

就是这个公式,方差变为 β t \\beta_t βt,其中 ϵ θ \\epsilon_\\theta ϵθ是模型model

def p_sample(model,x,t,betas,one_minus_alphas_bar_sqrt):
    """从x[T]采样t时刻的重构值"""
    t = torch.tensor([t])
    coeff = betas[t] / one_minus_alphas_bar_sqrt[t]
    eps_theta = model(x,t)
    mean = (1/(1-betas[t]).sqrt())*(x-(coeff*eps_theta))
    z = torch.randn_like(x)
    sigma_t = betas[t].sqrt()
    sample = mean + sigma_t * z
    return (sample)

代码来自diffusion_tf/diffusion_utils_2.py

  def p_sample(self, denoise_fn, *, x, t, noise_fn, clip_denoised=True, return_pred_xstart: bool):
    """
    Sample from the model
    """
    model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
      denoise_fn, x=x, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
    noise = noise_fn(shape=x.shape, dtype=x.dtype)
    assert noise.shape == x.shape
    # no noise when t == 0
    nonzero_mask = tf.reshape(1 - tf.cast(tf.equal(t, 0), tf.float32), [x.shape[0]] + [1] * (len(x.shape) - 1))
    sample = model_mean + nonzero_mask * tf.exp(0.5 * model_log_variance) * noise
    assert sample.shape == pred_xstart.shape
    return (sample, pred_xstart) if return_pred_xstart else sample

循环恢复。
可见初始的x_t完全是一个随机噪声。torch.randn(shape)
cur_x可以看做是一个当前的采样,是一个二维数组,就是上面说的10000行2列。
然后x_seq可以看做是一个三维数组,即元素为cur_x的一个数组。
i是时刻,从n_steps的反向开始。

def p_sample_loop(model,shape,n_steps,betas,one_minus_alphas_bar_sqrt):
    """从x[T]恢复x[T-1]、x[T-2]|...x[0]"""
    cur_x = torch.randn(shape)
    x_seq = [cur_x]
    

以上是关于扩散模型DDPM开源代码的剖析对应公式与作者给的开源项目,diffusion model的主要内容,如果未能解决你的问题,请参考以下文章

简单基础入门理解Denoising Diffusion Probabilistic Model,DDPM扩散模型

一文详解扩散模型:DDPM

去噪扩散概率模型(DDPM)的简单理解

去噪扩散概率模型(DDPM)的简单理解

Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成

DDPM代码详细解读:图解模型各部分结构用ConvNextBlock代替Resnet