读取自己的数据集

Posted zonghui

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了读取自己的数据集相关的知识,希望对你有一定的参考价值。

  图像分类任务中,大多数教程是直接导入深度学习库中的数据集直接用于模型训练,如果采用自己的数据集,会难以下手,这篇博客主要介绍使用Tensorflow2.1或Keras来读取自己的数据集。

1、Tensorflow方法制作数据集

   Tensorflow制作数据集,主要用到tf.data进行操作。步骤为制作csv文件、读取csv、读取数据、数据处理。

需要用到的库

import os
import random
import glob
import csv
import tensorflow as tf

1.1 制作csv文件

# 创建csv文件,输入分别为路径和要创建的csv文件名
def build_csv(root, filename):
    # 对种类进行编号,相当于用0,1,2分别表示这三个水果类别
    name2label = {}
    for name in sorted(os.listdir(os.path.join(root))):
        # 判断文件夹下的对象是否是一个文件夹
        # 不是文件夹,直接进行下一次判断
        # 是文件夹,对该目录进行编号
        if not os.path.isdir(os.path.join(root, name)):
            continue
        name2label[name] = len(name2label.keys())
    # 准备从每个文件夹中读取图片路径与编号
    images = []
    # 遍历数据集中的每个文件夹
    for name in name2label.keys():
        # 读取所有的png,jpg,jpeg格式的文件
        images += glob.glob(os.path.join(root, name, *.png))
        images += glob.glob(os.path.join(root, name, *.jpg))
        images += glob.glob(os.path.join(root, name, *.jpeg))
    print(len(images), images)
    random.shuffle(images)
    # 创建并写csv文件
    with open(os.path.join(root, filename), mode=w, newline=‘‘) as f:
        writer = csv.writer(f)
        for img in images:
            # 更改路径的分隔符
            name = img.split(os.sep)[-2]
            label = name2label[name]
            writer.writerow([img, label])
        print(written into csv file:, filename)

1.2读取csv文件

# 输入分别为路径和刚刚创建的csv文件名
def load_csv(root, filename):
    images, labels = [], []
    with open(os.path.join(root, filename)) as f:
        reader = csv.reader(f)
        for row in reader:
            img, label = row
            label = int(label)
            images.append(img)
            labels.append(label)
    return images, labels

1.3将数据集转换为tf.data格式

# 读取csv文件
images, labels = load_csv(root, filename)
# 转换为tf.data格式
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
# 数据处理操作,其中preprocessing是需要自己编写的一个实现数据处理功能的函数
dataset = dataset.shuffle(1000).map(preprocess).batch(32)

1.4数据处理操作

# 输入为路径和标签
def preprocess(x, y):
    # 根据路径读取图片
    x = tf.io.read_file(x)
    # 将图片数值转换为张量
    x = tf.image.decode_jpeg(x, channels=3)
    # 更改尺寸
    x = tf.image.resize(x, [244, 244])
    # 归一化
    x = tf.cast(x, dtype=tf.float32) / 255.
    y = tf.convert_to_tensor(y)

    return x, y

2、Keras方法制作数据集

   Keras制作数据集,使用Keras进行导入数据集。使用keras导入数据集,过程简单方便。

需要用到的库

from keras.preprocessing.image import ImageDataGenerator

2.1读取数据

# 将照片[0-255]数据缩放为[0-1]
train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

# 训练集与验证集路径
train_dir = "train/"
validation_dir = "validation/"

# 生成了224x224的RGB图像,形状为[20,224,224,3]与二进制标签[20,]的批量,每个批量包含20个样本
train_generator = train_datagen.flow_from_directory(
    train_dir,                  # 训练集路径
    target_size=(224, 224),     # 训练集样本尺寸大小为(224, 224)
    batch_size=32,              # 训练集每批包含20个样本
    class_mode=‘categorical)    
validation_generator = test_datagen.flow_from_directory(
    validation_dir,
    target_size=(224, 224),
    batch_size=16,
    class_mode=‘categorical)

2.2 输入数据到模型

history = model.fit_generator(
    train_generator,           
    validation_data=validation_generator,
  ......
)

         

以上是关于读取自己的数据集的主要内容,如果未能解决你的问题,请参考以下文章

tensorflowxun训练自己的数据集之从tfrecords读取数据

机器学习初探(手写数字识别)matlab读取数据集

将 ARB 程序集翻译成 GLSL?

手写数字识别——基于全连接层和MNIST数据集

深度学习(tensorflow) —— 自己数据集读取opencv

TensorFlow 制作自己的TFRecord数据集