inference.py篇
Posted 樱木之
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了inference.py篇相关的知识,希望对你有一定的参考价值。
inference.py 篇
目录:
- 前言
- 思考自己需要载入的超参
- 书写代码
- 函数手册
前言
在该模块中加载训练好的模型,对测试集的image进行推理。
思考自己需要载入的超参
该模块的书写,是train的简约版,例如你可能需要设置和train相同的batch_size
,device
,dataloader
等信息,但是这次你不需要设置epoch等信息,对模型的参数进行优化等。
书写代码
书写顺序如下:
写argparse()
方法收集需要传递的所有参数,传入main函数中(可选)。
main函数中思路如下:
- 写路径等信息
- 书写dataloder。设置
transforms
,dataset
,dataloader
,batch_size
等参数,因为dataloader中要用到。 - 设置其余超参,如
device
等,这次你必须要加载train中产生的预训练权重。 - 对测试集进行推理
下以AlexNet中的inference.py为例:
# add path
import os, sys
root_path = os.path.dirname(os.path.dirname(__file__))
project_path = os.path.dirname(__file__)
sys.path.append(project_path)
# add module
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json
import torch
import numpy as np
from model import AlexNet
def parse_args():
"""get your args"""
def convert_image(image_path:str = ""):
"""transform png to jpg"""
def main():
# 路径
root_path = os.path.dirname(os.path.dirname(__file__))
project_path = os.path.dirname(__file__)
weight_path = os.path.join(root_path, "weight", "AlexNet_2.pth")
image_path = "/home/yingmuzhi/AlexNet/daisy.jpg"
# 加载预测图片
img = None
data_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
img = Image.open(image_path)
print(np.array(img).shape)
img = data_transform(img) # 只接受[height, width, channel=3]的图片, 即RGB的jpg
img = torch.unsqueeze(img, dim = 0) # 传入网络需要[batch, channel, height, width]
# 加载json文件
try:
json_file = open(project_path + "/class_indices.json","r")
class_indict = json.load(json_file)
except Exception as e:
print(e)
exit(-1)
# 测试参数
net = AlexNet(num_classes=2)
net.load_state_dict(torch.load(weight_path))
net.eval() # 关闭dropout层并且不会梯度回传
with torch.no_grad():
# predict class
output = net(img)
# print(output.shape)
output = torch.squeeze(output)
# print(output.shape)
predict = torch.softmax(output, dim = 0)
# print(predict.shape)
predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())
if __name__ == "__main__":
args = parse_args()
main(args)
函数手册
以上是关于inference.py篇的主要内容,如果未能解决你的问题,请参考以下文章