keras中保存自定义层和loss

Posted flyuz

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了keras中保存自定义层和loss相关的知识,希望对你有一定的参考价值。

在keras中保存模型有几种方式:

(1):使用callbacks,可以保存训练中任意的模型,或选择最好的模型

logdir = './callbacks'
if not os.path.exists(logdir):
    os.mkdir(logdir)
output_model_file = os.path.join(logdir, "xxxx.h5")
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(output_model_file, save_best_file = True)
]

hist = model.fit_generator(xxxxx, callbacks = callbacks)

(2): 使用model.save(),会把整个模型保存下来,包括网络和参数

(3): 使用model.save_weights(),只保存模型的参数

当使用自定义的层或loss时,只有(3)可以直接使用,1 2会报下面这种错:

NotImplementedError: Layers with arguments in `__init__` must override `get_config`.
ValueError: Unknown loss function:loss
ValueError: Unknown layer: xxxlayer

解决办法:

在自定义网络层时重写get_config函数

我们主要看传入__init__接口时有哪些配置参数,然后在get_config内一一的将它们转为字典键值并且返回使用,以Mylayer为例:

class MyLayer(tf.keras.layers.Layer):
    def __init__(self, num_outputs, name="MyLayer", **kwargs):
        super(MyLayer, self).__init__(name=name, **kwargs)
        self.num_outputs = num_outputs

    def build(self, input_shape):
        self.kernel = self.add_variable("kernel", shape=[int(input_shape[-1]), self.num_outputs])
        super().build(input_shape)

    def call(self, input):
        output = tf.matmul(input, self.kernel)
        return output

    def get_config(self):
       config = {"num_outputs":self.num_outputs}
       base_config = super(Mylayer, self).get_config()
       return dict(list(base_config.items()) + list(config.items()))

一般来说,父类的config也是需要一并保存的,其中base_config即是父类网络层实现的配置参数,最后把父类及继承类的config组装为字典形式即可解决该问题

然后 在加载模型的时候,建立一个字典,该字典的键是自定义网络层时设定该层的名字,其值为该自定义网络层的类名,该字典将用于加载模型时使用

如果还使用了自定义的loss,则把loss也加到_custom_objects中

_custom_objects = {
    "Mylayer" :  Mylayer,
   "loss" : Myloss
}

最后在load模型的时候把_custom_objects传入

model = tf.keras.models.load_model("path/to/your/model", custom_objects=_custom_objects)

以上是关于keras中保存自定义层和loss的主要内容,如果未能解决你的问题,请参考以下文章

Keras 自定义loss函数 focal loss + triplet loss

Keras 中的自定义损失函数 - 遍历 TensorFlow

keras 中带有 train_on_batch 的自定义 Loss fnc 用于重放学习

Tensorflow+keras使用keras API保存模型权重plot画loss损失函数保存训练loss值

Tensorflow+keras使用keras API保存模型权重plot画loss损失函数保存训练loss值

keras库的安装及使用,以全连接层和手写数字识别MNIST为例