来自自定义 keras 层的 tflite 转换器

Posted

技术标签:

【中文标题】来自自定义 keras 层的 tflite 转换器【英文标题】:tflite converter from a custom keras layer 【发布时间】:2019-09-10 07:10:19 【问题描述】:

我在尝试将 keras .h5 文件转换为 tflite 时遇到 TypeError。 新层是高斯核(径向基层)。 为了能够保存和加载 keras 模型,我还在自定义层中定义了 get_config() 方法。所以我能够正确保存和加载模型。

class RBFLayer(Layer):
    def __init__(self, output_dim, centers=None, tol = 1e-6, gamma=0, **kwargs):
        super(RBFLayer, self).__init__(**kwargs)
        self.centers_ = centers
        self.output_dim= output_dim
        self.gamma_ = gamma
        self.tol_ = tol

    def build(self, input_shape):
        self.mu = K.variable(self.centers_, name='centers')
        self.gamma = K.variable(self.gamma_, name='gamma')
        self.tol = K.constant(self.tol_,name='tol')            
        super(RBFLayer, self).build(input_shape)
    def call(self, inputs): #Kernel radial
        a,b = self.mu.shape
        diff = K.reshape( K.tile(inputs,(1,a))-K.reshape(self.mu,(1,-1)), (-1,a,b))
        l2 =   K.sum(K.pow(diff, 2), axis=-1)
        res =  K.exp(-1 * self.gamma * l2)
        mask = K.greater( res, self.tol)
        return K.switch(mask, res, K.zeros_like(res)) 
    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)
    def get_config(self):
        config = 
        'output_dim': self.output_dim,
        'centers': self.centers_,
        'gamma': self.gamma_,
        'tol': self.tol_
        
        base_config = super(RBFLayer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

现在我想将模型保存到 tflite。我使用 keras 文件中的 TFLiteConverter,还包括“custom_objects”。

def save_tflite(self, base_name):
        file =base_name +'.h5'
        converter = tf.lite.TFLiteConverter.from_keras_model_file(file, custom_objects='RBFLayer':RBFLayer)
        tflite_model = converter.convert()
        open(base_name+".tflite", "wb").write(tflite_model)

我希望获得 tflite 模型文件,其中包括在训练完整模型(中心、tol、伽玛)时使用的 K.variables。

转换时我收到以下错误消息:

airgorbn.save_tflite( base_name)
Traceback (most recent call last):

  File "<ipython-input-7-cdaa1ec46233>", line 1, in <module>
    airgorbn.save_tflite( base_name)

  File "C:/Users/AIRFI/Hospital/keras_RadialBasis.py", line 158, in save_tflite
    converter = tf.lite.TFLiteConverter.from_keras_model_file(file, custom_objects='RBFLayer':RBFLayer)

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\lite\python\lite.py", line 747, in from_keras_model_file
    keras_model = _keras.models.load_model(model_file, custom_objects)

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\saving\save.py", line 146, in load_model
    return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\saving\hdf5_format.py", line 212, in load_model_from_hdf5
    custom_objects=custom_objects)

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\saving\model_config.py", line 55, in model_from_config
    return deserialize(config, custom_objects=custom_objects)

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\layers\serialization.py", line 89, in deserialize
    printable_module_name='layer')

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\utils\generic_utils.py", line 192, in deserialize_keras_object
    list(custom_objects.items())))

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\sequential.py", line 353, in from_config
    model.add(layer)

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\training\tracking\base.py", line 457, in _method_wrapper
    result = method(self, *args, **kwargs)

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\sequential.py", line 154, in add
    'Found: ' + str(layer))

TypeError: The added layer must be an instance of class Layer. Found: <__main__.RBFLayer object at 0x0000017D3A75AC50>

【问题讨论】:

【参考方案1】:

您需要将该层定义为自定义操作。

参考这个https://www.tensorflow.org/lite/guide/ops_custom

【讨论】:

以上是关于来自自定义 keras 层的 tflite 转换器的主要内容,如果未能解决你的问题,请参考以下文章

将 keras 模型从 pb 文件转换为 tflite 文件

处理来自 YOLOv5 TFlite 的输出数据

TFLite 转换器:为 keras 模型实现的 RandomStandardNormal,但不适用于纯 TensorFlow 模型

Yolov3 到 Tensorrt:tf-keras Lambda 层的自定义插件

制作自定义 Keras 层时不能使用未知的输入尺寸(批量大小)

需要内部层输出作为标签的自定义损失函数的 Keras 实现