如何使用 TF1.3 中的新 Dataset api 映射具有附加参数的函数?
Posted
技术标签:
【中文标题】如何使用 TF1.3 中的新 Dataset api 映射具有附加参数的函数?【英文标题】:How to map a function with additional parameter using the new Dataset api in TF1.3? 【发布时间】:2018-02-26 02:03:39 【问题描述】:我在玩the Dataset API in Tensorflow v1.3。这很棒。
可以使用here 中描述的函数映射数据集。我很想知道如何传递一个有附加参数的函数,例如arg1
:
def _parse_function(example_proto, arg1):
features = "image": tf.FixedLenFeature((), tf.string, default_value=""),
"label": tf.FixedLenFeature((), tf.int32, default_value=0)
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features["image"], parsed_features["label"]
当然,
dataset = dataset.map(_parse_function)
由于无法传入arg1
,因此无法使用。
【问题讨论】:
只是一个想法,也许我们可以通过传入一个以 arg1 作为其类成员的 python 类以及定义的 __call__ 方法来伪造它。arg1
是个什么样的参数?如果它是一个普通的 Python 变量(不是 TensorFlow),你可以在另一个已知 arg1
的函数中定义你的 _parse_function
函数,你不必再传递它了。
【参考方案1】:
这是一个使用 lambda 表达式来包装我们想要传递参数的函数的示例:
import tensorflow as tf
def fun(x, arg):
return x * arg
my_arg = tf.constant(2, dtype=tf.int64)
ds = tf.data.Dataset.range(5)
ds = ds.map(lambda x: fun(x, my_arg))
在上面,提供给map
的函数的签名必须与我们数据集的内容相匹配。所以我们必须编写我们的 lambda 表达式来匹配它。这里很简单,因为数据集中只包含一个元素,x
包含从 0 到 4 范围内的元素。
如有必要,您可以从数据集外部传入任意数量的外部参数:ds = ds.map(lambda x: my_other_fun(x, arg1, arg2, arg3)
,等等。
为了验证上述方法是否有效,我们可以观察到映射确实将每个数据集元素乘以 2:
iterator = ds.make_initializable_iterator()
next_x = iterator.get_next()
with tf.Session() as sess:
sess.run(iterator.initializer)
while True:
try:
print(sess.run(next_x))
except tf.errors.OutOfRangeError:
break
输出:
0
2
4
6
8
【讨论】:
【参考方案2】:您也可以使用Partial
函数来包装您的参数:
def _parse_function(arg1, example_proto):
features = "image": tf.FixedLenFeature((), tf.string, default_value=""),
"label": tf.FixedLenFeature((), tf.int32, default_value=0)
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features["image"], parsed_features["label"]
更改函数的参数顺序以适应偏向性,然后您可以使用如下参数值包装函数:
from functools import partial
arg1 = ...
dataset = dataset.map(partial(_parse_function, arg1))
【讨论】:
【参考方案3】:另一个解决方案是使用类包装器。在下面的代码中,我将参数 shape 传递给了解析函数。
class MyDataSets:
def __init__(self, shape):
self.shape = shape
def parse_sample(self.sample):
features = ...
f = tf.parse_example([example], features=features)
image_raw = tf.decode_raw(f['image_raw'], tf.uint8)
image = image.reshape(image_raw, self.shape)
label = tf.cast(f['label'], tf.int32)
return image, label
def init(self):
ds = tf.data.TFRecordDataSets(...)
ds = ds.map(self.parse_sample)
...
return ds.make_initializable_iterator()
【讨论】:
以上是关于如何使用 TF1.3 中的新 Dataset api 映射具有附加参数的函数?的主要内容,如果未能解决你的问题,请参考以下文章
如何通过 API(补丁)授予对 BigQuery 数据集的新视图访问权限