如何按特定值过滤 tf.data.Dataset?

Posted

技术标签:

【中文标题】如何按特定值过滤 tf.data.Dataset?【英文标题】:How can I filter tf.data.Dataset by specific values? 【发布时间】:2018-07-27 07:18:31 【问题描述】:

我通过读取 TFRecords 创建了一个数据集,我映射了这些值,我想过滤数据集以获取特定值,但由于结果是一个带有张量的 dict,我无法获得一个张量的实际值或用tf.cond()/tf.equal检查它。我该怎么做?

def mapping_func(serialized_example):
    feature =  'label': tf.FixedLenFeature([1], tf.string) 
    features = tf.parse_single_example(serialized_example, features=feature)
    return features

def filter_func(features):
    # this doesn't work
    #result = features['label'] == 'some_label_value'
    # neither this
    result = tf.reshape(tf.equal(features['label'], 'some_label_value'), [])
    return result

def main():
    file_names = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
    dataset = tf.contrib.data.TFRecordDataset(file_names)
    dataset = dataset.map(mapping_func)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.filter(filter_func)
    dataset = dataset.repeat()
    iterator = dataset.make_one_shot_iterator()
    sample = iterator.get_next()

【问题讨论】:

你得到哪个错误? 【参考方案1】:

我正在回答我自己的问题。我找到了问题!

我需要做的是tf.unstack()这样的标签:

label = tf.unstack(features['label'])
label = label[0]

在我把它交给tf.equal()之前:

result = tf.reshape(tf.equal(label, 'some_label_value'), [])

我想问题是标签被定义为一个数组,其中包含一个类型为字符串tf.FixedLenFeature([1], tf.string) 的元素,所以为了获得第一个和单个元素,我必须解包它(创建一个列表)然后得到索引为0的元素,如果我错了,请纠正我。

【讨论】:

【参考方案2】:

我认为您首先不需要将 label 设为一维数组。

与:

feature = 'label': tf.FixedLenFeature((), tf.string)

您无需在 filter_func 中取消堆叠标签

【讨论】:

【参考方案3】:

读取、过滤数据集非常简单,无需拆开任何东西。

读取数据集:

print(my_dataset, '\n\n')
##let us print the first 3 records
for record in my_dataset.take(3):
    ##below could be large in case of image
    print(record)
    ##let us print a specific key
    print(record['key2'])

过滤同样简单:

my_filtereddataset = my_dataset.filter(_filtcond1)

您可以根据需要在哪里定义 _filtcond1。假设您的数据集中有一个 'true' 'false' 布尔标志,那么:

@tf.function
def _filtcond1(x):
    return x['key_bool'] == 1

甚至是 lambda 函数:

my_filtereddataset = my_dataset.filter(lambda x: x['key_int']>13)

如果您正在读取尚未创建的数据集或者您不知道键(似乎是 OP 的情况),您可以使用它首先了解键和结构:

import json
from google.protobuf.json_format import MessageToJson

for raw_record in noidea_dataset.take(1):
    example = tf.train.Example()
    example.ParseFromString(raw_record.numpy())
    ##print(example) ##if image it will be toooolong
    m = json.loads(MessageToJson(example))
    print(m['features']['feature'].keys())

现在您可以继续过滤

【讨论】:

【参考方案4】:

你应该尝试使用 apply 函数 tf.data.TFRecordDataset tensorflow documentation

否则...阅读这篇关于 TFRecords 的文章以更好地了解 TFRecords TFRecords for humans

但最可能的情况是你不能访问也不能修改TFRecord...github上有一个关于这个主题的请求TFRecords request

我的建议是让事情尽可能简单...您必须知道您正在使用图表和会话...

无论如何...如果一切都失败了,请尽可能简单地尝试在 tensorflow 会话中不起作用的代码部分...可能所有这些操作都应该在 tf.session 运行时完成。 ..

【讨论】:

以上是关于如何按特定值过滤 tf.data.Dataset?的主要内容,如果未能解决你的问题,请参考以下文章

如何将 tf.data.Dataset 与 kedro 一起使用?

如何在 tf.data.Dataset.map 中使用 sklearn.preprocessing?

如何在 tf.data.Dataset 中输入不同大小的列表列表

如何更改 tf.data.Dataset 中数据的 dtype?

如何在 keras 自定义回调中访问 tf.data.Dataset?

如何在 tf 2.1.0 中创建 tf.data.Dataset 的训练、测试和验证拆分