很多次翻看DDPM,始终不太能理解论文中提到的$\text{Variational Inference}$到底是如何在这个工作中起到作用。五一假期在家,无意间又刷到徐亦达老师早些年录制的理论视频,没想到其中也有介绍这部分的内容。老师的上课方式总是娓娓道来,把每一步都讲解得很仔细。本文记录一下个人对开头问题的思考。

Background

如果需要简略地介绍一下DDPM这个工作,可能会用以下几句话简单地描述:DDPMMarkov的形式对数据(图片)“扩散过程”建模,使用神经网络进行训练拟合,学习数据的概率分布。

所以对于生成任务来说,希望从给定数据中学习到的是数据的潜在信息。比如图片生成,在给定一些图片后,模型学习到的是“正常图片长什么样子”,如:

  1. 一张包含手机正面的图片会有【手机屏幕】;
  2. 一张包含猫咪的图片会有人们观察到的猫咪模样;

对于图片中每个像素点和附近的像素点,进行“合理”布局,才能生成“符合人们认知的图片”。

图片生成能像常见的机器学习任务如分类任务、回归任务,能基于maximize likelihood的形式来训练么?

**结论是很难,**先回顾如何做maximum likelihood。给定一批数据,首先需要假定数据服从的分布,接着写出似然函数,之后直接通过解析解的形式或是梯度下降的形式,求出分布。

问题就出在假定分布这一步,没有人知道图片客观上服从什么分布。那如果使用神经网络直接拟合可以么?这好像也不现实,拿一张512*512*3的图片来说,网络输出层共有约75w的数值。

对于图片生成还有另外一个问题,世界上的图片太多了,目之所及稍做处理,皆为图片。即便使用神经网络能拟合,最后生成的图片很难存在多样性。

那目前图片生成模型都是怎么做的,比如VAE或是本文即将要介绍的Diffusion Model,它们学习的都是数据分布$p(x)$,但直接求$p(x)$这么麻烦,需要怎么做?这其实也是$\text{Variational Inference}$的核心思想,“曲线救国”,通过引入其它分布,将原本难以优化的问题转变为可优化问题。

ELOB

先把上述提到的所有背景先抛开,研究一下$p(x)$,看看能得到什么有意思的结论。

a. 基于条件概率分布,引入新的随机变量$z$:$p(x) = \frac{p(x, z)}{p(z\mid x)}$;

b. 对于两边同时取$\ln$,等式依然成立,因此有:$\ln{p(x)} = \ln{\frac{p(x, z)}{p(z \mid x)}}$;

c. 右边分子分母同乘以$q(z)$:$\ln{p(x)} = \ln{\frac{p(x, z) * q(z)}{p(z \mid x) * q(z)}} = \ln{\left(\frac{p(x, z)}{q(z)} * \frac{q(z)}{p(z \mid x)}\right)} = \ln{\frac{p(x, z)}{q(z)}} + \ln{\frac{q(z)}{p(z \mid x)}}$

d. 再次,对于上式左右两边求关于$q(z)$的期望,等式依然成立:

$$
\begin{aligned}
&\mathbb{E}{z\sim q(z)}{[\ln{p(x)}]} = \mathbb{E}{z\sim q(z)}{(\ln{\frac{p(x, z)}{q(z)}} + \ln{\frac{q(z)}{p(z \mid x)}})} \
\iff & \int_z q(z)\ln{p(x)}dz = \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz + \int_z q(z)\ln{\frac{q(z)}{p(z \mid x)}}dz \
\iff & \ln{p(x)} = \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz + \int_z q(z)\ln{\frac{q(z)}{p(z \mid x)}}dz
\end{aligned}
\tag{1}
$$

一系列变换后,$(1)$式是最后的推导结果,等式右边由两个项组成。第二个项$\int_z q(z)\ln{\frac{q(z)}{p(z \mid x)}}dz$,叫做KL散度,它被用来衡量两个分布之间的“距离”,性质是值不小于0

这样一来,通过$(1)$可以得到不等式$(2)$:

$$
\begin{equation*}
\ln{p(x)} \geq \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz
\end{equation*}
\tag{2}
$$

$(1)$式右边的第一项,同时也是$(2)$式的右边项,被学者们叫做$\text{ELBO(Evidence Lower Bound)}$。

Objective Function

上述推导的$(2)$式可以被视作“定理”一般的存在,即对于某个分布的对数形式,总可以找到它的下界。

那$(2)$式可以用来做什么?在Background中提到,图片生成任务中的$p(x)$想要对它做maximum likelihood根本无法做起。目标依然是最大化$p(x)$,但有了$(2)$式,求解的目标可以转移到最大化它的下界$\text{ELBO}$。

这也是论文中提到的:

This paper presents progress in diffusion probabilistic models. A diffusion probabilistic model (which we will call a “diffusion model” for brevity) is a parameterized Markov chain trained using variational inference to produce samples matching the data after finite time.

接下来,回到论文中,看看是如何一步步推导出DDPM的优化目标。$(3)$式直接摘录于论文:

$$
\begin{equation*}
\ln{p(x)} \geq \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz = \mathbb{E}_{z \sim q(z)}\left[\ln{\frac{p(x,z)}{q(z)}}\right]
\end{equation*}
\tag{2}
$$

$$
\begin{equation*}
\mathbb{E}\left[-\log p_\theta\left(\mathbf{x}0\right)\right] \leq \mathbb{E}q\left[-\log \frac{p\theta\left(\mathbf{x}{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}0\right)}\right]=\mathbb{E}q\left[-\log p\left(\mathbf{x}T\right)-\sum{t \geq 1} \log \frac{p\theta\left(\mathbf{x}{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}t \mid \mathbf{x}{t-1}\right)}\right]=: L
\end{equation*}
\tag{3}
$$

下面一项项地对$(3)$ 进行拆解,并且将它与$(2)$比对,能帮助更好地理解:

  1. $(3)$不等号左边的$\mathbb{E}\left[-\log p_\theta\left(\mathbf{x}0\right)\right]$进一步化简就是$-\log p\theta\left(\mathbf{x}0\right)$。其中,$p\theta\left(\mathbf{x}_0\right)$便是模型要学习的最终目标:图像的分布,$\theta$是模型的参数,$\mathbf{x}_0$是图片;

  2. $(2)$式的左右两边同时加上符号,$\geq$变为$\leq$;

  3. 看$(3)$不等式右边部分,$\mathbb{E}q\left[-\log \frac{p\theta\left(\mathbf{x}{0: T}\right)}{q\left(\mathbf{x}{1: T} \mid \mathbf{x}_0\right)}\right]$

    1. 很明显,$q(\mathbf{x}_{1:T} \mid \mathbf{x}_0)$相当于$(2)$中引入的额外分布$q(z)$。对于$z$,在生成模型中会给它一个称呼:隐变量$(\text{latent})$。实际上,在diffusion models里,对$\mathbf{x}_0$加噪后的$\mathbf{x}_1,\mathbf{x}_2,\ldots, \mathbf{x}_T$就可以看作隐变量,那不妨记作$z := {\mathbf{x}_1,\mathbf{x}_2,\ldots, \mathbf{x}_T}$;

    2. $p_\theta\left(\mathbf{x}{0: T}\right) = p\theta\left(\mathbf{x}{0}, \mathbf{x}{1}, \ldots, \mathbf{x}_{T}\right)$,是关于$\mathbf{x}_0, z$的联合概率分布,因为选用马尔代夫链建模,那么依据马尔可夫链的性质,论文定义:

$$
\begin{equation*}
\begin{aligned}
q\left(\mathbf{x}{1: T} \mid \mathbf{x}0\right)&:=\prod{t=1}^T q\left(\mathbf{x}t \mid \mathbf{x}{t-1}\right) \
p
\theta\left(\mathbf{x}{0: T}\right)&:=p\left(\mathbf{x}T\right) \prod{t=1}^T p\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)
\end{aligned}
\end{equation*}
\tag{4}
$$

  1. 将$(4)$带入$(3)$不等式右边的第一项,得到$L$:

$$
\begin{equation*}
\begin{aligned}
&\mathbb{E}q\left[-\log \frac{p\theta\left(\mathbf{x}{0: T}\right)}{q\left(\mathbf{x}{1: T} \mid \mathbf{x}0\right)}\right] \
=&\mathbb{E}q\left[-\log \frac{p\left(\mathbf{x}T\right) \prod{t=1}^T p\theta\left(\mathbf{x}
{t-1} \mid \mathbf{x}t\right)}{\prod{t=1}^T q\left(\mathbf{x}t \mid \mathbf{x}{t-1}\right)}\right] \
=&\mathbb{E}q\left[-\log p\left(\mathbf{x}T\right)-\sum{t \geq 1} \log \frac{p\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}t \mid \mathbf{x}{t-1}\right)}\right] := L
\end{aligned}
\end{equation*}
$$

到目前为止,经过了很多轮的变换以及数学公式,先捋一遍,再往下。$L$是一个替代的优化目标,
$$\mathop{\arg\min}{(L)} \iff \mathop{\arg\min}{(-\ln{p}_{\theta}(\mathbf{x}0))} \iff \mathop{\arg\max}{(\ln{p}{\theta}(\mathbf{x}_0))}$$

接下来,论文中对$L$进行了重写,以下步骤直接摘录自论文$\text{Appendix A}$

$$
\begin{equation*}
\begin{aligned}
L & =\mathbb{E}q\left[-\log \frac{p\theta\left(\mathbf{x}{0: T}\right)}{q\left(\mathbf{x}{1: T} \mid \mathbf{x}0\right)}\right] \ & =\mathbb{E}q\left[-\log p\left(\mathbf{x}T\right)-\sum{t \geq 1} \log \frac{p\theta\left(\mathbf{x}{t-1} \mid \mathbf{x}t\right)}{q\left(\mathbf{x}t \mid \mathbf{x}{t-1}\right)}\right] \ & =\mathbb{E}q\left[-\log p\left(\mathbf{x}T\right)-\sum{t>1} \log \frac{p\theta\left(\mathbf{x}{t-1} \mid \mathbf{x}t\right)}{q\left(\mathbf{x}t \mid \mathbf{x}{t-1}\right)}-\log \frac{p\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}{q\left(\mathbf{x}_1 \mid \mathbf{x}0\right)}\right] \
&=\mathbb{E}q\left[-\log p\left(\mathbf{x}T\right)-\sum{t>1} \log \left[\frac{p\theta\left(\mathbf{x}
{t-1} \mid \mathbf{x}t\right)}{q\left(\mathbf{x}{t-1} \mid \mathbf{x}_t, \mathbf{x}0\right)} \cdot \frac{q\left(\mathbf{x}{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}0\right)}\right]-\log \frac{p\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}\right]
\end{aligned}
\end{equation*}
\tag{5}
$$

倒数两步的变换发生在第二项,具体依据为:

$$
\begin{aligned}
q\left(\mathbf{x}t \mid \mathbf{x}{t-1}\right)
=& \frac{q\left(\mathbf{x}t, \mathbf{x}{t-1}\right)}{q\left(\mathbf{x}{t-1}\right)} \
=& \frac{q\left(\mathbf{x}t, \mathbf{x}{t-1} \mid \mathbf{x}
{0}\right) *q(\mathbf{x}{0})}{q\left(\mathbf{x}{t-1} \mid \mathbf{x}{0}\right) * q(\mathbf{x}{0})} \
=& \frac{q\left(\mathbf{x}t, \mathbf{x}{t-1} \mid \mathbf{x}{0}\right) }{q\left(\mathbf{x}{t-1} \mid \mathbf{x}0\right)}
\end{aligned}
\quad \Rightarrow \quad
\begin{aligned}
&\sum
{t>1} \log \frac{p_\theta\left(\mathbf{x}{t-1} \mid \mathbf{x}t\right)}{q\left(\mathbf{x}t \mid \mathbf{x}{t-1}\right)} \
=& \sum
{t>1} \log \frac{p
\theta\left(\mathbf{x}{t-1} \mid \mathbf{x}t\right)}{q\left(\mathbf{x}t, \mathbf{x}{t-1} \mid \mathbf{x}{0}\right) } \cdot {q\left(\mathbf{x}{t-1} \mid \mathbf{x}0\right)} \
=& \sum
{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}t\right)}{q\left(\mathbf{x}{t-1} \mid \mathbf{x}_t, \mathbf{x}0\right)} \cdot \frac{q\left(\mathbf{x}{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}
\end{aligned}
$$

接着对$(5)$进行改写得到最终形式$(6)$:
$$
\begin{aligned}
L &=\mathbb{E}_q\left[-\log \frac{p\left(\mathbf{x}T\right)}{q\left(\mathbf{x}T \mid \mathbf{x}0\right)}-\sum{t>1} \log \frac{p\theta\left(\mathbf{x}{t-1} \mid \mathbf{x}t\right)}{q\left(\mathbf{x}{t-1} \mid \mathbf{x}t, \mathbf{x}0\right)}-\log p\theta\left(\mathbf{x}0 \mid \mathbf{x}1\right)\right] \
&=\mathbb{E}q[\underbrace{D{\mathrm{KL}}\left(q\left(\mathbf{x}T \mid \mathbf{x}0\right) | p\left(\mathbf{x}T\right)\right)}{L_T}+\sum{t>1} \underbrace{D{\mathrm{KL}}\left(q\left(\mathbf{x}
{t-1} \mid \mathbf{x}t, \mathbf{x}0\right) | p\theta\left(\mathbf{x}{t-1} \mid \mathbf{x}t\right)\right)}{L
{t-1}} \underbrace{-\log p
\theta\left(\mathbf{x}_0 \mid \mathbf{x}1\right)}{L_0}]
\end{aligned}
\tag{6}
$$

Summary

太好了,对于$(6)$来说,它最起码是个可以优化的目标函数了,因为论文中定义马尔可夫链相邻状态的转变是服从高斯分布的。当然在论文中,$(6)$还会进一步被改写,得到更加精简的$\text{loss function}$形式。
DDPM是应用$\text{variational inference}$进行优化求解的典型例子,很值得借鉴学习。

Reference