使用 tf.function() 获取形状
Posted
技术标签:
【中文标题】使用 tf.function() 获取形状【英文标题】:Get shape using tf.function() 【发布时间】:2022-01-16 10:00:02 【问题描述】:我有这个功能
train_step_signature = [
tf.TensorSpec(shape=(None, None), dtype=tf.int32)
]
@tf.function(input_signature=train_step_signature)
def train_step(inp):
# do stuff
我需要在一个操作中使用 inp 的第一个暗淡(一个范围为 inp 形状 0 的循环),但是当我尝试时,会弹出错误:
TypeError: 'NoneType' object cannot be interpreted as an integer
这显然是因为train_step_signature
。我已经看到如果我从 args 中删除 train_step_signature
它会起作用,但是处理我的代码需要更多时间。我的问题是,有没有办法在不丢失train_step_signature
arg 的情况下获得第一个形状?
【问题讨论】:
【参考方案1】:您可能正在使用像for i in range(inp.shape[0])
这样的pythonic 循环,这是不可能的,因为inp.shape[0]
在tf.function
中是None
。不要害怕使用
tf.while_loop
在tf.function
.
或者,尝试使用tf.shape(inp)[0]
代替inp.shape[0]
。
【讨论】:
以上是关于使用 tf.function() 获取形状的主要内容,如果未能解决你的问题,请参考以下文章
为啥我不能在 @tf.function 中使用 TensorArray.gather()?
使用@tf.function 进行自定义张量流训练的内存泄漏