像在 numpy 中一样使用 tf.slice 检测越界切片

Posted

技术标签:

【中文标题】像在 numpy 中一样使用 tf.slice 检测越界切片【英文标题】:Detecting out-of-bounds slicing with tf.slice like in numpy 【发布时间】:2017-12-15 06:49:25 【问题描述】:

在 tensorflow 中,我正在尝试使用 tf.slice,但 as its documentation states,它需要切片以适合输入数组。例如,如果您尝试对张量 [1,2,3,4] 的前 5 个位置进行切片,它将崩溃。我希望拥有与 python 列表或 numpy 数组相同的功能,其中切片可以让您获得原始数组和您要求的切片的交集。例如,如果您要求 [1,2,3,4] 的位置 2 到 6,您将得到 [2,3,4]。

我如何在张量流中做到这一点?

谢谢!

【问题讨论】:

【参考方案1】:

您可以使用 tensorflow 的 python 切片运算符,它比 tf.slice 稍微强大一些,特别是进行一些绑定检查以类似于 numpy

x = tf.zeros((10,))
y = tf.slice(x, [5], [15])
print(y.shape)
# (15,)
z = x[5:42]
print(z.shape)
# (5,)

【讨论】:

【参考方案2】:

你可以使用tf.clip_by_value:

def robust_slice( t, begin, size, name=None ):
    with tf.name_scope('robust_slice'):
        new_size = tf.clip_by_value(size + begin, clip_value_min = 0, clip_value_max = t.shape()) - begin
        return tf.slice(t, begin, new_size, name)

(未经测试,但类似的东西应该可以完成这项工作)

【讨论】:

我试过了,但如果部分形状未知,它似乎不起作用。

以上是关于像在 numpy 中一样使用 tf.slice 检测越界切片的主要内容,如果未能解决你的问题,请参考以下文章

tf.strided_slice_and_tf.fill_and_tf.concat

tf.slice()

了解 tensorflow 中 tf.slice 的参数以及为啥我不能更改它

tf.slice函数解析

tf.slice()函数详解(极详细)

tf.slice()