当切片本身是张量流中的张量时如何进行切片分配
Posted
技术标签:
【中文标题】当切片本身是张量流中的张量时如何进行切片分配【英文标题】:how to do slice assignment while the slice itself is a tensor in tensorflow 【发布时间】:2019-11-02 12:58:12 【问题描述】:我想在 tensorflow 中进行切片分配。我知道我可以使用:
my_var = my_var[4:8].assign(tf.zeros(4))
基于此link。
正如您在 my_var[4:8]
中看到的,我们在这里有特定的索引 4、8 用于切片和分配。
我的情况不同,我想根据张量进行切片,然后进行分配。
out = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32))
rows_tf = tf.constant (
[[1, 2, 5],
[1, 2, 5],
[1, 2, 5],
[1, 4, 6],
[1, 4, 6],
[2, 3, 6],
[2, 3, 6],
[2, 4, 7]])
columns_tf = tf.constant(
[[1],
[2],
[3],
[2],
[3],
[2],
[3],
[2]])
changed_tensor = [[8.3356, 0., 8.457685 ],
[0., 6.103182, 8.602337 ],
[8.8974, 7.330564, 0. ],
[0., 3.8914037, 5.826657 ],
[8.8974, 0., 8.283971 ],
[6.103182, 3.0614321, 5.826657 ],
[7.330564, 0., 8.283971 ],
[6.103182, 3.8914037, 0. ]]
另外,这是 sparse_indices
张量,它是 rows_tf
和 columns_tf
的连接,使得需要更新的整个索引(以防它可以提供帮助:)
sparse_indices = tf.constant(
[[1 1]
[2 1]
[5 1]
[1 2]
[2 2]
[5 2]
[1 3]
[2 3]
[5 3]
[1 2]
[4 2]
[6 2]
[1 3]
[4 3]
[6 3]
[2 2]
[3 2]
[6 2]
[2 3]
[3 3]
[6 3]
[2 2]
[4 2]
[4 2]])
我想做的就是做这个简单的任务:
out[rows_tf, columns_tf] = changed_tensor
为此我正在这样做:
out[rows_tf:column_tf].assign(changed_tensor)
但是,我收到了这个错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected begin, end, and strides to be 1D equal size tensors, but got shapes [1,8,3], [1,8,1], and [1] instead. [Op:StridedSlice] name: strided_slice/
这是预期的输出:
[[0. 0. 0. 0. ]
[0. 8.3356 0. 8.8974 ]
[0. 0. 6.103182 7.330564 ]
[0. 0. 3.0614321 0. ]
[0. 0. 3.8914037 0. ]
[0. 8.457685 8.602337 0. ]
[0. 0. 5.826657 8.283971 ]
[0. 0. 0. 0. ]]
知道如何完成这个任务吗?
提前谢谢你:)
【问题讨论】:
tf.scatter_nd_update
应该适用于您的情况。在这个函数中,indices
应该是scatter_idx
,它是一个二维索引列表[[1,1], [2,1], [5,1], ...]
。 updates
是这些指数的预期值。 tf.scatter_nd
上的 tf 文档有很好的用法示例。
@greeness sparse_index = [[1 1] [2 1] [5 1] [1 2] [2 2] [5 2] [1 3] [2 3] [5 3] [ 1 2] [4 2] [6 2] [1 3] [4 3] [6 3] [2 2] [3 2] [6 2] [2 3] [3 3] [6 3] [2 2 ] [4 2] [4 2]] 与 updates
(8,3) 不匹配,这就是它引发错误 Outer dimensions of indices and update must match. Indices shape: [24,2], updates shape:[8,3] [Op:ScatterNd]
的原因。谢谢:)
你应该将tf.reshape
你的updates
张量形状为[-1]
,否则它与indices
的形状不匹配。我建议你仔细阅读关于“tensorflow sparse tensor”的文档。从长远来看,这将对您有所帮助。 tensorflow.org/api_docs/python/tf/sparse/SparseTensor
@greeness 你说得对,我需要花很多时间在这上面,谢谢你的建议。
即使在重塑后仍然无法运行:|
【参考方案1】:
这个例子(扩展自 tf 文档tf.scatter_nd_update
here)应该会有所帮助。
您想首先将您的 row_indices 和 column_indices 组合成一个二维索引列表,即 indices
到 tf.scatter_nd_update
的参数。然后你输入了一个期望值列表,即updates
。
ref = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32))
indices = tf.constant([[0, 2], [2, 2]])
updates = tf.constant([1.0, 2.0])
update = tf.scatter_nd_update(ref, indices, updates)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print sess.run(update)
Result:
[[ 0. 0. 1. 0.]
[ 0. 0. 0. 0.]
[ 0. 0. 2. 0.]
[ 0. 0. 0. 0.]
[ 0. 0. 0. 0.]
[ 0. 0. 0. 0.]
[ 0. 0. 0. 0.]
[ 0. 0. 0. 0.]]
专门针对您的数据,
ref = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32))
changed_tensor = [[8.3356, 0., 8.457685 ],
[0., 6.103182, 8.602337 ],
[8.8974, 7.330564, 0. ],
[0., 3.8914037, 5.826657 ],
[8.8974, 0., 8.283971 ],
[6.103182, 3.0614321, 5.826657 ],
[7.330564, 0., 8.283971 ],
[6.103182, 3.8914037, 0. ]]
updates = tf.reshape(changed_tensor, shape=[-1])
sparse_indices = tf.constant(
[[1, 1],
[2, 1],
[5, 1],
[1, 2],
[2, 2],
[5, 2],
[1, 3],
[2, 3],
[5, 3],
[1, 2],
[4, 2],
[6, 2],
[1, 3],
[4, 3],
[6, 3],
[2, 2],
[3, 2],
[6, 2],
[2, 3],
[3, 3],
[6, 3],
[2, 2],
[4, 2],
[4, 2]])
update = tf.scatter_nd_update(ref, sparse_indices, updates)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print sess.run(update)
Result:
[[ 0. 0. 0. 0. ]
[ 0. 8.3355999 0. 8.8973999 ]
[ 0. 0. 6.10318184 7.33056402]
[ 0. 0. 3.06143212 0. ]
[ 0. 0. 0. 0. ]
[ 0. 8.45768547 8.60233688 0. ]
[ 0. 0. 5.82665682 8.28397083]
[ 0. 0. 0. 0. ]]
【讨论】:
对许多 cmets 感到抱歉。我得到了你的例子,但是当我将更新重塑为shape=[-1, 2]
时,我的indices
形状是(24,2)
和updates
形状是(8,3)
我仍然收到错误`索引的外部尺寸和更新必须匹配。索引形状:[24,2],更新形状:[12,2] [Op:ScatterNd]`。如果您能告诉我这里出了什么问题,我将不胜感激
您应该将更新重塑为 [-1],而不是 [-1,2]。 updates
的预期形状是 (k,),indices
的形状是 (k, 2)。在你的情况下,k = 24。
你的意思是像这样tf.reshape(updates,shape=[-1,1]))
?
我想我完全迷路了,文档中参数的顺序与您的不同,我不知道我是否遗漏了什么?另外,我用sparse_indices
更新了这个问题,以防我更容易理解为什么我得到错误。再次非常感谢您抽出时间:)
我不知道为什么它对你不起作用。我整合你的数据。答案已修改。以上是关于当切片本身是张量流中的张量时如何进行切片分配的主要内容,如果未能解决你的问题,请参考以下文章