张量流中张量对象的非连续索引切片(高级索引,如numpy)

Posted

技术标签:

【中文标题】张量流中张量对象的非连续索引切片(高级索引,如numpy)【英文标题】:non continuous index slicing on tensor object in tensorflow (Advanced indexing like numpy) 【发布时间】:2019-10-31 09:11:17 【问题描述】:

我研究了 tensorflow 中不同的切片方式,即tf.gathertf.gather_nd。 在 tf.gather 中,它只是对一个维度进行切片,在 tf.gather_nd 中,它只接受一个 indices 应用于输入张量。

我需要的是不同的,我想使用两个不同的张量对输入张量进行切片;一个切片在行上,第二个切片在列上,它们的形状不一定相同。

例如:

假设这是我想要提取其中一部分的输入张量。

input_tf = tf.Variable([ [9.968594,  8.655439,  0.,        0.       ],
                         [0.,        8.3356,    0.,        8.8974   ],
                         [0.,        0.,        6.103182,  7.330564 ],
                         [6.609862,  0.,        3.0614321, 0.       ],
                         [9.497023,  0.,        3.8914037, 0.       ],
                         [0.,        8.457685,  8.602337,  0.       ],
                         [0.,        0.,        5.826657,  8.283971 ],
                         [0.,        0.,        0.,        0.       ]])

第二个是:

 rows_tf = tf.constant (
[[1, 2, 5],
 [1, 2, 5],
 [1, 2, 5],
 [1, 4, 6],
 [1, 4, 6],
 [2, 3, 6],
 [2, 3, 6],
 [2, 4, 7]])

第三张量:

columns_tf = tf.constant(
[[1],
 [2],
 [3],
 [2],
 [3],
 [2],
 [3],
 [2]])

现在,我想使用rows_tfcolumns_tfinput_tf 进行切片。在行中索引[1 2 5],在columns_tf 中索引[1]。同样,[1 2 5][2]columns_tf 中的行。

或者,[1 4 6][2]

总的来说,rows_tf 中的每个索引,与columns_tf 中的相同索引都会提取input_tf 的一部分。

因此,预期的输出将是:

[[8.3356,    0.,        8.457685 ],
 [0.,        6.103182,  8.602337 ],
 [8.8974,    7.330564,  0.       ],
 [0.,        3.8914037, 5.826657 ],
 [8.8974,    0.,        8.283971 ],
 [6.103182,  3.0614321, 5.826657 ],
 [7.330564,  0.,        8.283971 ],
 [6.103182,  3.8914037, 0.       ]]

例如,这里的第一行 [8.3356, 0., 8.457685 ] 正在使用

提取
rows in rows_tf [1,2,5] and column in columns_tf [1](row 1 and column 1, row 2 and column 1 and row 5 and column 1 in the input_tf)

有几个关于 tensorflow 切片的问题,尽管他们使用了 tf.gathertf.gather_ndtf.stack,但没有给出我想要的输出。

无需提及,在numpy 中,我们可以通过调用:input_tf[rows_tf, columns_tf] 轻松做到这一点。

我也看过这个高级索引,它试图模拟 numpy 中可用的高级索引,但它仍然不像 numpy 灵活https://github.com/SpinachR/ubuntuTest/blob/master/beautifulCodes/tensorflow_advanced_index_slicing.ipynb

这是我尝试过的不正确的:

tf.gather(tf.transpose(tf.gather(input_tf,rows_tf)),columns_tf)

这段代码的维度输出是(8, 1, 3, 8),完全不正确。

提前致谢!

【问题讨论】:

您应该编辑您的问题,以便正确格式化所有常量(添加 , @DSC 你是对的,我现在就做,谢谢 为什么您对您提到的“收集”操作的输出不满意?听起来它可以工作。是因为它返回它变平吗?如果是这样,您可以知道“rows_tf”和“columns_tf”的尺寸来重塑它 在***.com/questions/56640222/… 中添加scatter_idx 后,您应该可以使用tf.gather_nd(params=input_tf, indices=scatter_idx) 后跟tf.reshape 来获得所需的形状。 我复制了下面的完整代码作为答案,您可以使用其他线程中任何其他答案的类似方式来获取sparse_indices 【参考方案1】:

这个想法是首先将稀疏索引(通过连接行索引和列索引)作为一个列表。然后您可以使用gather_nd 来检索值。


tf.reset_default_graph()
input_tf = tf.Variable([ [9.968594,  8.655439,  0.,        0.       ],
                         [0.,        8.3356,    0.,        8.8974   ],
                         [0.,        0.,        6.103182,  7.330564 ],
                         [6.609862,  0.,        3.0614321, 0.       ],
                         [9.497023,  0.,        3.8914037, 0.       ],
                         [0.,        8.457685,  8.602337,  0.       ],
                         [0.,        0.,        5.826657,  8.283971 ],
                         [0.,        0.,        0.,        0.       ]])
rows_tf = tf.constant (
[[1, 2, 5],
 [1, 2, 5],
 [1, 2, 5],
 [1, 4, 6],
 [1, 4, 6],
 [2, 3, 6],
 [2, 3, 6],
 [2, 4, 7]])
columns_tf = tf.constant(
[[1],
 [2],
 [3],
 [2],
 [3],
 [2],
 [3],
 [2]])
rows_tf = tf.reshape(rows_tf, shape=[-1, 1])
columns_tf = tf.reshape(
    tf.tile(columns_tf, multiples=[1, 3]), 
    shape=[-1, 1])
sparse_indices = tf.reshape(
    tf.concat([rows_tf, columns_tf], axis=-1), 
    shape=[-1, 2])

v = tf.gather_nd(input_tf, sparse_indices)
v = tf.reshape(v, [-1, 3])

with tf.Session() as sess:
  sess.run(tf.initialize_all_variables())
  #print 'rows\n', sess.run(rows_tf)
  #print 'columns\n', sess.run(columns_tf)
  print sess.run(v)

结果是:

[[ 8.3355999   0.          8.45768547]
 [ 0.          6.10318184  8.60233688]
 [ 8.8973999   7.33056402  0.        ]
 [ 0.          3.89140368  5.82665682]
 [ 8.8973999   0.          8.28397083]
 [ 6.10318184  3.06143212  5.82665682]
 [ 7.33056402  0.          8.28397083]
 [ 6.10318184  3.89140368  0.        ]]

【讨论】:

我真诚地感谢您的帮助:) np,玩得开心 tf. :) 我在这里有一个困惑,这听起来像是一个逻辑问题。正是这段代码运行良好,当我在我的程序中包含这段代码时,它输出不正确。我找到了问题的根源,但这对我来说没有意义。我想知道你是否对此有任何想法。 rows_tfcolumns_tftf.constant' here, but in my code they are the result of some other process. if I give it tf.constant 上给出,就像这里一样,它工作得很好,但是一旦我改变它而不是 tf.constant,输出将与我上面报告的一样不正确。为什么会这样? 你得到rows_tf, columns_tf 正确吗?您可能希望包含所有代码以进行澄清。随时在此处更新您的问题***.com/questions/56568981/… 或者开始一个新线程。我很高兴能尽快看看。 很高兴知道这一点。

以上是关于张量流中张量对象的非连续索引切片(高级索引,如numpy)的主要内容,如果未能解决你的问题,请参考以下文章

当切片本身是张量流中的张量时如何进行切片分配

在 Tensorflow 中使用索引对张量进行切片

如何在张量流中张量的某些索引处插入某些值?

如何在张量流中使用索引数组?

[TensorFlow系列-17]:TensorFlow基础 - 张量的索引与切片

[PyTroch系列-17]:PyTorch基础 - 张量的索引与切片