如何让 HMM 在 TensorFlow 中处理实值数据
Posted
技术标签:
【中文标题】如何让 HMM 在 TensorFlow 中处理实值数据【英文标题】:How to get HMM working with real-valued data in Tensorflow 【发布时间】:2021-03-07 15:02:51 【问题描述】:我正在使用一个包含来自 IoT 设备的数据的数据集,我发现隐马尔可夫模型非常适合我的用例。因此,我正在尝试更改我发现 here 的 Tensorflow 教程中的一些代码。与教程中显示的计数数据相比,该数据集包含观察变量的实数值。
特别是,我认为需要更改以下内容,以使 HMM 具有正态分布的排放。不幸的是,我找不到任何关于如何改变模型以产生不同于泊松的发射的任何代码。
我应该如何更改代码以发出正态分布的值?
# Define variable to represent the unknown log rates.
trainable_log_rates = tf.Variable(
np.log(np.mean(observed_counts)) + tf.random.normal([num_states]),
name='log_rates')
hmm = tfd.HiddenMarkovModel(
initial_distribution=tfd.Categorical(
logits=initial_state_logits),
transition_distribution=tfd.Categorical(probs=transition_probs),
observation_distribution=tfd.Poisson(log_rate=trainable_log_rates),
num_steps=len(observed_counts))
rate_prior = tfd.LogNormal(5, 5)
def log_prob():
return (tf.reduce_sum(rate_prior.log_prob(tf.math.exp(trainable_log_rates))) +
hmm.log_prob(observed_counts))
optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)
@tf.function(autograph=False)
def train_op():
with tf.GradientTape() as tape:
neg_log_prob = -log_prob()
grads = tape.gradient(neg_log_prob, [trainable_log_rates])[0]
optimizer.apply_gradients([(grads, trainable_log_rates)])
return neg_log_prob, tf.math.exp(trainable_log_rates)
【问题讨论】:
对不起,如果这很明显......但你不能将正态分布传递给observation_distribution
吗? (例如MultivariateNormalDiag 或 MultivariateNormalTriL)
@rvinas 不幸的是,有些功能需要在示例中更改
什么功能?如果你能说明确切的问题是什么,我可能会提供帮助
【参考方案1】:
示例模型假定排放量x
是泊松分布的,其中四种比率是由潜在变量z
确定的。因此,它定义了可训练率(或对数率),定义了在 z
上具有均匀初始分布的 HMM、转换概率,以及来自泊松分布的观察值,对数率由可训练的概率给出。
为了更改为正态分布,您是说x
应该是正态分布,其可训练均值和标准差由潜在变量z
确定。因此,您需要将trainable_log_rates
替换为trainable_loc
和trainable_scale
并更改
observation_distribution=tfd.Poisson(log_rate=trainable_log_rates)
到
observation_distribution=tfd.Normal(loc=trainable_loc, scale=trainable_scale)
然后您需要用您选择的loc_prior
和scale_prior
替换您的rate_prior
,并使用它们来计算您的新log_prob
函数。
【讨论】:
以上是关于如何让 HMM 在 TensorFlow 中处理实值数据的主要内容,如果未能解决你的问题,请参考以下文章
Keras还是TensorFlow?深度学习框架选型实操分享
推荐阅读 | 如何让TensorFlow模型运行提速36.8%(续)