读取自己的数据集
Posted zonghui
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了读取自己的数据集相关的知识,希望对你有一定的参考价值。
图像分类任务中,大多数教程是直接导入深度学习库中的数据集直接用于模型训练,如果采用自己的数据集,会难以下手,这篇博客主要介绍使用Tensorflow2.1或Keras来读取自己的数据集。
1、Tensorflow方法制作数据集
Tensorflow制作数据集,主要用到tf.data进行操作。步骤为制作csv文件、读取csv、读取数据、数据处理。
需要用到的库
import os import random import glob import csv import tensorflow as tf
1.1 制作csv文件
# 创建csv文件,输入分别为路径和要创建的csv文件名 def build_csv(root, filename): # 对种类进行编号,相当于用0,1,2分别表示这三个水果类别 name2label = {} for name in sorted(os.listdir(os.path.join(root))): # 判断文件夹下的对象是否是一个文件夹 # 不是文件夹,直接进行下一次判断 # 是文件夹,对该目录进行编号 if not os.path.isdir(os.path.join(root, name)): continue name2label[name] = len(name2label.keys()) # 准备从每个文件夹中读取图片路径与编号 images = [] # 遍历数据集中的每个文件夹 for name in name2label.keys(): # 读取所有的png,jpg,jpeg格式的文件 images += glob.glob(os.path.join(root, name, ‘*.png‘)) images += glob.glob(os.path.join(root, name, ‘*.jpg‘)) images += glob.glob(os.path.join(root, name, ‘*.jpeg‘)) print(len(images), images) random.shuffle(images) # 创建并写csv文件 with open(os.path.join(root, filename), mode=‘w‘, newline=‘‘) as f: writer = csv.writer(f) for img in images: # 更改路径的分隔符 name = img.split(os.sep)[-2] label = name2label[name] writer.writerow([img, label]) print(‘written into csv file:‘, filename)
1.2读取csv文件
# 输入分别为路径和刚刚创建的csv文件名 def load_csv(root, filename): images, labels = [], [] with open(os.path.join(root, filename)) as f: reader = csv.reader(f) for row in reader: img, label = row label = int(label) images.append(img) labels.append(label) return images, labels
1.3将数据集转换为tf.data格式
# 读取csv文件 images, labels = load_csv(root, filename) # 转换为tf.data格式 dataset = tf.data.Dataset.from_tensor_slices((images, labels)) # 数据处理操作,其中preprocessing是需要自己编写的一个实现数据处理功能的函数 dataset = dataset.shuffle(1000).map(preprocess).batch(32)
1.4数据处理操作
# 输入为路径和标签 def preprocess(x, y): # 根据路径读取图片 x = tf.io.read_file(x) # 将图片数值转换为张量 x = tf.image.decode_jpeg(x, channels=3) # 更改尺寸 x = tf.image.resize(x, [244, 244]) # 归一化 x = tf.cast(x, dtype=tf.float32) / 255. y = tf.convert_to_tensor(y) return x, y
2、Keras方法制作数据集
Keras制作数据集,使用Keras进行导入数据集。使用keras导入数据集,过程简单方便。
需要用到的库
from keras.preprocessing.image import ImageDataGenerator
2.1读取数据
# 将照片[0-255]数据缩放为[0-1] train_datagen = ImageDataGenerator(rescale=1./255) test_datagen = ImageDataGenerator(rescale=1./255) # 训练集与验证集路径 train_dir = "train/" validation_dir = "validation/" # 生成了224x224的RGB图像,形状为[20,224,224,3]与二进制标签[20,]的批量,每个批量包含20个样本 train_generator = train_datagen.flow_from_directory( train_dir, # 训练集路径 target_size=(224, 224), # 训练集样本尺寸大小为(224, 224) batch_size=32, # 训练集每批包含20个样本 class_mode=‘categorical‘) validation_generator = test_datagen.flow_from_directory( validation_dir, target_size=(224, 224), batch_size=16, class_mode=‘categorical‘)
2.2 输入数据到模型
history = model.fit_generator( train_generator, validation_data=validation_generator,
......
)
以上是关于读取自己的数据集的主要内容,如果未能解决你的问题,请参考以下文章
tensorflowxun训练自己的数据集之从tfrecords读取数据