DDPM模型——pytorch实现
Posted Peach_____
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了DDPM模型——pytorch实现相关的知识,希望对你有一定的参考价值。
论文传送门:Denoising Diffusion Probabilistic Models
参考文章:The Annotated Diffusion Model
DDPM的目的:
从标准正态分布中采样出噪声图像,经过T次去噪后还原出与训练图像相似的生成图像,从而完成图像生成任务。
DDPM的方法:
①扩散过程(加噪过程):
对训练图像不断加噪,经过T次,使得训练图像近似变成各向独立的标准正态分布的噪声图像。
每次加噪记作
q
(
x
t
∣
x
t
−
1
)
q(x_t|x_t-1)
q(xt∣xt−1),其中t指当前时刻(加噪t次),t-1指上一时刻(加噪t-1次),
x
t
x_t
xt指当前时刻的图像,
x
t
−
1
x_t-1
xt−1指上一时刻的图像。
整个过程是马尔科夫链,即当前时刻的图像仅与其上一时刻有关,而与其他时刻无关。
设定一个长度为T的序列
β
β
β,
β
t
β_t
βt在(0,1)区间内单调递增,t时刻加入噪声的方差为
β
t
β_t
βt,均值由
β
t
β_t
βt和
x
t
x_t
xt共同决定,则可以写出当前时刻
q
(
x
t
∣
x
t
−
1
)
q(x_t|x_t-1)
q(xt∣xt−1)和整个扩散过程
q
(
x
1
:
T
∣
x
0
)
q(x_1:T|x_0)
q(x1:T∣x0)的公式:
当T趋近于∞时,可以认为
x
T
x_T
xT是各向独立的标准正态分布。
可以发现,任意时刻的噪声图像
x
t
x_t
xt可以由初始时刻图像(原图)
x
0
x_0
x0和
β
β
β序列来确定,定义
α
t
=
1
−
β
t
α_t=1-β_t
αt=1−βt,
α
ˉ
=
∏
s
=
1
t
α
s
\\barα=\\prod\\limits_s=1^tα_s
αˉ=s=1∏tαs,则:
扩散过程与网络无关,只要确定初始时刻图像
x
0
x_0
x0和
β
β
β序列,整个扩散过程均可求。
②逆扩散过程(去噪过程):
对噪声图像不断去噪,经过T次,使得噪声图像可以恢复为初始时刻图像
x
0
x_0
x0。
每次去噪记作
p
(
x
t
−
1
∣
x
t
)
p(x_t-1|x_t)
p(xt−1∣xt),公式:
p
(
x
t
−
1
∣
x
t
)
p(x_t-1|x_t)
p(xt−1∣xt)难以直接求解,所以使用网络进行计算。
DDPM的网络实际上在拟合去噪过程后验概率p的分布。
可以写出负对数似然函数的上界,通过最小化其上界达到最大化似然函数的目的:
损失函数:
第一项
L
T
L_T
LT为常数,与模型优化过程无关,第二项
L
t
−
1
L_t-1
Lt−1与第三项
L
0
L_0
L0可以进行展开化简。
注意到
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_t-1|x_t,x_0)
q(xt−1∣xt,x0)可以计算,使用贝叶斯公式和重参数技巧可以进行化简:
第二项
L
t
−
1
L_t-1
Lt−1:
第三项
L
0
L_0
L0:
最终的损失函数Loss:
DDPM的训练与采样过程:
DDPM的结构:
基于U-Net网络,加入了位置编码、残差结构、注意力机制和组标准化等模块。
train.py
import os
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import numpy as np
from torchvision import transforms, datasets
from model import Unet # DDPM模型
# 定义4种生成β的方法,均需传入总步长T,返回β序列
def cosine_beta_schedule(timesteps, s=0.008):
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999)
def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start, beta_end, timesteps)
def quadratic_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start ** 0.5, beta_end ** 0.5, timesteps) ** 2
def sigmoid_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
betas = torch.linspace(-6, 6, timesteps)
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
# 从序列a中取t时刻的值a[t](batch_size个),维度与x_shape相同,第一维为batch_size
def extract(a, t, x_shape):
batch_size = t.shape[0]
out = a.gather(-1, t.cpu())
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
# 扩散过程采样,即通过x0和t计算xt
def q_sample(x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumpord_t = extract(sqrt_one_minus_alphas_cumpord, t, x_start.shape)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumpord_t * noise
# 损失函数loss,共3种计算方式,原文使用l2
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
if noise is None:
noise = torch.randn_like(x_start)
x_noisy = q_sample(x_start, t, noise)
predicted_noise = denoise_model(x_noisy, t)
if loss_type == "l1":
loss = F.l1_loss(noise, predicted_noise)
elif loss_type == "l2":
loss = F.mse_loss(noise, predicted_noise)
elif loss_type == "huber":
loss = F.smooth_l1_loss(noise, predicted_noise)
else:
raise NotImplementedError()
return loss
# 逆扩散过程采样,即通过xt和t计算xt-1,此过程需要通过网络
@torch.no_grad()
def p_sample(model, x, t, t_index):
betas_t = extract(betas, t, x.shape)
sqrt_one_minus_alphas_cumpord_t = extract(sqrt_one_minus_alphas_cumpord, t, x.shape)
sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumpord_t)
if t_index == 0:
return model_mean
else:
posterior_variance_t = extract(posterior_variance, t, x.shape)
noise = torch.randn_like(x)
return model_mean + torch.sqrt(posterior_variance_t) * noise
# 逆扩散过程T次采样,即通过xT和T计算xi,获得每一个时刻的图像列表[xi],此过程需要通过网络
@torch.no_grad()
def p_sample_loop(model, shape):
device = next(model.parameters()).device
b = shape[0]
img = torch.randn(shape, device=device)
imgs = []
for i in tqdm(reversed(range(0, timesteps)), desc="sampling loop time step", total=timesteps):
img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
imgs.append(img.cpu())
return imgs
# 逆扩散过程T次采样,允许传入batch_size指定生成图片的个数,用于生成结果的可视化
@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=1):
return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))
if __name__ == "__main__":
timesteps = 300 # 总步长T
# 以下参数均为序列(List),需要传入t获得对应t时刻的值 xt = X[t]
betas = linear_beta_schedule(timesteps=timesteps) # 选择一种方式,生成β(t)
alphas = 1. - betas # α(t)
alphas_cumprod = torch.cumprod(alphas, axis=0) # α的连乘序列,对应α_bar(t)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0),
value=1.0) # 将α_bar的最后一个值删除,在最开始添加1,对应前一个时刻的α_bar,即α_bar(t-1)
sqrt_recip_alphas = torch.sqrt(1. / alphas) # 1/根号下α(t)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) # 根号下α_bar(t)
sqrt_one_minus_alphas_cumpord = torch.sqrt(1. - alphas_cumprod) # 根号下(1-α_bar(t))
posterior_variance = betas * (1. - alphas_cumprod_prev) 以上是关于DDPM模型——pytorch实现的主要内容,如果未能解决你的问题,请参考以下文章
DDPM代码详细解读:图解模型各部分结构用ConvNextBlock代替Resnet
DDPM代码详细解读:图解模型各部分结构用ConvNextBlock代替Resnet
扩散模型DDPM开源代码的剖析对应公式与作者给的开源项目,diffusion model
PyTorch - Diffusion Model 公式推导