商品检测数据集训练项目训练结构介绍

Posted ZSYL

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了商品检测数据集训练项目训练结构介绍相关的知识,希望对你有一定的参考价值。

项目训练结构介绍

1. 项目目录结构

  • ckpt:分为预训练与微调模型
  • datasets:放训练原始数据以及存储数据、读取数据代码以及模型priorbox
  • servingmodel:模型部署使用的模型位置
  • export_serving_model:导出TFserving指定模型类型
  • train_ssd:训练模型代码逻辑

2. train_ssd.py

"""商品检测数据集训练"""
import pickle
from computerVision.utils.detection_generate import Generator
from computerVision.utils.ssd_utils import BBoxUtility
from computerVision.nets.ssd_net import SSD300
from computerVision.utils.ssd_losses import MultiboxLoss
from tensorflow.python.keras.callbacks import ModelCheckpoint, TensorBoard
import tensorflow as tf
import keras

class SSDTrain(object):

    def __init__(self, num_classes=9, input_shape=(300, 300, 3), epochs=30):
        """
        初始化网络指定一些参数,训练数据类别,图片需要指定模型输入大小,迭代次数
        """
        self.num_classes = num_classes
        self.batch_size = 32
        self.input_shape = input_shape
        self.epochs = epochs

        # 指定训练读取数据的相关参数
        self.gt_path = './datasets/commodity_gt.pkl'
        self.image_path = './datasets/commodity/JPEGImages/'

        prior = pickle.load(open('./datasets/prior_boxes_ssd300.pkl', 'rb'))
        self.bbox_util = BBoxUtility(self.num_classes, prior)

        # 权重weights参数
        self.pre_trained = './ckpt/pre_trained/weights_SSD300.hdf5'

        # 初始化模型
        self.model = SSD300(self.input_shape, num_classes=self.num_classes)

    def get_detection_data(self):
        """
        获取检测的迭代数据
        :return:
        """
        # 1. 读取标注数据,构造训练图片名字列表,测试图片名字列表
        gt = pickle.load(open(self.gt_path, 'rb'))
        # 图片名字列表
        name_keys = keys = sorted(gt.keys())
        number = int(0.8 * len(name_keys))
        train_keys = name_keys[:number]
        val_keys = name_keys[number:]

        # 2. 通过generator去获取迭代批次数据
        # gt: 所有数据的目标值字典
        # path_prefix:图片路径
        # Generator:生成图片的目标值等数据
        gen = Generator(gt, self.bbox_util, self.batch_size, self.image_path,
                  train_keys, val_keys, (self.input_shape[0], self.input_shape[1]), do_crop=False)

        return gen

    def init_model_param(self):
        """
        初始化网络模型参数,指定微调的时候,训练部分
        :return:
        """
        # 1. 加载本地预训练好的模型
        self.model.load_weights(self.pre_trained, by_name=True)
        # 2. 指定模型当中默写结构freeze
        # 冻结模型部分为 SSD当中的VGG前半部分
        freeze = ['input_1', 'conv1_1', 'conv1_2', 'pool1',
                  'conv2_1', 'conv2_2', 'pool2',
                  'conv3_1', 'conv3_2', 'conv3_3', 'pool3']
        # 遍历每一层的结果
        for L in self.model.layers:
            if L.name in freeze:
                L.trainable = False

        return None

    def compile(self):
        """
        编译模型:
        SSD网络的损失函数计算MultiboxLoss的compute_loss
        :return:
        """
        # MultiboxLoss: N个类别+1个背景类别
        # tf.keras.optimizers.Adam() 出现问题,给4个,是需要3个
        # keras 1.2.2 optimizers.Adam()
        distribution = tf.contrib.distribute.MirroredStrategy()
        self.model.compile(optimizer=keras.optimizers.Adam(),
                           loss=MultiboxLoss(self.num_classes).compute_loss,
                           distribution=distribution)

    def fit_genrator(self, gen):
        """
        进行训练
        :return:
        """
        # 建立回调函数
        callback = [
            ModelCheckpoint('./ckpt/fine_tuning/weights.{epoch:02d}-{val_acc:.2f}.hdf5',
                                            monitor='val_acc',
                                            save_weights_only=True,
                                            save_best_only=True,
                                            mode='auto',
                                            period=1),  # 保存模型的次数及损失变换情况
            TensorBoard(log_dir='./graph')  # 保存图的信息,整个图的变换等 , histogram_freq=1, write_graph=True, write_images=True
        ]

        self.model.fit_generator(gen.generate(train=True), gen.train_batches, self.epochs, callbacks=callback, validation_data=gen.generate(train=False),
                                 nb_val_samples=gen.val_batches)


if __name__ == '__main__':
    ssd = SSDTrain(num_classes=9)
    gen = ssd.get_detection_data()
    ssd.init_model_param()
    ssd.compile()
    ssd.fit_genrator(gen)

以上是关于商品检测数据集训练项目训练结构介绍的主要内容,如果未能解决你的问题,请参考以下文章

商品检测数据集训练目标检测数据集与标记

商品检测数据集训练应用API完成商品数据集的训练

深度学习目标检测---使用yolov5训练自己的数据集模型(Windows系统)

YOLOv7训练自己的数据集(口罩检测)

戴眼镜检测和识别2:Pytorch实现戴眼镜检测和识别(含戴眼镜数据集和训练代码)

开源一个安全帽佩戴检测数据集及预训练模型