如何在 tf.data.Dataset.map 中使用 sklearn.preprocessing?

Posted

技术标签:

【中文标题】如何在 tf.data.Dataset.map 中使用 sklearn.preprocessing?【英文标题】:How to use sklearn.preprocessing in tf.data.Dataset.map? 【发布时间】:2021-06-02 16:05:12 【问题描述】:

我想找到一种在 TF2 中在 tf.data.Dataset.map() 中使用 sklearn.preprocessing 的方法。

假设我有一个从

生成的数据集
import tensorflow as tf
ds = tf.data.Dataset.from_tensor_slices((tf.random.uniform((3, 3))))
ds = ds.batch(1)
x = tf.concat(list(ds.as_numpy_iterator()), axis=0)
print(x)
# tf.Tensor(
# [[0.51869464 0.9198195  0.87195873]
#  [0.5842893  0.5363847  0.93642473]
#  [0.0109899  0.7908174  0.25996208]], shape=(3, 3), dtype=float32)

然后计算 QuantileTransformer

from sklearn.preprocessing import QuantileTransformer
qt = QuantileTransformer(n_quantiles=2, random_state=0)
qt.fit_transform(x)
print(qt.quantiles_)
# [[0.0109899  0.5363847  0.25996208]
#  [0.58428931 0.91981947 0.93642473]]

但是,我无法在tf.data.Dataset.map 中使用QuantileTransformer。例如,

ds.map(lambda x: qt.transform(x))

报错

TypeError: in user code:

    <ipython-input-106-867a262b9b69>:13 None  *
        lambda x: qt.transform(x)
    /lib/python3.8/site-packages/sklearn/preprocessing/_data.py:2769 transform  *
        X = self._check_inputs(X, in_fit=False, copy=self.copy)
    /lib/python3.8/site-packages/sklearn/preprocessing/_data.py:2699 _check_inputs  *
        X = self._validate_data(X, reset=in_fit,
    /lib/python3.8/site-packages/sklearn/base.py:420 _validate_data  *
        X = check_array(X, **check_params)
    /lib/python3.8/site-packages/sklearn/utils/validation.py:981 inner_f  *
        return f(*args, **kwargs)
    /lib/python3.8/site-packages/sklearn/utils/validation.py:616 check_array  *
        array = np.asarray(array, order=order, dtype=dtype)
    /lib/python3.8/site-packages/numpy/core/_asarray.py:83 asarray  **
        return array(a, dtype, copy=False, order=order)

    TypeError: __array__() takes 1 positional argument but 2 were given```

【问题讨论】:

【参考方案1】:
def map_decorator(func):
    def wrapper(inp):
        # Use a tf.py_function to prevent auto-graph from compiling the method
        return tf.py_function(
            func,
            inp=[inp],
            Tout=(inp.dtype)
        )
    return wrapper

tf.concat(list(ds.map(map_decorator(qt.transform)).as_numpy_iterator()), axis=0)

【讨论】:

以上是关于如何在 tf.data.Dataset.map 中使用 sklearn.preprocessing?的主要内容,如果未能解决你的问题,请参考以下文章

具有渴望模式的 TF.data.dataset.map(map_func)

tensorflow数据流水线如何分离?

在tensorflow中使用`dataset.map()`访问张量numpy数组

如何在图像中找到明亮区域(以及如何在图像中找到阴影区域)

在QGIS中如何添加天地图的WMTS

如何在异步任务中调用意图?或者如何在 onPostExecute 中开始新的活动?