keras.utils.Sequence:FileSequence生成文件序列流

Posted ZSYL

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了keras.utils.Sequence:FileSequence生成文件序列流相关的知识,希望对你有一定的参考价值。

keras.utils.Sequence:FileSequence文件序列流

前言

最近参加【2021年第三届全国高校计算机能力挑战赛】大数据应用赛,题目属于计算机视觉方向的图像扶正。

在训练模型时避免不了需要批量加载图片文件,因此需要构建一个FileSequence文件序列流类,进行加载图片文件。

我对官方示例提交代码进行浅显研究,仅供个人学习使用。

tensorflow.keras.utils.Sequence学习

tensorflow.keras.utils.Sequence的使用:控制模型从文件读入batch_size的数据.(数据生成器

在使用keras的时候,一般使用model.fit()来传入训练数据,fit() 接受多种类型的数据:

  1. 数组类型
  2. dataset类型
  3. python generator,但是限制比较多,一般要在编写python generator的平 台下运行模型
  4. tensorflow.keras.utils.Sequence和python generator差不多,但是限制较少,可迁移性更好.

官方例子

from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math

# Here, `x_set` is list of path to the images(图片路径列表)
# and `y_set` are the associated classes.(标签)

class CIFAR10Sequence(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return math.ceil(len(self.x) / self.batch_size)

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) *
        self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) *
        self.batch_size]

        return np.array([
            resize(imread(file_name), (200, 200))
               for file_name in batch_x]), np.array(batch_y)

官方链接

  • init():初始化类
  • len():返回batch_size的个数,也就是完整跑一遍数据要运行运行模型多少次
  • getitem():返回一个batch_size的数据(data,label)
  • on_epoch_end():这个函数例子中没有用到,但是官网有给,就是在每个 epoch跑完之后,你要做什么可以通过这个函数实现

这是以上函数的作用,虽然官方给的例子是像上面那样的。但是我们却不一定要写和它一模一样的格式,只要每个函数返回的东西和上面例子一样就行(比如:getitem()返回的是一个batch_size的数据,只要你在这个函数返回的是一个batch_size的数据,那么函数里面怎么运行的都可以)。

下面是比赛主办方提供的一个Sequence类,用于批量读取图片文件及进行图片处理等。

代码展示

# import区域,sys为必须导入,其他根据需求导入
import os
import sys
import random
import numpy as np
import pandas as pd
import cv2
import tensorflow as tf

# 设置GPU
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    tf.config.experimental.set_memory_growth(gpus[0],True)
import tensorflow.keras as keras
from tensorflow.keras import layers

# 文件流序列模型
class FileSequence(keras.utils.Sequence):
    def __init__(self,filenames,batch_size,filefunc,fileargs=(),labels=None,labelfunc=None,labelargs=(),shuffle=False):
        if labels: assert len(filenames) == len(labels)
        self.filenames  = filenames  # 文件名列表
        self.batch_size = batch_size  # 批次大小,一次训练多少文件(图片)
        self.filefunc   = filefunc
        self.fileargs   = fileargs
        self.labels     = labels  # 文件对应的标签列表
        self.labelfunc  = labelfunc
        self.labelargs  = labelargs  
        
        # 是否打乱
        if shuffle:
            idx_list = list(range(len(self.filenames)))  # 生成1-文件列表长度的序列下标
            random.shuffle(idx_list)  # 打乱
            self.filenames = [self.filenames[idx] for idx in idx_list]  # 按照打乱的下标,更新文件名
            if self.labels: self.labels = [self.labels[idx] for idx in idx_list]  # 按照打乱的下标,更新文件名

    def __len__(self):
        """文件流长度=文件数/batch_size"""
        return int(np.ceil(len(self.filenames) / float(self.batch_size)))

    def __getitem__(self, idx):
        """获取一个batch_size的文件"""
        batch_filenames = self.filenames[idx * self.batch_size: (idx+1) * self.batch_size]
        
        files = []
        for filename in batch_filenames:
            # tf.print(filename)
            file = self.filefunc(filename,*self.fileargs)
            files.append(file)
        if self.labels:
            batch_labels = self.labels[idx * self.batch_size: (idx+1) * self.batch_size]
            if self.labelfunc:
                return np.array(files), self.labelfunc(batch_labels,*self.labelargs)
            else:
                return np.array(files), batch_labels
        else:
            return np.array(files)

def fillWhite(img,size,mode=None):
    if len(img.shape) == 2: img = img.reshape(*img.shape,-1)
    assert len(img.shape) == 3  # 图片必须是h, w, c三通道
    h, w, c = img.shape
    assert (h < size) and (w < size)
    fillImg = np.zeros(shape=(size,size,c))
    if mode == "random":
        sh = random.randint(0,size-h)
        sw = random.randint(0,size-w)
        fillImg[sh:sh+h,sw:sw+w,...] = img
    elif mode == "centre" or mode == "center":
        fillImg[(size-h)//2:(size+h)//2,(size-w)//2:(size+w)//2,...] = img
    else:
        fillImg[:h,:w,...] = img
    return fillImg

def cropImg(img,size,mode=None):
    if len(img.shape) == 2: img = img.reshape(*img.shape,-1)
    assert len(img.shape) == 3
    h, w, c = img.shape
    assert (h >= size) and (w >= size)
    if mode == "random":
        sh = random.randint(0,h-size)
        sw = random.randint(0,w-size)
        cropImg = img[sh:sh+size,sw:sw+size,...]
    elif mode == "centre" or mode == "center":
        cropImg = img[(h-size)//2:(h+size)//2,(w-size)//2:(w+size)//2,...]
    else:
        cropImg = img[:size,:size,...]
    return cropImg

def fillCrop(img,size,mode=None):
    if len(img.shape) == 2: img = img.reshape(*img.shape,-1)
    assert len(img.shape) == 3
    h, w, c = img.shape
    assert ((h >= size) and (w < size)) or ((h < size) and (w >= size))
    fillcropImg = np.zeros(shape=(size,size,c))
    if mode == "random":
        if (h >= size) and (w < size):
            sh = random.randint(0,h-size)
            sw = random.randint(0,size-w)
            fillcropImg[:,sw:sw+w,:] = img[sh:sh+size,...]
        else:
            sh = random.randint(0,size-h)
            sw = random.randint(0,w-size)
            fillcropImg[sh:sh+h,...] = img[:,sw:sw+size,:]
    elif mode == "centre" or mode == "center":
        if (h >= size) and (w < size):
            fillcropImg[:,(size-w)//2:(size+w)//2,:] = img[(h-size)//2:(h+size)//2,...]
        else:
            fillcropImg[(size-h)//2:(size+h)//2,...] = img[:,(w-size)//2:(w+size)//2,:]
    else:
        if (h >= size) and (w < size):
            fillcropImg[:,:size,:] = img[:size,...]
        else:
            fillcropImg[:size,...] = img[:,:size,:]
    return fillcropImg

def resizeImg(img,size,mode=None):
    """图片大小重设"""
    if len(img.shape) == 2: img = img.reshape(*img.shape,-1)
    assert len(img.shape) == 3
    h, w, c = img.shape
    if (h < size) and (w < size): return fillWhite(img,size,mode)  # 小的话,填充白色
    elif (h >= size) and (w >= size): return cropImg(img,size,mode)  # 大的话,进行裁减
    else: return fillCrop(img,size,mode)

def filefunc(filename,mode):
    """图片文件大小处理功能函数"""
    tf.print(filename)
    img = cv2.imread(filename)
    if not isinstance(img,np.ndarray):
        tf.print(filename)
    h, w, c = img.shape
    # 进行裁减
    if (h >=256) or (w >= 256):
        img = resizeImg(img,256,mode)
        img = cv2.resize(img,(64,64))
    elif (h >=128) or (w >= 128):
        img = resizeImg(img,128,mode)
        img = cv2.resize(img,(64,64))
    else:
        img = resizeImg(img,64,mode)
    return img    

# 主函数,格式固定,to_pred_dir为预测所在文件夹,result_save_path为预测结果生成路径
def main(to_pred_dir, result_save_path):
    runpyp = os.path.abspath(__file__)  # 返回当前脚本文件的绝对路径
    modeldirp = os.path.dirname(runpyp)  # 返回当前脚本文件的所在路径(上一层目录)
    modelp = os.path.join(modeldirp,"model_6_4_0.9085.h5")  # 拼接路径
    model = keras.models.load_model(modelp)  # 加载模型

    pred_imgs = os.listdir(to_pred_dir)  # 准备训练的数据文件夹
    pred_imgsp_lines = [os.path.join(to_pred_dir,p) for p in pred_imgs]  # 路径拼接
    # pred_imgsp_lines = "\\n".join(pred_imgsp)
    # tf.print(pred_imgsp_lines)

    imgsets = FileSequence(pred_imgsp_lines,32,filefunc,fileargs=("centre",),shuffle=False)
    preds = model.predict(imgsets)
    preds = np.argmax(preds,axis=1)
    preds[preds<10] = 0
    preds[preds>=10] = 1

    prepreds = np.zeros_like(preds).astype(str)
    # 预测标签
    testset_dicts = 
        0:'cat',   1:'dog'
    
    for i in range(2):
        prepreds[preds==i]=testset_dicts[i]  # 赋值
    
    df = pd.DataFrame("id":pred_imgs,"label":prepreds)  # 生成DF对象
    
    # 生成csv结果
    df.to_csv(result_save_path,index=None)

# !!!注意:
# 图片赛题给出的参数为to_pred_dir,是一个文件夹,其图片内容为
# to_pred_dir/to_pred_0.png
# to_pred_dir/to_pred_1.png
# to_pred_dir/......
# 所需要生成的csv文件头为id,label,如下
# image_id,label
# to_pred_0,4
# to_pred_1,76
# to_pred_2,...

if __name__ == "__main__":
    to_pred_dir = sys.argv[1]  # 所需预测的文件夹路径
    result_save_path = sys.argv[2]  # 预测结果保存文件路径
    main(to_pred_dir, result_save_path)

参考Link


加油!

感谢!

努力!

以上是关于keras.utils.Sequence:FileSequence生成文件序列流的主要内容,如果未能解决你的问题,请参考以下文章

keras.utils.Sequence使用注意事项

keras.utils.Sequence使用注意事项

on_epoch_end() 未在 keras fit_generator() 中调用

:模型训练和预测的三种方法(fit&tf.GradientTape&train_step&tf.data)

if (!file) 和 if (file == NULL) 的区别

java是否用File f=new file;创建文件的