tf.data,用不同的数据构造一个批次?

Posted

技术标签:

【中文标题】tf.data,用不同的数据构造一个批次?【英文标题】:tf.data, construct a batch with different data? 【发布时间】:2020-10-24 20:51:50 【问题描述】:

我想构造一个batchsize为16的数据,使用tf.data,其中[:8]是一种数据A,[8:16]是一种数据B。

没有tf.data 很容易做到。如果使用tf.data,代码可能是:

def _decode_record(record, name_to_features):
    example = tf.parse_single_example(record, name_to_features)
    return example

dataA = tf.data.TFRecordDataset(input_files)
dataA = dataA.apply(
            tf.contrib.data.map_and_batch(
                lambda record: _decode_record(record, name_to_features),
                batch_size=batch_size)
           )

接下来该怎么做? 我试试:

dataB = tf.data.TFRecordDataset(input_files2)
dataB = dataB.apply(
            tf.contrib.data.map_and_batch(
                lambda record: _decode_record(record, name_to_features),
                batch_size=batch_size)
           )
dataC = dataA.concatenate(dataB)

concatenate 是:将整个数据集dataB 附加到dataA 的末尾。

对于concatenate,请注意name_to_features 对于dataAdataB 应该相同,这意味着我应该填充很多虚拟数据。

我不想用tf.condtf.where来判断tf.estimatormodel_fn里面的不同数据,也很难调试。

【问题讨论】:

【参考方案1】:

一种解决方法是判断不同的数据:

import tensorflow as tf

data_type = tf.constant([1, 2, 1, 2])
where_index1 = tf.where(tf.equal(data_type, 1))
where_index2 = tf.where(tf.equal(data_type, 2))

data = tf.constant([[10,10],[20,20],[30,30],[40,40]])

data1 = tf.gather_nd(data,where_index1)
data2 = tf.gather_nd(data,where_index2)

sess = tf.Session()

print(sess.run(data1))
print(sess.run(data2))

但这个答案只是以某种方式绕过了这个问题。

【讨论】:

【参考方案2】:

您可以将数据集压缩在一起,然后从 (dataA, dataB) 对构造批次:

import tensorflow as tf

dataset_1 = tf.data.Dataset.from_tensors(1).repeat(100)
dataset_2 = tf.data.Dataset.from_tensors(2).repeat(100)

dataset = tf.data.Dataset.zip((dataset_1, dataset_2))
dataset = dataset.batch(8)
dataset = dataset.map(lambda a, b: tf.concat([a, b], 0))

生产

tf.Tensor([1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2], shape=(16,), dtype=int32)
tf.Tensor([1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2], shape=(16,), dtype=int32)
tf.Tensor([1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2], shape=(16,), dtype=int32)
tf.Tensor([1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2], shape=(16,), dtype=int32)
...

【讨论】:

我没有在我的真实代码中尝试过。请先接受答案。

以上是关于tf.data,用不同的数据构造一个批次?的主要内容,如果未能解决你的问题,请参考以下文章

如何在 tf.data.Dataset 中输入不同大小的列表列表

tensorflow-读写数据tf.data

tf.data.Dataset.padded_batch 以不同的方式填充每个功能

tf.data.Dataset.interleave() 与 map() 和 flat_map() 究竟有何不同?

你能交错来自多个文件的 tf.data 数据集吗?

如何处理批次内不同实例中的不确定句子数量?