tf.gather和tf.gather_nd的详细用法--tensorflow通过索引取tensor里的数据

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tf.gather和tf.gather_nd的详细用法--tensorflow通过索引取tensor里的数据相关的知识,希望对你有一定的参考价值。

参考技术A 在numpy里取矩阵数据非常方便,比如:

这样就把矩阵a中的1,3,5行取出来了。

如果是只取某一维中单个索引的数据可以直接写成 tensor[:, 2] , 但如果要提取的索引不连续的话,在tensorflow里面的用法就要用到tf.gather.

tf.gather_nd允许在多维上进行索引:
matrix中直接通过坐标取数(索引维度与tensor维度相同):

取第二行和第一行:

3维tensor的结果:

另外还有tf.batch_gather的用法如下:
tf.batch_gather(params, indices, name=None)
Gather slices from params according to indices with leading batch dims.

This operation assumes that the leading dimensions of indices are dense,
and the gathers on the axis corresponding to the last dimension of indices .

Therefore params should be a Tensor of shape [A1, ..., AN, B1, ..., BM],
indices should be a Tensor of shape [A1, ..., AN-1, C] and result will be
a Tensor of size [A1, ..., AN-1, C, B1, ..., BM] .

如果索引是一维的tensor,结果和 tf.gather 是一样的.

tf实现用二维的索引从二维数组获取对应值 tf.gather_nd

a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
inds = tf.constant([[0, 2], [2, 1], [1, 1]])

#目的是实现 从[1,2,3]获取index为[0,2]的值也就是[1,3]作为第一行,
从[4,5,6]获取index为[2,1]的值也就是[6,5]作为第二行, 
从[7,8,9]获取index[1,1]的值作为第三行,也就是输出是
[[1 3]
 [6 5]
 [8 8]]






这种需求应该很常见,但是想通过look_up_table好像不行,以及想通过tf.gather_fn似乎可以但是也不好写

本文提供一种写法:

import tensorflow as tf

def gather_batch(v, inds):
    return tf.gather(v, inds)

def test2():
    a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    inds = tf.constant([[0, 2], [2, 1], [1, 1]])
    vs = tf.map_fn(fn=lambda x: gather_batch(x[:3], x[3:]), elems=tf.concat([a, inds], 1))

    with tf.Session() as sess:
        print(sess.run(vs))
 

if __name__ == '__main__':
    # test1()
    test2()

 

但是上面写法还是用了循环 会很慢 所以更好写法

def test3():
    a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    inds = tf.constant([[0, 2], [2, 1], [1, 1]])
    batch_size = inds.shape[0]
    cnt = inds.shape[1]
    left_inds = tf.tile(
        tf.expand_dims(tf.range(batch_size), 1),
        [1, cnt]
    )
    ind = tf.squeeze(
        tf.stack(
            [
                tf.expand_dims(left_inds, 2),
                tf.expand_dims(inds, 2),
            ],
            2
        )
        ,-1
    )

    vs = tf.gather_nd(a, ind)
    with tf.Session() as sess:
        # print(sess.run(ind))
        print(sess.run(vs))

 

 

以上是关于tf.gather和tf.gather_nd的详细用法--tensorflow通过索引取tensor里的数据的主要内容,如果未能解决你的问题,请参考以下文章

tf实现用二维的索引从二维数组获取对应值 tf.gather_nd

张量流中张量对象的非连续索引切片(高级索引,如numpy)

获取tensorflow中tensor的值

tensorflow 高级函数 where,gather,gather_nd

Tensorflow进阶

tensorflow切片