从 TensorFlow 中给定的非均匀分布进行无替换采样
Posted
技术标签:
【中文标题】从 TensorFlow 中给定的非均匀分布进行无替换采样【英文标题】:Sampling without replacement from a given non-uniform distribution in TensorFlow 【发布时间】:2017-09-04 17:23:10 【问题描述】:我正在寻找类似于numpy.random.choice(range(3),replacement=False,size=2,p=[0.1,0.2,0.7])
的东西
在 TensorFlow 中。
最接近它的Op
似乎是tf.multinomial(tf.log(p))
,它将logits 作为输入,但如果没有替换它就无法采样。有没有其他方法可以从 TensorFlow 中的非均匀分布中进行采样?
谢谢。
【问题讨论】:
【参考方案1】:是的,有。有关一些背景信息,请参阅 here 和 here。解决办法是:
z = -tf.log(-tf.log(tf.random_uniform(tf.shape(p),0,1)))
_, indices = tf.nn.top_k(tf.log(p) + z, size)
【讨论】:
【参考方案2】:您可以使用 tf.py_func
包装 numpy.random.choice
并将其作为 TensorFlow 操作提供:
a = tf.placeholder(tf.float32)
size = tf.placeholder(tf.int32)
replace = tf.placeholder(tf.bool)
p = tf.placeholder(tf.float32)
y = tf.py_func(np.random.choice, [a, size, replace, p], tf.float32)
with tf.Session() as sess:
print(sess.run(y, a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]))
你可以像往常一样指定 numpy 种子:
np.random.seed(1)
print(sess.run(y, a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]))
print(sess.run(y, a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]))
print(sess.run(y, a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]))
np.random.seed(1)
print(sess.run(y, a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]))
print(sess.run(y, a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]))
print(sess.run(y, a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]))
np.random.seed(1)
print(sess.run(y, a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]))
将打印:
[ 2. 0.]
[ 2. 1.]
[ 0. 1.]
[ 2. 0.]
[ 2. 1.]
[ 0. 1.]
[ 2. 0.]
【讨论】:
使用tf.py_func
很糟糕,因为调用python 操作会极大地减慢计算速度。特别是如果您使用 GPU,则利用率可能会根据任务从 100% 下降到 5%。因为我认为没有新的操作可以做到这一点,所以最好是编写自定义 c++ 操作。但是,您可以通过迭代地从数组中采样一个元素,然后通过 tf.gather
之类的方式删除该元素来破解它。以上是关于从 TensorFlow 中给定的非均匀分布进行无替换采样的主要内容,如果未能解决你的问题,请参考以下文章