像在 numpy 中一样使用 tf.slice 检测越界切片
Posted
技术标签:
【中文标题】像在 numpy 中一样使用 tf.slice 检测越界切片【英文标题】:Detecting out-of-bounds slicing with tf.slice like in numpy 【发布时间】:2017-12-15 06:49:25 【问题描述】:在 tensorflow 中,我正在尝试使用 tf.slice,但 as its documentation states,它需要切片以适合输入数组。例如,如果您尝试对张量 [1,2,3,4] 的前 5 个位置进行切片,它将崩溃。我希望拥有与 python 列表或 numpy 数组相同的功能,其中切片可以让您获得原始数组和您要求的切片的交集。例如,如果您要求 [1,2,3,4] 的位置 2 到 6,您将得到 [2,3,4]。
我如何在张量流中做到这一点?
谢谢!
【问题讨论】:
【参考方案1】:您可以使用 tensorflow 的 python 切片运算符,它比 tf.slice
稍微强大一些,特别是进行一些绑定检查以类似于 numpy
。
x = tf.zeros((10,))
y = tf.slice(x, [5], [15])
print(y.shape)
# (15,)
z = x[5:42]
print(z.shape)
# (5,)
【讨论】:
【参考方案2】:你可以使用tf.clip_by_value:
def robust_slice( t, begin, size, name=None ):
with tf.name_scope('robust_slice'):
new_size = tf.clip_by_value(size + begin, clip_value_min = 0, clip_value_max = t.shape()) - begin
return tf.slice(t, begin, new_size, name)
(未经测试,但类似的东西应该可以完成这项工作)
【讨论】:
我试过了,但如果部分形状未知,它似乎不起作用。以上是关于像在 numpy 中一样使用 tf.slice 检测越界切片的主要内容,如果未能解决你的问题,请参考以下文章
tf.strided_slice_and_tf.fill_and_tf.concat