Pytorch加载模型并进行图像分类预测
Posted 无穷QQ君
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch加载模型并进行图像分类预测相关的知识,希望对你有一定的参考价值。
目录
1) How can i convert an RGB image into grayscale in Python?
4)Convert 3 channel image to 2 channel
7)torch.from_numpy VS torch.Tensor
8)torch.squeeze() 和torch.unsqueeze()
1)TypeError: 'module' object is not callable
2)TypeError: 'collections.OrderedDict' object is not callable
3) TypeError: __init__() missing 1 required positional argument: 'XX'
4) RuntimeError: Error(s) in loading state_dict for PythonNet: Missing key(s) in state_dict:
1. 整体流程
1)实例化模型
Assume that the content of YourClass.py is:
class YourClass:
# ......
If you use:
from YourClassParentDir import YourClass # means YourClass
from model import PythonNet
net= PythonNet(T=16).eval().cuda()
2)加载模型
import torch
net.load_state_dict(torch.load('checkpoint_max.pth'),False)
3)输入图像
目的:从文件夹中加载所有图像组合为一个 numpy array 作为模型输入
原始图像输入维度:(480,640)
目标图像输入维度:(16,1,128,128)
import glob
from PIL import Image
import numpy as np
from torch.autograd import Variable
#获取图像路径
filelist = glob.glob('./testdata/a/*.jpg')
#打开图像open('frame_path')--》转换为灰度图convert('L')--》缩放图像resize((width, height)) --》合并文件夹中的所有图像为一个numpy array
x = np.array([np.array(Image.open(frame).convert('L').resize((128,128))) for frame in filelist])
#用torch.from_numpy这个方法将numpy类转换成tensor类
x = torch.from_numpy(x).type(torch.FloatTensor).cuda()
#扩充数据维度
x = Variable(torch.unsqueeze(x,dim=1).float(),requires_grad=False)
4)输出分类结果
outputs = net(x)
_, predicted = torch.max(outputs,1)
torch.max()这个函数返回输入张量中所有元素的最大值。
返回一个命名元组(values,indices),其中values是给定维dim中输入张量的每一行的最大值。indices是找到的每个最大值的索引位置(argmax)。也就是说,返回的第一个值是对应图像在类别中的最大概率值,第二个值是最大概率值的对应类别。
Pytorch 分类问题输出结果的数据整理方式:_, predicted = torch.max(outputs.data, 1) - stardsd - 博客园
5)完整代码
from PIL import Image
from torch.autograd import Variable
import numpy as np
import torch
import glob
from model import PythonNet
##############处理输入图像#######################################
#获取图像路径
filelist = glob.glob('./testdata/a/*.jpg')
#打开图像open('frame_path')--》转换为灰度图convert('L')--》缩放图像resize((width, height)) --》合并文件夹中的所有图像为一个numpy array
x = np.array([np.array(Image.open(frame).convert('L').resize((128,128))) for frame in filelist])
#用torch.from_numpy这个方法将numpy类转换成tensor类
x = torch.from_numpy(x).type(torch.FloatTensor).cuda()
#扩充数据维度
x = Variable(torch.unsqueeze(x,dim=1).float(),requires_grad=False)
#############定义预测函数#######################################
def predict(x):
net= PythonNet(T=16).eval().cuda()
net.load_state_dict(torch.load('checkpoint_max.pth'),False)
outputs = net(x)
_, predicted = torch.max(outputs,1)
print("_:",_)
print("predicted:",predicted)
print("outputs:",outputs)
############输入图像进行预测#####################################
predict(x)
2. 处理图像
1) How can i convert an RGB image into grayscale in Python?
matplotlib - How can I convert an RGB image into grayscale in Python? - Stack Overflow
2)PIL 处理图像的基本操作
python——PIL Image处理图像_aaon22357的博客-CSDN博客
3)图像通道数的理解
关于图像通道的理解 | TinaCristal's Blog
4)Convert 3 channel image to 2 channel
5)图像通道转换
图像通道转换——从np.ndarray的[w, h, c]转为Tensor的[c, w, h]_莫邪莫急的博客-CSDN博客
6)将所有的图像合并为一个numpy数组
7)torch.from_numpy VS torch.Tensor
torch.from_numpy VS torch.Tensor_麦克斯韦恶魔的博客-CSDN博客
8)torch.squeeze() 和torch.unsqueeze()
pytorch学习 中 torch.squeeze() 和torch.unsqueeze()的用法_xiexu911的博客-CSDN博客_torch.unsqueeze
3.issue
1)TypeError: 'module' object is not callable
python - TypeError: 'module' object is not callable - Stack Overflow
2)TypeError: 'collections.OrderedDict' object is not callable
pytorch加载模型报错TypeError: ‘collections.OrderedDict‘ object is not callable_xiaoqiaoliushuiCC的博客-CSDN博客
3) TypeError: __init__() missing 1 required positional argument: 'XX'
Python成功解决TypeError: __init__() missing 1 required positional argument: ‘comment‘_肥鼠路易的博客-CSDN博客
4) RuntimeError: Error(s) in loading state_dict for PythonNet: Missing key(s) in state_dict:
5) RuntimeError: Expected 4-dimensional input for 4-dimensional weight [128, 1, 3, 3], but got 2-dimensional input of size [480, 640] instead
6)RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
以上是关于Pytorch加载模型并进行图像分类预测的主要内容,如果未能解决你的问题,请参考以下文章
Pytorch CIFAR10图像分类 GoogLeNet篇