如何在 TensorFlow 中使用“group_by_window”函数
Posted
技术标签:
【中文标题】如何在 TensorFlow 中使用“group_by_window”函数【英文标题】:How do I use the "group_by_window" function in TensorFlow 【发布时间】:2017-12-30 17:50:45 【问题描述】:在 TensorFlow 的一组新输入管道函数中,可以使用“group_by_window”函数将记录集分组在一起。它在此处的文档中进行了描述:
https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#group_by_window
我不完全理解这里用于描述函数的解释,我倾向于通过示例来学习。我在互联网上的任何地方都找不到此功能的任何示例代码。有人可以为这个函数制作一个准系统和可运行的例子来展示它是如何工作的,以及赋予这个函数什么?
【问题讨论】:
【参考方案1】:对于 tensorflow 版本 1.9.0 这是我可以想出的一个简单示例:
import tensorflow as tf
import numpy as np
components = np.arange(100).astype(np.int64)
dataset = tf.data.Dataset.from_tensor_slices(components)
dataset = dataset.apply(tf.contrib.data.group_by_window(key_func=lambda x: x%2, reduce_func=lambda _, els: els.batch(10), window_size=100)
iterator = dataset.make_one_shot_iterator()
features = iterator.get_next()
sess = tf.Session()
sess.run(features) # array([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18], dtype=int64)
第一个参数key_func
将数据集中的每个元素映射到一个键。
window_size
定义了分配给reduce_fund
的存储桶大小。
在reduce_func
中,您会收到一个window_size
元素块。您可以随心所欲地随机播放、批处理或填充。
使用 group_by_window 函数more here 编辑动态填充和分桶:
如果你有一个 tf.contrib.dataset
包含 (sequence, sequence_length, label)
并且序列是 tf.int64 的张量:
def bucketing_fn(sequence_length, buckets):
"""Given a sequence_length returns a bucket id"""
t = tf.clip_by_value(buckets, 0, sequence_length)
return tf.argmax(t)
def reduc_fn(key, elements, window_size):
"""Receives `window_size` elements"""
return elements.shuffle(window_size, seed=0)
# Create buckets from 0 to 500 with an increment of 15 -> [0, 15, 30, ... , 500]
buckets = [tf.constant(num, dtype=tf.int64) for num in range(0, 500, 15)
window_size = 1000
# Bucketing
dataset = dataset.group_by_window(
lambda x, y, z: bucketing_fn(x, buckets),
lambda key, x: reduc_fn(key, x, window_size), window_size)
# You could pad it in the reduc_func, but I'll do it here for clarity
# The last element of the dataset is the dynamic sentences. By giving it tf.Dimension(None) it will pad the sencentences (with 0) according to the longest sentence.
dataset = dataset.padded_batch(batch_size, padded_shapes=(
tf.TensorShape([]), tf.TensorShape([]), tf.Dimension(None)))
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()
features = iterator.get_next()
【讨论】:
key_func=lambda x: x%2
将 x 映射到 0 和 1,对吗?我不明白为什么结果只有偶数元素?
是的。它基本上创建了两个桶:一个用于偶数,一个用于奇数。第一个打印语句中只有偶数元素,因为它从该批次的偶数桶中获取元素
嗨@MaximeDeBruyn,你能解释一下为什么你有lambda x, y, z: bucketing_fn(x, buckets)
和lambda key, x: reduc_fn(key, x, window_size), window_size)
吗?只传递函数而不是 lambda 会有什么问题?以上是关于如何在 TensorFlow 中使用“group_by_window”函数的主要内容,如果未能解决你的问题,请参考以下文章
如何让 Tensorflow Profiler 在 Tensorflow 2.5 中使用“tensorflow-macos”和“tensorflow-metal”工作