过滤数据集以仅获取特定类的图像
Posted
技术标签:
【中文标题】过滤数据集以仅获取特定类的图像【英文标题】:Filter Dataset to get just images from specific class 【发布时间】:2019-09-07 22:47:35 【问题描述】:我想为 n-shot 学习准备全能数据集。 因此我需要来自 10 个类(字母)的 5 个样本
复制代码
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
builder = tfds.builder("omniglot")
# assert builder.info.splits['train'].num_examples == 60000
builder.download_and_prepare()
# Load data from disk as tf.data.Datasets
datasets = builder.as_dataset()
dataset, test_dataset = datasets['train'], datasets['test']
def resize(example):
image = example['image']
image = tf.image.resize(image, [28, 28])
image = tf.image.rgb_to_grayscale(image, )
image = image / 255
one_hot_label = np.zeros((51, 10))
return image, one_hot_label, example['alphabet']
def stack(image, label, alphabet):
return (image, label), label[-1]
def filter_func(image, label, alphabet):
# get just images from alphabet in array, not just 2
arr = np.array(2,3,4,5)
result = tf.reshape(tf.equal(alphabet, 2 ), [])
return result
# correct size
dataset = dataset.map(resize)
# now filter the dataset for the batch
dataset = dataset.filter(filter_func)
# infinite stream of batches (classes*samples + 1)
dataset = dataset.repeat().shuffle(1024).batch(51)
# stack the images together
dataset = dataset.map(stack)
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.batch(32)
for i, (image, label) in enumerate(tfds.as_numpy(dataset)):
print(i, image[0].shape)
现在我想使用过滤器功能过滤数据集中的图像。 tf.equal 只是让我按一个类过滤,我想要数组中的张量。
您是否看到使用过滤器功能的方法? 或者这是错误的方法,还有更简单的方法?
我想创建一批 51 个图像和相应的标签,它们来自相同的 N=10 个类。在每一堂课中,我需要 K=5 个不同的图像和一个额外的图像(我需要对其进行分类)。每批 N*K+1 (51) 张图片应该来自 10 个新的随机类。
非常感谢您。
【问题讨论】:
Also: this filtering must be applied for every new batch (of size 51) randomly :-/
<--
澄清这一点。随机应用过滤是什么意思?
我想创建一批 51 个图像和相应的标签,它们来自相同的 10 个类。每批 51 张图片应该来自 10 个新的随机类。
更糟糕的是:我需要每类 K (5) 张图片,来自 N (10) 个随机类,另外还有一张图片 -> 批量大小为 N*K+1 (51) 张图片
刚刚浏览了tf.Dataset
文档。在我看来,使用当前的tf.Dataset
API 是不可能的。但是您可以将其转换为 numpy,在 Python/numpy 中准备此数据集,然后创建新数据集。你应该从测试数据中抽取第 51 个样本进行分类。它不应该是训练数据批次的一部分。
好吧,太糟糕了。非常感谢您的宝贵时间。第51个样本是否需要来自测试数据?!
【参考方案1】:
只保留特定标签使用这个谓词:
dataset = datasets['train']
def predicate(x, allowed_labels=tf.constant([0, 1, 2])):
label = x['label']
isallowed = tf.equal(allowed_labels, tf.cast(label, allowed_labels.dtype))
reduced = tf.reduce_sum(tf.cast(isallowed, tf.float32))
return tf.greater(reduced, tf.constant(0.))
dataset = dataset.filter(predicate).batch(20)
for i, x in enumerate(tfds.as_numpy(dataset)):
print(x['label'])
# [1 0 0 1 2 1 1 2 1 0 0 1 2 0 1 0 2 2 0 1]
# [1 0 2 2 0 2 1 2 1 2 2 2 0 2 0 2 1 2 1 1]
# [2 1 2 1 0 1 1 0 1 2 2 0 2 0 1 0 0 0 0 0]
allowed_labels
指定要保留的标签。所有不在这个张量中的标签都会被过滤掉。
【讨论】:
是否可以在一批的每一代中更改 tf.constant 的值?这样我就可以为每个生成的批次提供 10 个随机类 是的,当然。我明天会更新我的答案,在我这里已经过了午夜。 这里也是,只需要看看你的答案 :-) 非常感谢你的时间和精力! 请查看上面的最后一条评论和修改后的文字。 嗨@Vlad 和 janbolle,我想要完全相同的功能,例如“更改 tf.constant 的值与每一批次的每一代?”。你找到解决方案了吗?如果您能更新答案,我们将不胜感激。非常感谢。以上是关于过滤数据集以仅获取特定类的图像的主要内容,如果未能解决你的问题,请参考以下文章