tensflow2 基础
Posted wt7018
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensflow2 基础相关的知识,希望对你有一定的参考价值。
https://blog.csdn.net/lzs781/article/details/104742043/
官网
https://tensorflow.google.cn/tutorials/images/classification
一、生成模型 , 为了增加训练的精确率,可以使 epochs 值变大
import tensorflow as tf from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D import os import matplotlib.pyplot as plt # # 1. 训练路径 PATH = r‘C:UserswuhaoDesktopcats_and_dogs_filteredcats_and_dogs_filtered‘ train_dir = os.path.join(PATH, ‘train‘) train_cats_dir = os.path.join(train_dir, ‘cats‘) train_dogs_dir = os.path.join(train_dir, ‘dogs‘) batch_size = 128 epochs = 5 IMG_HEIGHT = 150 IMG_WIDTH = 150 # 2.转化为生成器 train_image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255) train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size, directory=train_dir, shuffle=True, target_size=(IMG_HEIGHT, IMG_WIDTH), class_mode=‘binary‘) sample_training_images, _ = next(train_data_gen) # 3.展示图片(可有可无) def plot_images(images_arr): fig, axes = plt.subplots(1, 5, figsize=(20, 20)) axes = axes.flatten() for img, ax in zip(images_arr, axes): ax.imshow(img) ax.axis(‘off‘) plt.tight_layout() plt.show() # 显示 5张 图片 plot_images(sample_training_images[:5]) # 4. 创建模型 model = Sequential([ Conv2D(16, 3, padding=‘same‘, activation=‘relu‘, input_shape=(IMG_HEIGHT, IMG_WIDTH ,3)), MaxPooling2D(), Conv2D(32, 3, padding=‘same‘, activation=‘relu‘), MaxPooling2D(), Conv2D(64, 3, padding=‘same‘, activation=‘relu‘), MaxPooling2D(), Flatten(), Dense(512, activation=‘relu‘), Dense(1) ]) # 5. 编译模型 model.compile( optimizer=‘adam‘, loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), metrics=[‘accuracy‘] ) model.summary() # 6.训练模型 num_cats_tr = len(os.listdir(train_cats_dir)) num_dogs_tr = len(os.listdir(train_dogs_dir)) total_train = num_cats_tr + num_dogs_tr history = model.fit_generator( train_data_gen, steps_per_epoch=total_train // batch_size, epochs=epochs, ) # 7.训练结果可视化 acc = history.history[‘accuracy‘] loss = history.history[‘loss‘] epochs_range = range(epochs) plt.figure(figsize=(8, 8)) plt.subplot(1, 2, 1) plt.plot(epochs_range, acc, label=‘Training Accuracy‘) plt.legend(loc=‘lower right‘) plt.title(‘Training and Validation Accuracy‘) plt.subplot(1, 2, 2) plt.plot(epochs_range, loss, label=‘Training Loss‘) plt.legend(loc=‘upper right‘) plt.title(‘Training and Validation Loss‘) plt.show() # 9. 保存训练模型 model.save(‘path_to_my_model.h5‘)
2、加载模型
import tensorflow as tf import os # batch_size = 128 epochs = 5 IMG_HEIGHT = 150 IMG_WIDTH = 150 PATH = r‘C:UserswuhaoDesktopcats_and_dogs_filteredcats_and_dogs_filtered‘ validation_dir = os.path.join(PATH, ‘validation‘) # 1.加载模型 new_model = tf.keras.models.load_model(‘path_to_my_model.h5‘) new_model.summary() # 2.获取验证的生成器 validation_image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255) val_data_gen = validation_image_generator.flow_from_directory(batch_size=batch_size, directory=validation_dir, target_size=(IMG_HEIGHT, IMG_WIDTH), class_mode=‘binary‘) # 3.获取模型的精确率和 损失率 res = new_model.evaluate(val_data_gen) print(res)
以上是关于tensflow2 基础的主要内容,如果未能解决你的问题,请参考以下文章
[Go] 通过 17 个简短代码片段,切底弄懂 channel 基础