Fashion Mnist TensorFlow 数据形状不兼容

Posted

技术标签:

【中文标题】Fashion Mnist TensorFlow 数据形状不兼容【英文标题】:Fashion Mnist Tensorflow Data Shape Incompatibility 【发布时间】:2021-10-19 12:27:18 【问题描述】:

我知道有类似的问题。虽然我已经检查过了,但我没有解决我的问题。

我尝试在时尚 Mnist 数据集上实现小批量。因此,我使用tf.data.Dataset.from_tensor_slices 将数据集从 np.array 转换为张量,但我无法解决数据形状不兼容的问题。这是我的代码:

加载数据

(train_images, train_labels) , (test_images, test_labels) = fashion_mnist.load_data()

转换为 tf.Dataset:

 train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
 test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels))

我的模型

model_1 = tf.keras.Sequential([
    
    tf.keras.layers.Flatten(input_shape = [28,28]),
    tf.keras.layers.Dense(50, activation = "relu"),
    tf.keras.layers.Dense(30, activation = "relu"),
    tf.keras.layers.Dense(10, activation = "softmax"),
    
])

model_1.compile( loss = tf.keras.losses.SparseCategoricalCrossentropy(),
               optimizer = tf.keras.optimizers.Adam(),
               metrics = ["accuracy"])

info = model_1.fit(train_ds,
                  epochs = 10,
                  validation_data = (test_images, test_labels))

但这给了我这个错误:

ValueError: Input 0 of layer dense_1 is incompatible with the layer: expected axis -1 of input shape to have value 784 but received input with shape [28, 28]

我使用以下代码检查了输入形状:(输出为 [28, 28])

list(train_ds.as_numpy_iterator().next()[0].shape)

我该如何解决这个问题,如果你能帮助我,我将不胜感激。

谢谢!

【问题讨论】:

【参考方案1】:

由于您使用 tf.data.Dataset API 来提供模型,因此您应该从数据集中定义 batch_size。

train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).batch(256)
test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(256)

现在您可以使用这两个数据集来训练您的模型,例如:

info = model_1.fit(x=train_ds, epochs = 10, validation_data=test_ds)

【讨论】:

以上是关于Fashion Mnist TensorFlow 数据形状不兼容的主要内容,如果未能解决你的问题,请参考以下文章

Tensorflow 2 fashion-mnist离线数据集手动下载离线安装本地加载快速读取

3.3 Fashion-MNIST softmax分类tensorflow2实现——python实战

机器学习(TensorFlow)---Fashion MNIST数据集使用范例(计算机视觉)

Fashion Mnist TensorFlow 数据形状不兼容

TensorFlow2 手把手教你训练 Fashion Mnist

深度学习基于tensorflow的服装图像分类训练(数据集:Fashion-MNIST)