使用队列时如何在张量流中训练期间测试网络
Posted
技术标签:
【中文标题】使用队列时如何在张量流中训练期间测试网络【英文标题】:How to test a network during training in tensorflow when using a queue 【发布时间】:2016-11-12 19:50:35 【问题描述】:我正在使用下面的代码使用队列将我的训练示例提供给我的网络,并且它可以正常工作。
但是,我希望能够在每 n 次迭代中提供一些 测试数据,但我真的不知道应该如何进行。我应该暂时停止队列并手动提供测试数据吗?我应该创建另一个队列来仅用于测试数据吗?
编辑: 正确的做法是创建一个单独的文件,比如eval.py
,持续读取最后一个检查点并评估网络吗?这就是他们在 CIFAR10 示例中的做法。
batch = 128 # size of the batch
x = tf.placeholder("float32", [None, n_steps, n_input])
y = tf.placeholder("float32", [None, n_classes])
queue = tf.RandomShuffleQueue(capacity=4*batch,
min_after_dequeue=3*batch,
dtypes=[tf.float32, tf.float32],
shapes=[[n_steps, n_input], [n_classes]])
enqueue_op = queue.enqueue_many([x, y])
X_batch, Y_batch = queue.dequeue_many(batch)
sess = tf.Session()
def load_and_enqueue(data):
while True:
X, Y = data.get_next_batch(batch)
sess.run(enqueue_op, feed_dict=x: X, y: Y)
train_thread = threading.Thread(target=load_and_enqueue, args=(data))
train_thread.daemon = True
train_thread.start()
for _ in xrange(max_iter):
sess.run(train_op)
【问题讨论】:
最近添加到github repository 中有一些很好的高级功能。它们基于使用单独的可执行文件运行评估,该可执行文件读取训练创建的检查点文件。 @user728291,是否有任何示例可以在同一个脚本中执行此操作?似乎其他工具(如 Caffe)就是这样做的。 如何使用两个队列(或一个队列和一个馈送的占位符),并使用tf.where
来决定使用这两个源中的哪一个来馈送网络?
【参考方案1】:
您可以在代码中添加 eval_op,然后在每 n 次(例如 n=1000)次迭代中进行评估。一个例子如下:
for niter in xrange(max_iter):
sess.run(train_op)
if niter % 1000 == 0:
sess.run(eval_op)
【讨论】:
【参考方案2】:您可以像这样构建另一个测试队列和训练模型的副本作为测试模型:
trainX, trainY = Queue0(batchSize, ...)...
testX, testY= Queue1(batchSize, ...)...
modelTrain = inference(trainX, trainY, ...)
# reuse variables
modelTest = inference(testX, testY, ...)
sess.run(train_op,loss_op,trainX,trainY)
sess.run(test_op,testX,testY)
由于初始化了2个模型,这种方式可能会消耗更多内存,希望看到更好的解决方案
【讨论】:
以上是关于使用队列时如何在张量流中训练期间测试网络的主要内容,如果未能解决你的问题,请参考以下文章