如何在张量流中使用索引数组?
Posted
技术标签:
【中文标题】如何在张量流中使用索引数组?【英文标题】:How can I use the index array in tensorflow? 【发布时间】:2017-07-26 03:46:04 【问题描述】:如果给定一个形状为(5,3)
的矩阵a
和形状为(5,)
的索引数组b
,我们可以很容易地得到对应的向量c
通过,
c = a[np.arange(5), b]
但是,我不能用 tensorflow 做同样的事情,
a = tf.placeholder(tf.float32, shape=(5, 3))
b = tf.placeholder(tf.int32, [5,])
# this line throws error
c = a[tf.range(5), b]
Traceback(最近一次调用最后一次):文件“”,第 1 行,in 文件 "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", 第 513 行,在 _SliceHelper 名称=名称)
文件 "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", 第 671 行,在 strided_slice 中 shrink_axis_mask=shrink_axis_mask)文件“~/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py”, 第 3688 行,在 strided_slice 中 收缩轴掩码=收缩轴掩码,名称=名称)文件“~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py”, 第 763 行,在 apply_op 中 op_def=op_def) 文件 "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", 第 2397 行,在 create_op 中 set_shapes_for_outputs(ret) 文件“~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py”, 第 1757 行,在 set_shapes_for_outputs 形状 = shape_func(op) 文件“~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py”, 第 1707 行,在 call_with_requiring 中 返回 call_cpp_shape_fn(op, require_shape_fn=True) 文件 "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", 第 610 行,在 call_cpp_shape_fn 中 debug_python_shape_fn, require_shape_fn) 文件 "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", 第 675 行,在 _call_cpp_shape_fn_impl raise ValueError(err.message) ValueError: Shape must be rank 1 but is rank 2 for 'strided_slice_14' (op: 'StridedSlice') with input 形状:[5,3]、[2,5]、[2,5]、[2]。
我的问题是,如果使用上述方法在 tensorflow 中无法像在 numpy 中那样产生预期的结果,我该怎么办?
【问题讨论】:
【参考方案1】:TensorFlow 目前未实现此功能。 GitHub issue #4638 正在跟踪 NumPy 风格的“高级”索引的实现。但是,您可以使用tf.gather_nd()
运算符来实现您的程序:
a = tf.placeholder(tf.float32, shape=(5, 3))
b = tf.placeholder(tf.int32, (5,))
row_indices = tf.range(5)
# `indices` is a 5 x 2 matrix of coordinates into `a`.
indices = tf.transpose([row_indices, b])
c = tf.gather_nd(a, indices)
【讨论】:
以上是关于如何在张量流中使用索引数组?的主要内容,如果未能解决你的问题,请参考以下文章