变分贝叶斯推断 问题描述 假设当前有一个贝叶斯模型,且其中的参数都有相应的先验分布。同时,模型中还可能有潜变量,将其与各种参数标记为Θ 。同样地,把所有观测变量集合标记为Y 。因此,我们希望找到分布q ( Θ ) 来逼近真实后验分布p ( Θ ∣ Y ) ,而这可以通过最小化KL散度实现,也即:
KL ( q ( Θ ) ‖ p ( Θ | Y ) ) = ∫ q ( Θ ) ln { q ( Θ ) p ( Θ ∣ Y ) } d Θ = ln p ( Y ) − ∫ q ( Θ ) ln { p ( Y , Θ ) q ( Θ ) } d Θ
其中ln p ( Y ) 表示模型证据(Evidence),则其下界(lower bound)可以定义为L ( q ) = ∫ q ( Θ ) ln p ( Y , Θ ) q ( Θ ) d Θ 。因为模型证据是一个常量,当KL散度为0时,下界出现最大值,这就意味着q ( Θ ) = p ( Y , Θ ) 。
平均场理论 根据平均场理论(mean field theory),我们假设变分分布q ( Θ ) 可以被分解成各组变量分布的乘积,即可以写作:
q ( Θ ) = M ∏ j q j ( Θ j )
这是针对该分布的唯一假设,其中每一个独立因子q j ( Θ j ) 的特定函数形式可以具体地一个个推导出来。通过最大化下届L ( q ) ,第j 个因子的优化形式由下式给出:
ln q j ( Θ j ) = E q ( Θ ∖ Θ j ) [ ln p ( Y , Θ ) ] + const
其中E q ( Θ ∖ Θ j ) [ ⋅ ] 表示关于除Θ j 外所有变量的q 分布的期望。因为所有参数的分布都属于指数族分布,且都与其父节点共轭,因此我们能够通过公式(3)推导出Θ 中每个参数的后验分布更新的近似形式。
变分下界 在模型计算时,可以通过直接计算变分下界L ( q ) 来判断算法是否收敛,因为在每一次迭代中变分下界不减。变分下界可以通过下式进行计算
L ( q ) = ∫ q ( Θ ) ln { p ( Y , Θ ) p ( Θ ) } d Θ = E q [ ln p ( Y , Θ ) ] − E q [ ln q ( Θ ) ] = E q [ ln p ( Y ∣ Θ ) ] + E q [ ln p ( Θ ) ] − ∑ j E q j ( Θ j ) [ ln q j ( Θ j ) ]
其中第一项表示联合分布的后验期望,而第二项则表示后验q 分布的熵H ( q ( Θ ) ) = − E q [ ln p ( q ( Θ ) ) ] 。
例子-一元高斯模型 模型推断 给定一组观测数据值D = x 1 , x 2 , … , x N ,假定数据是独立地从高斯分布中抽取的,目标是通过最大化后验分布推断得到均值参数μ 和精度参数τ ,其似然函数为
p ( D ∣ μ , τ ) = N ∏ i = 1 ( τ 2 π ) 1 2 exp { − τ ( x i − μ ) 2 2 } = ( τ 2 π ) N 2 exp { − τ 2 N ∑ i = 1 ( x i − μ ) 2 }
同时引入参数μ 和τ 的共轭先验分布,其形式为
p ( μ ∣ τ ) = N ( μ ∣ μ 0 , [ λ 0 τ ] − 1 ) p ( τ ) = Gam ( τ ∣ a 0 , b 0 )
则其联合分布可以表示为
p ( D , μ , τ ) = p ( D ∣ μ , τ ) p ( μ ∣ τ ) p ( τ )
对后验概率分布的变分近似进行分解
q ( μ , τ ) = q μ ( μ ) q τ ( τ )
根据公式(3),可以推导出各个因子的优化形式,对于q μ ( μ ) ,我们有
ln q ∗ μ ( μ ) = E τ [ ln p ( D ∣ μ , τ ) + ln p ( μ ∣ τ ) ] + const = − E [ τ ] 2 { N ∑ i = 1 ( x i − μ ) 2 + λ 0 ( μ − μ 0 ) 2 } + const = − 1 2 { ( λ 0 + N ) E [ τ ] μ 2 − 2 ( λ 0 μ 0 + N ˉ x ) E [ τ ] μ } + const
由共轭性,可以看到q μ ( μ ) 是一个高斯分布,其参数为
λ ∗ = ( λ 0 + N ) E [ τ ] μ ∗ = [ λ ∗ ] − 1 ( λ 0 μ 0 + N ˉ x ) E [ τ ] = λ 0 μ 0 + N ˉ x λ 0 + N
同理,因子q τ ( τ ) 的最优形式为
ln q ∗ τ ( τ ) = E μ [ ln p ( D ∣ μ , τ ) + ln p ( μ ∣ τ ) + ln p ( τ ) ] + const = ( a 0 − 1 + N + 1 2 ) ln τ − { b 0 + 1 2 E μ [ N ∑ i = 1 ( x i − μ ) 2 + λ 0 ( μ − μ 0 ) 2 ] } τ + const
因此q τ ( τ ) 是一个Gamma分布,参数为
a ∗ = a 0 + N + 1 2 b ∗ = b 0 + 1 2 E μ [ N ∑ i = 1 ( x i − μ ) 2 + λ 0 ( μ − μ 0 ) 2 ]
为了评估模型的收敛性,我们进一步计算变分下界,其公式如下
L ( q ) = E q [ ln p ( Y , Θ ) ] − E q [ ln q ( Θ ) ] = E q ( μ , τ ) [ ln p ( D ∣ μ , τ ) ] + E q ( μ , τ ) [ ln p ( μ ∣ τ ) ] + E q ( τ ) [ ln p ( τ ) ] − E q ( μ ) [ ln q ( μ ) ] − E q ( τ ) [ ln q ( τ ) ]
上式各个部分可以通过下面公式计算得到,其中$\mathbb{E}q[\mu^2]=\mathbb{E} {\mu}[\mu^2]=(\mathbb{E}[\mu])^2+\text{Var}[\mu], \mathbb{E}q[\text{ln}\tau]=\mathbb{E} {\tau}[\text{ln}\tau]=\psi(a)-\text{ln}b, \psi(\cdot)为 双 伽 马 函 数 ( d i g a m m a f u n c t i o n ) 而 \Gamma(\cdot)$为伽马函数。
E q [ ln p ( D ∣ μ , τ ) ] = − N 2 ln ( 2 π ) + N 2 E q [ ln τ ] − 1 2 E q [ τ ] E q [ N ∑ i = 1 ( x i − μ ) 2 ] = − N 2 ln ( 2 π ) + N 2 ( ψ ( a ∗ ) − ln ( b ∗ ) ) − a ∗ 2 b ∗ { N ∑ i = 1 x 2 i − 2 μ ∗ N ∑ i = 1 x i + E q [ μ 2 ] }
E q [ ln p ( μ ∣ τ ) ] = − 1 2 ln ( 2 π ) + 1 2 E q [ ln ( λ 0 τ ) ] − λ 0 2 E q [ τ ] E q [ ( μ − μ 0 ) 2 ] = − 1 2 ln ( 2 π ) + ln λ 0 2 + 1 2 ( ψ ( a ∗ ) − ln ( b ∗ ) ) − λ 0 a ∗ 2 b ∗ { E q [ μ 2 ] − 2 μ 0 μ ∗ + μ 2 0 }
E q [ ln p ( τ ) ] = − ln Γ ( a 0 ) + a 0 ln b 0 + ( a 0 − 1 ) E q [ ln τ ] − b 0 E q [ τ ] = − ln Γ ( a 0 ) + a 0 ln b 0 + ( a 0 − 1 ) ( ψ ( a ∗ ) − ln b ∗ ) − b 0 a ∗ b ∗
− E q [ ln q ( μ ) ] = 1 2 ln ( 2 π ) − ln λ ∗ 2 + λ ∗ 2 { E q [ μ 2 ] − 2 ( μ ∗ ) 2 + ( μ ∗ ) 2 } = 1 2 ln ( 2 π ) − ln λ ∗ 2 + 1 2
− E q [ ln q ( τ ) ] = ln Γ ( a ∗ ) − ( a ∗ − 1 ) ψ ( a ∗ ) − ln b ∗ + a ∗
至此,我们得到了关于参数μ 与τ 最优分布的表达式,且各自依赖于另一个分布的计算所得的一阶矩或二阶矩。因此通过初始化参数的值,可以通过不断迭代计算出后验分布。
在此例子中,由于模型简单且参数数量只有两个,因此我们可以通过直接求解上式的因子找到显示解。首先我们可以通过设置无信息先验(noninformative prior)来简化上述表达式,也即μ 0 = λ 0 = a 0 = b 0 = 0 。根据Gamma分布的均值计算公式,我们有
1 E [ τ ] = b ∗ a ∗ = E [ 1 N + 1 N ∑ i = 1 ( x i − μ ) 2 ] = N N + 1 ( ¯ x 2 − 2 ˉ x E [ μ ] + E [ μ 2 ] )
由公式(10),我们可以获得近似分布q μ ( μ ) 的一阶矩与二阶矩
E [ μ ] = ˉ x , E [ μ 2 ] = ˉ x 2 + 1 N E [ τ ]
将其代入公式(19),可以解出E [ τ ]
1 E [ τ ] = ¯ x 2 − ˉ x 2 = 1 N N ∑ i = 1 ( x i − ˉ x ) 2
代码示例 上节一元高斯的求解MATLAB代码如下所示,首先通过随机数生成器生成N 个高斯分布的随机数作为模型观测值
clear; mu_real = 1 ; tau_real = 1.5 ; N = 200 ; D = normrnd(mu_real, sqrt (1. /tau_real), [N,1 ]); sumD = sum(D); sumD2 = sum(D.^2 ); mu_est = mean(D); tau_est = 1. /(mean(D.^2 )-mean(D).^2 ); x_real=-4 +mu_real:0.1 :mu_real+4 ; y_real=normpdf(x_real,mu_real,sqrt (1. /tau_real)); x_est=-4 +mu_est:0.1 :mu_est+4 ; y_est=normpdf(x_est,mu_est,sqrt (1. /tau_est)); figure;plot(x_real,y_real,'-r.' ,x_est,y_est,'--b.' );grid;
初始化模型参数和超参数
mu0 = 1e-6 ; lambda0 = 1e-6 ; a0 = 1e-6 ; b0 = 1e-6 ; mu = randn (); mus(1 ) = mu; tau = rand (); taus(1 ) = tau; LB = 0 ; tol = 1e-5 ; maxiters = 100 ;
可视化求解过程,将分别绘出模型参数的真值与估计值比较图、变分下界变化图和精度参数τ 的后验分布变化图
scrnsz = get(0 ,'ScreenSize' ); h = figure('Position' ,[scrnsz(3 )*0.25 scrnsz(4 )*0.25 scrnsz(3 )*0.5 scrnsz(4 )*0.5 ]); set(0 ,'CurrentFigure' ,h); subplot(2 ,2 ,1 ); plot(mu_real, '-r.' ,'LineWidth' ,1.5 ,'MarkerSize' ,10 ); title('Model parameter \mu' ); xlabel('Iteration' ); grid on; subplot(2 ,2 ,2 ); plot(tau_real, '-r.' ,'LineWidth' ,1.5 ,'MarkerSize' ,10 ); title('Model parameter \tau' ); xlabel('Iteration' ); grid on; subplot(2 ,2 ,3 ); plot(LB, '-r.' ,'LineWidth' ,1.5 ,'MarkerSize' ,10 ); title('Lower bound' ); xlabel('Iteration' ); grid on; subplot(2 ,2 ,4 ); plot(0 :0.1 :20 , gampdf(0 :0.1 :20 , a0, 1. /b0), 'r-' ); title('Posterior pdf' ); xlabel('Noise precision \tau' ); grid on; set(findall(h,'type' ,'text' ),'fontSize' ,12 ); drawnow;
模型迭代求解
for it=1 :maxiters lambda_new = (lambda0+N)*tau; mu_new = (1. /lambda_new)*(lambda0*mu0+sumD)*tau; mu = mu_new; mus(it+1 ) = mu; E_mu2 = mu_new^2 +1. /lambda_new; a_new = a0+(N+1 )./2 ; b_new = b0+0.5 *(sumD2-2 *(sumD+lambda0*mu0)*mu+(N+lambda0)*E_mu2+lambda0*(mu0^2 )); tau = a_new./b_new; taus(it+1 ) = tau; E_lntau = psi (a_new)-safelog(b_new); E_pD = -0.5 *N*safelog(2 *pi )+0.5 *N*E_lntau-0.5 *tau*(sumD2-2 *mu*sumD+E_mu2); E_pmu = -0.5 *safelog(2 *pi )+0.5 *E_lntau+0.5 *safelog(lambda0)-0.5 *lambda0*tau*(E_mu2-2 *mu0*mu+mu0^2 ); E_ptau = -safelog(gamma (a0))+a0*safelog(b0)+(a0-1 )*E_lntau-b0*tau; E_qmu = 0.5 *safelog(2 *pi )-0.5 *safelog(lambda_new)+0.5 ; E_qtau = safelog(gamma (a_new))-(a_new-1 )*psi (a_new)-safelog(b_new)+a_new; LB(it) = E_pD + E_pmu + E_ptau + E_qmu + E_qtau; set(0 ,'CurrentFigure' ,h); subplot(2 ,2 ,1 ); plot(mu_real*ones (1 ,it+1 ), '-r.' ,'LineWidth' ,1.5 ,'MarkerSize' ,10 );hold on;plot(mu_est*ones (1 ,it+1 ), '--r.' ,'LineWidth' ,1.5 ,'MarkerSize' ,10 );hold on;plot(mus, '-b.' ,'LineWidth' ,1.5 ,'MarkerSize' ,10 );hold off; title('Model parameter \mu' ); xlabel('Iteration' ); grid on; subplot(2 ,2 ,2 ); plot(tau_real*ones (1 ,it+1 ), '-r.' ,'LineWidth' ,1.5 ,'MarkerSize' ,10 );hold on;plot(tau_est*ones (1 ,it+1 ), '--r.' ,'LineWidth' ,1.5 ,'MarkerSize' ,10 );hold on;plot(taus, '-b.' ,'LineWidth' ,1.5 ,'MarkerSize' ,10 );hold off; title('Model parameter \tau' ); xlabel('Iteration' ); grid on; subplot(2 ,2 ,3 ); plot(LB, '-r.' ,'LineWidth' ,1.5 ,'MarkerSize' ,10 ); title('Lower bound' ); xlabel('Iteration' ); grid on; subplot(2 ,2 ,4 ); plot(0 :0.05 :2 *tau, gampdf(0 :0.05 :2 *tau, a_new, 1. /b_new), '-r.' , 'LineWidth' ,1.5 ); title('Posterior pdf' ); xlabel('Noise precision \tau' ); grid on; set(findall(h,'type' ,'text' ),'fontSize' ,12 ); drawnow; if it>3 LB_change = -1 *(LB(it) - LB(it-1 ))/LB(3 ); else LB_change = NaN; end if it>10 && (abs (LB_change) < tol) disp ('Converged!' ); break ; end end function y = safelog (x) x(x<1e-300 )=1e-200 ; x(x>1e300 )=1e300 ; y=log (x); end
求解结果如下图所示,左上与右上两图中的红色实线 表示生成观测数据模型参数的真值,红色虚线 表示利用观测数据显式计算得到的模型参数的估计值,蓝色实线 表示利用变分推断迭代求解的参数估计值。在本例子中,由于推断公式较简单,E [ μ ] 的计算不依赖于E [ τ ] ,所以算法在第二次迭代就已经收敛到显式计算的估计值。Result Convergence
例子-混合高斯模型 模型推断 代码示例 参考
Christopher M. Bishop. “Pattern recognition and machine learning.” Springer , 2006.
PRML Errata 1st: https://www.microsoft.com/en-us/research/wp-content/uploads/2016/05/prml-errata-1st-20110921.pdf
Qibin Zhao, Liqing Zhang, and Andrzej Cichocki. “Bayesian CP factorization of incomplete tensors with automatic rank determination.” IEEE transactions on pattern analysis and machine intelligence , 2015.
https://github.com/qbzhao/BCPF
https://en.wikipedia.org/wiki/Gamma_distribution
https://en.wikipedia.org/wiki/Conjugate_prior