tensorflow 对Model检测点的操作model.get_layer从 checkpoint加载权重set_weightsmodel层属性获取
Posted 炫云云
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow 对Model检测点的操作model.get_layer从 checkpoint加载权重set_weightsmodel层属性获取相关的知识,希望对你有一定的参考价值。
model层属性获取
在tensorflow中,要想获取层的输出的各种信息,可以先获取层对象,再通过层对象的属性获取层输出的其他特性.
获取model对应层的方法为:
get_layer(self, name=None, index=None):
函数功能:根据层的名称(这个名称具有唯一性)或者索引号检索model获取对应的层.
获取层输出的其他特性
- model.get_layer(index=0).output # 输出张量
- model.get_layer(index=0).output_shape #各自的形状
- model.get_layer(index=0).input # 输出张量
- model.get_layer(index=0).output_shape #各自的形状
- #该层有多个节点时(node_index为节点序号):
- layer.get_input_at(node_index)
- layer.get_output_at(node_index)
- layer.get_input_shape_at(node_index)
- layer.get_output_shape_at(node_index)
- model.get_layer(“word_embeddings”).set_weights(weights) #将权重加载到该层
- model.get_layer(“word_embeddings”).get_weights() #返回层的权重(numpy array)
- config = model.get_layer(“word_embeddings”).get_config() #保存该层的配置
检查点
保存模型并不限于在训练模型后,在训练模型之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况,我们自然希望能够将训练得到的参数保存下来,否则下次又要重新训练。
这种在训练中保存模型,习惯上称之为保存检查点。
load_checkpoint
tf.train.load_checkpoint(ckpt_dir_or_file):
在ckpt_dir_or_file中找到的检查点返回’ CheckpointReader ’
如果’ ckpt_dir_or_file '解析为具有多个检查点的目录,则返回最新检查点的reader。variables = tf.train.load_checkpoint(init_checkpoint)
从Checkpoint对象获取张量:
variables.get_tensor(“bert/embeddings/word_embeddings”)
bert 中load_checkpoint并且get_layer().set_weights操作
variables = tf.train.load_checkpoint(init_checkpoint)
# embedding weights
model._encoder_layer.get_layer("word_embeddings").set_weights([
variables.get_tensor("bert/embeddings/word_embeddings")])
model._encoder_layer.get_layer("position_embeddings").set_weights([
variables.get_tensor("bert/embeddings/position_embeddings")])
model._encoder_layer.get_layer("type_embeddings").set_weights([
variables.get_tensor("bert/embeddings/token_type_embeddings")])
model._encoder_layer.get_layer("embeddings/layer_norm").set_weights([
variables.get_tensor("bert/embeddings/LayerNorm/gamma"),
variables.get_tensor("bert/embeddings/LayerNorm/beta")
])
model._encoder_layer.get_layer("embedding_projection").set_weights([
variables.get_tensor("bert/encoder/embedding_hidden_mapping_in/kernel"),
variables.get_tensor("bert/encoder/embedding_hidden_mapping_in/bias")
])
# multi attention weights
for i in range(model._config['bert_config'].num_hidden_layers):
model._encoder_layer.get_layer("transformer/layer_{}".format(i)).set_weights([
tf.reshape(variables.get_tensor(
"bert/encoder/layer_{}/attention/self/query/kernel".format(i)),
[model.bert_config.hidden_size, model.bert_config.num_attention_heads, -1]),
tf.reshape(
variables.get_tensor("bert/encoder/layer_{}/attention/self/query/bias".format(i)),
[model.bert_config.num_attention_heads, -1]),
tf.reshape(variables.get_tensor(
"bert/encoder/layer_{}/attention/self/key/kernel".format(i)),
[model.bert_config.hidden_size, model.bert_config.num_attention_heads, -1]),
tf.reshape(
variables.get_tensor("bert/encoder/layer_{}/attention/self/key/bias".format(i)),
[model.bert_config.num_attention_heads, -1]),
tf.reshape(variables.get_tensor(
"bert/encoder/layer_{}/attention/self/value/kernel".format(i)),
[model.bert_config.hidden_size, model.bert_config.num_attention_heads, -1]),
tf.reshape(
variables.get_tensor("bert/encoder/layer_{}/attention/self/value/bias".format(i)),
[model.bert_config.num_attention_heads, -1]),
tf.reshape(variables.get_tensor(
"bert/encoder/layer_{}/attention/output/dense/kernel".format(i)),
[model.bert_config.num_attention_heads, -1, model.bert_config.hidden_size]),
variables.get_tensor("bert/encoder/layer_{}/attention/output/dense/bias".format(i)),
variables.get_tensor("bert/encoder/layer_{}/attention/output/LayerNorm/gamma".format(i)),
variables.get_tensor("bert/encoder/layer_{}/attention/output/LayerNorm/beta".format(i)),
variables.get_tensor("bert/encoder/layer_{}/intermediate/dense/kernel".format(i)),
variables.get_tensor("bert/encoder/layer_{}/intermediate/dense/bias".format(i)),
variables.get_tensor("bert/encoder/layer_{}/output/dense/kernel".format(i)),
variables.get_tensor("bert/encoder/layer_{}/output/dense/bias".format(i)),
variables.get_tensor("bert/encoder/layer_{}/output/LayerNorm/gamma".format(i)),
variables.get_tensor("bert/encoder/layer_{}/output/LayerNorm/beta".format(i)),
])
model._encoder_layer.get_layer("pooler_transform").set_weights([
variables.get_tensor("bert/pooler/dense/kernel"),
variables.get_tensor("bert/pooler/dense/bias"),
])
tf.train.list_variables(init_checkpoint)
#列出检查点中变量的检查点键和形状。
#bert 例子
init_vars = tf.train.list_variables(init_checkpoint)
for name, shape in init_vars:
if name.startswith("bert"):
print(f"{name}, shape={shape}, *INIT FROM CKPT SUCCESS*")
import tensorflow as tf
import os
ckpt_directory = "/tmp/training_checkpoints/ckpt"
ckpt = tf.train.Checkpoint(optimizer=optimizer, model=model)
manager = tf.train.CheckpointManager(ckpt, ckpt_directory, max_to_keep=3)
train_and_checkpoint(model, manager)
tf.train.list_variables(manager.latest_checkpoint)
保存检测点
x_train,y_train,x_test,y_test=process_data()
model=mode(x_train,y_train,x_test,y_test)
checkpoint=tf.train.Checkpoint(A=model) #保存model
checkpoint.save('./checkpoint/01.ckpt') #在源文件夹建立一个checkpoint文件夹,保存的是文件目录加文件前缀
以上是关于tensorflow 对Model检测点的操作model.get_layer从 checkpoint加载权重set_weightsmodel层属性获取的主要内容,如果未能解决你的问题,请参考以下文章
如何从代码运行 tensorflow 对象检测 api (model_main_tf2)?
如何在 TensorFlow 对象检测 API 中从头开始训练?