鸢尾花等表格数据简单分类器(模型可以替换)
Posted yzpopulation
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了鸢尾花等表格数据简单分类器(模型可以替换)相关的知识,希望对你有一定的参考价值。
Keras 2.2.4
Keras-Applications 1.0.6
Keras-Preprocessing 1.0.5
tensorflow 1.11.0
numpy 1.15.2
pandas 0.23.4
scikit-learn 0.20.0
测试成功
1 # -*- coding: utf-8 -*- 2 import numpy 3 import pandas 4 from keras.layers.core import Dense, Dropout, Activation 5 from keras.models import Sequential 6 from keras.utils import np_utils 7 from keras.utils import plot_model 8 from sklearn import utils 9 from sklearn.model_selection import StratifiedShuffleSplit 10 from sklearn.preprocessing import LabelEncoder 11 12 13 def load_data(): 14 ‘‘‘ 15 获取数据 16 :return x_train, y_train, x_test, y_test, encoder: 17 ‘‘‘ 18 # 载入数据 19 data_frame = pandas.read_csv("iris.csv", header=None) 20 data_set = data_frame.values 21 # 取所有行,从第0列到第4列(不包含第4列) 22 x_data = data_set[:, 0:4].astype(float) 23 # 取所有行,第4列 24 y_data = data_set[:, 4] 25 # 标签编码 26 encoder = LabelEncoder() 27 # 将字符串编译成0,1,2,3分类 28 # encoder.classes_以npy可以保存加载编码规则(np.save(‘encoder.npy‘,encoder.classes_),encoder.classes_=np.load(‘encoder.npy‘)) 29 encoded_transform_y = encoder.fit_transform(y_data) 30 # 编译好的0,1,2,3 One_Hot 31 y_data = np_utils.to_categorical(encoded_transform_y) 32 # 打乱数据集 33 x_data, y_data = utils.shuffle(x_data, y_data) 34 # 切分数据集 35 train_idx, test_idx = next(iter( 36 StratifiedShuffleSplit(n_splits=1, test_size=0.2, 37 random_state=0).split(x_data, y_data))) 38 x_train = x_data[train_idx] 39 y_train = y_data[train_idx] 40 x_test = x_data[test_idx] 41 y_test = y_data[test_idx] 42 return x_train, y_train, x_test, y_test, encoder 43 44 45 def compile_model(): 46 # 模型 47 _model = Sequential() 48 _model.add(Dense(10, input_shape=(4,))) 49 _model.add(Activation(‘tanh‘)) 50 _model.add(Dropout(0.2)) 51 _model.add(Dense(3)) 52 _model.add(Activation(‘softmax‘)) 53 _model.compile( 54 loss="categorical_crossentropy", 55 optimizer=‘adam‘, 56 metrics=[‘accuracy‘]) 57 # 生成模型图片 58 plot_model(_model, to_file=‘model.png‘, show_shapes=‘True‘) 59 return _model 60 61 62 def train_model(_model, _x_train, _y_train, _x_test, _y_test): 63 # 训练 64 history = _model.fit(_x_train, _y_train, epochs=100, batch_size=12, 65 verbose=1, validation_data=[_x_test, _y_test]) 66 # 测试训练集 67 score = _model.evaluate(_x_test, _y_test, verbose=1) 68 print(‘Test score:‘, score[0]) 69 print(‘Test accuracy:‘, score[1]) 70 71 72 def test(_model, _encoder, _x_test): 73 # 校验,返回标签 74 result = _model.predict(_x_test) 75 result = numpy.argmax(result, axis=1) 76 result = _encoder.inverse_transform(result) 77 print(result) 78 79 80 if __name__ == ‘__main__‘: 81 x_train, y_train, x_test, y_test, encoder = load_data() 82 model = compile_model() 83 train_model(model, x_train, y_train, x_test, y_test) 84 test(model, encoder, x_test)
以上是关于鸢尾花等表格数据简单分类器(模型可以替换)的主要内容,如果未能解决你的问题,请参考以下文章
Matlab基于人工神经网络ANN实现多分类预测(Excel可直接替换数据)