扩散模型DDPM开源代码的剖析对应公式与作者给的开源项目,diffusion model
Posted 旋转的油纸伞
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了扩散模型DDPM开源代码的剖析对应公式与作者给的开源项目,diffusion model相关的知识,希望对你有一定的参考价值。
扩散模型DDPM开源代码的剖析【对应公式与作者给的开源项目,diffusion model】
- 一、简介
- 二、扩散过程:输入是x_0和时刻num_steps,输出是x_t
- 三、逆扩散过程:输入x_t,不断采样最终输出x_0
- 四、具体参考算法流程图
- 五、模型model和损失函数(最重要!)
- 六、损失函数的推导
一、简介
论文地址:https://proceedings.neurips.cc/paper/2020/hash/4c5bcfec8584af0d967f1ab10179ca4b-Abstract.html
项目地址:https://github.com/hojonathanho/diffusion
公式推导参考这篇博客:https://blog.csdn.net/qq_45934285/article/details/129107994?spm=1001.2014.3001.5502
本文主要对扩散模型的关键公式给出原代码帮助理解和学习。有pytorch和TensorFlow版。
原作者给的代码不太好理解,给出了pytorch的好理解一些。
二、扩散过程:输入是x_0和时刻num_steps,输出是x_t
首先值得注意的是:x_0是一个二维数组,例如这里给的是一个10000行2列的数组,即每一行代表一个点。
这里取了s_curve的x轴和z轴的坐标,用点表示看起来就像一个s型
s_curve,_ = make_s_curve(10**4,noise=0.1)
s_curve = s_curve[:,[0,2]]/10.0
dataset = torch.Tensor(s_curve).float()
扩散过程其实就是一个不断加噪的过程,其不含参。可以给出最终公式。
x t = α ‾ t x 0 x_t=\\sqrt \\overline\\alpha_ tx_ 0 xt=αtx0 + 1 − α ‾ t \\sqrt 1-\\overline \\alpha _ t 1−αt z ‾ t \\overline z_t zt
在t不断变大的时候 β t \\beta_t βt越来越大, α t = 1 − β t \\alpha_t=1-\\beta_t αt=1−βt越来越小。即t增大的时候上面公式的前一项系数越来越小,后一项系数越来越大不断接近一个 z ‾ t \\overline z_t zt的高斯分布。
代码来自diffusion_tf/diffusion_utils_2.py
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = tf.random_normal(shape=x_start.shape)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
pytorch:
#计算任意时刻的x采样值,基于x_0和重参数化
def q_x(x_0,t):
"""可以基于x[0]得到任意时刻t的x[t]"""
noise = torch.randn_like(x_0)# 创建了一个与 x_0 张量具有相同形状的名为 noise 的张量,并且该张量的值是从标准正态分布中随机采样得到的。
alphas_t = alphas_bar_sqrt[t]
alphas_1_m_t = one_minus_alphas_bar_sqrt[t]
return (alphas_t * x_0 + alphas_1_m_t * noise)#在x[0]的基础上添加噪声
可见对于求 x t x_t xt的公式最难理解的就是代码如何实现 z ‾ t \\overline z_t zt在代码中是创建了一个与 x_0 张量具有**相同形状**的名为 noise 的张量,并且该张量的值是从标准正态分布中随机采样得到的。这个noise其元素的值是从均值为0、标准差为1的正态分布中随机采样得到的。这个张量可以被用于实现噪声注入,数据增强等操作,也可以被用于一些随机化的算法中。
值得一提的是原项目中用num_diffusion_timesteps=1000来表示t,假如num_steps=100,那么很多需要用到的参数都可以提前算出来。
三、逆扩散过程:输入x_t,不断采样最终输出x_0
最终公式是:
q
(
X
t
−
1
∣
X
t
X
0
)
=
N
(
X
t
−
1
;
1
α
t
(
X
t
−
β
t
(
1
−
α
ˉ
t
)
Z
)
,
1
−
α
ˉ
t
−
1
1
−
α
ˉ
t
β
t
)
,
Z
∼
N
(
0
,
I
)
q\\left(X_t-1 \\mid X_t X_0\\right)=N\\left(X_t-1 ; \\frac1\\sqrt\\alpha_t (X_t-\\frac\\beta_t\\sqrt\\left(1-\\bar\\alpha_t\\right) Z), \\frac1-\\bar\\alpha_t-11-\\bar\\alpha_t \\beta_t\\right), Z \\sim N(0, I)
q(Xt−1∣XtX0)=N(Xt−1;αt1(Xt−(1−αˉt)βtZ),1−αˉt1−αˉt−1βt),Z∼N(0,I)
在论文中方差设置为一个常数
β
t
\\beta _t
βt或
β
~
t
\\tilde\\beta _t
β~t其中:
β
~
t
=
1
−
α
ˉ
t
−
1
1
−
α
ˉ
t
β
t
\\tilde\\beta _t=\\frac1-\\bar\\alpha_t-11-\\bar\\alpha_t \\beta_t
β~t=1−αˉt1−αˉt−1βt
因此可训练的参数只存在与其均值之中。
就是这个公式,方差变为
β
t
\\beta_t
βt,其中
ϵ
θ
\\epsilon_\\theta
ϵθ是模型model
def p_sample(model,x,t,betas,one_minus_alphas_bar_sqrt):
"""从x[T]采样t时刻的重构值"""
t = torch.tensor([t])
coeff = betas[t] / one_minus_alphas_bar_sqrt[t]
eps_theta = model(x,t)
mean = (1/(1-betas[t]).sqrt())*(x-(coeff*eps_theta))
z = torch.randn_like(x)
sigma_t = betas[t].sqrt()
sample = mean + sigma_t * z
return (sample)
代码来自diffusion_tf/diffusion_utils_2.py
def p_sample(self, denoise_fn, *, x, t, noise_fn, clip_denoised=True, return_pred_xstart: bool):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, x=x, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
noise = noise_fn(shape=x.shape, dtype=x.dtype)
assert noise.shape == x.shape
# no noise when t == 0
nonzero_mask = tf.reshape(1 - tf.cast(tf.equal(t, 0), tf.float32), [x.shape[0]] + [1] * (len(x.shape) - 1))
sample = model_mean + nonzero_mask * tf.exp(0.5 * model_log_variance) * noise
assert sample.shape == pred_xstart.shape
return (sample, pred_xstart) if return_pred_xstart else sample
循环恢复。
可见初始的x_t完全是一个随机噪声。torch.randn(shape)
cur_x可以看做是一个当前的采样,是一个二维数组,就是上面说的10000行2列。
然后x_seq可以看做是一个三维数组,即元素为cur_x的一个数组。
i是时刻,从n_steps的反向开始。
def p_sample_loop(model,shape,n_steps,betas,one_minus_alphas_bar_sqrt):
"""从x[T]恢复x[T-1]、x[T-2]|...x[0]"""
cur_x = torch.randn(shape)
x_seq = [cur_x]
以上是关于扩散模型DDPM开源代码的剖析对应公式与作者给的开源项目,diffusion model的主要内容,如果未能解决你的问题,请参考以下文章
简单基础入门理解Denoising Diffusion Probabilistic Model,DDPM扩散模型