TensorFlow 中类似 NumPy 的切片

Posted

技术标签:

【中文标题】TensorFlow 中类似 NumPy 的切片【英文标题】:NumPy-like slicing in TensorFlow 【发布时间】:2019-10-27 07:59:36 【问题描述】:

我有一个张量对象,我想切片它的一部分。

tf_a1 = tf.Variable([    [9.968594,  8.655439,  0.,        0.       ],
                         [0.,        8.3356,    0.,        8.8974   ],
                         [0.,        0.,        6.103182,  7.330564 ],
                         [6.609862,  0.,        3.0614321, 0.       ],
                         [9.497023,  0.,        3.8914037, 0.       ],
                         [0.,        8.457685,  8.602337,  0.       ],
                         [0.,        0.,        5.826657,  8.283971 ],
                         [0.,        0.,        0.,        0.       ]])

另外,我有这个数组:

tf_a2 = tf.constant([[1, 2, 5],
                    [1, 4, 6],
                    [0, 7, 7],
                    [2, 3, 6],
                    [2, 4, 7]])

我想像切片一样做这个 numpy:

 tf_a1[tf_a2]

numpy 代码的预期输出如下:

[[[0.        8.3356    0.        8.8974   ]
  [0.        0.        6.103182  7.330564 ]
  [0.        8.457685  8.602337  0.       ]]

 [[0.        8.3356    0.        8.8974   ]
  [9.497023  0.        3.8914037 0.       ]
  [0.        0.        5.826657  8.283971 ]]

 [[9.968594  8.655439  0.        0.       ]
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]]

 [[0.        0.        6.103182  7.330564 ]
  [6.609862  0.        3.0614321 0.       ]
  [0.        0.        5.826657  8.283971 ]]

 [[0.        0.        6.103182  7.330564 ]
  [9.497023  0.        3.8914037 0.       ]
  [0.        0.        0.        0.       ]]]

我想我可以使用以下方法在 tensorflow 中进行类似的操作:

tf.gather_nd(tf_a1, tf_a2)

但它会引发此错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: index innermost dimension length must be <= params rank; saw: 3 vs. 2 [Op:GatherNd]

任何帮助表示赞赏:)

【问题讨论】:

行数不相等。这个问题似乎没有很好的定义... @cs95,是的,行数不一样,但我们可以在 numpy 中做同样的事情。我会添加更多关于我想要什么的解释 此操作的预期输出是否应该是形状为 (5, 3, 4) 的张量? @cs95 完全是 3d 张量 【参考方案1】:

我觉得你可以用tf.gather:

tf.gather(tf_a1, tf_a2, axis=0)                                                                                        
# <tf.Tensor 'GatherV2_10:0' shape=(5, 3, 4) dtype=float32>

TensorFlow 2.0 上的可重现示例

tf.__version__
# '2.0.0-beta0'

tf.gather(tf_a1, tf_a2, axis=0)

<tf.Tensor: id=9, shape=(5, 3, 4), dtype=float32, numpy=
array([[[0.       , 8.3356   , 0.       , 8.8974   ],
        [0.       , 0.       , 6.103182 , 7.330564 ],
        [0.       , 8.457685 , 8.602337 , 0.       ]],

       [[0.       , 8.3356   , 0.       , 8.8974   ],
        [9.497023 , 0.       , 3.8914037, 0.       ],
        [0.       , 0.       , 5.826657 , 8.283971 ]],

       [[9.968594 , 8.655439 , 0.       , 0.       ],
        [0.       , 0.       , 0.       , 0.       ],
        [0.       , 0.       , 0.       , 0.       ]],

       [[0.       , 0.       , 6.103182 , 7.330564 ],
        [6.609862 , 0.       , 3.0614321, 0.       ],
        [0.       , 0.       , 5.826657 , 8.283971 ]],

       [[0.       , 0.       , 6.103182 , 7.330564 ],
        [9.497023 , 0.       , 3.8914037, 0.       ],
        [0.       , 0.       , 0.       , 0.       ]]], dtype=float32)>

【讨论】:

太棒了,非常感谢,这正是我想要的。如果我对问题进行一项修改可以吗 @sariii 如果它是一个重要的扩展,希望您提出一个新问题,但请交给我。 再次感谢您的回复。我在这里问了我的另一个问题***.com/questions/56568981/…如果你能看一下,我很感激

以上是关于TensorFlow 中类似 NumPy 的切片的主要内容,如果未能解决你的问题,请参考以下文章

Tensorflow 2 - tf.slice 及其 NumPy 切片语法不兼容的行为

像在 numpy 中一样使用 tf.slice 检测越界切片

吴裕雄--天生自然TensorFlow2教程:numpy [ ] 索引

如何在 Tensorflow 中进行切片分配

张量流中张量对象的非连续索引切片(高级索引,如numpy)

在 Tensorflow 中进行这种基于切片的乘法的最有效方法