如何使用 tf.data.Dataset 对象的 map 方法删除或省略数据?

Posted

技术标签:

【中文标题】如何使用 tf.data.Dataset 对象的 map 方法删除或省略数据?【英文标题】:How can I remove or omit data using map method for tf.data.Dataset objects? 【发布时间】:2021-02-28 08:01:57 【问题描述】:

我使用的是张量流 2.3.0

我有一个 python 数据生成器-

import tensorflow as tf
import numpy as np

vocab = [1,2,3,4,5]

def create_generator():
    'generates a random number from 0 to len(vocab)-1'
    count = 0
    while count < 4:
        x = np.random.randint(0, len(vocab))
        yield x
        count +=1

我将其设为 tf.data.Dataset 对象

gen = tf.data.Dataset.from_generator(create_generator, 
                                     args=[], 
                                     output_types=tf.int32, 
                                     output_shapes = (), )

现在我想使用 ma​​p 方法对项目进行子采样,这样 tf 生成器就不会输出任何偶数。

def subsample(x):
    'remove item if it is present in an even number [2,4]'
    
    '''
    #TODO
    '''
    return x
    
gen = gen.map(subsample)   

如何使用 ma​​p 方法实现这一点?

【问题讨论】:

【参考方案1】:

不,您不能使用map 过滤数据。映射函数对数据集的每个元素应用一些转换。您想要的是检查某个谓词的每个元素,并仅获取那些满足谓词的元素。

那个函数是filter()

所以你可以这样做:

gen = gen.filter(lambda x: x % 2 != 0)

更新:

如果你想使用自定义函数而不是lambda,你可以这样做:

def filter_func(x):
    if x**2 < 500:
        return True
    return False
gen = gen.filter(filter_func)

如果将此函数传递给filter,所有平方小于500的数字都将被返回。

【讨论】:

哇!这很有帮助。我也可以使用自定义python函数根据一些自定义规则过滤掉数据集的元素吗? 我真正想做的是,根据项目的出现频率从项目列表中对项目进行子样本。项目列表可以像 [1,3,4,6,1,3,9],现在假设在应用二次采样后,项目列表减少到 [1,3,6,1]。接下来我要做的是丢弃所有长度小于 2 的序列。如果长度 >= 2,我想使用它。我怎样才能做到这一点? 可以,只要您的自定义函数返回布尔值。我会更新答案 但是,布尔值只会决定是否过滤该特定数据集元素。我不确定如何使用自定义 python 函数转换数据集元素,如果序列长度 对于您的特定问题,lambda x: gen.count(x) &gt; 2。但是您的gen 必须是外部范围内的变量。

以上是关于如何使用 tf.data.Dataset 对象的 map 方法删除或省略数据?的主要内容,如果未能解决你的问题,请参考以下文章

Tensorflow:如何查找 tf.data.Dataset API 对象的大小

如何将 tf.data.Dataset 与 kedro 一起使用?

如何在 tf.data.Dataset.map 中使用 sklearn.preprocessing?

提供给 `tf.data.Dataset.from_generator(...)` 的 map 函数可以解析张量对象吗?

如何在 tf.data.Dataset 中输入不同大小的列表列表

如何更改 tf.data.Dataset 中数据的 dtype?