模型创建时的简单 RNN Python Tensorflow 错误

Posted

技术标签:

【中文标题】模型创建时的简单 RNN Python Tensorflow 错误【英文标题】:Simple RNN Python Tensorflow error on model creation 【发布时间】:2021-09-02 05:50:08 【问题描述】:

我正在运行直接取自谷歌示例之一的示例代码,用于创建 RNN,但运行时出现错误。我在 VisualStudio 2019、带有 i7-10510U 和 mx230 的 Windows 10 x64 上运行它

代码:

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

model = keras.Sequential()
# Add an Embedding layer expecting input vocab of size 1000, and
# output embedding dimension of size 64.
model.add(layers.Embedding(input_dim=1000, output_dim=64))

# Add a LSTM layer with 128 internal units.
model.add(layers.SimpleRNN(128))

# Add a Dense layer with 10 units.
model.add(layers.Dense(10))

model.summary()

model.add(layers.SimpleRNN(128)) 上的错误:

无法将符号张量 (simple_rnn/strided_slice:0) 转换为 numpy 数组。此错误可能表明您正在尝试通过 不支持 NumPy 调用的张量

【问题讨论】:

【参考方案1】:

您可以尝试将 Tensorflow 升级到最新版本。我可以在Tensorflow 2.5.0 中毫无问题地执行代码,如下所示

import numpy as np
import tensorflow as tf
print(tf.__version__)
from tensorflow import keras
from tensorflow.keras import layers

model = keras.Sequential()
model.add(layers.Embedding(input_dim=1000, output_dim=64))
model.add(layers.SimpleRNN(128))
model.add(layers.Dense(10))

model.summary()

输出:

2.5.0
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, None, 64)          64000     
_________________________________________________________________
simple_rnn (SimpleRNN)       (None, 128)               24704     
_________________________________________________________________
dense (Dense)                (None, 10)                1290      
=================================================================
Total params: 89,994
Trainable params: 89,994
Non-trainable params: 0
_________________________________________________________________

【讨论】:

以上是关于模型创建时的简单 RNN Python Tensorflow 错误的主要内容,如果未能解决你的问题,请参考以下文章

Python - RNN LSTM模型精度低

Pytorch基础——使用 RNN 生成简单序列

在 Julia Flux 中评估简单的 RNN

循环神经网络(RNN)

RNN+CTC 模型似乎没有正确获取数据维度

Pytorch Note39 RNN 序列预测