Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成
Posted Bubbliiiing
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成相关的知识,希望对你有一定的参考价值。
Diffusion扩散模型学习1——Pytorch搭建DDPM利用深度卷积神经网络实现图片生成
学习前言
我又死了我又死了我又死了!
源码下载地址
https://github.com/bubbliiiing/ddpm-pytorch
喜欢的可以点个star噢。
网络构建
一、什么是Diffusion
如上图所示。DDPM模型主要分为两个过程:
1、Forward加噪过程(从右往左),数据集的真实图片中逐步加入高斯噪声,最终变成一个杂乱无章的高斯噪声,这个过程一般发生在训练的时候。加噪过程满足一定的数学规律。
2、Reverse去噪过程(从左往右),指对加了噪声的图片逐步去噪,从而还原出真实图片,这个过程一般发生在预测生成的时候。尽管在这里说的是加了噪声的图片,但实际去预测生成的时候,是随机生成一个高斯噪声来去噪。去噪的时候不断根据
X
t
X_t
Xt的图片生成
X
t
−
1
X_t-1
Xt−1的噪声,从而实现图片的还原。
1、加噪过程
Forward加噪过程主要符合如下的公式:
x
t
=
α
t
x
t
−
1
+
1
−
α
t
z
1
x_t=\\sqrt\\alpha_t x_t-1+\\sqrt1-\\alpha_t z_1
xt=αtxt−1+1−αtz1
其中
α
t
\\sqrt\\alpha_t
αt是预先设定好的超参数,被称为Noise schedule,通常是小于1的值,在论文中
α
t
\\alpha_t
αt的值从0.9999到0.998。
ϵ
t
−
1
∼
N
(
0
,
1
)
\\epsilon_t-1 \\sim N(0, 1)
ϵt−1∼N(0,1)是高斯噪声。由公式(1)迭代推导。
x t = a t ( a t − 1 x t − 2 + 1 − α t − 1 z 2 ) + 1 − α t z 1 = a t a t − 1 x t − 2 + ( a t ( 1 − α t − 1 ) z 2 + 1 − α t z 1 ) x_t=\\sqrta_t\\left(\\sqrta_t-1 x_t-2+\\sqrt1-\\alpha_t-1 z_2\\right)+\\sqrt1-\\alpha_t z_1=\\sqrta_t a_t-1 x_t-2+\\left(\\sqrta_t\\left(1-\\alpha_t-1\\right) z_2+\\sqrt1-\\alpha_t z_1\\right) xt=at(at−1xt−2+1−αt−1z2)+1−αtz1=atat−1xt−2+(at(1−αt−1)z2+1−αtz1)
其中每次加入的噪声都服从高斯分布
z
1
,
z
2
,
…
∼
N
(
0
,
1
)
z_1, z_2, \\ldots \\sim \\mathcalN(0, 1)
z1,z2,…∼N(0,1),两个高斯分布的相加高斯分布满足公式:
N
(
0
,
σ
1
2
)
+
N
(
0
,
σ
2
2
)
∼
N
(
0
,
(
σ
1
2
+
σ
2
2
)
)
\\mathcalN\\left(0, \\sigma_1^2 \\right)+\\mathcalN\\left(0, \\sigma_2^2 \\right) \\sim \\mathcalN\\left(0,\\left(\\sigma_1^2+\\sigma_2^2\\right) \\right)
N(0,σ12)+N(0,σ22)∼N(0,(σ12+σ22)),因此,得到
x
t
x_t
xt的公式为: 以上是关于Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成的主要内容,如果未能解决你的问题,请参考以下文章 PyTorch笔记 - Diffusion Model 源码开发 PyTorch - Diffusion Model 公式推导 PyTorch笔记 - Diffusion Model 公式推导 扩散模型 Diffusion Models 入门到实践 | 论文学习资源课程整理
x
t
=
a
t
a
t
−
1
x
t
−
2
+
1
−
α
t
α
t
−
1
z
2
x_t = \\sqrta_t a_t-1 x_t-2+\\sqrt1-\\alpha_t \\alpha_t-1 z_2
xt=atat−1xt−2+1−αtαt−1z2
因此不断往里面套,就能发现规律了,其实就是累乘
可以直接得出
x
0
x_0
x0到
x
t
x_t
xt的公式:
x
t
=