sklearn 使用SGDClassifier 使用 kerasMNIST数据集 进行图片二分类

Posted wanluN1

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了sklearn 使用SGDClassifier 使用 kerasMNIST数据集 进行图片二分类相关的知识,希望对你有一定的参考价值。

依赖:

tensorflow、matplotlib、numpy、sklearn、pickle

代码

from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import SGDClassifier
from sklearn.preprocessing import StandardScaler
import pickle

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
print(len(x_train))#60000个数据集
print(len(y_test))#10000个测试集
#print(x_train[0]) #二维List
some_digit = x_train[0]
some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap="binary")
plt.axis("off")
plt.show()
y_train = y_train.astype(np.uint8)
y_test = y_test.astype(np.uint8)
print("x_train is ",y_train[0])

#转变数据集形式
x_train_transed=[]
for index in range(len(x_train)):
    x_train_transed.append(x_train[index].reshape(-1))
    print("疯狂转变中 ",(index/len(x_train))*100,"%")
print("转变完了")
sgd_clf = SGDClassifier(max_iter=1000, tol=1e-3, random_state=42)

#将数字标签转换为bool型标签,List内item的转换
y_train_5 = (y_train == 5)
print("训练中")
sgd_clf.fit(x_train_transed,y_train_5)
print("训练完了")
print("正在保存模型")
with open('./5_image_test.model', 'wb') as fw:
    pickle.dump(sgd_clf, fw)
print("正在加载模型")
with open('./5_image_test.model','rb') as fr:
    test_5_model=pickle.load(fr)
    print("使用模型中")
    for i in range(10):
        some_digit_image = x_train[i].reshape(28, 28)
        plt.imshow(some_digit_image, cmap="binary")
        print("image smaple ",i,"predict result ,is 5 :",\\
            test_5_model.predict([x_train_transed[i]]))
        plt.axis("off")
        plt.show()

效果


以上是关于sklearn 使用SGDClassifier 使用 kerasMNIST数据集 进行图片二分类的主要内容,如果未能解决你的问题,请参考以下文章

sklearn SGDClassifier,当没有匹配时产生标签?

sklearn SGDClassifier fit() 与 partial_fit()

sklearn中的SGDClassifier

sklearn SGDClassifier 模型阈值与模型分数有何关系?

Scikit-learn——LogisticRegression与SGDClassifier

向 Sklearn 分类器添加功能