Tensorflow2.0语法 - keras_API的使用

Posted whw1314

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow2.0语法 - keras_API的使用相关的知识,希望对你有一定的参考价值。

转自 https://segmentfault.com/a/1190000021181739

前言

keras接口大都实现了 _call_ 方法。
母类 _call_ 调用了 call()。
因此下面说的几乎所有模型/网络层 都可以在定义后,直接像函数一样调用。
eg:

模型对象(参数) 
网络层对象(参数)

我们还可以实现继承模板

导入

from tensorflow import keras

metrics (统计平均)

里面有各种度量值的接口
如:二分类、多分类交叉熵损失容器,MSE、MAE的损失值容器, Accuracy精确率容器等。
下面以Accuracy伪码为例:

acc_meter = keras.metrics.Accuracy() # 建立一个容器
for _ in epoches:
    for _ in batches:
        y = ...
        y_predict = ...
        acc_meter.update_state(y, y_predict) # 每次扔进去数据,容器都会自动计算accuracy,并储存
    
        if times % 100 == 0: # 一百次一输出, 设置一个阈值/阀门
            print(acc_meter.result().numpy())   # 取出容器内所有储存的数据的,均值准确率
    acc_meter。reset_states()    # 容器缓存清空, 下一epoch从头计数。

激活函数+损失函数+优化器

导入方式:

keras.activations.relu()    # 激活函数:以relu为例,还有很多
keras.losses.categorical_crossentropy() # 损失函数:以交叉熵为例,还有很多
keras.optimizers.SGD()      # 优化器:以随机梯度下降优化器为例
keras.callbacks.EarlyStopping()  # 回调函数: 以‘按指定条件提前暂停训练’回调为例

Sequential(继承自Model)属于模型

模型定义方式

定义方式1:

model = keras.models.Sequential( [首层网络,第二层网络。。。] )

定义方式1:

model = keras.models.Sequential()
model.add(首层网络)
model.add(第二层网络)

模型相关回调配置

logdir = ‘callbacks‘
if not os.path.exists(logdir):
    os.mkdir(logdir)
save_model_file = os.path.join(logdir, ‘mymodel.h5‘)

callbacks = [
    keras.callbacks.TensorBoard(logdir),    # 写入tensorboard
    keras.callbacks.ModelCheckpoint(output_model_file, save_best_only=True),  # 模型保存
    keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3)  # 按条件终止模型训练
    # 验证集,每次都会提升,如果提升不动了,提升小于这个min_delta阈值,则会耐心等待5次。
    # 5次过后,要是还提升这么点。就提前结束。
]
# 代码写在这里,如何传递调用, 下面 “模型相关量度配置” 会提到

模型相关量度配置:((损失,优化器,准确率等)

说明,下面的各种量度属性,可通过字符串方式,也可通过上面讲的导入实例化对象方式。

model.compile(
    loss="sparse_categorical_crossentropy",    # 损失函数,这是字符串方式
    optimizer= keras.optimizers.SGD()          # 这是实例化对象的方式,这种方式可以传参
    metrics=[‘accuracy‘]  # 这项会在fit()时打印出来
)  # compile() 操作,没有真正的训练。
model.fit(
    x,y,
    epochs=10,                              # 反复训练 10 轮
    validation_data = (x_valid,y_valid),    # 把划分好的验证集放进来(fit时打印loss和val)
    validation_freq = 5,                    # 训练5次,验证一次。  可不传,默认为1。
    callbacks=callbacks,                    # 指定回调函数, 请衔接上面‘模型相关回调配置’
    
)   # fit()才是真正的训练 

模型 验证&测试

一般我们会把 数据先分成三部分(如果用相同的数据,起不到测试和验证效果,参考考试作弊思想):

  1. 训练集: (大批量,主体)
  2. 测试集: (模型所有训练结束后, 才用到)
  3. 验证集: (训练的过程种就用到)

说明1:(如何分离?)

1. 它们的分离是需要(x,y)组合在一起的,如果手动实现,需要随机打散、zip等操作。
2. 但我们可以通过 scikit-learn库,的 train_test_split() 方法来实现 (2次分隔)
3. 可以使用 tf.split()来手动实现

具体分离案例:参考上一篇文章: https://segmentfault.com/a/11...

说明2:(为什么我们有了测试集,还需要验证集?)

  1. 测试集是用来在最终,模型训练成型后(参数固定),进行测试,并且返回的是预测的结果值!!!!
  2. 验证集是伴随着模型训练过程中而验证)

代码如下:

loss, accuracy = model.evaluate( (x_test, y_test) ) # 度量, 注意,返回的是精度指标等
target = model.predict( (x_test, y_test) )          # 测试, 注意,返回的是 预测的结果!

可用参数

model.trainable_variables    # 返回模型中所有可训练的变量
# 使用场景: 就像我们之前说过的 gradient 中用到的 zip(求导结果, model.trainable_variables)

自定义Model

Model相当于母版, 你继承了它,并实现对应方法,同样也能简便实现模型的定义。

自定义Layer

同Model, Layer也相当于母版, 你继承了它,并实现对应方法,同样也能简便实现网络层的定义。

模型保存与加载

方法1:之前callback说的

方法2:只保存weight(模型不完全一致)

保存:

model = keras.Sequential([...])
...
model.fit()
model.save_weights(‘weights.ckpt‘)

加载:

假如在另一个文件中。(当然要把保存的权重要复制到本地目录)
model = keras.Sequential([...])    # 此模型构建必须和保存时候定义结构的一模一样的!
model.load_weights(‘weights.ckpt‘)
model.evaluate(...)
model.predict(...)

方法3:保存整个模型(模型完全一致)

保存:

model = keras.Sequential([...])
...
model.fit()
model.save(‘model.h5‘)    # 注意 这里变了,是 save

加载:(直接加载即可,不需要重新复原建模过程)

假如在另一个文件中。(当然要把保存的模型要复制到本地目录)
model = keras.models.load_model(‘model.h5‘)  # load_model是在 keras.models下
model.evaluate(...)
model.predict(...)

方法4:导出可供其他语言使用(工业化)

保存: (使用tf.saved_model模块)

model = keras.Sequential([...])
...
model.fit()
tf.saved_model.save(model, ‘目录‘)

加载:(使用tf.saved_model模块)

model = tf.saved_model.load(‘目录‘)

以上是关于Tensorflow2.0语法 - keras_API的使用的主要内容,如果未能解决你的问题,请参考以下文章

Tensorflow2.0语法 - keras_API的使用

TensorFlow 2.0 语法变更

TensorFlow2.0--TensorFlow2.0构架

Tensorflow2.0笔记

Tensorflow2.0笔记

Tensorflow2.0笔记