从 TensorFlow 中给定的非均匀分布进行无替换采样

Posted

技术标签:

【中文标题】从 TensorFlow 中给定的非均匀分布进行无替换采样【英文标题】:Sampling without replacement from a given non-uniform distribution in TensorFlow 【发布时间】:2017-09-04 17:23:10 【问题描述】:

我正在寻找类似于numpy.random.choice(range(3),replacement=False,size=2,p=[0.1,0.2,0.7])的东西 在 TensorFlow 中。

最接近它的Op 似乎是tf.multinomial(tf.log(p)),它将logits 作为输入,但如果没有替换它就无法采样。有没有其他方法可以从 TensorFlow 中的非均匀分布中进行采样?

谢谢。

【问题讨论】:

【参考方案1】:

是的,有。有关一些背景信息,请参阅 here 和 here。解决办法是:

z = -tf.log(-tf.log(tf.random_uniform(tf.shape(p),0,1))) 
_, indices = tf.nn.top_k(tf.log(p) + z, size)

【讨论】:

【参考方案2】:

您可以使用 tf.py_func 包装 numpy.random.choice 并将其作为 TensorFlow 操作提供:

a = tf.placeholder(tf.float32)
size = tf.placeholder(tf.int32)
replace = tf.placeholder(tf.bool)
p = tf.placeholder(tf.float32)

y = tf.py_func(np.random.choice, [a, size, replace, p], tf.float32)

with tf.Session() as sess:
    print(sess.run(y, a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]))

你可以像往常一样指定 numpy 种子:

np.random.seed(1)
print(sess.run(y, a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]))
print(sess.run(y, a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]))
print(sess.run(y, a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]))
np.random.seed(1)
print(sess.run(y, a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]))
print(sess.run(y, a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]))
print(sess.run(y, a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]))
np.random.seed(1)
print(sess.run(y, a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]))

将打印:

[ 2.  0.]
[ 2.  1.]
[ 0.  1.]
[ 2.  0.]
[ 2.  1.]
[ 0.  1.]
[ 2.  0.]

【讨论】:

使用tf.py_func 很糟糕,因为调用python 操作会极大地减慢计算速度。特别是如果您使用 GPU,则利用率可能会根据任务从 100% 下降到 5%。因为我认为没有新的操作可以做到这一点,所以最好是编写自定义 c++ 操作。但是,您可以通过迭代地从数组中采样一个元素,然后通过 tf.gather 之类的方式删除该元素来破解它。

以上是关于从 TensorFlow 中给定的非均匀分布进行无替换采样的主要内容,如果未能解决你的问题,请参考以下文章

在给定任意颜色范围的情况下均匀分布画布渐变颜色

从非均匀数据创建均匀分布的示例

漫谈“采样”(sampling)

证明一个随机生成的数是均匀分布的

如何在哈希表中均匀分布不同的键?

如何在 Postgres 中从具有非均匀分布的表中选择随机行?