如何在 TensorFlow 中对数据集进行切片?

Posted

技术标签:

【中文标题】如何在 TensorFlow 中对数据集进行切片?【英文标题】:How can slicing dataset in TensorFlow? 【发布时间】:2020-12-05 08:15:12 【问题描述】:

我想在tf.data 中切片数据集。我的数据是这样的:

dataset = tf.data.Dataset.from_tensor_slices([[0, 1, 2, 3, 4],
                                              [1, 2, 3, 4, 5],
                                              [2, 3, 4, 5, 6],
                                              [3, 4, 5, 6, 7],
                                              [4, 5, 6, 7, 8]])

那么主要数据是:

[0 1 2 3 4]
[1 2 3 4 5]
[2 3 4 5 6]
[3 4 5 6 7]
[4 5 6 7 8]

我想创建包含如下数据的其他张量数据集:

       [[1, 2],
       [2, 3],
       [3, 4],
       [4, 5],
       [5, 6]]

在numpy中是这样的:

dataset[:,1:3]

如何在 TensorFlow 中做到这一点?

更新:

我这样做:

dataset2 = dataset.map(lambda data: data[1:3])
for val in dataset2:
    print(val.numpy())

但我认为有很好的解决方案。

【问题讨论】:

【参考方案1】:

在我看来,您的解决方案是最好的解决方案。为了社区的利益,我使用tf.data.Dataset 的as_numpy_iterator() 方法对数据集进行切片(对您的代码进行小的语法更改)。

请参考以下代码

import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([[0, 1, 2, 3, 4],
                                              [1, 2, 3, 4, 5],
                                              [2, 3, 4, 5, 6],
                                              [3, 4, 5, 6, 7],
                                              [4, 5, 6, 7, 8]])


dataset2 = dataset.map(lambda data: data[1:3])
for val in dataset2.as_numpy_iterator():
    print(val)

输出:

[1 2]
[2 3]
[3 4]
[4 5]
[5 6]

【讨论】:

以上是关于如何在 TensorFlow 中对数据集进行切片?的主要内容,如果未能解决你的问题,请参考以下文章

如何在 Tensorflow 中对无维度的张量进行切片

如何在 Javascript 中对对象进行切片?

如何在 Bash 中对数组进行切片

如何在 Swift 中对字符串进行切片? [复制]

如何在 Excel VBA 中对数组进行切片?

如何在 Tensorflow 中进行切片分配