Dropout与过拟合抑制函数式API
Posted AI与计算机视觉
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Dropout与过拟合抑制函数式API相关的知识,希望对你有一定的参考价值。
重磅干货,第一时间送达
如何添加Dropout层
在网络中添加Dropout层,主要是在隐藏层中使用,依然是使用之前的例子,如下:
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))
model.add(tf.keras.layers.Dense(128, activation='relu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(128, activation='relu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(128, activation='relu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
最后可以看到训练集和测试集的损失和正确率的曲线比较可以看出都没有过拟合。但是因为Dropout层数多啦,发现训练集的损失和准确率要比测试集的高和低。
减小网络规模也是抑制过拟合的非常好的方法。
正则化的原理就是控制网络规模,控制参数规模。
函数式API
Sequential()模型就只有一个输入和一个输出,中间的隐藏层都是顺序连接的,结构单一。如果要建立残差网络,还有输入直接连着输出的情况,这就需要使用函数式API。
下面是代码示例:(依然是使用Fashion_MNIST数据集为例)
# -*- coding: UTF-8 -*-"""
Author: LGD
FileName: functional_api
DateTime: 2020/11/12 09:23
SoftWare: PyCharm
matplotlib.pyplot as plt
from tensorflow import keras
fashion_mnist = keras.datasets.fashion_mnist
train_labels), (test_images, test_labels) = fashion_mnist.load_data()
# 数据归一化
train_images = train_images / 255.0
test_images = test_images / 255.0
# 使用函数式api建立模型# 输入层
input = keras.Input(shape=(28, 28)) # [(None, 28, 28)] None代表它是个数维度,任意个
x = keras.layers.Flatten()(input)
# 隐藏层
x = keras.layers.Dense(32, activation='relu')(x)
x = keras.layers.Dropout(0.5)(x)
x = keras.layers.Dense(64, activation='relu')(x)
# 输出层
output = keras.layers.Dense(10, activation='softmax')(x)
# 输入参数建立模型
model = keras.Model(inputs=input, outputs=output)
# 查看模型
model.summary()
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
history = model.fit(
train_images,
train_labels,
epochs=30,
validation_data=(test_images, test_labels)
)
test_acc = model.evaluate(test_images, test_labels)
history.history['loss'], 'r', label='loss')
history.history['val_loss'], 'b--', label='val_loss')
'loss curve of tests and trains') =
plt.show()
·END·
以上是关于Dropout与过拟合抑制函数式API的主要内容,如果未能解决你的问题,请参考以下文章