tensflow2 基础

Posted wt7018

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensflow2 基础相关的知识,希望对你有一定的参考价值。

https://blog.csdn.net/lzs781/article/details/104742043/

 

官网

https://tensorflow.google.cn/tutorials/images/classification

一、生成模型 , 为了增加训练的精确率,可以使 epochs 值变大

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
import os
import matplotlib.pyplot as plt


#


# 1. 训练路径
PATH = rC:UserswuhaoDesktopcats_and_dogs_filteredcats_and_dogs_filtered
train_dir = os.path.join(PATH, train)
train_cats_dir = os.path.join(train_dir, cats)
train_dogs_dir = os.path.join(train_dir, dogs)


batch_size = 128
epochs = 5
IMG_HEIGHT = 150
IMG_WIDTH = 150

# 2.转化为生成器
train_image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
                                                           directory=train_dir,
                                                           shuffle=True,
                                                           target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                           class_mode=binary)


sample_training_images, _ = next(train_data_gen)


# 3.展示图片(可有可无)
def plot_images(images_arr):
    fig, axes = plt.subplots(1, 5, figsize=(20, 20))
    axes = axes.flatten()
    for img, ax in zip(images_arr, axes):
        ax.imshow(img)
        ax.axis(off)
    plt.tight_layout()
    plt.show()


# 显示 5张 图片
plot_images(sample_training_images[:5])

# 4. 创建模型

model = Sequential([
    Conv2D(16, 3, padding=same, activation=relu, input_shape=(IMG_HEIGHT, IMG_WIDTH ,3)),
    MaxPooling2D(),
    Conv2D(32, 3, padding=same, activation=relu),
    MaxPooling2D(),
    Conv2D(64, 3, padding=same, activation=relu),
    MaxPooling2D(),
    Flatten(),
    Dense(512, activation=relu),
    Dense(1)
])


# 5. 编译模型
model.compile(
    optimizer=adam,
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[accuracy]
)

model.summary()


# 6.训练模型
num_cats_tr = len(os.listdir(train_cats_dir))
num_dogs_tr = len(os.listdir(train_dogs_dir))
total_train = num_cats_tr + num_dogs_tr
history = model.fit_generator(
    train_data_gen,
    steps_per_epoch=total_train // batch_size,
    epochs=epochs,
)
# 7.训练结果可视化
acc = history.history[accuracy]
loss = history.history[loss]
epochs_range = range(epochs)
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label=Training Accuracy)
plt.legend(loc=lower right)
plt.title(Training and Validation Accuracy)
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label=Training Loss)
plt.legend(loc=upper right)
plt.title(Training and Validation Loss)
plt.show()

# 9. 保存训练模型
model.save(path_to_my_model.h5)

2、加载模型

import tensorflow as tf
import os

#

batch_size = 128
epochs = 5
IMG_HEIGHT = 150
IMG_WIDTH = 150
PATH = rC:UserswuhaoDesktopcats_and_dogs_filteredcats_and_dogs_filtered
validation_dir = os.path.join(PATH, validation)
# 1.加载模型
new_model = tf.keras.models.load_model(path_to_my_model.h5)

new_model.summary()

# 2.获取验证的生成器
validation_image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
val_data_gen = validation_image_generator.flow_from_directory(batch_size=batch_size,
                                                              directory=validation_dir,
                                                              target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                              class_mode=binary)
# 3.获取模型的精确率和 损失率
res = new_model.evaluate(val_data_gen)
print(res)

 

以上是关于tensflow2 基础的主要内容,如果未能解决你的问题,请参考以下文章

[Go] 通过 17 个简短代码片段,切底弄懂 channel 基础

201555332盛照宗—网络对抗实验1—逆向与bof基础

VsCode 代码片段-提升研发效率

20155201 李卓雯 《网络对抗技术》实验一 逆向及Bof基础

JSP基础

动态SQL基础概念复习(Javaweb作业5)