张量流中张量对象的非连续索引切片(高级索引,如numpy)
Posted
技术标签:
【中文标题】张量流中张量对象的非连续索引切片(高级索引,如numpy)【英文标题】:non continuous index slicing on tensor object in tensorflow (Advanced indexing like numpy) 【发布时间】:2019-10-31 09:11:17 【问题描述】:我研究了 tensorflow 中不同的切片方式,即tf.gather
和tf.gather_nd
。
在 tf.gather 中,它只是对一个维度进行切片,在 tf.gather_nd
中,它只接受一个 indices
应用于输入张量。
我需要的是不同的,我想使用两个不同的张量对输入张量进行切片;一个切片在行上,第二个切片在列上,它们的形状不一定相同。
例如:
假设这是我想要提取其中一部分的输入张量。
input_tf = tf.Variable([ [9.968594, 8.655439, 0., 0. ],
[0., 8.3356, 0., 8.8974 ],
[0., 0., 6.103182, 7.330564 ],
[6.609862, 0., 3.0614321, 0. ],
[9.497023, 0., 3.8914037, 0. ],
[0., 8.457685, 8.602337, 0. ],
[0., 0., 5.826657, 8.283971 ],
[0., 0., 0., 0. ]])
第二个是:
rows_tf = tf.constant (
[[1, 2, 5],
[1, 2, 5],
[1, 2, 5],
[1, 4, 6],
[1, 4, 6],
[2, 3, 6],
[2, 3, 6],
[2, 4, 7]])
第三张量:
columns_tf = tf.constant(
[[1],
[2],
[3],
[2],
[3],
[2],
[3],
[2]])
现在,我想使用rows_tf
和columns_tf
对input_tf
进行切片。在行中索引[1 2 5]
,在columns_tf
中索引[1]
。同样,[1 2 5]
和 [2]
在 columns_tf
中的行。
或者,[1 4 6]
和 [2]
。
总的来说,rows_tf
中的每个索引,与columns_tf
中的相同索引都会提取input_tf
的一部分。
因此,预期的输出将是:
[[8.3356, 0., 8.457685 ],
[0., 6.103182, 8.602337 ],
[8.8974, 7.330564, 0. ],
[0., 3.8914037, 5.826657 ],
[8.8974, 0., 8.283971 ],
[6.103182, 3.0614321, 5.826657 ],
[7.330564, 0., 8.283971 ],
[6.103182, 3.8914037, 0. ]]
例如,这里的第一行 [8.3356, 0., 8.457685 ]
正在使用
rows in rows_tf [1,2,5] and column in columns_tf [1](row 1 and column 1, row 2 and column 1 and row 5 and column 1 in the input_tf)
有几个关于 tensorflow 切片的问题,尽管他们使用了 tf.gather
或 tf.gather_nd
和 tf.stack
,但没有给出我想要的输出。
无需提及,在numpy
中,我们可以通过调用:input_tf[rows_tf, columns_tf]
轻松做到这一点。
我也看过这个高级索引,它试图模拟 numpy 中可用的高级索引,但它仍然不像 numpy 灵活https://github.com/SpinachR/ubuntuTest/blob/master/beautifulCodes/tensorflow_advanced_index_slicing.ipynb
这是我尝试过的不正确的:
tf.gather(tf.transpose(tf.gather(input_tf,rows_tf)),columns_tf)
这段代码的维度输出是(8, 1, 3, 8)
,完全不正确。
提前致谢!
【问题讨论】:
您应该编辑您的问题,以便正确格式化所有常量(添加,
)
@DSC 你是对的,我现在就做,谢谢
为什么您对您提到的“收集”操作的输出不满意?听起来它可以工作。是因为它返回它变平吗?如果是这样,您可以知道“rows_tf”和“columns_tf”的尺寸来重塑它
在***.com/questions/56640222/… 中添加scatter_idx
后,您应该可以使用tf.gather_nd(params=input_tf, indices=scatter_idx)
后跟tf.reshape
来获得所需的形状。
我复制了下面的完整代码作为答案,您可以使用其他线程中任何其他答案的类似方式来获取sparse_indices
。
【参考方案1】:
这个想法是首先将稀疏索引(通过连接行索引和列索引)作为一个列表。然后您可以使用gather_nd
来检索值。
tf.reset_default_graph()
input_tf = tf.Variable([ [9.968594, 8.655439, 0., 0. ],
[0., 8.3356, 0., 8.8974 ],
[0., 0., 6.103182, 7.330564 ],
[6.609862, 0., 3.0614321, 0. ],
[9.497023, 0., 3.8914037, 0. ],
[0., 8.457685, 8.602337, 0. ],
[0., 0., 5.826657, 8.283971 ],
[0., 0., 0., 0. ]])
rows_tf = tf.constant (
[[1, 2, 5],
[1, 2, 5],
[1, 2, 5],
[1, 4, 6],
[1, 4, 6],
[2, 3, 6],
[2, 3, 6],
[2, 4, 7]])
columns_tf = tf.constant(
[[1],
[2],
[3],
[2],
[3],
[2],
[3],
[2]])
rows_tf = tf.reshape(rows_tf, shape=[-1, 1])
columns_tf = tf.reshape(
tf.tile(columns_tf, multiples=[1, 3]),
shape=[-1, 1])
sparse_indices = tf.reshape(
tf.concat([rows_tf, columns_tf], axis=-1),
shape=[-1, 2])
v = tf.gather_nd(input_tf, sparse_indices)
v = tf.reshape(v, [-1, 3])
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
#print 'rows\n', sess.run(rows_tf)
#print 'columns\n', sess.run(columns_tf)
print sess.run(v)
结果是:
[[ 8.3355999 0. 8.45768547]
[ 0. 6.10318184 8.60233688]
[ 8.8973999 7.33056402 0. ]
[ 0. 3.89140368 5.82665682]
[ 8.8973999 0. 8.28397083]
[ 6.10318184 3.06143212 5.82665682]
[ 7.33056402 0. 8.28397083]
[ 6.10318184 3.89140368 0. ]]
【讨论】:
我真诚地感谢您的帮助:) np,玩得开心 tf. :) 我在这里有一个困惑,这听起来像是一个逻辑问题。正是这段代码运行良好,当我在我的程序中包含这段代码时,它输出不正确。我找到了问题的根源,但这对我来说没有意义。我想知道你是否对此有任何想法。rows_tf
和 columns_tf
在 tf.constant' here, but in my code they are the result of some other process. if I give it
tf.constant 上给出,就像这里一样,它工作得很好,但是一旦我改变它而不是 tf.constant
,输出将与我上面报告的一样不正确。为什么会这样?
你得到rows_tf
, columns_tf
正确吗?您可能希望包含所有代码以进行澄清。随时在此处更新您的问题***.com/questions/56568981/… 或者开始一个新线程。我很高兴能尽快看看。
很高兴知道这一点。以上是关于张量流中张量对象的非连续索引切片(高级索引,如numpy)的主要内容,如果未能解决你的问题,请参考以下文章