Tensorflow函数式API的使用
Posted Geek Song
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow函数式API的使用相关的知识,希望对你有一定的参考价值。
在我们使用tensorflow时,如果不能使用函数式api进行编程,那么一些复杂的神经网络结构就不会实现出来,只能使用简单的单向模型进行一层一层地堆叠。如果稍微复杂一点,遇到了Resnet这种带有残差模块的神经网络,那么用简单的神经网络堆叠的方式则不可能把这种网络堆叠出来。下面我们来使用函数式API来编写一个简单的全连接神经网络:
首先导包:
from tensorflow import keras import tensorflow as tf import pandas as pd import numpy as np import matplotlib.pyplot as plt
导入图片数据集:mnist
(train_image,train_label),(test_image,test_label)=tf.keras.datasets.fashion_mnist.load_data()
归一化:
train_image=train_image/255 test_image=test_image/255#进行数据的归一化,加快计算的进程
搭建全连接神经网络:
input=keras.Input(shape=(28,28)) x=keras.layers.Flatten()(input)#调用input x=keras.layers.Dense(32,activation="relu")(x) x=keras.layers.Dropout(0.5)(x)#一层一层的进行调用上一层的结果 output=keras.layers.Dense(10,activation="softmax")(x) model=keras.Model(inputs=input,outputs=output) model.summary()
输出:
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 28, 28)] 0 _________________________________________________________________ flatten (Flatten) (None, 784) 0 _________________________________________________________________ dense (Dense) (None, 32) 25120 _________________________________________________________________ dropout (Dropout) (None, 32) 0 _________________________________________________________________ dense_1 (Dense) (None, 10) 330 ================================================================= Total params: 25,450 Trainable params: 25,450 Non-trainable params: 0
拟合模型:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss="sparse_categorical_crossentropy", metrics=[‘acc‘] ) history=model.fit(train_image, train_label, epochs=15, validation_data=(test_image,test_label))
输出:
Train on 60000 samples, validate on 10000 samples Epoch 1/15 60000/60000 [==============================] - 4s 64us/sample - loss: 0.8931 - acc: 0.6737 - val_loss: 0.5185 - val_acc: 0.8160 Epoch 2/15 60000/60000 [==============================] - 3s 57us/sample - loss: 0.6757 - acc: 0.7508 - val_loss: 0.4805 - val_acc: 0.8230 Epoch 3/15 60000/60000 [==============================] - 3s 50us/sample - loss: 0.6336 - acc: 0.7647 - val_loss: 0.4587 - val_acc: 0.8369 Epoch 4/15 60000/60000 [==============================] - 3s 49us/sample - loss: 0.6174 - acc: 0.7689 - val_loss: 0.4712 - val_acc: 0.8294 Epoch 5/15 60000/60000 [==============================] - 3s 48us/sample - loss: 0.6080 - acc: 0.7732 - val_loss: 0.4511 - val_acc: 0.8404 Epoch 6/15 60000/60000 [==============================] - 3s 48us/sample - loss: 0.5932 - acc: 0.7773 - val_loss: 0.4545 - val_acc: 0.8407 Epoch 7/15 60000/60000 [==============================] - 3s 47us/sample - loss: 0.5886 - acc: 0.7772 - val_loss: 0.4394 - val_acc: 0.8428 Epoch 8/15 60000/60000 [==============================] - 3s 52us/sample - loss: 0.5820 - acc: 0.7788 - val_loss: 0.4338 - val_acc: 0.8506 Epoch 9/15 60000/60000 [==============================] - 3s 48us/sample - loss: 0.5742 - acc: 0.7839 - val_loss: 0.4393 - val_acc: 0.8454 Epoch 10/15 60000/60000 [==============================] - 3s 49us/sample - loss: 0.5713 - acc: 0.7847 - val_loss: 0.4422 - val_acc: 0.8477 Epoch 11/15 60000/60000 [==============================] - 3s 47us/sample - loss: 0.5642 - acc: 0.7858 - val_loss: 0.4325 - val_acc: 0.8488 Epoch 12/15 60000/60000 [==============================] - 3s 48us/sample - loss: 0.5582 - acc: 0.7873 - val_loss: 0.4294 - val_acc: 0.8492 Epoch 13/15 60000/60000 [==============================] - 3s 48us/sample - loss: 0.5574 - acc: 0.7882 - val_loss: 0.4263 - val_acc: 0.8523 Epoch 14/15 60000/60000 [==============================] - 3s 48us/sample - loss: 0.5524 - acc: 0.7888 - val_loss: 0.4350 - val_acc: 0.8448 Epoch 15/15 60000/60000 [==============================] - 3s 47us/sample - loss: 0.5486 - acc: 0.7901 - val_loss: 0.4297 - val_acc: 0.8493
最后验证集的精度达到了84%,这是一个仅仅使用全连接神经网络和softmax就能够得到的一个很不错的结果了!
以上是关于Tensorflow函数式API的使用的主要内容,如果未能解决你的问题,请参考以下文章
一文详解 TensorFlow 2.0 的 符号式 API 和命令式 API
使用 Tensorflow 2 的 Keras 功能 API 时传递 `training=true`
tensorflow2.0高阶api--主要为tf.keras.models提供的模型的类接口
当我在 Tensorflow 上使用 Keras API 连接两个模型时,模型的输入张量必须来自 `tf.layers.Input`