以数据源为数据库构建 tensorflow 数据集

Posted

技术标签:

【中文标题】以数据源为数据库构建 tensorflow 数据集【英文标题】:Build tensorflow dataset with datasource as database 【发布时间】:2019-11-14 22:42:49 【问题描述】:

我必须使用 tensorflow tf.data 创建一个数据输入管道。数据源是 mongodb 和 sql server。如何从数据库创建 tf.data 对象。我看到的所有文章都有 .tfrecords 或 .csv 作为 tensorflow 的数据源。

谢谢。 感谢您的投入

【问题讨论】:

【参考方案1】:

从数据库中检索数据并将其存储为 numpy 数组。如果数组对于内存来说太大,请尝试使用 memmap 数组。

然后创建一个生成器,这是我自己的代码中图像及其 onehot 编码的示例:

def tf_augmented_image_generator(images,
                                 onehots,
                                 batch_size,
                                 map_fn,
                                 shuffle_size=1000,
                                 num_parallel_calls=tf.data.experimental.AUTOTUNE):
    """
    Create a generator suing a tf.data.Dataframe with augmentation via a map function.
    The generator can then be used for training in model.fit_generator

    The map function must consist of tensorflow operators (not numpy).

    On Windows machines this will lead to faster augmentation, as there are some
    problems performing augmentation in parallel when multiprocessing is enabled in
    in model.fit / model.fit_generator and the default Keras numpy-based augmentated is used,
    e.g. in ImageDataGenerator

    :param images: Images to augment
    :param onehots: Onehot encoding of target class
    :param batch_size: Batch size for training
    :param map_fn: The augmentation map function
    :param shuffle_size: Batch size of images shuffled. Smaller values reduce memory consumption.
    :param num_parallel_calls: Number of calls in parallel, default is automatic tuning.
    :return:
    """
    # Get shapes from input data
    img_size = images.shape
    img_size = (None, img_size[1], img_size[2], img_size[3])
    onehot_size = onehots.shape
    onehot_size = (None, onehot_size[1])
    images_tensor = tf.placeholder(tf.float32, shape=img_size)
    onehots_tensor = tf.placeholder(tf.float32, shape=onehot_size)

    # Create dataset
    dataset = tf.data.Dataset.from_tensor_slices((images_tensor, onehots_tensor))
    if map_fn is not None:
        dataset = dataset.map(lambda x, y: (map_fn(x), y), num_parallel_calls=num_parallel_calls)
    dataset = dataset.shuffle(shuffle_size, reshuffle_each_iteration=True).repeat()
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(1)

    iterator = dataset.make_initializable_iterator()
    init_op = iterator.initializer
    next_val = iterator.get_next()

    with K.get_session().as_default() as sess:
        sess.run(init_op, feed_dict=images_tensor: images, onehots_tensor: onehots)
        while True:
            inputs, labels = sess.run(next_val)
            yield inputs, labels

然后使用fit_generator训练模型

【讨论】:

【参考方案2】:

Checkout TFMongoDB,一个用于 Tensorflow 的 C++ 实现的数据集操作,可让您连接到 MongoDB。

dataset = MongoDBDataset("dbname", "collname")
dataset = dataset.map(_parse_line)
repeat_dataset2 = dataset.repeat()
batch_dataset = repeat_dataset2.batch(20)

iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
#init_op = iterator.make_initializer(dataset)
init_batch_op = iterator.make_initializer(batch_dataset)
get_next = iterator.get_next()

with tf.Session() as sess:
    sess.run(init_batch_op, feed_dict=)

    for i in range(5):
        print(sess.run(get_next))

【讨论】:

以上是关于以数据源为数据库构建 tensorflow 数据集的主要内容,如果未能解决你的问题,请参考以下文章

智能算法自创数据集,使用TensorFlow预测股票入门

在 TensorFlow 中导入巨大的非图像数据集

如何准备 imagenet 数据集以运行 resnet50(来自官方 Tensorflow 模型花园)培训

物体检测快速入门系列 | 01-基于Tensorflow2.x Object Detection API构建自定义物体检测器

使用 tensorflow 将数据集拆分为训练和测试

AI - TensorFlow - 示例:影评文本分类