Fashion Mnist TensorFlow 数据形状不兼容
Posted
技术标签:
【中文标题】Fashion Mnist TensorFlow 数据形状不兼容【英文标题】:Fashion Mnist Tensorflow Data Shape Incompatibility 【发布时间】:2021-10-19 12:27:18 【问题描述】:我知道有类似的问题。虽然我已经检查过了,但我没有解决我的问题。
我尝试在时尚 Mnist 数据集上实现小批量。因此,我使用tf.data.Dataset.from_tensor_slices
将数据集从 np.array 转换为张量,但我无法解决数据形状不兼容的问题。这是我的代码:
加载数据
(train_images, train_labels) , (test_images, test_labels) = fashion_mnist.load_data()
转换为 tf.Dataset:
train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
我的模型
model_1 = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape = [28,28]),
tf.keras.layers.Dense(50, activation = "relu"),
tf.keras.layers.Dense(30, activation = "relu"),
tf.keras.layers.Dense(10, activation = "softmax"),
])
model_1.compile( loss = tf.keras.losses.SparseCategoricalCrossentropy(),
optimizer = tf.keras.optimizers.Adam(),
metrics = ["accuracy"])
info = model_1.fit(train_ds,
epochs = 10,
validation_data = (test_images, test_labels))
但这给了我这个错误:
ValueError: Input 0 of layer dense_1 is incompatible with the layer: expected axis -1 of input shape to have value 784 but received input with shape [28, 28]
我使用以下代码检查了输入形状:(输出为 [28, 28])
list(train_ds.as_numpy_iterator().next()[0].shape)
我该如何解决这个问题,如果你能帮助我,我将不胜感激。
谢谢!
【问题讨论】:
【参考方案1】:由于您使用 tf.data.Dataset
API 来提供模型,因此您应该从数据集中定义 batch_size。
train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).batch(256)
test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(256)
现在您可以使用这两个数据集来训练您的模型,例如:
info = model_1.fit(x=train_ds, epochs = 10, validation_data=test_ds)
【讨论】:
以上是关于Fashion Mnist TensorFlow 数据形状不兼容的主要内容,如果未能解决你的问题,请参考以下文章
Tensorflow 2 fashion-mnist离线数据集手动下载离线安装本地加载快速读取
3.3 Fashion-MNIST softmax分类tensorflow2实现——python实战
机器学习(TensorFlow)---Fashion MNIST数据集使用范例(计算机视觉)
Fashion Mnist TensorFlow 数据形状不兼容