如何在 TensorFlow 中执行 PyTorch 风格的张量切片更新?

Posted

技术标签:

【中文标题】如何在 TensorFlow 中执行 PyTorch 风格的张量切片更新?【英文标题】:How to perform PyTorch style tensor slice update in TensorFlow? 【发布时间】:2019-03-09 07:54:15 【问题描述】:

在 Pytorch 中,您可以像这样轻松更新张量:

 for i in range(x_len):
     tensor_abc[:, i, i] = 0

我们如何在 tensorflow 中更新这样的张量?

我尝试了tf.assigntf.scatter_update,但都不起作用。

【问题讨论】:

【参考方案1】:

此答案仅适用于变量。

import tensorflow as tf

sess = tf.InteractiveSession()
v = tf.zeros((5,5,5))
var = tf.Variable(initial_value=v)


init = tf.variables_initializer([var])
sess.run(init)


var = var[ 1 : 2 ,
           1 : 2 ,
           1 : 2 ].assign(tf.ones((1,1,1)))

print(sess.run(var))

这会产生

[[[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 1. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]]

还有这个

var = var[ 1 : 2 ,
           0 : 1 ,
           0 : 1 ].assign(tf.ones((1,1,1)))

生产

  [[[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[1. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

  ....
  ....]]

另一个例子是

var = var[ 1 : 2 ,
             : 2 ,
             : 2 ].assign(tf.ones((1,2,2)))

[[[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[1. 1. 0. 0. 0.]
  [1. 1. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

      ....
      ....]]

您应该探索 tf.scatter_nd 的张量。

【讨论】:

【参考方案2】:

tf.Variable 是唯一可以更新的张量。对于变量,您可以使用 gatherscatter_update 之类的代码进行切片。

请注意,其他张量不适合赋值。如果这是你想要做的,我想知道为什么它是必要的。但是,仍然可以使用您想要的值(而不是就地分配)创建新的张量,代码有点复杂。例如,以下内容不起作用:

index = ... tensor = tf.constant([0,1,2,3,4]) 
tensor[i] = 0  
## Doesn't work (TypeError: `Tensor` object does not support item assignment)

但其中任何一个都可以做同样的事情:

tensor = tf.constant([0,1,2,3,4]) 
tensor = tf.concat([tensor[:i], tf.zeros_like(tensor[i:i+1]), tensor[i+1:]], 0)  
## This works, creates a new tensor
张量 = tf.constant([0,1,2,3,4]) 张量 = tf.concat([tensor[:i], tf.fill([1], 0), tensor[i+1:]], 0) ## 这行得通,创建一个新的张量

【讨论】:

创建新张量可以,但是多维张量可以吗?例如将下面张量的所有对角线更新为 0: tensor = tf.Variable([ [[1,2,3],[4,5,6],[7,8,9] ], [[3,2, 1],[6,5,4],[9,8,7] ], [[2,2,2],[2,2,2],[2,2,2] ] ]) 如果你使用变量,你可以使用tf.scatter_update而不是创建新的张量。对于您的具体示例,“所有对角线”是什么意思? (你指的是哪个轴?)你能告诉我一个所需结果矩阵的例子吗?这将允许我提供更具体的代码示例

以上是关于如何在 TensorFlow 中执行 PyTorch 风格的张量切片更新?的主要内容,如果未能解决你的问题,请参考以下文章

如何在TensorFlow中执行稀疏矩阵*稀疏矩阵乘法?

在具有急切执行的 TensorFlow 2.0 中,如何计算特定层的网络输出的梯度?

书籍深度学习框架:PyTorch入门与实践(附代码)

Tensorflow 2:如何将执行从 GPU 切换到 CPU 并返回?

如何理解TensorFlow中的tensor

如何使用 Tensorflow 2.0 数据集在训练时执行 10 次裁剪图像增强