如何在 tf.data.Dataset.map 中使用 sklearn.preprocessing?
Posted
技术标签:
【中文标题】如何在 tf.data.Dataset.map 中使用 sklearn.preprocessing?【英文标题】:How to use sklearn.preprocessing in tf.data.Dataset.map? 【发布时间】:2021-06-02 16:05:12 【问题描述】:我想找到一种在 TF2 中在 tf.data.Dataset.map()
中使用 sklearn.preprocessing
的方法。
假设我有一个从
生成的数据集import tensorflow as tf
ds = tf.data.Dataset.from_tensor_slices((tf.random.uniform((3, 3))))
ds = ds.batch(1)
x = tf.concat(list(ds.as_numpy_iterator()), axis=0)
print(x)
# tf.Tensor(
# [[0.51869464 0.9198195 0.87195873]
# [0.5842893 0.5363847 0.93642473]
# [0.0109899 0.7908174 0.25996208]], shape=(3, 3), dtype=float32)
然后计算 QuantileTransformer
from sklearn.preprocessing import QuantileTransformer
qt = QuantileTransformer(n_quantiles=2, random_state=0)
qt.fit_transform(x)
print(qt.quantiles_)
# [[0.0109899 0.5363847 0.25996208]
# [0.58428931 0.91981947 0.93642473]]
但是,我无法在tf.data.Dataset.map
中使用QuantileTransformer
。例如,
ds.map(lambda x: qt.transform(x))
报错
TypeError: in user code:
<ipython-input-106-867a262b9b69>:13 None *
lambda x: qt.transform(x)
/lib/python3.8/site-packages/sklearn/preprocessing/_data.py:2769 transform *
X = self._check_inputs(X, in_fit=False, copy=self.copy)
/lib/python3.8/site-packages/sklearn/preprocessing/_data.py:2699 _check_inputs *
X = self._validate_data(X, reset=in_fit,
/lib/python3.8/site-packages/sklearn/base.py:420 _validate_data *
X = check_array(X, **check_params)
/lib/python3.8/site-packages/sklearn/utils/validation.py:981 inner_f *
return f(*args, **kwargs)
/lib/python3.8/site-packages/sklearn/utils/validation.py:616 check_array *
array = np.asarray(array, order=order, dtype=dtype)
/lib/python3.8/site-packages/numpy/core/_asarray.py:83 asarray **
return array(a, dtype, copy=False, order=order)
TypeError: __array__() takes 1 positional argument but 2 were given```
【问题讨论】:
【参考方案1】:def map_decorator(func):
def wrapper(inp):
# Use a tf.py_function to prevent auto-graph from compiling the method
return tf.py_function(
func,
inp=[inp],
Tout=(inp.dtype)
)
return wrapper
tf.concat(list(ds.map(map_decorator(qt.transform)).as_numpy_iterator()), axis=0)
【讨论】:
以上是关于如何在 tf.data.Dataset.map 中使用 sklearn.preprocessing?的主要内容,如果未能解决你的问题,请参考以下文章
具有渴望模式的 TF.data.dataset.map(map_func)