商品检测数据集训练项目训练结构介绍
Posted ZSYL
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了商品检测数据集训练项目训练结构介绍相关的知识,希望对你有一定的参考价值。
项目训练结构介绍
1. 项目目录结构
- ckpt:分为预训练与微调模型
- datasets:放训练原始数据以及存储数据、读取数据代码以及模型priorbox
- servingmodel:模型部署使用的模型位置
- export_serving_model:导出TFserving指定模型类型
- train_ssd:训练模型代码逻辑
2. train_ssd.py
"""商品检测数据集训练"""
import pickle
from computerVision.utils.detection_generate import Generator
from computerVision.utils.ssd_utils import BBoxUtility
from computerVision.nets.ssd_net import SSD300
from computerVision.utils.ssd_losses import MultiboxLoss
from tensorflow.python.keras.callbacks import ModelCheckpoint, TensorBoard
import tensorflow as tf
import keras
class SSDTrain(object):
def __init__(self, num_classes=9, input_shape=(300, 300, 3), epochs=30):
"""
初始化网络指定一些参数,训练数据类别,图片需要指定模型输入大小,迭代次数
"""
self.num_classes = num_classes
self.batch_size = 32
self.input_shape = input_shape
self.epochs = epochs
# 指定训练读取数据的相关参数
self.gt_path = './datasets/commodity_gt.pkl'
self.image_path = './datasets/commodity/JPEGImages/'
prior = pickle.load(open('./datasets/prior_boxes_ssd300.pkl', 'rb'))
self.bbox_util = BBoxUtility(self.num_classes, prior)
# 权重weights参数
self.pre_trained = './ckpt/pre_trained/weights_SSD300.hdf5'
# 初始化模型
self.model = SSD300(self.input_shape, num_classes=self.num_classes)
def get_detection_data(self):
"""
获取检测的迭代数据
:return:
"""
# 1. 读取标注数据,构造训练图片名字列表,测试图片名字列表
gt = pickle.load(open(self.gt_path, 'rb'))
# 图片名字列表
name_keys = keys = sorted(gt.keys())
number = int(0.8 * len(name_keys))
train_keys = name_keys[:number]
val_keys = name_keys[number:]
# 2. 通过generator去获取迭代批次数据
# gt: 所有数据的目标值字典
# path_prefix:图片路径
# Generator:生成图片的目标值等数据
gen = Generator(gt, self.bbox_util, self.batch_size, self.image_path,
train_keys, val_keys, (self.input_shape[0], self.input_shape[1]), do_crop=False)
return gen
def init_model_param(self):
"""
初始化网络模型参数,指定微调的时候,训练部分
:return:
"""
# 1. 加载本地预训练好的模型
self.model.load_weights(self.pre_trained, by_name=True)
# 2. 指定模型当中默写结构freeze
# 冻结模型部分为 SSD当中的VGG前半部分
freeze = ['input_1', 'conv1_1', 'conv1_2', 'pool1',
'conv2_1', 'conv2_2', 'pool2',
'conv3_1', 'conv3_2', 'conv3_3', 'pool3']
# 遍历每一层的结果
for L in self.model.layers:
if L.name in freeze:
L.trainable = False
return None
def compile(self):
"""
编译模型:
SSD网络的损失函数计算MultiboxLoss的compute_loss
:return:
"""
# MultiboxLoss: N个类别+1个背景类别
# tf.keras.optimizers.Adam() 出现问题,给4个,是需要3个
# keras 1.2.2 optimizers.Adam()
distribution = tf.contrib.distribute.MirroredStrategy()
self.model.compile(optimizer=keras.optimizers.Adam(),
loss=MultiboxLoss(self.num_classes).compute_loss,
distribution=distribution)
def fit_genrator(self, gen):
"""
进行训练
:return:
"""
# 建立回调函数
callback = [
ModelCheckpoint('./ckpt/fine_tuning/weights.{epoch:02d}-{val_acc:.2f}.hdf5',
monitor='val_acc',
save_weights_only=True,
save_best_only=True,
mode='auto',
period=1), # 保存模型的次数及损失变换情况
TensorBoard(log_dir='./graph') # 保存图的信息,整个图的变换等 , histogram_freq=1, write_graph=True, write_images=True
]
self.model.fit_generator(gen.generate(train=True), gen.train_batches, self.epochs, callbacks=callback, validation_data=gen.generate(train=False),
nb_val_samples=gen.val_batches)
if __name__ == '__main__':
ssd = SSDTrain(num_classes=9)
gen = ssd.get_detection_data()
ssd.init_model_param()
ssd.compile()
ssd.fit_genrator(gen)
以上是关于商品检测数据集训练项目训练结构介绍的主要内容,如果未能解决你的问题,请参考以下文章
深度学习目标检测---使用yolov5训练自己的数据集模型(Windows系统)