如何从 TensorFlow 数据集中提取数据/标签
Posted
技术标签:
【中文标题】如何从 TensorFlow 数据集中提取数据/标签【英文标题】:How to extract data/labels back from TensorFlow dataset 【发布时间】:2019-10-07 03:49:35 【问题描述】:有很多示例如何创建和使用 TensorFlow 数据集,例如
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
我的问题是如何以 numpy 形式从 TF 数据集中取回数据/标签?换句话说,希望是上面行的反向操作,即我有一个 TF 数据集并希望从中取回图像和标签。
【问题讨论】:
***.com/questions/70535683/… 【参考方案1】:如果您的 tf.data.Dataset
被批处理,以下代码将检索所有 y 标签:
y = np.concatenate([y for x, y in ds], axis=0)
【讨论】:
优雅和pythonic! +1 @TimMironov 谢谢。我也可以使用 _ 作为单行中的 x。实际上,我认为如果你想同时提取 x 和 y,会有一个缺点。我还没有弄清楚你是否可以在类似的单线中做到这一点。【参考方案2】:假设我们的 tf.data.Dataset 被称为 train_dataset
并开启 eager_execution
(在 TF 2.x 中默认),您可以像这样检索图像和标签:
for images, labels in train_dataset.take(1): # only take first element of dataset
numpy_images = images.numpy()
numpy_labels = labels.numpy()
内联操作.numpy()
将 tf.Tensor 转换为 numpy 数组
如果要检索数据集的更多元素,只需增加take 方法中的数字即可。如果你想要所有元素,只需插入-1
【讨论】:
需要注意的是,该方法在某些情况下会返回count
批量图片,而不是单个图片。【参考方案3】:
我认为我们在这里得到了一个很好的例子:
https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/overview.ipynb#scrollTo=BC4pEXtkp4K-
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
# where mnsit train is a tf dataset
mnist_train = tfds.load(name="mnist", split=tfds.Split.TRAIN)
assert isinstance(mnist_train, tf.data.Dataset)
mnist_example, = mnist_train.take(1)
image, label = mnist_example["image"], mnist_example["label"]
plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap("gray"))
print("Label: %d" % label.numpy())
因此,可以像访问字典一样访问数据集的每个单独组件。大概不同的数据集有不同的字段名称(波士顿住房没有图像和价值,但可能有“特征”和“目标”或“价格”:
cnn = tfds.load(name="cnn_dailymail", split=tfds.Split.TRAIN)
assert isinstance(cnn, tf.data.Dataset)
cnn_ex, = cnn.take(1)
print(cnn_ex)
返回一个带有键 ['article', 'highlight'] 的 dict(),其中包含 numpy 字符串。
【讨论】:
【参考方案4】:如果您可以将图像和标签保留为tf.Tensor
s,您可以这样做
images, labels = tuple(zip(*dataset))
将数据集的效果想象为zip(images, labels)
。当我们想要取回图像和标签时,我们可以简单地unzip 它。
如果您需要 numpy 数组版本,请使用 np.array()
进行转换:
images = np.array(images)
labels = np.array(labels)
【讨论】:
【参考方案5】:这是我自己解决问题的方法:
def dataset2numpy(dataset, steps=1):
"Helper function to get data/labels back from TF dataset"
iterator = dataset.make_one_shot_iterator()
next_val = iterator.get_next()
with tf.Session() as sess:
for _ in range(steps):
inputs, labels = sess.run(next_val)
yield inputs, labels
请注意,此函数将产生数据集批次的输入/标签。这些步骤控制从数据集中取出多少批次。
【讨论】:
【参考方案6】:这对我有用
features = np.array([list(x[0].numpy()) for x in list(ds_test)])
labels = np.array([x[1].numpy() for x in list(ds_test)])
# NOTE: ds_test was created
iris, iris_info = tfds.load('iris', with_info=True)
ds_orig = iris['train']
ds_orig = ds_orig.shuffle(150, reshuffle_each_iteration=False)
ds_train = ds_orig.take(100)
ds_test = ds_orig.skip(100)
【讨论】:
【参考方案7】:import numpy as np
import tensorflow as tf
batched_features = tf.constant([[[1, 3], [2, 3]],
[[2, 1], [1, 2]],
[[3, 3], [3, 2]]], shape=(3, 2, 2))
batched_labels = tf.constant([[0, 0],
[1, 1],
[0, 1]], shape=(3, 2, 1))
dataset = tf.data.Dataset.from_tensor_slices((batched_features, batched_labels))
classes = np.concatenate([y for x, y in dataset], axis=0)
unique = np.unique(classes, return_counts=True)
labels_dict = dict(zip(unique[0], unique[1]))
print(classes)
print(labels_dict)
# 0: 3, 1: 3
【讨论】:
虽然这可能会回答这个问题,但如果可能的话,您应该edit 回答您的回答,以简要说明 如何 此代码块回答该问题。这有助于提供上下文,并使您的答案对未来的读者更有用。【参考方案8】:TensorFlow 的 get_single_element()
最后是 around,可用于从数据集中提取数据和标签。
这避免了使用.map()
或iter()
生成和使用迭代器的需要(这对于大型数据集可能成本很高)。
get_single_element()
返回一个封装数据集所有成员的张量(或张量的元组或字典)。我们需要将数据集的所有成员批量传递到单个元素中。
这可用于获取特征作为张量数组,或特征和标签作为元组或字典(张量数组),具体取决于原始数据集的创建方式。
查看 SO 上的 answer 以获取将特征和标签解包到张量数组元组中的示例。
【讨论】:
【参考方案9】:https://www.tensorflow.org/tutorials/images/classification
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.axis("off")
【讨论】:
【参考方案10】:您可以使用 TF Dataset 方法 unbatch() 取消批量数据集,然后您可以轻松地从中检索数据和标签:
ds_labels=[]
for images, labels in ds.unbatch():
ds_labels.append(labels) # or labels.numpy().argmax() for int labels
【讨论】:
以上是关于如何从 TensorFlow 数据集中提取数据/标签的主要内容,如果未能解决你的问题,请参考以下文章