TensorFlow:dataset.train.next_batch 是如何定义的?
Posted
技术标签:
【中文标题】TensorFlow:dataset.train.next_batch 是如何定义的?【英文标题】:TensorFlow: how is dataset.train.next_batch defined? 【发布时间】:2017-05-18 04:40:14 【问题描述】:我正在尝试学习 TensorFlow 并在以下位置学习示例:https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/autoencoder.ipynb
然后我在下面的代码中有一些问题:
for epoch in range(training_epochs):
# Loop over all batches
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
# Run optimization op (backprop) and cost op (to get loss value)
_, c = sess.run([optimizer, cost], feed_dict=X: batch_xs)
# Display logs per epoch step
if epoch % display_step == 0:
print("Epoch:", '%04d' % (epoch+1),
"cost=", ":.9f".format(c))
由于 mnist 只是一个数据集,mnist.train.next_batch
到底是什么意思? dataset.train.next_batch
是如何定义的?
谢谢!
【问题讨论】:
【参考方案1】:mnist
对象从tf.contrib.learn
模块中定义的read_data_sets()
function 返回。 mnist.train.next_batch(batch_size)
方法实现here,它返回两个数组的元组,其中第一个表示一批batch_size
MNIST 图像,第二个表示与这些图像对应的一批batch-size
标签。 /p>
图像以大小为 [batch_size, 784]
的二维 NumPy 数组形式返回(因为 MNIST 图像中有 784 个像素),标签以大小为 [batch_size]
的一维 NumPy 数组形式返回(如果read_data_sets()
是用one_hot=False
调用的)或一个大小为[batch_size, 10]
的二维NumPy 数组(如果read_data_sets()
是用one_hot=True
调用的)。
【讨论】:
值得一提的是,next_batch 在每个 epoch 遍历完所有示例后都会重新洗牌。您可以通过DataSet._index_in_epoch
跟踪您在纪元中的位置,例如mnist.train._index_in_epoch
以上是关于TensorFlow:dataset.train.next_batch 是如何定义的?的主要内容,如果未能解决你的问题,请参考以下文章