如何使用 Python 在 Keras 中加载数据集?

Posted

技术标签:

【中文标题】如何使用 Python 在 Keras 中加载数据集?【英文标题】:How to load dataset for in Keras using Python? 【发布时间】:2020-06-08 05:33:31 【问题描述】:

我是使用 Python 学习 Keras 的初学者。

我已经阅读了一些使用 MNIST 数据集加载数据集的示例代码。

我不理解变量 (X_train, y_train) 和 (X_test, y_test)。

请帮我解释一下这些变量的用途。

另外,这些变量分配了什么类型的数据?

from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense
from keras.utils import np_utils

# Load data
(X_train, y_train), (X_test, y_test) = mnist.load_data()

【问题讨论】:

【参考方案1】:

Mnist 数据集包含大约 75 000 张手写数字的样本图像。每个数字还带有一个标签,其中包含可以在图像中看到的数字。每个图像的大小为28x28 像素。这些图像被分成两部分。 training-ImagesTest-Images。您使用training-images 来训练您的模型。然后你通过测试生成的神经网络在之前未使用和看不见的test-images 上的工作情况来验证你的accuracyloss

(X_train, Y_train) 是一个元组,两个值的组合存储在一个变量/列表元素中......

图像然后作为数组存储在这些列表中。所以X_train 包含大约 60 000 个大小为 784 (28*28) 的数组。每个单元格代表一个像素的值。它可以是从 0(白色)到 255(黑色)的任何值

X_test 包含一个包含大约 15 000 个此类数组的列表。适合图像的Labels存储在所属的Y_train/Y_test中

【讨论】:

【参考方案2】:

根据keras documentation:

x_train, x_test: uint8 灰度图像数据数组,形状为 (num_samples, 28, 28)。

y_train, y_test: uint8 数字标签数组(0-9 范围内的整数),形状为 (num_samples,)。

x_trainy_train 分别是用于训练的特征和标签。 x_testy_test 分别是用于测试的特征和标签。

【讨论】:

以上是关于如何使用 Python 在 Keras 中加载数据集?的主要内容,如果未能解决你的问题,请参考以下文章

尝试在 Keras 中加载顺序模型时出现“KeyError:0”

在 python 中加载 Tensorflow Lite 模型

无法在 Keras 2.1.0(使用 Tensorflow 1.3.0)中保存的 Keras 2.4.3(使用 Tensorflow 2.3.0)中加载 Keras 模型

如何在张量流中加载本地图像?

在 keras 中加载模型时图形断开错误

Keras + Tensorflow 和 Python 中的多处理