将 Keras 模型导出到 .pb 文件并针对推理进行优化会在 Android 上提供随机猜测

Posted

技术标签:

【中文标题】将 Keras 模型导出到 .pb 文件并针对推理进行优化会在 Android 上提供随机猜测【英文标题】:export Keras model to .pb file and optimize for inference gives random guess on Android 【发布时间】:2018-09-03 14:16:30 【问题描述】:

我正在开发一个用于年龄和性别识别的安卓应用程序。我找到了一个有用的model in GitHub。他们正在构建基于first-place winning paper 的 Keras 模型(tensorflow 后端)。他们提供了 Python 模块来训练和构建网络,已经训练过的权重文件可供下载和使用,以及网络摄像头上的工作演示。

我想在演示中使用提供的权重将他们的模型转换为 .pb 文件,以便它也可以在 android 上执行。

我使用this code 进行了与模型相关的细微修改进行转换:

from keras.models import Sequential
from keras.models import model_from_json
from keras import backend as K
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
import os

# Load existing model.
with open("model.json",'r') as f:
    modelJSON = f.read()

model = model_from_json(modelJSON)
model.load_weights("weights.18-4.06.hdf5")
print(model.summary())

# All new operations will be in test mode from now on.
K.set_learning_phase(0)

# Serialize the model and get its weights, for quick re-building.
config = model.get_config()
weights = model.get_weights()

# Re-build a model where the learning phase is now hard-coded to 0.
#new_model = model.from_config(config)
#new_model.set_weights(weights)

temp_dir = "graph"
checkpoint_prefix = os.path.join(temp_dir, "saved_checkpoint")
checkpoint_state_name = "checkpoint_state"
input_graph_name = "input_graph.pb"
output_graph_name = "output_graph.pb"

# Temporary save graph to disk without weights included.
saver = tf.train.Saver()
checkpoint_path = saver.save(K.get_session(), checkpoint_prefix, global_step=0, latest_filename=checkpoint_state_name)
tf.train.write_graph(K.get_session().graph, temp_dir, input_graph_name)

input_graph_path = os.path.join(temp_dir, input_graph_name)
input_saver_def_path = ""
input_binary = False
output_node_names = "dense_1/Softmax,dense_2/Softmax" # model dependent
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = os.path.join(temp_dir, output_graph_name)
clear_devices = False

# Embed weights inside the graph and save to disk.
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
                          input_binary, checkpoint_path,
                          output_node_names, restore_op_name,
                          filename_tensor_name, output_graph_path,
                          clear_devices, "")

我直接从演示中生成了 model.json 文件。带有model.json的demo.py文件的main函数代码为:

def main():
    args = get_args()
    depth = args.depth
    k = args.width
    weight_file = args.weight_file

    if not weight_file:
        weight_file = get_file("weights.18-4.06.hdf5", pretrained_model, cache_subdir="pretrained_models",
                               file_hash=modhash, cache_dir=os.path.dirname(os.path.abspath(__file__)))

    # for face detection
    detector = dlib.get_frontal_face_detector()

    # load model and weights
    img_size = 64
    model = WideResNet(img_size, depth=depth, k=k)()
    model.load_weights(weight_file)
    print(model.summary())

    # write model to json
    model_json = model.to_json()
    with open("model.json", "w") as json_file:
        json_file.write(model_json)

    for img in yield_images():
        input_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_h, img_w, _ = np.shape(input_img)

        # detect faces using dlib detector
        detected = detector(input_img, 1)
        faces = np.empty((len(detected), img_size, img_size, 3))

        if len(detected) > 0:
            for i, d in enumerate(detected):
                x1, y1, x2, y2, w, h = d.left(), d.top(), d.right() + 1, d.bottom() + 1, d.width(), d.height()
                xw1 = max(int(x1 - 0.4 * w), 0)
                yw1 = max(int(y1 - 0.4 * h), 0)
                xw2 = min(int(x2 + 0.4 * w), img_w - 1)
                yw2 = min(int(y2 + 0.4 * h), img_h - 1)
                cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
                # cv2.rectangle(img, (xw1, yw1), (xw2, yw2), (255, 0, 0), 2)
                faces[i, :, :, :] = cv2.resize(img[yw1:yw2 + 1, xw1:xw2 + 1, :], (img_size, img_size))

            # predict ages and genders of the detected faces
            results = model.predict(faces)
            predicted_genders = results[0]
            ages = np.arange(0, 101).reshape(101, 1)
            predicted_ages = results[1].dot(ages).flatten()

            # draw results
            for i, d in enumerate(detected):
                label = ", ".format(int(predicted_ages[i]),
                                        "F" if predicted_genders[i][0] > 0.5 else "M")
                draw_label(img, (d.left(), d.top()), label)

        cv2.imshow("result", img)
        key = cv2.waitKey(30)

        if key == 27:
            break


if __name__ == '__main__':
    main()

代码成功编译并生成多个检查点文件以及一个 .pb 文件。

这是模型的图表摘要:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 64, 64, 3)    0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 64, 64, 16)   432         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 64, 64, 16)   64          conv2d_1[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 64, 64, 16)   0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 64, 64, 128)  18432       activation_1[0][0]               
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 64, 64, 128)  512         conv2d_2[0][0]                   
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 64, 64, 128)  0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 64, 64, 128)  147456      activation_2[0][0]               
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 64, 64, 128)  2048        activation_1[0][0]               
__________________________________________________________________________________________________
add_1 (Add)                     (None, 64, 64, 128)  0           conv2d_3[0][0]                   
                                                                 conv2d_4[0][0]                   
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 64, 64, 128)  512         add_1[0][0]                      
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 64, 64, 128)  0           batch_normalization_3[0][0]      
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 64, 64, 128)  147456      activation_3[0][0]               
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 64, 64, 128)  512         conv2d_5[0][0]                   
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 64, 64, 128)  0           batch_normalization_4[0][0]      
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 64, 64, 128)  147456      activation_4[0][0]               
__________________________________________________________________________________________________
add_2 (Add)                     (None, 64, 64, 128)  0           conv2d_6[0][0]                   
                                                                 add_1[0][0]                      
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 64, 64, 128)  512         add_2[0][0]                      
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 64, 64, 128)  0           batch_normalization_5[0][0]      
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 32, 32, 256)  294912      activation_5[0][0]               
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 32, 32, 256)  1024        conv2d_7[0][0]                   
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 32, 32, 256)  0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 32, 32, 256)  589824      activation_6[0][0]               
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 32, 32, 256)  32768       activation_5[0][0]               
__________________________________________________________________________________________________
add_3 (Add)                     (None, 32, 32, 256)  0           conv2d_8[0][0]                   
                                                                 conv2d_9[0][0]                   
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 32, 32, 256)  1024        add_3[0][0]                      
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 32, 32, 256)  0           batch_normalization_7[0][0]      
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 32, 32, 256)  589824      activation_7[0][0]               
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 32, 32, 256)  1024        conv2d_10[0][0]                  
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 32, 32, 256)  0           batch_normalization_8[0][0]      
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 32, 32, 256)  589824      activation_8[0][0]               
__________________________________________________________________________________________________
add_4 (Add)                     (None, 32, 32, 256)  0           conv2d_11[0][0]                  
                                                                 add_3[0][0]                      
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 32, 32, 256)  1024        add_4[0][0]                      
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 32, 32, 256)  0           batch_normalization_9[0][0]      
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 16, 16, 512)  1179648     activation_9[0][0]               
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 16, 16, 512)  2048        conv2d_12[0][0]                  
__________________________________________________________________________________________________
activation_10 (Activation)      (None, 16, 16, 512)  0           batch_normalization_10[0][0]     
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 16, 16, 512)  2359296     activation_10[0][0]              
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 16, 16, 512)  131072      activation_9[0][0]               
__________________________________________________________________________________________________
add_5 (Add)                     (None, 16, 16, 512)  0           conv2d_13[0][0]                  
                                                                 conv2d_14[0][0]                  
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 16, 16, 512)  2048        add_5[0][0]                      
__________________________________________________________________________________________________
activation_11 (Activation)      (None, 16, 16, 512)  0           batch_normalization_11[0][0]     
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 16, 16, 512)  2359296     activation_11[0][0]              
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 16, 16, 512)  2048        conv2d_15[0][0]                  
__________________________________________________________________________________________________
activation_12 (Activation)      (None, 16, 16, 512)  0           batch_normalization_12[0][0]     
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 16, 16, 512)  2359296     activation_12[0][0]              
__________________________________________________________________________________________________
add_6 (Add)                     (None, 16, 16, 512)  0           conv2d_16[0][0]                  
                                                                 add_5[0][0]                      
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 16, 16, 512)  2048        add_6[0][0]                      
__________________________________________________________________________________________________
activation_13 (Activation)      (None, 16, 16, 512)  0           batch_normalization_13[0][0]     
__________________________________________________________________________________________________
average_pooling2d_1 (AveragePoo (None, 16, 16, 512)  0           activation_13[0][0]              
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 131072)       0           average_pooling2d_1[0][0]        
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 2)            262144      flatten_1[0][0]                  
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 101)          13238272    flatten_1[0][0]                  
==================================================================================================
Total params: 24,463,856
Trainable params: 24,456,656
Non-trainable params: 7,200
__________________________________________________________________________________________________

我采用了输出模型并使用以下脚本来优化推理:

python -m tensorflow.python.tools.optimize_for_inference --input output_graph.pb --output g.pb --input_names=input_1 --output_names=dense_1/Softmax,dense_2/Softmax

在操作过程中,终端给了我很多这样的警告。

 FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (16,), for node batch_normalization_1/FusedBatchNorm
WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (128,), for node batch_normalization_2/FusedBatchNorm
WARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_3/FusedBatchNorm'
WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (128,), for node batch_normalization_4/FusedBatchNorm
WARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_5/FusedBatchNorm'
WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (256,), for node batch_normalization_6/FusedBatchNorm
WARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_7/FusedBatchNorm'
WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (256,), for node batch_normalization_8/FusedBatchNorm
WARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_9/FusedBatchNorm'
WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (512,), for node batch_normalization_10/FusedBatchNorm
WARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_11/FusedBatchNorm'
WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (512,), for node batch_normalization_12/FusedBatchNorm
WARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_13/FusedBatchNorm'

看来这些警告太可怕了!!

我已经在我的 android 应用程序上尝试了这两个文件。优化后的文件根本不起作用,而未优化的文件是可执行的,但会产生无意义的结果“e.g. GUESSING”。

我知道这个问题有点长,但它是整个工作日的总结,我不想错过任何细节。

我不知道问题出在哪里。是在输出节点名称中、冻结图形、使用权重实例化模型还是在优化推理脚本中。

【问题讨论】:

您会使用固定的批量大小还是需要动态的? 我猜你在问一些关于培训的问题。我没有训练模型。我正在使用此代码来转换预训练的权重而不是重新训练模型。 我知道,但冻结模型似乎不支持动态批量大小(看起来好像它期望批量大小是训练期间使用的那个)。您是否将使用“非优化”.pb 文件与“优化”文件的输出进行了比较? 这对我来说开始有意义了。我想这是一个依赖于训练的问题。我不是这个主题的专家“这是我的第一个项目”所以有些术语是新的。我想我需要一个固定版本来进行推理。顺便说一句,我怎样才能使 .pb 文件动态或固定?感谢您的宝贵时间。 【参考方案1】:

经过研究,随机猜测的问题终于解决了。

问题不是像我最初预期的那样将模型转换为.pb 文件,而是将图像正确地输入到 Android 中的模型。

我再次转换模型。以下几点将总结我的工作。

首先,我从上面问题中引用的 demo.py 中获得了模型。我使用以下代码保存了它: # save the model to .h5 file. model.save('./saved_model/model.h5')

其次,我将.h5生成的文件转换为.pb文件。我使用了this repository中的代码。如果您在超链接中无法访问该链接,请参阅:https://github.com/amir-abdi/keras_to_tensorflow。这个存储库的代码证明了它的可靠性。它将模型转换为.pb 文件并立即对其进行优化以进行推理。太棒了!

第三,我将生成的 .pb 文件放到 android assets 文件夹中,以便使用我的应用程序对其进行配置。

第四,我将预期的图像转换为像素值,并进行了逐位移位以提取颜色。获取有关此代码的帮助以完成此任务。请记住,getPixels 方法保留了颜色通道。因此,如果您需要反转颜色通道,请遵循以下代码。我从 this answer 那里得到了帮助。

    Bitmap bitmap = createScaledBitmap(faces[0], INPUT_SIZE , INPUT_SIZE , true);
    // get pixel values
    bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());

    for (int i = 0; i < intValues.length; ++i) 

        final int val = intValues[i];

        // extract colors using bit-wise shifting.
        floatValues[i * 3 + 0] = ((val >> 16) & 0xFF );
        floatValues[i * 3 + 1] = ((val >> 8) & 0xFF );
        floatValues[i * 3 + 2] = (val & 0xFF );

        // reverse the color orderings.
        floatValues[i*3 + 2] = Color.red(val);
        floatValues[i*3 + 1] = Color.green(val);
        floatValues[i*3] = Color.blue(val);
    

最后,我可以使用张量流推理方法将图像输入模型,进行推理并输出结果。

【讨论】:

以上是关于将 Keras 模型导出到 .pb 文件并针对推理进行优化会在 Android 上提供随机猜测的主要内容,如果未能解决你的问题,请参考以下文章

将张量流权重导出到 hdf5 文件和模型到 keras model.json

Keras模型的导出和pb文件的转换

修改 tensorflow savedmodel pb 文件以使用自定义操作进行推理

将 keras 模型从 pb 文件转换为 tflite 文件

如何将保存的模型转换或加载到 TensorFlow 或 Keras?

Tensorflow (.pb) 格式到 Keras (.h5)