无法导入图像来测试Scikit学习应用。

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了无法导入图像来测试Scikit学习应用。相关的知识,希望对你有一定的参考价值。

我是一个使用Scikit learn的新手程序员,所以我的问题是一个基本问题。我用草图数据集创建了第一个机器学习代码程序,用于识别苹果和香蕉草图之间的对象,它在训练和测试方面工作得很好。

import cv2 as cv
import numpy as np
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split as tts
from sklearn.metrics import accuracy_score

#loading datasets
apples_Full = np.load('dataset/apple.npy')
bananas_Full = np.load('dataset/banana.npy')

N_Samples = 1000
test_Number = 0.2
APPLE = 0
BANANA = 1

def normalize(data):
    return np.interp(data , [0 , 255] , [-1 , 1])

apples = apples_Full[:N_Samples]
bananas = bananas_Full[:N_Samples]
dataset = np.concatenate((apples , bananas))
dataset = normalize(dataset)
labels = [APPLE] * N_Samples + [BANANA] * N_Samples
#spliting data
x_train , x_test , y_train , y_test = tts(dataset , labels ,test_size = test_Number)

alg = SVC()
alg.fit(x_train , y_train)
preds = alg.predict(x_test)
Result = accuracy_score(y_test , preds)
print(Result)

现在我想输入一张草图图像,以便将其作为对象识别应用程序。我试着导入一张图像,并将其转换为一个.npy文件,并将其作为数据集使用,就像测试步骤一样,但我得到了错误:X.shape[1] = 151875,应该等于784,训练时的特征数量。

testfile = "My_test.jpg"
Image = cv.imread(testfile)
TEST = np.array(Image , dtype = 'uint8')
np.save('My_test' + '.npy' , TEST)
Sketch = np.load('My_test.npy')
Sketch = np.reshape(Sketch, (1 , -1))
Testdata = normalizer(Sketch)
finaltest = alg.predict(Testdata)
print(finaltest)

我应该怎么做?

答案

通过误差,我认为你在推理过程中输入的形状与你在模型训练过程中使用的图像形状不同。您使用的图像大小为 28*28 在训练过程中,通过在推理过程中输入不同尺寸的图像,你只需要 resize 您的测试图像,如下图所示。

testfile = "My_test.jpg"
Image = cv.imread(testfile)
Image = cv.resize(Image,(28,28))  # will convert your image to 28*28
TEST = np.array(Image , dtype = 'uint8')

另外,如果你想把数据标准化,你可以把整个图像除以: 255 而不是使用 np.interp,类似这样的东西。

Image = Image/255.0

希望能帮到你!

以上是关于无法导入图像来测试Scikit学习应用。的主要内容,如果未能解决你的问题,请参考以下文章

python机器学习工具包scikit-learn

scikit-learn.impute 没有使用机器学习 A-Z 教程中的代码通过 Spyder 从 Imputer 导入

使用 scikit learn 在字典学习中出现内存错误

Scikit-Learn 朴素贝叶斯分类丨数析学院

如何将 .mat 文件导入 Jupyter notebook 以在 scikit-learn 中使用它们进行机器学习? [复制]

如何在 Scikit 中应用二元分类器来学习何时属性是字符串(不是 int 或 float)