Diffusion models as plug-and-play priors

Posted 馒头and花卷

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Diffusion models as plug-and-play priors相关的知识,希望对你有一定的参考价值。

Graikos A., Malkin N., Jojic N. and Samaras D. Diffusion models as plug-and-play priors. NIPS, 2022.

有了先验分布 \\(p(\\mathbfx)\\) (用一般的扩散模型去拟合), 我们总是像添加一些约束, 即希望从条件概率分布 \\(p(\\mathbfx|\\mathbfy)\\) 中采样. 作者在这里讨论的范围要更大, 只需给定一些约束 \\(c(\\mathbfx, \\mathbfy)\\) 即可.

问题

  • 假设我们对后验概率 \\(p(\\mathbfx|\\mathbfy) \\propto p(\\mathbfx) c(\\mathbfx, \\mathbfy)\\) 感兴趣.

  • \\(c(\\mathbfx, \\mathbfy) = p(y | \\mathbfx)\\) 容易估计的时候, 这个问题是容易的 (这里我们假设先验分布 \\(p(\\mathbfx\\)) 容易通过扩散模型近似).

  • 但是当 \\(c(\\mathbfx, \\mathbfy)\\) 本身不容易估计的时候, 这个问题会比较麻烦.

  • 作者的想法是, 利用一个替代分布 (or variational posterior) \\(q(\\mathbfx)\\) 来近似 \\(p(\\mathbfx|\\mathbfy)\\), 当然 \\(q(\\mathbfx)\\) 在形式上必须是容易采样和求解的.

  • 需要注意的是, 虽然这里的符号是 \\(q(\\mathbfx)\\), 但是不意味着对于所有的 \\(\\mathbfy\\) 都用同一个 \\(q(\\mathbfx)\\) 近似. 实际上, 真正的做法是, 对于每一条件 \\(\\mathbfy\\) 求解一个替代分布 \\(q(\\mathbfx)\\), 所以实际上 \\(q_\\mathbfy(\\mathbfx)\\) 或许更为贴切, 不过这里还是遵循此惯用用法.

  • 一种简单的做法就是通过最小化 KL 散度:

    \\[\\tag1 \\beginarrayll \\min_q \\: \\textKL(q(\\mathbfx) \\| p(\\mathbfx|y)) &= \\int q(\\mathbfx) \\log \\fracq(\\mathbfx)p(\\mathbfx|\\mathbfy) \\\\ &\\Leftrightarrow -\\mathbbE_q(\\mathbfx) \\Bigg [\\log p(\\mathbfx) + \\log c(\\mathbfx, \\mathbfy) - \\log q(\\mathbfx) \\Bigg] \\\\ &=: F. \\endarray \\]

    最后的部分, 称为 variational free energy. 最小化 free energy 可以得到一个合适的近似. 注: variational free energy 又名 Evidence Lower Bound.

  • 现在, 我们假设 \\(p(\\mathbfx)\\) 本身也是难以直接得到的 (正如扩散模型一下), 但我们可以直接的带 \\(p(\\mathbfx, \\mathbfh)\\), 但是我们依旧只添加约束 \\(c(\\mathbfx, \\mathbfy)\\), 此时 free energy 成为了:

    \\[\\tag2 \\beginarrayll F &= -\\mathbbE_q(\\mathbfx) q(\\mathbfh|\\mathbfx) \\Bigg [\\log p(\\mathbfx, \\mathbfh) + \\log c(\\mathbfx, \\mathbfy) - \\log q(\\mathbfx)q(\\mathbfh|\\mathbfx) \\Bigg] \\\\ &=-\\mathbbE_q(\\mathbfx) q(\\mathbfh|\\mathbfx) \\Bigg [\\log p(\\mathbfx, \\mathbfh) - \\log q(\\mathbfx)q(\\mathbfh|\\mathbfx) \\Bigg] - \\mathbbE_q(\\mathbfx) [\\log c(\\mathbfx, \\mathbfy)]. \\endarray \\]

与扩散模型的联系

  • DDPM 拟合得到了如下的一个联合分布:

    \\[p(\\mathbfx_T, \\mathbfx_T-1, \\ldots, \\mathbfx_0) = p(\\mathbfx_T) \\prod_t=1^T p_\\theta(\\mathbfx_t-1|\\mathbfx_t), \\\\ p_\\theta(\\mathbfx_t-1|\\mathbfx_t) = \\mathcalN(\\mathbfx_t-1; \\bm\\mu_\\theta(\\mathbfx_t, t), \\Sigma_\\theta(\\mathbfx_t, t)), \\quad p(\\mathbfx_T) = \\mathcalN(\\bm0, \\mathbfI). \\]

  • \\(\\mathbfx = \\mathbfx_0, \\mathbfh = \\mathbfx_1:T\\):

    1. 首先, 根据扩散模型的性质我们可以知道得到:

      \\[q(\\mathbfh|\\mathbfx=\\mathbfx_0) = \\prod_t=1^T q(\\mathbfx_t|\\mathbfx_t-1), \\: q(\\mathbfx_t|\\mathbfx_t-1) = \\mathcalN(\\mathbfx_t; \\sqrt1 - \\beta_t \\mathbfx_t-1, \\beta_t \\mathbfI), \\quad t=1,\\ldots, T. \\]

    2. 换言之, 我们需要估计仅仅是 \\(q(\\mathbfx)\\):

      \\[\\tag3 \\min_q(\\mathbfx) \\quad F. \\]

  • Ok, 让我们缓一缓, 其实扩散模型的损失和这里的损失是同一个, 只是扩散模型固定 \\(q\\) (2) 来训练 \\(p_\\theta\\), 而现在是固定 \\(p_\\theta\\) 来训练 \\(q(\\mathbfx)\\).

  • 作者直接采用最简单的形式:

    \\[q(\\mathbfx) = \\delta(\\mathbfx - \\bm\\eta), \\]

    故 (3) 相当于:

    \\[\\tag4 \\min_\\bm\\eta \\quad F. \\]

  • 进一步将 \\(F\\) 写成关于 \\(\\bm\\eta\\) 的形式:

    \\[F(\\bm\\eta) = \\sum_t \\mathbbE_\\bm\\epsilon \\sim \\mathcalN(\\bm0, \\mathbfI) [\\|\\bm\\epsilon - \\bm\\epsilon_\\theta(\\mathbfx_t, t)\\|_2^2] - \\log c(\\bm\\eta, \\mathbfy), \\: \\mathbfx_t = \\sqrt\\bar\\alpha_t \\bm\\eta + \\sqrt1 - \\bar\\alpha_t \\bm\\epsilon. \\]

  • 关于 \\(\\bm\\eta\\) 最小化 \\(F(\\bm\\eta)\\) 即可:

  • 考虑如下的优化:

    \\[\\max_\\bm\\eta \\int \\delta(\\mathbfx - \\mathbf\\eta) \\log p(\\mathbfx|\\mathbfy) \\mathrmd\\mathbfx = \\max_\\bm\\eta \\: \\log p(\\mathbfx|\\mathbfy), \\]

    \\[\\mathbf\\bm\\eta^* = \\arg\\max p(\\mathbfx|\\mathbfy) \\]

    实际上是 MAP 估计.

  • 一个好处是, 我们不需要严格保证 \\(c(\\mathbfx, \\mathbfy) = p(\\mathbfy|\\mathbfx)\\), 故严格上任意关于 \\(\\bm\\eta\\) 可导的函数都可以作为约束放在这里.

应用

条件采样

  • \\(c\\) 为 digit 的不同属性的 score (不用归一化):

  • \\[c(\\mathbfx, \\mathbfy) = \\prod_y \\in \\mathbfy p(y | \\mathbfx), \\]

    其中 \\(p(y|\\mathbfx)\\) 是对 \\(\\mathbfx\\) 不同属性的预测, 比如在人脸中: no beard, smiling, blond hair, male 等.

语义分割

  • 没怎么看懂.

解决离散问题

  • 没怎么看懂.

代码

official

从VAE到Diffusion Model

以上是关于Diffusion models as plug-and-play priors的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch - Diffusion Model 公式推导

CVLatent diffusion model 扩散模型体验

PyTorch笔记 - Diffusion Model 公式推导

最新最全Diffusion Models论文代码汇总[三万字总结]

Video Diffusion Models:基于扩散模型的视频生成

Diffusion Models和GANs结合