如何保存图像分类模型并将其用于 android

Posted

技术标签:

【中文标题】如何保存图像分类模型并将其用于 android【英文标题】:How do I save an image classification model and use it for android 【发布时间】:2019-10-20 04:57:49 【问题描述】:

如何使用 Keras 和 Tensorflow 将图像分类模型保存为 .pb 文件及其 label.txt 以便在 android.i 上使用这两个文件。我有一个开始代码,该代码仅保存 .pb 文件但不是 label.txt

我已经完成了洞的事情,但不是 label.txt 这是代码

import pandas as pd 
import numpy as np
import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import keras
from keras.models import Sequential
from keras.layers import Conv2D,MaxPooling2D,Dense,Flatten,Dropout,Activation
from keras.optimizers import Adam
from keras.callbacks import TensorBoard
from keras.layers.core import Lambda
from keras.optimizers import Adam
import keras 
import keras.backend as k
import tensorflow as tf
from tensorflow.python.framework import graph_util
print(keras.__version__)
print(tf.__version__)
import os
train_df = pd.read_csv('fashionmnist/fashion-mnist_train.csv',sep=',')
test_df = pd.read_csv('fashionmnist/fashion-mnist_test.csv',sep=',')


train_data =np.array(train_df,dtype = 'float32')
test_data = np.array(test_df,dtype = 'float32')
x_train = train_data[:,1:]/255
y_train = train_data[:,0]
x_test = train_data[:,1:]/255
y_test = train_data[:,0]
x_train,x_validate,y_train,y_validate=train_test_split(x_train,y_train,test_size = 0.2,random_state = 12345)
image = x_train[50,:].reshape((28,28))
plt.imshow(image)
plt.show()

image_rows =28
image_cols= 28
batch_size =100
image_shape =(image_rows,image_cols,1)



x_train = x_train.reshape(x_train.shape[0],*image_shape)
x_test = x_test.reshape(x_test.shape[0],*image_shape)
x_validate = x_validate.reshape(x_validate.shape[0],*image_shape)


def build_network(is_training=True):
    model = Sequential()
    model.add(Conv2D(32, (3, 3), input_shape=image_shape,  padding='same',name="1_conv"))
    model.add(Activation('relu'))
    model.add(Conv2D(32, (3, 3), padding='same',name="2_conv"))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2),name="1_pool"))

    model.add(Conv2D(64, (3, 3), padding='same',name="3_conv"))
    model.add(Activation('relu'))
    model.add(Conv2D(64,(3, 3), padding='same',name="4_conv"))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2),name="2_pool"))

    model.add(Conv2D(128,(3, 3),padding='same',name="5_conv"))
    model.add(Activation('relu'))
    model.add(Conv2D(128, (3, 3),padding='same',name="6_conv"))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2),name="3_pool"))

    model.add(Conv2D(256,(3, 3), padding='same',name="7_conv"))
    model.add(Activation('relu'))
    model.add(Conv2D(256, (3, 3), padding='same',name="8_conv"))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2),name="4_pool"))

    model.add(Flatten())
    model.add(Dense(512,name="fc_1"))
    model.add(Activation('relu'))


    if (is_training):
        #model.add(Dense(512, activation='relu'))
        #model.add(Dropout(0.5, name="drop_1"))
        model.add(Lambda(lambda x:k.dropout(x,level=0.5),name="drop_1"))



    model.add(Dense(10,name="fc_2"))
    model.add(Activation('softmax',name="class_result"))
    #model.summary()
    return model


    tf.reset_default_graph()
sess = tf.Session()
k.set_session(sess)
model=build_network()

history_dict = 
model.compile(loss='sparse_categorical_crossentropy',optimizer = Adam(),metrics=['accuracy'])




class TFCheckpointCallback(keras.callbacks.Callback):
    def __init__(self,saver,sess):
        self.saver=saver
        self.sess=sess

    def on_epoch_end(self,epoch,log=None):
        self.saver.save(self.sess,'fMnist/ckpt',global_step=epoch)


tf_saver= tf.train.Saver(max_to_keep=2)
checkpoint_callback= TFCheckpointCallback(tf_saver,sess)
%time
tf_graph=sess.graph
tf.train.write_graph(tf_graph.as_graph_def(),'freeze','fm_graph.pdtxt',as_text=True)
%time
history = model.fit(x_train,
                    y_train,
                    batch_size=batch_size,
                    epochs=50,
                    callbacks=[checkpoint_callback],
                    shuffle=True,
                    verbose=1,
                    validation_data=(x_validate,y_validate)
                   )

sess.close()


model_folder='fMnist/'
def prepare_graph_for_freezing(model_folder):
    model=build_network(is_training=False)
    checkpoint=tf.train.get_checkpoint_state(model_folder)
    input_checkpoint=checkpoint.model_checkpoint_path
    saver=tf.train.Saver()
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        k.set_session(sess)
        saver.restore(sess,input_checkpoint)
        tf.gfile.MakeDirs(model_folder+'freeze')
        saver.save(sess,model_folder + 'freeze/ckpt',global_step=0)


def freeze_graph(model_folder):
    checkpoint =tf.train.get_checkpoint_state(model_folder)
    print(model_folder+'freeze/')
    input_checkpoint = checkpoint.model_checkpoint_path
    absolut_model_folder="/".join(input_checkpoint.split('/')[:-1])
    output_graph=absolut_model_folder + "/fm_freazen_model.pb"
    print(output_graph)
    output_node_name = "class_result/Softmax"
    clear_devices = True
    new_saver=  tf.train.import_meta_graph(input_checkpoint + '.meta',clear_devices=clear_devices)

    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()


    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess2:
        print(input_checkpoint)
        new_saver.restore(sess2,input_checkpoint)

        output_graph_def=graph_util.convert_variables_to_constants(
        sess2,
        input_graph_def,
        output_node_name.split(","))

        with tf.gfile.GFile(output_graph,"wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph."% len(output_graph_def.node))
tf.reset_default_graph()
prepare_graph_for_freezing("freeze/")
freeze_graph("freeze/")

我有检查点和 .pb 文件

但我没有 label.txt 文件

【问题讨论】:

你找到答案了吗? 是的,我得到了答案,但不是来自这个页面。 【参考方案1】:

就 Android 上的图像分类而言,我建议您使用 TensorFlow Lite 而不是直接使用协议缓冲区。

首先,您需要将 Keras 模型 (.h5) 转换为 TensorFlow Lite 模型 (.tflite)。

converter = tf.lite.TFLiteConverter.from_keras_model_file( 'model.h5' )
tflite_buffer = converter.convert()
open( 'tflite_model.tflite' , 'wb' ).write( tflite_buffer )

模型已准备好在 Android 上加载。要检查输入和输出dtypeshape,请参阅this 文件。

现在在 Android 上,首先在 build.gradle 中添加 TensorFlow Lite 依赖项。

dependencies 
...
   implementation 'org.tensorflow:tensorflow-lite:1.13.1'
...

现在我们将模型加载为MappedByteBuffer 对象。

@Throws(IOException::class)

private fun loadModelFile(): MappedByteBuffer 
    val MODEL_ASSETS_PATH = "model.tflite"
    val assetFileDescriptor = assets.openFd(MODEL_ASSETS_PATH)
    val fileInputStream = FileInputStream(assetFileDescriptor.getFileDescriptor())
    val fileChannel = fileInputStream.getChannel()
    val startoffset = assetFileDescriptor.getStartOffset()
    val declaredLength = assetFileDescriptor.getDeclaredLength()
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, startoffset, declaredLength)

使用interpreter.run() 方法,我们在给定一些输入的情况下产生推理。看到这个file。此文件包含使用Bitmap.createScaledBitmap 方法调整Bitmap 大小并将其转换为float[][] 的方法

val interpreter = Interpreter( loadModelFile() )
val inputs : Array<FloatArray> = arrayOf( some_input_image. )
val outputs : Array<FloatArray> = arrayOf( floatArrayOf( 0.0f , 0.0f ) )
interpreter.run( inputs , outputs )
val output = outputs[0]

就是这样。 TFLite 比 TensorFlow Mobile 快得多。

注意:TF Lite 仅支持少量操作。由于完全支持与 CNN 相关的操作,因此我们也可以在 Android 和 ios 中使用 TFLite 进行图像分类。

提示:

    要减小 .tflite 文件的大小,请在 Python 中转换模型时使用 post_training_quantize 标志。

    converter = tf.lite.TFLiteConverter.from_keras_model_file( 'model.h5' )
    converter.post_training_quantize = True
    tflite_buffer = converter.convert()
    open( 'tflite_model.tflite' , 'wb' ).write( tflite_buffer )
    

    另外,请尝试使用 Firebase MLKit 在 Firebase 中托管自定义模型。

    我创建了许多使用 TF 对图像和文本进行分类的应用程序。

https://github.com/shubham0204/Spam_Classification_Android_Demo

https://github.com/shubham0204/Skinly_for_Melanoma

【讨论】:

对于labels.txt,可以将文件放在app的assets文件夹中阅读。 先生,谢谢您的最佳解释,但关键是我如何为该模型获取 label.txt 文件(我如何编写该文本文件)。我有 178 个班级,我用这 178 个文件夹图像数据(类)中的每一个都被正确标记,每个类有 5000 个图像。 @haptomee 你有找到获取 label.txt 的方法吗,我发现获取 label.txt 真的很难。

以上是关于如何保存图像分类模型并将其用于 android的主要内容,如果未能解决你的问题,请参考以下文章

如何使用 .pkl 文件预测图像

混淆矩阵和分类报告

对象检测/分类任务的性能指标(用于图像)

Tensorflow2 图像分类-Flowers数据深度学习模型保存读取参数查看和图像预测

Tensorflow2 图像分类-Flowers数据深度学习模型保存读取参数查看和图像预测

Tensorflow2 图像分类-Flowers数据深度学习模型保存读取参数查看和图像预测