在tensorflow中匹配pytorch scatter输出

Posted

技术标签:

【中文标题】在tensorflow中匹配pytorch scatter输出【英文标题】:Match pytorch scatter output in tensorflow 【发布时间】:2022-01-13 11:27:42 【问题描述】:

如何在tensorflow做同样的操作?

tensor = np.random.RandomState(42).uniform(size=(2, 4, 2)).astype(np.float32)
tensor = torch.from_numpy(tensor)
index = tensor.max(-1, keepdim=True)[1]
output = torch.zeros_like(tensor).scatter_(-1, index, 1.0)

expected output:
tensor([[[0., 1.],
         [1., 0.],
         [1., 0.],
         [0., 1.]],

        [[0., 1.],
         [0., 1.],
         [1., 0.],
         [0., 1.]]])

【问题讨论】:

【参考方案1】:

与往常一样,Tensorflow 的一切都变得更加复杂:

import tensorflow as tf
import numpy as np

tensor = np.random.RandomState(42).uniform(size=(2, 4, 2)).astype(np.float32)
tensor = tf.constant(tensor)
_, indices = tf.math.top_k(tensor)
zeros = tf.zeros_like(tensor)

ij = tf.stack(tf.meshgrid(
    tf.range(zeros.shape[0], dtype=tf.int32), 
    tf.range(zeros.shape[1], dtype=tf.int32),
                              indexing='ij'), axis=-1)
gathered_indices = tf.concat([ij, indices], axis=-1)
indices_shape = tf.shape(indices)
values = tf.ones((indices_shape[0], indices_shape[1]))
output = tf.tensor_scatter_nd_update(zeros, gathered_indices, values)

print(output)
tf.Tensor(
[[[0. 1.]
  [1. 0.]
  [1. 0.]
  [0. 1.]]

 [[0. 1.]
  [0. 1.]
  [1. 0.]
  [0. 1.]]], shape=(2, 4, 2), dtype=float32)

【讨论】:

【参考方案2】:

所以,这是我最终使用的解决方案。

tensor = np.random.RandomState(42).uniform(size=(2, 4, 2)).astype(np.float32)
tensor = tf.convert_to_tensor(tensor, dtype=tf.float32)


fill_value = 1.0
indices = tf.argmax(tensor, axis=-1)
depth = tensor.shape[-1]
output = tf.cast(tf.one_hot(indices, depth, on_value=fill_value), dtype=tf.float32)
tf.Tensor(
[[[0. 1.]
  [1. 0.]
  [1. 0.]
  [0. 1.]]

 [[0. 1.]
  [0. 1.]
  [1. 0.]
  [0. 1.]]], shape=(2, 4, 2), dtype=float32)

【讨论】:

不是通用的(如果您想使用除 1 之外的其他值),但如果它适用于您的用例。去吧。 @AloneTogether 我稍微编辑了答案。可变深度现在取自张量形状。这就是您所说的“如果您想使用除 1 之外的其他值”吗? 不。如果您想用值 2 而不是 1 替换某些索引处的张量中的值...您的解决方案会起作用吗? @AloneTogether 我相信tf.one_hoton_value 参数可以解决问题。检查编辑。 @AloneTogether 是的,所有值都必须与我的解决方案相同。但是我仍然赞成您的回答,因为它正在执行所要求的操作,并且更新值也增加了灵活性。

以上是关于在tensorflow中匹配pytorch scatter输出的主要内容,如果未能解决你的问题,请参考以下文章

比较 Conv2D 与 Tensorflow 和 PyTorch 之间的填充

如何在 pytorch 和 tensorflow 中使用张量核心?

Pytorch OR TensorFlow

如何在 TensorFlow 中执行 PyTorch 风格的张量切片更新?

如何在 TensorFlow、Keras 或 PyTorch 中部署 CoreML 模型?

从Tensorflow转移到Pytorch