DDPM模型——pytorch实现

Posted Peach_____

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了DDPM模型——pytorch实现相关的知识,希望对你有一定的参考价值。


论文传送门:Denoising Diffusion Probabilistic Models

参考文章:The Annotated Diffusion Model

DDPM的目的:

从标准正态分布中采样出噪声图像,经过T次去噪后还原出与训练图像相似的生成图像,从而完成图像生成任务。

DDPM的方法:

①扩散过程(加噪过程):

对训练图像不断加噪,经过T次,使得训练图像近似变成各向独立的标准正态分布的噪声图像。
每次加噪记作 q ( x t ∣ x t − 1 ) q(x_t|x_t-1) q(xtxt1),其中t指当前时刻(加噪t次),t-1指上一时刻(加噪t-1次), x t x_t xt指当前时刻的图像, x t − 1 x_t-1 xt1指上一时刻的图像。
整个过程是马尔科夫链,即当前时刻的图像仅与其上一时刻有关,而与其他时刻无关。
设定一个长度为T的序列 β β β β t β_t βt在(0,1)区间内单调递增,t时刻加入噪声的方差为 β t β_t βt,均值由 β t β_t βt x t x_t xt共同决定,则可以写出当前时刻 q ( x t ∣ x t − 1 ) q(x_t|x_t-1) q(xtxt1)和整个扩散过程 q ( x 1 : T ∣ x 0 ) q(x_1:T|x_0) q(x1:Tx0)的公式:


当T趋近于∞时,可以认为 x T x_T xT是各向独立的标准正态分布。
可以发现,任意时刻的噪声图像 x t x_t xt可以由初始时刻图像(原图) x 0 x_0 x0 β β β序列来确定,定义 α t = 1 − β t α_t=1-β_t αt=1βt α ˉ = ∏ s = 1 t α s \\barα=\\prod\\limits_s=1^tα_s αˉ=s=1tαs,则:

扩散过程与网络无关,只要确定初始时刻图像 x 0 x_0 x0 β β β序列,整个扩散过程均可求。

②逆扩散过程(去噪过程):

对噪声图像不断去噪,经过T次,使得噪声图像可以恢复为初始时刻图像 x 0 x_0 x0
​每次去噪记作 p ( x t − 1 ∣ x t ) p(x_t-1|x_t) p(xt1xt),公式:

p ( x t − 1 ∣ x t ) p(x_t-1|x_t) p(xt1xt)难以直接求解,所以使用网络进行计算。
DDPM的网络实际上在拟合去噪过程后验概率p的分布。
可以写出负对数似然函数的上界,通过最小化其上界达到最大化似然函数的目的:

损失函数:

第一项 L T L_T LT为常数,与模型优化过程无关,第二项 L t − 1 L_t-1 Lt1与第三项 L 0 L_0 L0可以进行展开化简。
注意到 q ( x t − 1 ∣ x t , x 0 ) q(x_t-1|x_t,x_0) q(xt1xt,x0)可以计算,使用贝叶斯公式重参数技巧可以进行化简:

第二项 L t − 1 L_t-1 Lt1




第三项 L 0 L_0 L0

最终的损失函数Loss:

DDPM的训练与采样过程:

DDPM的结构:

基于U-Net网络,加入了位置编码、残差结构、注意力机制和组标准化等模块。

train.py

import os

import torch
from torch.utils.data import DataLoader

import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import numpy as np
from torchvision import transforms, datasets

from model import Unet  # DDPM模型


# 定义4种生成β的方法,均需传入总步长T,返回β序列
def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)


def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)


def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start ** 0.5, beta_end ** 0.5, timesteps) ** 2


def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start


# 从序列a中取t时刻的值a[t](batch_size个),维度与x_shape相同,第一维为batch_size
def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)


# 扩散过程采样,即通过x0和t计算xt
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumpord_t = extract(sqrt_one_minus_alphas_cumpord, t, x_start.shape)
    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumpord_t * noise


# 损失函数loss,共3种计算方式,原文使用l2
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start, t, noise)
    predicted_noise = denoise_model(x_noisy, t)

    if loss_type == "l1":
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == "l2":
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()

    return loss


# 逆扩散过程采样,即通过xt和t计算xt-1,此过程需要通过网络
@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumpord_t = extract(sqrt_one_minus_alphas_cumpord, t, x.shape)
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)

    model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumpord_t)
    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise


# 逆扩散过程T次采样,即通过xT和T计算xi,获得每一个时刻的图像列表[xi],此过程需要通过网络
@torch.no_grad()
def p_sample_loop(model, shape):
    device = next(model.parameters()).device
    b = shape[0]
    img = torch.randn(shape, device=device)
    imgs = []
    for i in tqdm(reversed(range(0, timesteps)), desc="sampling loop time step", total=timesteps):
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
        imgs.append(img.cpu())
    return imgs


# 逆扩散过程T次采样,允许传入batch_size指定生成图片的个数,用于生成结果的可视化
@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=1):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))


if __name__ == "__main__":
    timesteps = 300  # 总步长T
    # 以下参数均为序列(List),需要传入t获得对应t时刻的值 xt = X[t]
    betas = linear_beta_schedule(timesteps=timesteps)  # 选择一种方式,生成β(t)
    alphas = 1. - betas  # α(t)
    alphas_cumprod = torch.cumprod(alphas, axis=0)  # α的连乘序列,对应α_bar(t)
    alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0),
                                value=1.0)  # 将α_bar的最后一个值删除,在最开始添加1,对应前一个时刻的α_bar,即α_bar(t-1)
    sqrt_recip_alphas = torch.sqrt(1. / alphas)  # 1/根号下α(t)
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)  # 根号下α_bar(t)
    sqrt_one_minus_alphas_cumpord = torch.sqrt(1. - alphas_cumprod)  # 根号下(1-α_bar(t))
    posterior_variance = betas * (1. - alphas_cumprod_prev) 以上是关于DDPM模型——pytorch实现的主要内容,如果未能解决你的问题,请参考以下文章

DDPM代码详细解读:图解模型各部分结构用ConvNextBlock代替Resnet

DDPM代码详细解读:图解模型各部分结构用ConvNextBlock代替Resnet

扩散模型DDPM开源代码的剖析对应公式与作者给的开源项目,diffusion model

PyTorch - Diffusion Model 公式推导

PyTorch笔记 - Diffusion Model 公式推导

使用DDPM实现三维点云重建