变分推断
Posted demonhunter
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了变分推断相关的知识,希望对你有一定的参考价值。
变分推断的基本形式
变分推断是使(q(z))逼近(p(zvert x))来求得隐变量(z)的后验分布(p(zvert x))。根据贝叶斯公式,有
[egin{align*} underbrace{logleft(p(x)
ight)}_{ ext{evidence}} &= logleft(pleft(x, z
ight)
ight)-logleft(p(zvert x)
ight)\ &= underbrace{int_z q(z)logleft(frac{p(x, z)}{q(z)}
ight)}_{ ext{evidence low bound}}-underbrace{int_z q(z)logleft(frac{p(zvert x)}{q(z)}
ight)}_{ ext{KL divergence}}end{align*}
]
(logleft(p(x)
ight))被称为Evidence
的原因是因为它是来自我们观察到的,又因为KL-divergence
不为负,为了使得(q(z))逼近(p(zvert x)),优化的目标就是上面的ELOB
。
Mean field
中场理论(Mean field)一般假设
[q(z)=prod_{i=1}^M q(z_i) label{eq:mean field} ag{1}
]
代入(~
ef{eq:mean field})得到
[int_z prod_{i=1}^M q(z_i) logleft(p(x,z)
ight),dz-\ int_{z_1}q(z_i)int_{z_2}q(z_2)cdotsint_{z_M}q(z_M) sum_{i=1}^Mlogleft(q(z_i)
ight),dz_Mcdots dz_1
]
即
[egin{align*}int_z prod_{i=1}^M q(z_i) logleft(q(x,z)
ight),dz-sum_{i=1}^Mint_{z_1}q(z_1)logleft(q(z_1)
ight)end{align*}
]
令
[logleft( ilde{p}_j(x, z)
ight)=E_{i
eq j}left[logleft(p(x,z_j)
ight)
ight]
]
针对第(z_j),ELOW
为
[int_{z_j} q(z_j) logleft( ilde{q}_j(x,)
ight) - int_{z_j} q(z_j) logleft(q(z_j)
ight)
]
因此当(q(z_j)= ilde{q}_j(x,z))时上式取得最小值(0)。因此通过迭代(z_j)可以求得逼近(p(zvert x))的(q(z))
指数函数变分推断例子
假设(p(x),,p(xvert z))都来自某指数族分布,指数族分布形式如下
[p(xvert eta)=h(x)expleft(eta^TT(x)-A(eta)
ight)
]
且满足
[egin{align*} A‘(eta_{MLE}) &= frac{1}{n}sum_{i=1}^n T(x_i)\ A‘(eta) &= E_{p(xvert eta)}left[T(x)
ight]\ A‘‘(eta) &= Varleft[T(x)
ight]end{align*}
]
假设隐变量(z)可以分为两部分(Z)和(eta),那么ELOB
可以写为
[int_{Z,eta} q(Z,eta)logleft(p(x,Z,eta)
ight)-int_{Z,eta}q(Z,eta)logleft(q(Z,eta)
ight)
]
根据指数族的性质后验分布(p(etavert Z,x))和(p(Zvert eta, z))都属于指数族
[egin{align*}p(etavert Z,x) &= h(eta) expleft(T(eta)^Teta(Z,x)-Aleft(eta(Z,x)
ight)
ight)\p(Zvert eta,x) &= h(Z) expleft(T(Z)^Teta(eta,x)-Aleft(eta(eta,x)
ight)
ight)end{align*}
]
这里只展示(p(etavert Z,x))的近似分布(q(etavert lambda))求解,对于(p(Zvert eta, x))的近似分布(q(Zvert phi))也类似
[q(etavert lambda)=h(lambda)expleft(T(eta)^Tlambda-A(lambda)
ight)
]
根据上一节的结果,ELOB
是关于(lambda,, phi)的函数
[E_{q(Z,eta)}left[logleft(p(etavert Z,x)
ight)logleft(p( Zvert x)
ight)logleft(p(x)
ight)
ight]-\E_{q(Z,eta)}left[logleft(q(Z)
ight)logleft(q(eta)
ight)
ight]
]
固定(phi),上式中与(lambda)有关的项为
[E_{q(Z,eta)}left[logleft(q(etavert Z,x)
ight)
ight]-E_{q(Z,eta)}left[logleft(q(eta)
ight)
ight]
]
将(logleft(q(etavert Z,x)
ight))和(logleft(q(eta)
ight))定义带入,得到与(lambda)有关的项为
[E_{q(eta)}left[T(eta)
ight]^TE_{q(Z)}left[eta(Z,x)
ight]-lambda^TE_{q(eta)}left[T(eta)
ight]+A(lambda)
]
利用(A‘(eta)=E_{p(xvert eta)}left[T(x)
ight])得到
[egin{align*}L(lambda,phi) &= A‘(lambda)^TE_{q(Z)}left[eta(Z,x)
ight]-lambda^T A‘(lambda) + A(lambda)\frac{partial L(lambda, phi)}{partial lambda}&=A‘‘(lambda)^TE_{q(Z)}left[eta(Z,x)
ight]-A‘(lambda) - lambda^T A‘‘(lambda) + A‘(lambda)end{align*}
]
因为(A‘‘(lambda)
eq0),因此
[lambda=E_{q(Zvert phi)}left[eta(Z, x)
ight]
]
同理
[phi = E_{q(etavert lambda)}left[eta(eta, x)
ight]
]
随机梯度变分推断
不同于mean field,随机梯度变分推断将分布(q(zvert phi))看为关于(phi)的分布,通过对(phi)进行优化得到最优的分布。
[egin{align*}
abla_{phi}L&=
abla_{phi}E_{q(zvert phi)}left[logleft(p(x,z)
ight)-logleft(q(zvertphi)
ight)
ight]\&=E_{q(zvert phi)}left[
abla_{phi}left[log q(zvert phi
ight]left(log p(x,z)-log q(zvert phi)
ight)
ight]end{align*}
]
随后用蒙塔卡罗就可以近似出梯度,虽然直接使用蒙特卡洛会造成方差较大,可以通过重参数技巧进行减小方差(在VAE中也有用到)。重参数后的计算参见SGVI。
参考
-
WallE-Chang SGVI repository
-
ws13685555932 machine learning derivative repository
-
shuhuai008 SGVI
以上是关于变分推断的主要内容,如果未能解决你的问题,请参考以下文章
变分推断
变分推断—— 进阶(续)
变分推断—— 进阶
文本主题模型之LDA LDA求解之变分推断EM算法
变分推断
文本主题模型之LDA LDA求解之变分推断EM算法