如何使用 TensorFlow 张量索引列表?
Posted
技术标签:
【中文标题】如何使用 TensorFlow 张量索引列表?【英文标题】:How to index a list with a TensorFlow tensor? 【发布时间】:2017-05-02 09:57:11 【问题描述】:假设一个列表包含需要通过查找表访问的不可连接对象。所以列表索引将是一个张量对象,但这是不可能的。
tf_look_up = tf.constant(np.array([3, 2, 1, 0, 4]))
index = tf.constant(2)
list = [0,1,2,3,4]
target = list[tf_look_up[index]]
这将显示以下错误消息。
TypeError: list indices must be integers or slices, not Tensor
是用张量索引列表的方法/解决方法吗?
【问题讨论】:
先使用sess.run将张量转为numpy @YaroslavBulatov 如果列表是动态生成的,例如RNN 产生的状态。任何方式列表动态索引都可以工作吗? 也许tf.gather
就像@soloice 的回答一样?
【参考方案1】:
tf.gather
就是为此目的而设计的。
只需运行tf.gather(list, tf_look_up[index])
,你就会得到你想要的。
【讨论】:
如果列表类似于[[1], [2, 3], [4]]
怎么办?【参考方案2】:
Tensorflow 实际上支持HashTable
。有关详细信息,请参阅documentation。
在这里,您可以执行以下操作:
table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(tf_look_up, list), -1)
然后通过运行获得所需的输入
target = table.lookup(index)
请注意,如果找不到密钥,-1
是默认值。根据张量的配置,您可能必须将key_dtype
和value_dtype
添加到构造函数中。
【讨论】:
以上是关于如何使用 TensorFlow 张量索引列表?的主要内容,如果未能解决你的问题,请参考以下文章