PyTorch 99%程序员都不知道, 深度学习还能这样玩 (建议收藏)

Posted 我是小白呀

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch 99%程序员都不知道, 深度学习还能这样玩 (建议收藏)相关的知识,希望对你有一定的参考价值。

概述

你还在为训练无从下手而苦恼么?
你还在为模型训练时间漫长而痛苦么?
你还在为模型准确率提升困难在深夜一个人啜泣么?

在这里插入图片描述
今天教大家一个方法, 使得我们的模型起跑线上直接甩开别人几条街. 隔壁王叔叔都学会了!

迁移学习

迁移学习 (Transfer Learning) 是把已学训练好的模型参数用作新训练模型的起始参数.

入住 GitHub

经过几天的日夜狂肝, 本人完成了在 GitHub 上的第一个项目. 把迁移学习封装成了一个有手就能用的黑盒模型.

在这里插入图片描述
大家只要替换自己的数据集就可以实现多个可选模型迁移学习并自动保存. 就是两个字简单

项目详解

GitHub 链接
在这里插入图片描述

get_data.py (获取数据)

目前支持 MNIST, Fashion MNIST, CIFAR 10 和 CIFAR 100 数据集.

可以在```get_data.py``下自行替换成自己需要的数据集:
在这里插入图片描述

传入数据的格式为:

data_loader = {"train": train_loader, "valid": test_loader}

get_model (获取模型)

目前支持:

  • resnet18
  • resnet34
  • resnet50
  • resnet101
  • resnet152
  • alexnet
  • squeezenet
  • vgg11
  • vgg13
  • vgg16
  • vgg19

替换模型的方法:

python main.py --model_name "模型名称"

例如, 使用 vgg 13:

python main.py --model_name vgg13

例如, 使用 resnet 152:

python main.py --model_name resnet152

参数详解

必填参数:

  • model_name: 模型名称, 类型为 string
  • num_classes: 输出类别数, 类型为 int (例如 MNIST 是 10 分类, CIFAR 100 是 100 分类)

重要参数:

  • data_name: 数据名称, 类型为 string, 默认为 CIFAR10
  • data_gray: 是否为灰度图, 类型为 boolean, 默认为 False
  • num_epochs: 迭代次数, 类型为 int, 默认为 20
  • batch_size: 一个批次的样本数目, 默认为 512

可选参数 (不建议修改):

  • feature_exact: 是否冻层, 类型为 boolean, 默认为 False
  • use_pretrained: 是否使用预训练权重, 类型为 boolean, 默认为 True
  • pretrained_model_path: 预训练权重, 类型为 string, 默认为 pretrained_model/
  • model_save_path: 模型保存路径, 类型为 string, 默认为 “checkpoint/”
  • visualize: 模型可视化, 类型为 boolean, 默认为 True

使用说明

首先我们需要cd到文件路径, 例如:

cd C:\\Users\\Windows\\Desktop\\Project\\transfer_learning-main

训练 MNIST

使用 resnet18 训练 MNIST 数据集:

python main.py --data_name MNIST --data_gray True --model_name resnet18 --num_classes 10 --batch_size 512

训练 Fashion MNIST

使用 resnet34 训练 Fashion MNIST 数据集:

python main.py --data_name FashionMNIST --data_gray True --model_name resnet34 --num_classes 10 --batch_size 512

训练 CIFAR 10

使用 resnet50 训练 CIFAR 10 数据集:

python main.py --data_name CIFAR10 --model_name resnet50 --num_classes 10 --batch_size 512

训练 CIFAR 100

使用 resnet152 训练 CIFAR 10 数据集:

python main.py --data_name CIFAR100 --model_name resnet152 --num_classes 100 --batch_size 512

训练自己的数据

python main.py --data_name other --model_name ? --num_classes ? --batch_size ? --epochs ?

以上是关于PyTorch 99%程序员都不知道, 深度学习还能这样玩 (建议收藏)的主要内容,如果未能解决你的问题,请参考以下文章

有必要TensorFlow和pytorch都配重吗?

一个简单而强大的深度学习库—PyTorch

深度学习框架PyTorch为何值得学

Java程序员学深度学习 DJL上手7 使用Pytorch引擎

PyTorch深度学习60分钟快速入门 Part1:PyTorch是什么?

每月好书深度学习框架PyTorch入门与实践