使用 tf.function() 获取形状

Posted

技术标签:

【中文标题】使用 tf.function() 获取形状【英文标题】:Get shape using tf.function() 【发布时间】:2022-01-16 10:00:02 【问题描述】:

我有这个功能

train_step_signature = [
    tf.TensorSpec(shape=(None, None), dtype=tf.int32)
]

@tf.function(input_signature=train_step_signature)
def train_step(inp):
   # do stuff

我需要在一个操作中使用 inp 的第一个暗淡(一个范围为 inp 形状 0 的循环),但是当我尝试时,会弹出错误:

TypeError: 'NoneType' object cannot be interpreted as an integer

这显然是因为train_step_signature。我已经看到如果我从 args 中删除 train_step_signature 它会起作用,但是处理我的代码需要更多时间。我的问题是,有没有办法在不丢失train_step_signature arg 的情况下获得第一个形状?

【问题讨论】:

【参考方案1】:

您可能正在使用像for i in range(inp.shape[0]) 这样的pythonic 循环,这是不可能的,因为inp.shape[0]tf.function 中是None。不要害怕使用 tf.while_looptf.function.

或者,尝试使用tf.shape(inp)[0] 代替inp.shape[0]

【讨论】:

以上是关于使用 tf.function() 获取形状的主要内容,如果未能解决你的问题,请参考以下文章

为啥我不能在 @tf.function 中使用 TensorArray.gather()?

使用@tf.function 进行自定义张量流训练的内存泄漏

我应该将@tf.function 用于所有功能吗?

从具有适当形状的现有迭代中创建3D numpy数组

关于TensorFlow2的tf.function()和AutoGraph的一些问题解决

Tensorflow 2.0:自定义 keras 指标导致 tf.function 回溯警告