如何将 PrefetchDataset 转换为 TF 张量?

Posted

技术标签:

【中文标题】如何将 PrefetchDataset 转换为 TF 张量?【英文标题】:How to convert PrefetchDataset to a TF tensor? 【发布时间】:2020-04-17 22:33:47 【问题描述】:
def get_train_dataset(file_path, **kwargs):
  dataset = tf.data.experimental.make_csv_dataset(
      file_path,
      batch_size=5, 
      label_name=LABEL_COLUMN,
      na_value="?",
      num_epochs=1,
      ignore_errors=True,
      **kwargs)
  return dataset

raw_train_data = get_train_dataset(train_file_path, select_columns=CSV_COLUMNS)

我从“make_csv_dataset”函数创建了一个 DataSet,它是 OrderDict 的 PrefectDataset。但是,当我拟合模型时:

embedding = "https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1"
hub_layer = hub.KerasLayer(embedding, input_shape=[],
                           dtype=tf.string, trainable=True)
model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(16, activation='relu'))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])
history = model.fit(train_data.shuffle(10000),
                    epochs=20,
                    validation_data=val_data,
                    verbose=1)

报错:

  File "/home/my-env/tf/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py", line 118, in <listcomp>
    inputs = [inputs[key] for key in model._feed_input_names]
KeyError: 'keras_layer_input'

我希望将这个 OrderedDictionary 转换成 TF.Tensor,那么 'fit' 方法应该可以工作。怎么做?还是有其他方法可以解决这个问题?

在另一篇文章中,我看到了:

The not very elegant workaround you can try is to match the name of input layer with csv column name

我的 csv 文本列名称是“文本”。如果我想使用上面的解决方法,该怎么做?

【问题讨论】:

【参考方案1】:

我认为问题出在其他地方,但首先更改 fit 内的数据集名称。您没有在任何地方定义 train_data 变量。

如果这仍然不起作用,则将您的 get_train_data 替换为这一行,同时将值提供给其中的相应参数。

dataset = tf.data.experimental.make_csv_dataset('PATH OR FILE NAME', batch_size = 1,select_columns = ['Column1','Label'], label_name = 'Label', num_epochs = 1 ,shuffle = True)

【讨论】:

以上是关于如何将 PrefetchDataset 转换为 TF 张量?的主要内容,如果未能解决你的问题,请参考以下文章

如何将DataTable转换成List<T>

如何将字节数组转换为 boost::multiprecision::uint128_t?

如何将 Optional<T> 转换为 Stream<T>?

如何将字节从 uint64_t 转换为 double?

如何将 std::string 转换为原始字符串? [复制]

如何将 List<T> 转换为 DataSet?