如何在张量流中使用索引数组?

Posted

技术标签:

【中文标题】如何在张量流中使用索引数组?【英文标题】:How can I use the index array in tensorflow? 【发布时间】:2017-07-26 03:46:04 【问题描述】:

如果给定一个形状为(5,3)的矩阵a和形状为(5,)的索引数组b,我们可以很容易地得到对应的向量c通过,

c = a[np.arange(5), b]

但是,我不能用 tensorflow 做同样的事情,

a = tf.placeholder(tf.float32, shape=(5, 3))
b = tf.placeholder(tf.int32, [5,])
# this line throws error
c = a[tf.range(5), b]

Traceback(最近一次调用最后一次):文件“”,第 1 行,in 文件 "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", 第 513 行,在 _SliceHelper 名称=名称)

文件 "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", 第 671 行,在 strided_slice 中 shrink_axis_mask=shrink_axis_mask)文件“~/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py”, 第 3688 行,在 strided_slice 中 收缩轴掩码=收缩轴掩码,名称=名称)文件“~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py”, 第 763 行,在 apply_op 中 op_def=op_def) 文件 "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", 第 2397 行,在 create_op 中 set_shapes_for_outputs(ret) 文件“~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py”, 第 1757 行,在 set_shapes_for_outputs 形状 = shape_func(op) 文件“~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py”, 第 1707 行,在 call_with_requiring 中 返回 call_cpp_shape_fn(op, require_shape_fn=True) 文件 "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", 第 610 行,在 call_cpp_shape_fn 中 debug_python_shape_fn, require_shape_fn) 文件 "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", 第 675 行,在 _call_cpp_shape_fn_impl raise ValueError(err.message) ValueError: Shape must be rank 1 but is rank 2 for 'strided_slice_14' (op: 'StridedSlice') with input 形状:[5,3]、[2,5]、[2,5]、[2]。

我的问题是,如果使用上述方法在 tensorflow 中无法像在 numpy 中那样产生预期的结果,我该怎么办?

【问题讨论】:

【参考方案1】:

TensorFlow 目前未实现此功能。 GitHub issue #4638 正在跟踪 NumPy 风格的“高级”索引的实现。但是,您可以使用tf.gather_nd() 运算符来实现您的程序:

a = tf.placeholder(tf.float32, shape=(5, 3))
b = tf.placeholder(tf.int32, (5,))

row_indices = tf.range(5)

# `indices` is a 5 x 2 matrix of coordinates into `a`.
indices = tf.transpose([row_indices, b])

c = tf.gather_nd(a, indices)

【讨论】:

以上是关于如何在张量流中使用索引数组?的主要内容,如果未能解决你的问题,请参考以下文章

如何根据张量流中的列条件获取张量值的索引

当切片本身是张量流中的张量时如何进行切片分配

张量流中张量对象的非连续索引切片(高级索引,如numpy)

如何在张量流中对张量进行子集化?

在张量流中,如何迭代存储在张量中的输入序列?

在 Tensorflow 中使用索引对张量进行切片