为啥我不能在 @tf.function 中使用 TensorArray.gather()?
Posted
技术标签:
【中文标题】为啥我不能在 @tf.function 中使用 TensorArray.gather()?【英文标题】:Why I cannot using TensorArray.gather() in @tf.function?为什么我不能在 @tf.function 中使用 TensorArray.gather()? 【发布时间】:2021-05-20 21:36:49 【问题描述】:从 TensorArray 中读取:
def __init__(self, size):
self.obs_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
self.obs2_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
self.act_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
self.rew_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
self.done_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
def get_sample(self, batch_size):
idxs = tf.random.uniform(shape=[batch_size], maxval=self.size, dtype=tf.int32)
tf.print(idxs)
return self.obs_buf.gather(indices=idxs), \ # HERE IS THE ISSUE
self.act_buf.gather(indices=idxs), \
self.rew_buf.gather(indices=idxs), \
self.obs2_buf.gather(indices=idxs), \
self.done_buf.gather(indices=idxs)
使用:
@tf.function
def train(self, rpm, batch_size, gradient_steps):
for gradient_step in tf.range(1, gradient_steps + 1):
obs, act, rew, next_obs, done = rpm.get_sample(batch_size)
with tf.GradientTape() as tape:
...
问题:
Traceback(最近一次调用最后一次): 文件“.\main.py”,第 130 行,在 rl_training.train() 文件“C:\Users\user\Documents\Projects\rl-toolkit\rl_training.py”,第 129 行,在火车中 self._rpm,self.batch_size,self.gradient_steps,logging_wandb=self.logging_wandb 调用中的文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\def_function.py”,第 828 行 结果 = self._call(*args, **kwds) _call 中的文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\def_function.py”,第 871 行 self._initialize(args, kwds, add_initializers_to=initializers) _initialize 中的文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\def_function.py”,第 726 行 *args, **kwds)) 文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\function.py”,第 2969 行,在 _get_concrete_function_internal_garbage_collected 图函数,_ = self._maybe_define_function(args,kwargs) _maybe_define_function 中的文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\function.py”,第 3361 行 graph_function = self._create_graph_function(args, kwargs) _create_graph_function 中的文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\function.py”,第 3206 行 capture_by_value=self._capture_by_value), func_graph_from_py_func 中的文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\func_graph.py”,第 990 行 func_outputs = python_func(*func_args, **func_kwargs) 文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\def_function.py”,第 634 行,位于 Wrapped_fn out = weak_wrapped_fn().wrapped(*args, **kwds) 文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\function.py”,第 3887 行,位于 bound_method_wrapper return Wrapped_fn(*args, **kwargs) 包装器中的文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\func_graph.py”,第 977 行 引发 e.ag_error_metadata.to_exception(e) tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError:在用户代码中:
C:\Users\user\Documents\Projects\rl-toolkit\policy\sac\sac.py:183 update *
obs, act, rew, next_obs, done = rpm.get_sample(batch_size)
C:\Users\user\Documents\Projects\rl-toolkit\utils\replay_buffer.py:39 __call__ *
return self.obs_buf.gather(indices=idxs), self.act_buf.gather(indices=idxs), self.rew_buf.gather(indices=idxs), self.obs2_buf.gather(indices=idxs), self.done_buf.gather(indices=idxs)
C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\ops\tensor_array_ops.py:1190 gather **
return self._implementation.gather(indices, name=name)
C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\ops\tensor_array_ops.py:861 gather
return array_ops.stack([self._maybe_zero(i) for i in indices])
C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\ops.py:505 __iter__
self._disallow_iteration()
C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\ops.py:498 _disallow_iteration
self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")
C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\ops.py:476 _disallow_when_autograph_enabled
" indicate you are trying to use an unsupported feature.".format(task))
OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
为什么我不能在这种情况下使用 TensorArray?我有什么选择?
【问题讨论】:
这github.com/tensorflow/tensorflow/issues/31952 有帮助吗? 抱歉,因为我遇到了 tf.TensorArray.gather() 而不是 tf.gather() ..... +0 解决方案在这种情况下不起作用。 【参考方案1】:已解决here。必须使用 tf.Variable 而不是 tf.TensorArray。
【讨论】:
以上是关于为啥我不能在 @tf.function 中使用 TensorArray.gather()?的主要内容,如果未能解决你的问题,请参考以下文章
Tensorflow 2.0:自定义 keras 指标导致 tf.function 回溯警告
为啥我不能将 const 左值引用绑定到返回 T&& 的函数?