Pytorch加载模型并进行图像分类预测

Posted 无穷QQ君

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch加载模型并进行图像分类预测相关的知识,希望对你有一定的参考价值。

目录

1. 整体流程

1)实例化模型

2)加载模型

3)输入图像

4)输出分类结果

5)完整代码

2. 处理图像

 1) How can i convert an RGB image into grayscale in Python?

2)PIL 处理图像的基本操作

3)图像通道数的理解

4)Convert 3 channel image to  2 channel

5)图像通道转换

6)将所有的图像合并为一个numpy数组

7)torch.from_numpy VS torch.Tensor

8)torch.squeeze() 和torch.unsqueeze()

3.issue

1)TypeError: 'module' object is not callable

2)TypeError: 'collections.OrderedDict' object is not callable

3)   TypeError: __init__() missing 1 required positional argument: 'XX'

4)    RuntimeError: Error(s) in loading state_dict for PythonNet: Missing key(s) in state_dict:

5)    RuntimeError: Expected 4-dimensional input for 4-dimensional weight [128, 1, 3, 3], but got 2-dimensional input of size [480, 640] instead

6)RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor



1. 整体流程

1)实例化模型

Assume that the content of YourClass.py is:

class YourClass:
    # ......

If you use:

from YourClassParentDir import YourClass  # means YourClass
from model import PythonNet

net= PythonNet(T=16).eval().cuda()

2)加载模型

import torch

net.load_state_dict(torch.load('checkpoint_max.pth'),False)

3)输入图像

目的:从文件夹中加载所有图像组合为一个 numpy array 作为模型输入

原始图像输入维度:(480,640)

目标图像输入维度:(16,1,128,128)

import glob
from PIL import Image
import numpy as np
from torch.autograd import Variable

#获取图像路径
filelist = glob.glob('./testdata/a/*.jpg')

#打开图像open('frame_path')--》转换为灰度图convert('L')--》缩放图像resize((width, height)) --》合并文件夹中的所有图像为一个numpy array
x = np.array([np.array(Image.open(frame).convert('L').resize((128,128))) for frame in filelist])

#用torch.from_numpy这个方法将numpy类转换成tensor类
x = torch.from_numpy(x).type(torch.FloatTensor).cuda()

#扩充数据维度
x = Variable(torch.unsqueeze(x,dim=1).float(),requires_grad=False)

4)输出分类结果

outputs = net(x)
_, predicted = torch.max(outputs,1)

torch.max()这个函数返回输入张量中所有元素的最大值。

返回一个命名元组(values,indices),其中values是给定维dim中输入张量的每一行的最大值。indices是找到的每个最大值的索引位置(argmax)。也就是说,返回的第一个值是对应图像在类别中的最大概率值,第二个值是最大概率值的对应类别。

Pytorch 分类问题输出结果的数据整理方式:_, predicted = torch.max(outputs.data, 1) - stardsd - 博客园

5)完整代码

from PIL import Image
from torch.autograd import Variable
import numpy as np
import torch
import glob
from model import PythonNet

##############处理输入图像#######################################
#获取图像路径
filelist = glob.glob('./testdata/a/*.jpg')

#打开图像open('frame_path')--》转换为灰度图convert('L')--》缩放图像resize((width, height)) --》合并文件夹中的所有图像为一个numpy array
x = np.array([np.array(Image.open(frame).convert('L').resize((128,128))) for frame in filelist])

#用torch.from_numpy这个方法将numpy类转换成tensor类
x = torch.from_numpy(x).type(torch.FloatTensor).cuda()

#扩充数据维度
x = Variable(torch.unsqueeze(x,dim=1).float(),requires_grad=False)

#############定义预测函数#######################################
def predict(x):
    net= PythonNet(T=16).eval().cuda()
    net.load_state_dict(torch.load('checkpoint_max.pth'),False)

    outputs = net(x)
    _, predicted = torch.max(outputs,1)
    print("_:",_)
    print("predicted:",predicted)
    print("outputs:",outputs)

############输入图像进行预测#####################################
predict(x)

2. 处理图像

 1) How can i convert an RGB image into grayscale in Python?

matplotlib - How can I convert an RGB image into grayscale in Python? - Stack Overflow

2)PIL 处理图像的基本操作

python——PIL Image处理图像_aaon22357的博客-CSDN博客

3)图像通道数的理解

关于图像通道的理解 | TinaCristal's Blog

4)Convert 3 channel image to  2 channel

python - I have converted 3 channel RGB image into 2 channels grayscale image, How to decrease greyscale channels to 1? - Stack Overflow

5)图像通道转换

图像通道转换——从np.ndarray的[w, h, c]转为Tensor的[c, w, h]_莫邪莫急的博客-CSDN博客

6)将所有的图像合并为一个numpy数组

python — 如何在numpy数组中加载多个图像?

7)torch.from_numpy VS torch.Tensor

torch.from_numpy VS torch.Tensor_麦克斯韦恶魔的博客-CSDN博客

8)torch.squeeze() 和torch.unsqueeze()

pytorch学习 中 torch.squeeze() 和torch.unsqueeze()的用法_xiexu911的博客-CSDN博客_torch.unsqueeze

3.issue

1)TypeError: 'module' object is not callable

python - TypeError: 'module' object is not callable - Stack Overflow

2)TypeError: 'collections.OrderedDict' object is not callable

pytorch加载模型报错TypeError: ‘collections.OrderedDict‘ object is not callable_xiaoqiaoliushuiCC的博客-CSDN博客

3)   TypeError: __init__() missing 1 required positional argument: 'XX'

Python成功解决TypeError: __init__() missing 1 required positional argument: ‘comment‘_肥鼠路易的博客-CSDN博客

4)    RuntimeError: Error(s) in loading state_dict for PythonNet: Missing key(s) in state_dict:

pytorch加载模型报错RuntimeError: Error(s) in loading state_dict for SSD:Missing key(s) in state_dict:... - 代码先锋网

5)    RuntimeError: Expected 4-dimensional input for 4-dimensional weight [128, 1, 3, 3], but got 2-dimensional input of size [480, 640] instead

RuntimeError: Expected 4-dimensional input for 4-dimensional weight 64 3 3, but got 3-dimensional in_Steven_ycs的博客-CSDN博客

6)RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.DoubleTensor) should be the_一千零一夜的博客-CSDN博客


​​​​​​

以上是关于Pytorch加载模型并进行图像分类预测的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch CIFAR10图像分类 GoogLeNet篇

Pytorch CIFAR10图像分类 AlexNet篇

Pytorch CIFAR10图像分类 VGG篇

Pytorch CIFAR10图像分类 MobileNet v1篇

Pytorch CIFAR10图像分类 自定义网络篇

Pytorch CIFAR10图像分类 LeNet5篇