tensorflow从列表中收集类似的值
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow从列表中收集类似的值相关的知识,希望对你有一定的参考价值。
我有一个张量如下:
arr = [[1.5,0.2],[2.3,0.1],[1.3,0.21],[2.2,0.09],[4.4,0.8]]
我想收集小数组,其第一个元素的差异在0.3之内,第二个元素在0.03之内。例如,[1.5,0.2]和[1.3,0.21]应属于同一类别。它们的第一元素的差异是0.2 <0.3并且第二元素0.01 <0.03。
我想张力看起来像这样
arr = {[[1.5,0.2],[1.3,0.21]],[[2.3,0.1],[2.2,0.09]]}
如何在tensorflow中执行此操作?渴望模式还可以。
我找到了一种有点丑陋和缓慢的方法:
samples = np.array([[1.5,0.2],[2.3,0.1],[1.3,0.2],[2.2,0.09],[4.4,0.8],[2.3,0.11]],dtype=np.float32)
ini_samples = samples
samples = tf.split(samples,2,1)
a = samples[0]
b = samples[1]
find_match1 = tf.reduce_sum(tf.abs(tf.expand_dims(a,0) - tf.expand_dims(a,1)),2)
a = tf.logical_and(tf.greater(find_match1, tf.zeros_like(find_match1)),tf.less(find_match1, 0.3*tf.ones_like(find_match1)))
find_match2 = tf.reduce_sum(tf.abs(tf.expand_dims(b,0) - tf.expand_dims(b,1)),2)
b = tf.logical_and(tf.greater(find_match2, tf.zeros_like(find_match2)),tf.less(find_match2, 0.03*tf.ones_like(find_match2)))
x,y = tf.unique(tf.reshape(tf.where(tf.logical_or(a,b)),[1,-1])[0])
r = tf.gather(ini_samples, x)
张量流是否具有更优雅的功能?
答案
您无法获得由具有不同大小的向量的“组”组成的结果。相反,您可以创建一个“组ID”张量,根据您的标准将每个向量分类为一个组。使这个更复杂的部分是你必须用共同的元素“融合”组,我认为只能通过循环来完成。这段代码做了类似的事情:
import tensorflow as tf
def make_groups(correspondences):
# Multiply each row by its index
m = tf.to_int32(correspondences) * tf.range(tf.shape(correspondences)[0])
# Pick the largest index for each row
r = tf.reduce_max(m, axis=1)
# While loop accounts for transitive correspondences
# (e.g. if A and B go toghether and B and C go together, then A, B and C go together)
# The loop makes sure every element gets the largest common group id
r_prev = -tf.ones_like(r)
r, _ = tf.while_loop(lambda r, r_prev: tf.reduce_any(tf.not_equal(r, r_prev)),
lambda r, r_prev: (tf.gather(r, r), tf.identity(r)),
[r, r_prev])
# Use unique indices to make sequential group ids starting from 0
return tf.unique(r)[1]
# Test
with tf.Graph().as_default(), tf.Session() as sess:
arr = tf.constant([[1.5 , 0.2 ],
[2.3 , 0.1 ],
[1.3 , 0.21],
[2.2 , 0.09],
[4.4 , 0.8 ],
[1.1 , 0.23]])
a = arr[:, 0]
b = arr[:, 0]
cond = (tf.abs(a - a[:, tf.newaxis]) < 0.3) | (tf.abs(b - b[:, tf.newaxis]) < 0.03)
groups = make_groups(cond)
print(sess.run(groups))
# [0 1 0 1 2 0]
所以在这种情况下,这些组将是:
[1.5, 0.2]
,[1.3, 0.21]
和[1.1, 0.23]
[2.3, 0.1]
和[2.2, 0.09]
[4.4, 0.8]
以上是关于tensorflow从列表中收集类似的值的主要内容,如果未能解决你的问题,请参考以下文章