张量流图模式下 tfp.distributions.Categorical.log_prob 的解决方法/备用值

Posted

技术标签:

【中文标题】张量流图模式下 tfp.distributions.Categorical.log_prob 的解决方法/备用值【英文标题】:Workaround / fallback value for tfp.distributions.Categorical.log_prob in tensorflow graph mode 【发布时间】:2021-08-14 00:26:29 【问题描述】:

如果输入的标签超出范围,有没有办法避免tfp.distributions.Categorical.log_prob引发错误?

我将一批样本传递给log_prob 方法,其中一些具有n_categories + 1 的值,这是您从全零概率分布中采样时得到的后备值。我的probs 批次中的一些概率分布全为零**。

dec_output, h_state, c_state = self.decoder(dec_inp, [h_state, c_state])
probs = self.attention(enc_output, dec_output, pointer_mask, len_mask)
distr = tfp.distributions.Categorical(probs=probs)
pointer = distr.sample()
log_prob = distr.log_prob(pointer) # log of the probability of choosing that action

我不在乎在这些情况下我从log_prob 得到什么价值,因为稍后我会屏蔽它而不使用它。不确定是否可以以某种方式实现 fallback 值。如果没有,是否有任何解决方法可以避免在我以图形模式(使用@tf.function)执行时引发错误?

**这是因为我正在使用 RNN 进行随机解码,该 RNN 是一批可变长度的序列,一个 seq to seq 任务。

【问题讨论】:

【参考方案1】:

如果您可以屏蔽 log_prob,您也可以将 probs 屏蔽为 1 / n。 请注意,使用 Categorical 的 logits 参数化并放弃(可能)上游 softmax 激活在数值上更稳定。

【讨论】:

以上是关于张量流图模式下 tfp.distributions.Categorical.log_prob 的解决方法/备用值的主要内容,如果未能解决你的问题,请参考以下文章

在 Android 上使用来自冻结的张量流图的变量

pytorch 可以优化顺序操作(如张量流图或 JAX 的 jit)吗?

如何找出冻结的张量流图的正确输入和输出操作?

Tensorflow瞎搞

Tensorflow(4) 张量属性:维数、形状、数据类型

tensorflow中张量(tensor)的属性——维数(阶)形状和数据类型