如何根据张量流中的列条件获取张量值的索引

Posted

技术标签:

【中文标题】如何根据张量流中的列条件获取张量值的索引【英文标题】:how to get indices of a tensor values based on a column condition in tensorflow 【发布时间】:2019-12-28 02:18:30 【问题描述】:

我有一个这样的张量:

sim_topics = [[0.65 0.   0.   0.   0.42  0.   0.   0.51 0.   0.34 0.]
              [0.   0.51 0.   0.   0.52  0.   0.   0.   0.53 0.42 0.]
              [0.   0.32 0.   0.50 0.34  0.   0.   0.39 0.32 0.52 0.]
              [0.   0.23 0.37 0.   0.    0.37 0.37 0.   0.47 0.39 0.3 ]]

我想根据张量条件获取该张量中的索引:

masked_t = [True  False  True  False True True False True False True False]

所以输出应该是这样的:

[[0.65 0. 0.   0.   0.42  0.   0.   0.51 0.   0.34 0.]
 [0.   0. 0.   0.   0.52  0.   0.   0.   0.   0.42 0.]
 [0.   0. 0.   0.   0.34  0.   0.   0.39 0.   0.52 0.]
 [0.   0. 0.37 0.   0.    0.37 0.   0.   0.   0.39 0.]]

所以条件对初始张量的列起作用。实际上我需要在maske_t 中为真的元素的索引。

所以索引应该是:

[[0, 0],
 [1,0],
 [2, 0],
 [3,0],
 [0,2],
 [1,2],
 [2,2],
 [3,2],
 ....]]

实际上,这种方法在我按行执行时有效,但在这里我想根据条件选择特定列,因此会引发不兼容错误:

out = tf.cast(tf.zeros(shape=tf.shape(sim_topics), dtype=tf.float64), tf.float64)
indices = tf.where(tf.where(masked_t, out, sim_topics))

【问题讨论】:

【参考方案1】:

你可以像这样直接获取你需要的张量:

result = tf.multiply(sim_topics, tf.cast(masked_t, dtype=tf.float64))

让广播完成 masked_t 与 sim_topics 大小相同的工作

【讨论】:

以上是关于如何根据张量流中的列条件获取张量值的索引的主要内容,如果未能解决你的问题,请参考以下文章

由张量流中的索引张量指定的切片二维张量

当切片本身是张量流中的张量时如何进行切片分配

张量流中的条件图和访问张量大小的for循环

在张量流中,如何迭代存储在张量中的输入序列?

如何使用条件索引在单元格上获取标量值

如何在张量流中张量的某些索引处插入某些值?