如何在特定数量的列中找到 tensorflow 数据集批次中的最大值?
Posted
技术标签:
【中文标题】如何在特定数量的列中找到 tensorflow 数据集批次中的最大值?【英文标题】:How do I find the max value in a tensorflow dataset batch across a specific number of columns? 【发布时间】:2021-12-19 17:40:03 【问题描述】:假设下面的代码:
import tensorflow as tf
import numpy as np
simple_data_samples = np.array([
[1, 1, 1, 7, -1],
[2, -2, 2, -2, -2],
[3, 3, 3, -3, -3],
[-4, 4, 4, -4, -4],
[5, 5, 5, -5, -5],
[6, 6, 6, -4, -6],
[7, 7, 8, -7, -7],
[8, 8, 8, -8, -8],
[9, 4, 9, -9, -9],
[10, 10, 10, -10, -10],
[11, 5, 11, -11, -11],
[12, 12, 12, -12, -12],
])
def print_dataset(ds):
for inputs, targets in ds:
print("---Batch---")
print("Feature:", inputs.numpy())
print("Label:", targets.numpy())
print("")
def timeseries_dataset_multistep_combined(features, label_slice, input_sequence_length, output_sequence_length, sequence_stride, batch_size):
feature_ds = tf.keras.preprocessing.timeseries_dataset_from_array(features, None, sequence_length=input_sequence_length + output_sequence_length, sequence_stride=sequence_stride ,batch_size=batch_size, shuffle=False)
def split_feature_label(x):
return x[:, :input_sequence_length, :]+ tf.reduce_max(x[:,:,:],axis=1), x[:, input_sequence_length:, label_slice]+ tf.reduce_max(x[:,:,:],axis=1)
feature_ds = feature_ds.map(split_feature_label)
return feature_ds
ds = timeseries_dataset_multistep_combined(simple_data_samples, slice(None, None, None), input_sequence_length=4, output_sequence_length=2, sequence_stride=2, batch_size=1)
print_dataset(ds)
让我解释一下上面的代码是做什么的。它创建了许多特征和标签。然后它从每列中获取最大值并将其添加到列中的各个值中。比如这个特征及其对应的标签:
Feature: [[[ 1 1 1 7 -1]
[ 2 -2 2 -2 -2]
[ 3 3 3 -3 -3]
[-4 4 4 -4 -4]]]
Label: [[[ 5 5 5 -5 -5]
[ 6 6 6 -4 -6]]]
在每列中具有以下最大值:
6,6,6,7,-1
然后将最大值添加到相应的列中,然后得到最终输出:
[[ 7 7 7 14 -2]
[ 8 4 8 4 -3]
[ 9 9 9 3 -4]
[ 2 10 10 2 -5]]]
Label: [[[11 11 11 1 -6]
[12 12 12 2 -7]]]
我想从每个特征的前三列和最后两列及其对应的标签中提取最大值,而不是从每一列中提取最大值。提取后,我想将最大值添加到相应列中的每个值。例如,在上面的示例中,前三列的最大值为 6,后两列的最大值为 7。之后,前三列中的每个值将添加 6,后两列中的每个值添加 7。第一批的最终输出如下所示:
Feature: [[[ 7 7 7 14 6]
[ 8 4 8 5 5]
[ 9 9 9 4 4]
[ 2 10 10 3 3]]]
Label: [[[ 11 11 11 2 2]
[ 12 12 12 3 1]]]
有人知道如何从每批的前三列和后两列中提取最大值吗?
【问题讨论】:
【参考方案1】:像这样使用tf.tile
和tf.reduce_max
对你有用吗:
import tensorflow as tf
import numpy as np
simple_data_samples = np.array([
[1, 1, 1, 7, -1],
[2, -2, 2, -2, -2],
[3, 3, 3, -3, -3],
[-4, 4, 4, -4, -4],
[5, 5, 5, -5, -5],
[6, 6, 6, -4, -6],
[7, 7, 8, -7, -7],
[8, 8, 8, -8, -8],
[9, 4, 9, -9, -9],
[10, 10, 10, -10, -10],
[11, 5, 11, -11, -11],
[12, 12, 12, -12, -12],
])
def print_dataset(ds):
for inputs, targets in ds:
print("---Batch---")
print("Feature:", inputs.numpy())
print("Label:", targets.numpy())
print("")
def timeseries_dataset_multistep_combined(features, label_slice, input_sequence_length, output_sequence_length, sequence_stride, batch_size):
feature_ds = tf.keras.preprocessing.timeseries_dataset_from_array(features, None, sequence_length=input_sequence_length + output_sequence_length, sequence_stride=sequence_stride ,batch_size=batch_size, shuffle=False)
def split_feature_label(x):
reduced_first_max_columns = tf.reduce_max(x[:,:,:3], axis=1, keepdims=True)
reduced_last_max_columns = tf.reduce_max(x[:,:,3:], axis=1, keepdims=True)
reduced_first_max_columns = tf.tile(tf.reduce_max(reduced_first_max_columns, axis=-1), [1, 3])
reduced_last_max_columns = tf.tile(tf.reduce_max(reduced_last_max_columns, axis=-1), [1, 2])
reduced_x = tf.expand_dims(tf.concat([reduced_first_max_columns, reduced_last_max_columns], axis=1), axis=0)
return x[:, :input_sequence_length, :] + reduced_x, x[:, input_sequence_length:, label_slice] + reduced_x
feature_ds = feature_ds.map(split_feature_label)
return feature_ds
ds = timeseries_dataset_multistep_combined(simple_data_samples, slice(None, None, None), input_sequence_length=4, output_sequence_length=2, sequence_stride=2, batch_size=1)
print_dataset(ds)
---Batch---
Feature: [[[ 7 7 7 14 6]
[ 8 4 8 5 5]
[ 9 9 9 4 4]
[ 2 10 10 3 3]]]
Label: [[[11 11 11 2 2]
[12 12 12 3 1]]]
---Batch---
Feature: [[[11 11 11 -6 -6]
[ 4 12 12 -7 -7]
[13 13 13 -8 -8]
[14 14 14 -7 -9]]]
Label: [[[ 15 15 16 -10 -10]
[ 16 16 16 -11 -11]]]
...
【讨论】:
以上是关于如何在特定数量的列中找到 tensorflow 数据集批次中的最大值?的主要内容,如果未能解决你的问题,请参考以下文章