TensorFlow:dataset.train.next_batch 是如何定义的?

Posted

技术标签:

【中文标题】TensorFlow:dataset.train.next_batch 是如何定义的?【英文标题】:TensorFlow: how is dataset.train.next_batch defined? 【发布时间】:2017-05-18 04:40:14 【问题描述】:

我正在尝试学习 TensorFlow 并在以下位置学习示例:https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/autoencoder.ipynb

然后我在下面的代码中有一些问题:

for epoch in range(training_epochs):
    # Loop over all batches
    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        # Run optimization op (backprop) and cost op (to get loss value)
        _, c = sess.run([optimizer, cost], feed_dict=X: batch_xs)
    # Display logs per epoch step
    if epoch % display_step == 0:
        print("Epoch:", '%04d' % (epoch+1),
              "cost=", ":.9f".format(c))

由于 mnist 只是一个数据集,mnist.train.next_batch 到底是什么意思? dataset.train.next_batch 是如何定义的?

谢谢!

【问题讨论】:

【参考方案1】:

mnist 对象从tf.contrib.learn 模块中定义的read_data_sets() function 返回。 mnist.train.next_batch(batch_size) 方法实现here,它返回两个数组的元组,其中第一个表示一批batch_size MNIST 图像,第二个表示与这些图像对应的一批batch-size 标签。 /p>

图像以大小为 [batch_size, 784] 的二维 NumPy 数组形式返回(因为 MNIST 图像中有 784 个像素),标签以大小为 [batch_size] 的一维 NumPy 数组形式返回(如果read_data_sets() 是用one_hot=False 调用的)或一个大小为[batch_size, 10] 的二维NumPy 数组(如果read_data_sets() 是用one_hot=True 调用的)。

【讨论】:

值得一提的是,next_batch 在每个 epoch 遍历完所有示例后都会重新洗牌。您可以通过DataSet._index_in_epoch 跟踪您在纪元中的位置,例如mnist.train._index_in_epoch

以上是关于TensorFlow:dataset.train.next_batch 是如何定义的?的主要内容,如果未能解决你的问题,请参考以下文章

pytorch闪电模型的输出预测

我该如何解决这个问题?输入必须有 3 个维度,得到 4

python将一个文件中的所有图片的名称放入一个.txt文件中

为图像分类模型绘制混淆矩阵

python读取MNIST image数据

如何加载数据以及如何使用 pytorch 进行数据扩充