来自自定义 keras 层的 tflite 转换器
Posted
技术标签:
【中文标题】来自自定义 keras 层的 tflite 转换器【英文标题】:tflite converter from a custom keras layer 【发布时间】:2019-09-10 07:10:19 【问题描述】:我在尝试将 keras .h5 文件转换为 tflite 时遇到 TypeError。 新层是高斯核(径向基层)。 为了能够保存和加载 keras 模型,我还在自定义层中定义了 get_config() 方法。所以我能够正确保存和加载模型。
class RBFLayer(Layer):
def __init__(self, output_dim, centers=None, tol = 1e-6, gamma=0, **kwargs):
super(RBFLayer, self).__init__(**kwargs)
self.centers_ = centers
self.output_dim= output_dim
self.gamma_ = gamma
self.tol_ = tol
def build(self, input_shape):
self.mu = K.variable(self.centers_, name='centers')
self.gamma = K.variable(self.gamma_, name='gamma')
self.tol = K.constant(self.tol_,name='tol')
super(RBFLayer, self).build(input_shape)
def call(self, inputs): #Kernel radial
a,b = self.mu.shape
diff = K.reshape( K.tile(inputs,(1,a))-K.reshape(self.mu,(1,-1)), (-1,a,b))
l2 = K.sum(K.pow(diff, 2), axis=-1)
res = K.exp(-1 * self.gamma * l2)
mask = K.greater( res, self.tol)
return K.switch(mask, res, K.zeros_like(res))
def compute_output_shape(self, input_shape):
return (input_shape[0], self.output_dim)
def get_config(self):
config =
'output_dim': self.output_dim,
'centers': self.centers_,
'gamma': self.gamma_,
'tol': self.tol_
base_config = super(RBFLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
现在我想将模型保存到 tflite。我使用 keras 文件中的 TFLiteConverter,还包括“custom_objects”。
def save_tflite(self, base_name):
file =base_name +'.h5'
converter = tf.lite.TFLiteConverter.from_keras_model_file(file, custom_objects='RBFLayer':RBFLayer)
tflite_model = converter.convert()
open(base_name+".tflite", "wb").write(tflite_model)
我希望获得 tflite 模型文件,其中包括在训练完整模型(中心、tol、伽玛)时使用的 K.variables。
转换时我收到以下错误消息:
airgorbn.save_tflite( base_name)
Traceback (most recent call last):
File "<ipython-input-7-cdaa1ec46233>", line 1, in <module>
airgorbn.save_tflite( base_name)
File "C:/Users/AIRFI/Hospital/keras_RadialBasis.py", line 158, in save_tflite
converter = tf.lite.TFLiteConverter.from_keras_model_file(file, custom_objects='RBFLayer':RBFLayer)
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\lite\python\lite.py", line 747, in from_keras_model_file
keras_model = _keras.models.load_model(model_file, custom_objects)
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\saving\save.py", line 146, in load_model
return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\saving\hdf5_format.py", line 212, in load_model_from_hdf5
custom_objects=custom_objects)
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\saving\model_config.py", line 55, in model_from_config
return deserialize(config, custom_objects=custom_objects)
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\layers\serialization.py", line 89, in deserialize
printable_module_name='layer')
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\utils\generic_utils.py", line 192, in deserialize_keras_object
list(custom_objects.items())))
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\sequential.py", line 353, in from_config
model.add(layer)
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\training\tracking\base.py", line 457, in _method_wrapper
result = method(self, *args, **kwargs)
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\sequential.py", line 154, in add
'Found: ' + str(layer))
TypeError: The added layer must be an instance of class Layer. Found: <__main__.RBFLayer object at 0x0000017D3A75AC50>
【问题讨论】:
【参考方案1】:您需要将该层定义为自定义操作。
参考这个https://www.tensorflow.org/lite/guide/ops_custom
【讨论】:
以上是关于来自自定义 keras 层的 tflite 转换器的主要内容,如果未能解决你的问题,请参考以下文章
将 keras 模型从 pb 文件转换为 tflite 文件
TFLite 转换器:为 keras 模型实现的 RandomStandardNormal,但不适用于纯 TensorFlow 模型
Yolov3 到 Tensorrt:tf-keras Lambda 层的自定义插件