有没有一种简单的方法可以在 tensorflow 中使用 tf.data.Dataset.from_generator 和自定义 model_fn(Estimator) 中的功能

Posted

技术标签:

【中文标题】有没有一种简单的方法可以在 tensorflow 中使用 tf.data.Dataset.from_generator 和自定义 model_fn(Estimator) 中的功能【英文标题】:is there a simple way to use features from tf.data.Dataset.from_generator with a custom model_fn(Estimator) in tensorflow 【发布时间】:2018-05-03 13:57:00 【问题描述】:

我将 tensorflow dataset api 用于我的训练数据,tf.data.Dataset.from_generator api 的 input_fn 和生成器

def generator():
    ......
    yield  "x" : features , label


def input_fn():
    ds = tf.data.Dataset.from_generator(generator, ......)
    ......
    feature, label = ds.make_one_shot_iterator().get_next()
    return feature, label

然后我为我的 Estimator 创建了一个自定义 model_fn,其中包含如下代码:

def model_fn(features, labels, mode, params):
    print(features)
    ......
    layer = network.create_full_connect(input_tensor=features["x"], 
    (or layer = tf.layers.dense(features["x"], 200, ......)
    ......

训练时:

estimator.train(input_fn=input_fn)

但是,代码不起作用,因为函数 model_fn 的 features 参数是一些东西:

Tensor("IteratorGetNext:0", dtype=float32, device=/device:CPU:0)

code "features["x"]" 会失败并告诉我:

......“site-packages\tensorflow\python\ops\array_ops.py”,第 504 行,在 _SliceHelper 中 end.append(s + 1) TypeError: 必须是 str,而不是 int

如果我将 input_fn 更改为:

input_fn = tf.estimator.inputs.numpy_input_fn(
  x="x": np.array([[1,2,3,4,5,6]]),
  y=np.array([1]),

代码继续,因为 features 现在是一个 dict。

我搜索了估算器的代码,发现它使用了一些函数,例如

features, labels = self._get_features_and_labels_from_input_fn(
      input_fn, model_fn_lib.ModeKeys.TRAIN)

从 input_fn 检索特征和标签,但我不知道为什么它通过使用不同的数据集实现传递给我(model_fn)两种不同数据类型的特征,如果我想使用我的生成器模式,那么如何使用它类型(IteratorGetNext)的功能?

感谢您的帮助!

[更新]

我对代码做了一些修改,

def generator():
    ......
    yield features, label

def input_fn():
    ds = tf.data.Dataset.from_generator(generator, ......)
    ......
    feature, label = ds.make_one_shot_iterator().get_next()
    return "x": feature, label

然而,在 tf.layers.dense 仍然失败,现在它说

“dense_1 层的输入 0 与该层不兼容:其秩未定义,但该层需要已定义的秩。”

虽然特征是一个字典:

'x': tf.Tensor 'IteratorGetNext:0' shape=unknown dtype=float64

在正确的情况下,它是一些东西:

'x': tf.Tensor 'random_shuffle_queue_DequeueMany:1' shape=(128, 6) dtype=float64

我从

那里学到了类似的用法

https://developers.googleblog.com/2017/09/introducing-tensorflow-datasets.html

def my_input_fn(file_path, perform_shuffle=False, repeat_count=1):
   def decode_csv(line):
      ......
      d = dict(zip(feature_names, features)), label
      return d

   dataset = (tf.data.TextLineDataset(file_path)

但是对于将迭代器返回到自定义 model_fn 的生成器案例,没有官方示例。

【问题讨论】:

【参考方案1】:

根据examples on how to use from_generator,生成器返回 以放入数据集中,而不是特征字典。相反,您在 input_fn 中构建字典。

如下修改代码应该可以正常工作:

def generator():
    ......
    yield features, label

def input_fn():
    ds = tf.data.Dataset.from_generator(generator, ......)
    ......
    feature, label = ds.make_one_shot_iterator().get_next()
    return "x": feature, label

回复更新:

您的代码失败是因为Dataset.from_generator 的迭代器生成的张量没有定义静态shape(因为生成器原则上可以返回具有不同形状的数据)。 假设您的数据确实始终具有相同的形状,您可以在 returning 之前从 input_fn 调用 feature.set_shape(<the_shape_of_your_data>)(有关正确方法,请参阅编辑打击)。

编辑:

正如您在评论中指出的那样,tf.data.Dataset.from_generator() 有第三个参数来设置输出张量的形状,因此在 from_generator() 中将形状作为 output_shapes 传递而不是 feature.set_shape()

【讨论】:

谢谢!但是代码在 tf.layers.dense 上仍然失败,我做了一些有问题的描述 太棒了!现在可以了。根据你的回答。我添加了第三个参数(我只是在正弦第三个可选之前使用了第一个和第二个参数)来定义张量的形状: tf.data.Dataset.from_generator(functools.partial(geneator, dataset), (tf.float64, tf. float64), (tf.TensorShape([FEATURE_SIZE]), tf.TensorShape([]))),再次感谢! 顺便说一句,如果我直接在 features["x"] 上调用 set_shape,它只会返回一个 None 值????所以我必须在 from_generator 函数上定义类型,然后代码才能工作,不知道,也许 Tensor Type "IteratorGetNext" 不能直接重塑。 它返回None 因为set_shape 直接作用于它被调用的张量,它不会返回一个新的张量(即,写my_tensor = my_tensor.set_shape([42]) 是错误的,而只是my_tensor.set_shape([42]) 会工作)。 另外,感谢您让我知道from_generator 的第三个参数,我更新了答案以提及它,因为这显然是在这种情况下要走的路。

以上是关于有没有一种简单的方法可以在 tensorflow 中使用 tf.data.Dataset.from_generator 和自定义 model_fn(Estimator) 中的功能的主要内容,如果未能解决你的问题,请参考以下文章

Tensorflow实战

培训批次:哪种Tensorflow方法是正确的?

Tensorflow卷积神经网络[转]

如何在 keras 中使用 tensorflow ctc_batch_cost 函数?

Tensorflow 中的笛卡尔积

如何部署Tensorflow训练模型以推断Windows独立应用程序