TypeError:由于 tensor_scatter_update 而预期单个张量时的张量列表
Posted
技术标签:
【中文标题】TypeError:由于 tensor_scatter_update 而预期单个张量时的张量列表【英文标题】:TypeError: List of Tensors when single Tensor expected due to tensor_scatter_update 【发布时间】:2022-01-10 07:42:10 【问题描述】:看看下面的代码示例:
def myFun(my_tensor):
#The following line works
my_tensor= tf.tensor_scatter_update(my_tensor, tf.constant([[0]]), tf.constant([1]))
#The following line leads to error
p = tf.cond(tf.math.equal(0, 0), lambda: 1, lambda: 1)
my_tensor= tf.tensor_scatter_update(my_tensor, tf.constant([[p]]), tf.constant([1]))
我用一个简单的案例来描述我面临的问题 此函数(myFun)被称为 tf.while_loop 的主体(如果相关) my_tensor的定义
my_tensor = tf.zeros(5, tf.int32)
如何定义 tf.tensor_scatter_update 的 indices 参数? 我正在使用 tensorflow1.15
【问题讨论】:
【参考方案1】:您不能使用张量 p
作为 tf.constant
的参数。也许尝试这样的事情:
%tensorflow_version 1.x
import tensorflow as tf
def myFun(my_tensor):
my_tensor= tf.tensor_scatter_update(my_tensor, tf.constant([[0]]), tf.constant([1]))
p = tf.cond(tf.math.equal(0, 0), lambda: 1, lambda: 1)
new_tensor= tf.tensor_scatter_update(my_tensor, [[p]], tf.constant([1]))
with tf.Session() as sess:
p_value = p.eval()
tensor_values = my_tensor.eval()
new_tensor_values = new_tensor.eval()
print('p -->', p_value)
print('my_tensor -->', tensor_values)
print('new_tensor -->', new_tensor_values)
my_tensor = tf.zeros(5, tf.int32)
myFun(my_tensor)
p --> 1
my_tensor --> [1 0 0 0 0]
new_tensor --> [1 1 0 0 0]
您也可以将p
包裹在tf.Variable
周围:
def myFun(my_tensor):
my_tensor= tf.tensor_scatter_update(my_tensor, tf.constant([[0]]), tf.constant([1]))
p = tf.cond(tf.math.equal(0, 0), lambda: 1, lambda: 1)
indices = tf.Variable([[p]])
new_tensor= tf.tensor_scatter_update(my_tensor, indices, tf.constant([1]))
with tf.Session() as sess:
sess.run(indices.initializer)
p_value = p.eval()
tensor_values = my_tensor.eval()
new_tensor_values = new_tensor.eval()
【讨论】:
谢谢!你成就了我的一天以上是关于TypeError:由于 tensor_scatter_update 而预期单个张量时的张量列表的主要内容,如果未能解决你的问题,请参考以下文章
TypeError:由于 tensor_scatter_update 而预期单个张量时的张量列表
TypeError: can't pickle generator objects: Spark collect() 由于不可序列化的生成器返回类型(dict_key)而失败