张量流图模式下 tfp.distributions.Categorical.log_prob 的解决方法/备用值
Posted
技术标签:
【中文标题】张量流图模式下 tfp.distributions.Categorical.log_prob 的解决方法/备用值【英文标题】:Workaround / fallback value for tfp.distributions.Categorical.log_prob in tensorflow graph mode 【发布时间】:2021-08-14 00:26:29 【问题描述】:如果输入的标签超出范围,有没有办法避免tfp.distributions.Categorical.log_prob
引发错误?
我将一批样本传递给log_prob
方法,其中一些具有n_categories + 1
的值,这是您从全零概率分布中采样时得到的后备值。我的probs
批次中的一些概率分布全为零**。
dec_output, h_state, c_state = self.decoder(dec_inp, [h_state, c_state])
probs = self.attention(enc_output, dec_output, pointer_mask, len_mask)
distr = tfp.distributions.Categorical(probs=probs)
pointer = distr.sample()
log_prob = distr.log_prob(pointer) # log of the probability of choosing that action
我不在乎在这些情况下我从log_prob
得到什么价值,因为稍后我会屏蔽它而不使用它。不确定是否可以以某种方式实现 fallback
值。如果没有,是否有任何解决方法可以避免在我以图形模式(使用@tf.function
)执行时引发错误?
**这是因为我正在使用 RNN 进行随机解码,该 RNN 是一批可变长度的序列,一个 seq to seq 任务。
【问题讨论】:
【参考方案1】:如果您可以屏蔽 log_prob,您也可以将 probs 屏蔽为 1 / n。 请注意,使用 Categorical 的 logits 参数化并放弃(可能)上游 softmax 激活在数值上更稳定。
【讨论】:
以上是关于张量流图模式下 tfp.distributions.Categorical.log_prob 的解决方法/备用值的主要内容,如果未能解决你的问题,请参考以下文章