EM学习-思想和代码
Posted 今夜无风
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了EM学习-思想和代码相关的知识,希望对你有一定的参考价值。
EM算法的简明实现
当然是教学用的简明实现了,这份实现是针对双硬币模型的。
双硬币模型
假设有两枚硬币A、B,以相同的概率随机选择一个硬币,进行如下的抛硬币实验:共做5次实验,每次实验独立的抛十次,结果如图中a所示,例如某次实验产生了H、T、T、T、H、H、T、H、T、H,H代表正面朝上。
假设试验数据记录员可能是实习生,业务不一定熟悉,造成a和b两种情况
a表示实习生记录了详细的试验数据,我们可以观测到试验数据中每次选择的是A还是B
b表示实习生忘了记录每次试验选择的是A还是B,我们无法观测实验数据中选择的硬币是哪个
问在两种情况下分别如何估计两个硬币正面出现的概率?
a情况相信大家都很熟悉,既然能观测到试验数据是哪枚硬币产生的,就可以统计正反面的出现次数,直接利用最大似然估计即可。
b情况就无法直接进行最大似然估计了,只能用EM算法,接下来引用nipunbatra博主的简明EM算法Python实现。
1 # -*- coding: utf-8 -*- 2 """ 3 Created on Tue Jul 4 18:23:28 2017 4 5 @author: Administrator 6 """ 7 8 import numpy as np 9 from scipy import stats 10 11 priors = [0.6, 0.5] 12 observations = np.array([[1,0,0,0,1,1,0,1,0,1], 13 [1,1,1,1,0,1,1,1,1,1], 14 [1,0,1,1,1,1,1,0,1,1], 15 [1,0,1,0,0,0,1,1,0,0], 16 [0,1,1,1,0,1,1,1,0,1]]) 17 18 def em_single(priors, observations): 19 """ 20 input: 21 priors:[theta_A, theta_B] 22 obvervations:m*n matrix 23 24 output: 25 26 """ 27 theta_A = priors[0] 28 theta_B = priors[1] 29 counts = {‘A‘:{‘H‘:0,‘T‘:0}, ‘B‘:{‘H‘:0,‘T‘:0}} 30 31 # e-step 32 for observation in observations: 33 len_observation = len(observation) 34 num_heads = observation.sum() # 正面个数 35 num_tails = len_observation - num_heads # 反面个数 36 # 两个二项分布 37 contribution_A = stats.binom.pmf(num_heads, len_observation, theta_A) 38 contribution_B = stats.binom.pmf(num_heads, len_observation, theta_B) 39 # 采用各自硬币的权重 40 weight_A = contribution_A/(contribution_A+contribution_B) 41 weight_B = contribution_B/(contribution_A+contribution_B) 42 43 # 更新在当前参数下,硬币A和B产生正反面的次数 44 counts[‘A‘][‘H‘] += weight_A * num_heads 45 counts[‘A‘][‘T‘] += weight_A * num_tails 46 counts[‘B‘][‘H‘] += weight_B * num_heads 47 counts[‘B‘][‘T‘] += weight_B * num_tails 48 49 # M-step 50 new_theta_A = counts[‘A‘][‘H‘]/(counts[‘A‘][‘H‘] + counts[‘A‘][‘T‘]) 51 new_theta_B = counts[‘B‘][‘H‘]/(counts[‘B‘][‘H‘] + counts[‘B‘][‘T‘]) 52 53 return [new_theta_A, new_theta_B] 54 55 56 57 def em(observations, prior, tol=1e-6, iterations=10000): 58 """ 59 EM算法 60 param observations: 观察数据 61 param prior: 模型初值 62 param tol: 迭代结束阈值 63 param iteration: 最大迭代数 64 return: 局部最优的模型参数 65 """ 66 import math 67 iter = 0 68 while iter < iterations: 69 new_prior = em_single(prior, observations) 70 delta_change = np.abs(new_prior[0]-prior[0]) 71 if delta_change < tol: 72 break 73 else: 74 prior = new_prior 75 iter += 1 76 print (iter) 77 78 return [new_prior, iter] 79 80 y = em(observations, priors)
参考自:http://www.hankcs.com/ml/em-algorithm-and-its-generalization.html
以上是关于EM学习-思想和代码的主要内容,如果未能解决你的问题,请参考以下文章
EM算法(Expectation Maximization Algorithm)详解(附代码)---大道至简之机器学习系列---通俗理解EM算法。