用 tf.data API 替换 tf.placeholder 和 feed_dict
Posted
技术标签:
【中文标题】用 tf.data API 替换 tf.placeholder 和 feed_dict【英文标题】:Replacing tf.placeholder and feed_dict with tf.data API 【发布时间】:2018-09-20 14:03:29 【问题描述】:我有一个现有的 TensorFlow 模型,它使用 tf.placeholder 作为模型输入,并使用 tf.Session().run 的 feed_dict 参数来输入数据。以前整个数据集被读入内存并以这种方式传递。
我想使用更大的数据集并利用 tf.data API 的性能改进。我已经定义了一个 tf.data.TextLineDataset 和其中的一次性迭代器,但我很难弄清楚如何将数据导入模型来训练它。
起初我试图将 feed_dict 定义为从占位符到 iterator.get_next() 的字典,但这给了我一个错误,说 feed 的值不能是 tf.Tensor 对象。更多的挖掘使我明白这是因为 iterator.get_next() 返回的对象已经是图形的一部分,这与您输入 feed_dict 的对象不同——而且我根本不应该尝试使用 feed_dict性能原因。
所以现在我已经摆脱了输入 tf.placeholder 并将其替换为定义我的模型的类的构造函数的参数;在我的训练代码中构建模型时,我将 iterator.get_next() 的输出传递给该参数。这似乎有点笨拙,因为它打破了模型定义和数据集/训练过程之间的分离。我现在收到一条错误消息,表示(我相信)我的模型输入的张量必须来自与 iterator.get_next() 中的张量相同的图表。
我在这种方法上是否走在正确的轨道上,只是在设置图表和会话的方式上做错了什么,或者类似的事情? (数据集和模型都在会话之外初始化,并且在我尝试创建之前发生错误。)
或者我是否完全不了解这个,需要做一些不同的事情,比如使用 Estimator API 并在输入函数中定义所有内容?
这里有一些代码演示了一个最小的例子:
import tensorflow as tf
import numpy as np
class Network:
def __init__(self, x_in, input_size):
self.input_size = input_size
# self.x_in = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size)) # Original
self.x_in = x_in
self.output_size = 3
tf.reset_default_graph() # This turned out to be the problem
self.layer = tf.layers.dense(self.x_in, self.output_size, activation=tf.nn.relu)
self.loss = tf.reduce_sum(tf.square(self.layer - tf.constant(0, dtype=tf.float32, shape=[self.output_size])))
data_array = np.random.standard_normal([4, 10]).astype(np.float32)
dataset = tf.data.Dataset.from_tensor_slices(data_array).batch(2)
model = Network(x_in=dataset.make_one_shot_iterator().get_next(), input_size=dataset.output_shapes[-1])
【问题讨论】:
你能发布一些示例代码吗?这可能有助于弄清楚出了什么问题。 【参考方案1】:我也花了一点时间才弄明白。你在正确的轨道上。整个数据集定义只是图表的一部分。我通常将它创建为与我的模型类不同的类,并将数据集传递给模型类。我在命令行上指定要加载的 Dataset 类,然后动态加载该类,从而模块化地解耦 Dataset 和图形。
请注意,您可以(并且应该)命名数据集中的所有张量,这确实有助于在您通过各种所需的转换传递数据时使事情变得容易理解。
您可以编写简单的测试用例,从iterator.get_next()
中提取样本并显示它们,您将得到类似sess.run(next_element_tensor)
的内容,而不是feed_dict
,正如您正确指出的那样。
一旦您开始了解它,您可能会开始喜欢 Dataset 输入管道。它迫使你很好地模块化你的代码,并迫使它成为一个易于单元测试的结构。
请务必阅读开发者指南,那里有大量示例:
https://www.tensorflow.org/programmers_guide/datasets
我要注意的另一件事是,使用此管道处理训练和测试数据集是多么容易。这很重要,因为您经常在训练数据集上执行数据增强,而不是在测试数据集上执行,from_string_handle
允许您这样做,并且在上面的指南中有明确描述。
【讨论】:
感谢您让我知道我在正确的轨道上!结果只是我拥有的模型的构造函数中的 tf.reset_default_graph() 行。【参考方案2】:我得到的原始代码中模型的构造函数中的 tf.reset_default_graph()
行是导致它的原因。删除修复它。
【讨论】:
以上是关于用 tf.data API 替换 tf.placeholder 和 feed_dict的主要内容,如果未能解决你的问题,请参考以下文章
Tensorflow:如何查找 tf.data.Dataset API 对象的大小