tf实现用二维的索引从二维数组获取对应值 tf.gather_nd

Posted Z-Pilgrim

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tf实现用二维的索引从二维数组获取对应值 tf.gather_nd相关的知识,希望对你有一定的参考价值。

a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
inds = tf.constant([[0, 2], [2, 1], [1, 1]])

#目的是实现 从[1,2,3]获取index为[0,2]的值也就是[1,3]作为第一行,
从[4,5,6]获取index为[2,1]的值也就是[6,5]作为第二行, 
从[7,8,9]获取index[1,1]的值作为第三行,也就是输出是
[[1 3]
 [6 5]
 [8 8]]






这种需求应该很常见,但是想通过look_up_table好像不行,以及想通过tf.gather_fn似乎可以但是也不好写

本文提供一种写法:

import tensorflow as tf

def gather_batch(v, inds):
    return tf.gather(v, inds)

def test2():
    a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    inds = tf.constant([[0, 2], [2, 1], [1, 1]])
    vs = tf.map_fn(fn=lambda x: gather_batch(x[:3], x[3:]), elems=tf.concat([a, inds], 1))

    with tf.Session() as sess:
        print(sess.run(vs))
 

if __name__ == '__main__':
    # test1()
    test2()

 

但是上面写法还是用了循环 会很慢 所以更好写法

def test3():
    a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    inds = tf.constant([[0, 2], [2, 1], [1, 1]])
    batch_size = inds.shape[0]
    cnt = inds.shape[1]
    left_inds = tf.tile(
        tf.expand_dims(tf.range(batch_size), 1),
        [1, cnt]
    )
    ind = tf.squeeze(
        tf.stack(
            [
                tf.expand_dims(left_inds, 2),
                tf.expand_dims(inds, 2),
            ],
            2
        )
        ,-1
    )

    vs = tf.gather_nd(a, ind)
    with tf.Session() as sess:
        # print(sess.run(ind))
        print(sess.run(vs))

 

 

以上是关于tf实现用二维的索引从二维数组获取对应值 tf.gather_nd的主要内容,如果未能解决你的问题,请参考以下文章

python使用np.argsort对一维numpy概率值数据排序获取倒序索引获取的top索引(例如top2top5top10)索引二维numpy数组中对应的原始数据:原始数据概率最大的头部数据

python使用np.argsort对一维numpy概率值数据排序获取升序索引获取的top索引(例如top2top5top10)索引二维numpy数组中对应的原始数据:原始数据概率最小的头部数据

FLASH AS3 二维数组如何查找某个元素的索引?

Labview中 怎么获取 波形数据中 其Y最小值对应的X轴坐标,谢谢,

c++输入一个5行5列的二维数组,求最大值和最小值其对应行列的位置。。

提取二维数组中最小点的索引