Sklearn load digit ValueError: Found array with dim 3. Estimator expected <= 2

Posted

技术标签:

【中文标题】Sklearn load digit ValueError: Found array with dim 3. Estimator expected <= 2【英文标题】: 【发布时间】:2020-09-07 12:52:12 【问题描述】:

我刚开始学习 sklearn 模块,遇到了我的第一个障碍。首先,我创建了一个数字识别逻辑回归模型,它似乎工作正常。但后来我决定自己测试一张随机图片,所以我使用 OpenCV 模块打开了一张我从网上随机选择的图片(如下图所示)。图像的形状为 (425,425,3)。我把它们变成了灰度和阈值。然后我尝试使用我刚刚创建的模型来预测它。但我得到了一个“ValueError:找到暗淡 3 的数组。估计器预期

import cv2
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn import metrics
from sklearn.linear_model import LogisticRegression 

test_case = cv2.imread('unnamed.jpg')
test_case = cv2.cvtColor(test_case, cv2.COLOR_BGR2GRAY)
ret, test_case = cv2.threshold(test_case, 200, 255, cv2.THRESH_BINARY)



digits = load_digits()
x_train, x_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size = 0.23, random_state = 2)

#test_case = test_case.reshape(test_case.shape[0], -1)

train = LogisticRegression(dual = False, max_iter = 1000000)
train.fit(x_train, y_train)
nbr = train.predict(np.array([test_case], 'float64'))

【问题讨论】:

【参考方案1】:

首先,您根据 sklearn 中的数据训练模型。每个数字都有 = (64) 的形式,这意味着它们是连续展平的。

其次,你必须将你的图片尺寸改为 8x8

例如:

test_case = cv2.imread('ELXkj.jpg')
test_case = cv2.cvtColor(test_case, cv2.COLOR_BGR2GRAY)
test_case = cv2.resize(test_case, (8,8))
ret, test_case = cv2.threshold(test_case, 200, 255, cv2.THRESH_BINARY)


digits = load_digits()
x_train, x_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size = 0.23, random_state = 2)

train = LogisticRegression(dual = False, max_iter = 1000000)
train.fit(x_train, y_train)
nbr = train.predict(np.array([test_case.flatten()], 'float64'))

【讨论】:

以上是关于Sklearn load digit ValueError: Found array with dim 3. Estimator expected <= 2的主要内容,如果未能解决你的问题,请参考以下文章

Sklearn 数字数据集

利用sklearn获取手写数字数据集,并进行可视化

数据集中的目标,图像 [0] sklearn

留出法K折交叉验证留一法进行数据集划分

机器学习-kNN-寻找最好的超参数

如何使用 sklearn.datasets.load_files 加载数据百分比