使用Tensorflow搭建回归预测模型之八:模型与外部接口对接

Posted jimchen1218

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了使用Tensorflow搭建回归预测模型之八:模型与外部接口对接相关的知识,希望对你有一定的参考价值。

前一篇中,我们讨论了模型的压缩,将标准tensorflow格式的模型文件转换成tflite格式,极大的缩小了模型的大小。

本篇我们将介绍如何使用标准C/C++来调用tflite格式的模型。

接下来依次介绍下:

一、BUILD文件修改:

技术图片
# Description:
# TensorFlow Lite A/C of Traffic Assist.

package(default_visibility = ["//visibility:public"])

licenses(["notice"])  # Apache 2.0

load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts")

exports_files(glob([
    "testdata/*.txt",
]))


tf_cc_binary(
    name = "ta_ac",
    srcs = [
        "get_ac_settings.h",
        "get_ac_settings_impl.h",
        "ta_ac.cc",
        "arm_caller.cc",
        "ta_ac.h",
    ],
    linkopts = tflite_linkopts() + select({
        "//tensorflow:android": [
            "-pie",  # Android 5.0 and later supports only PIE
            "-lm",  # some builtin ops, e.g., tanh, need -lm
        ],
        "//conditions:default": [],
    }),
    deps = [
        ":data_helpers",
        "//tensorflow/contrib/lite:framework",
        "//tensorflow/contrib/lite:string_util",
        "//tensorflow/contrib/lite/kernels:builtin_ops",
    ],
)

cc_library(
    name = "data_helpers",
    srcs = ["data_helpers.cc"],
    hdrs = [
        "data_helpers.h",
        "data_helpers_impl.h",
        "ta_ac.h",
    ],
    deps = [
        "//tensorflow/contrib/lite:builtin_op_data",
        "//tensorflow/contrib/lite:framework",
        "//tensorflow/contrib/lite:schema_fbs_version",
        "//tensorflow/contrib/lite:string",
        "//tensorflow/contrib/lite:string_util",
        "//tensorflow/contrib/lite/kernels:builtin_ops",
        "//tensorflow/contrib/lite/schema:schema_fbs",
    ],
)

cc_test(
    name = "ta_ac_test",
    srcs = [
        "get_ac_settings.h",
        "get_ac_settings_impl.h",
        "ta_ac_test.cc",
    ],
    data = [
        "testdata/ac_data_input.txt",
    ],
    tags = ["no_oss"],
    deps = [
        ":data_helpers",
        "@com_google_googletest//:gtest",
    ],
)
View Code

在tf_cc_library中增加一个主文件,arm_caller.cc,该文件的作用是用来模拟外部接口。

二、ta_ac.h文件改动:

技术图片
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_TA_AC_TA_AC_H
#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_TA_AC_TA_AC_H

#include "tensorflow/contrib/lite/string.h"

namespace tflite {
namespace ta_ac {

struct Settings {
  bool verbose = false;
  bool accel = false;
  bool input_floating = false;
  bool profiling = false;
  int loop_count = 1;
  float input_mean = 127.5f;
  float input_std = 127.5f;
  string model_name = "./model_ac.tflite";
  string input_data_name = "./ac_data_input.txt";
  string labels_file_name = "./ac_labels.txt";
  string input_layer_type = "uint8_t";
  int number_of_threads = 4;
};


struct ac_settings {
    float temp =0.0;
    int direct =0;
    int power = 0;
};


extern ac_settings RunInference(Settings* s,std::vector<float> ac_input,int data_width,int data_height,int data_channels);
extern Settings getopt(int argc, char** argv);

} // namespace ta_ac
}  // namespace tflite



#endif  // TENSORFLOW_CONTRIB_LITE_EXAMPLES_TA_AC_TA_AC_H
View Code

在命令空间里增加了一个结构体和两个接口,用于提供给外部接口调用:

1、增加了一个结构体:

struct ac_settings {
    float temp =0.0; //空调温度
    int direct =0;//空调风向
    int power = 0; //空调风力
};  

2、推理接口

extern ac_settings RunInference(Settings* s,std::vector<float> ac_input,int data_width,int data_height,int data_channels);

3、初始化参数接口,推理前的参数设置
extern Settings getopt(int argc, char** argv);

三、ta_ac.cc文件修改:

技术图片
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdarg>
#include <cstdio>
#include <cstdlib>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>

#include <fcntl.h>      // NOLINT(build/include_order)
#include <getopt.h>     // NOLINT(build/include_order)
#include <sys/time.h>   // NOLINT(build/include_order)
#include <sys/types.h>  // NOLINT(build/include_order)
#include <sys/uio.h>    // NOLINT(build/include_order)
#include <unistd.h>     // NOLINT(build/include_order)

#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/optional_debug_tools.h"
#include "tensorflow/contrib/lite/string_util.h"

#include "tensorflow/contrib/lite/examples/ta_ac/data_helpers.h"
#include "tensorflow/contrib/lite/examples/ta_ac/data_helpers_impl.h"
#include "tensorflow/contrib/lite/examples/ta_ac/get_ac_settings.h"

#define LOG(x) std::cerr

namespace tflite {
namespace ta_ac {

double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); }

// Takes a file name, and loads a list of labels from it, one per line, and
// returns a vector of the strings. It pads with empty strings so the length
// of the result is a multiple of 16, because our model expects that.
TfLiteStatus ReadLabelsFile(const string& file_name,
                            std::vector<string>* result,
                            size_t* found_label_count) {
  std::ifstream file(file_name);
  if (!file) {
    LOG(FATAL) << "Labels file " << file_name << " not found
";
    return kTfLiteError;
  }
  result->clear();
  string line;
  while (std::getline(file, line)) {
    result->push_back(line);
  }
  *found_label_count = result->size();
  const int padding = 16;
  while (result->size() % padding) {
    result->emplace_back();
  }
  return kTfLiteOk;
}

void PrintProfilingInfo(const profiling::ProfileEvent* e, uint32_t op_index,
                        TfLiteRegistration registration) {
  // output something like
  // time (ms) , Node xxx, OpCode xxx, symblic name
  //      5.352, Node   5, OpCode   4, DEPTHWISE_CONV_2D

  LOG(INFO) << std::fixed << std::setw(10) << std::setprecision(3)
            << (e->end_timestamp_us - e->begin_timestamp_us) / 1000.0
            << ", Node " << std::setw(3) << std::setprecision(3) << op_index
            << ", OpCode " << std::setw(3) << std::setprecision(3)
            << registration.builtin_code << ", "
            << EnumNameBuiltinOperator(
                   static_cast<BuiltinOperator>(registration.builtin_code))
            << "
";
}




ac_settings RunInference(Settings* s,std::vector<float> ac_input,int data_width,int data_height,int data_channels) {
    ac_settings ac;
  if (!s->model_name.c_str()) {
    LOG(ERROR) << "no model file name
";
    exit(-1);
  }

  std::unique_ptr<tflite::FlatBufferModel> model;
  std::unique_ptr<tflite::Interpreter> interpreter;
  model = tflite::FlatBufferModel::BuildFromFile(s->model_name.c_str());
  if (!model) {
    LOG(FATAL) << "
Failed to mmap model " << s->model_name << "
";
    exit(-1);
  }
  LOG(INFO) << "Loaded tensorflow lite model " << s->model_name << "
";
  model->error_reporter();
  //LOG(INFO) << "resolved reporter
";

  tflite::ops::builtin::BuiltinOpResolver resolver;

  tflite::InterpreterBuilder(*model, resolver)(&interpreter);
  if (!interpreter) {
    LOG(FATAL) << "Failed to construct interpreter
";
    exit(-1);
  }
  
  interpreter->UseNNAPI(s->accel);

  if (s->verbose) {
    LOG(INFO) << "tensors size: " << interpreter->tensors_size() << "
";
    LOG(INFO) << "nodes size: " << interpreter->nodes_size() << "
";
    LOG(INFO) << "inputs: " << interpreter->inputs().size() << "
";
    LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0) << "
";

    int t_size = interpreter->tensors_size();
    for (int i = 0; i < t_size; i++) {
      if (interpreter->tensor(i)->name)
        LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", "
                  << interpreter->tensor(i)->bytes << ", "
                  << interpreter->tensor(i)->type << ", "
                  << interpreter->tensor(i)->params.scale << ", "
                  << interpreter->tensor(i)->params.zero_point << "
";
    }
  }

  if (s->number_of_threads != -1) {
    interpreter->SetNumThreads(s->number_of_threads);
  }

  int input = interpreter->inputs()[0];

  if (s->verbose) 
  {
      LOG(INFO) << "input: " << input << "
";
  }

  const std::vector<int> inputs = interpreter->inputs();
  const std::vector<int> outputs = interpreter->outputs();

  if (s->verbose) 
  {
    LOG(INFO) << "number of inputs: " << inputs.size() << "
";;
    LOG(INFO) << "number of outputs: " << outputs.size() << "
";
  }

  if (interpreter->AllocateTensors() != kTfLiteOk)
  {
    LOG(FATAL) << "Failed to allocate tensors!";
  }

  if (s->verbose) 
  {
      PrintInterpreterState(interpreter.get());
  }

  //LOG(INFO) << "interpreter->tensor(input)->type: " << interpreter->tensor(input)->type << "
";
  switch (interpreter->tensor(input)->type) {
    case kTfLiteFloat32:
      s->input_floating = true;
      setdata<float>(interpreter->typed_tensor<float>(input), ac_input.data(), data_height, data_width, data_channels, s);
      break;
    case kTfLiteUInt8:
      setdata<uint8_t>(interpreter->typed_tensor<uint8_t>(input), ac_input.data(), data_height, data_width, data_channels, s);
      break;
    default:
      LOG(FATAL) << "cannot handle input type " << interpreter->tensor(input)->type << " yet";
      exit(-1);
  }


  profiling::Profiler* profiler = new profiling::Profiler();
  interpreter->SetProfiler(profiler);

  if (s->profiling) profiler->StartProfiling();

  struct timeval start_time, stop_time;
  gettimeofday(&start_time, nullptr);
  for (int i = 0; i < s->loop_count; i++) {
    if (interpreter->Invoke() != kTfLiteOk) {
      LOG(FATAL) << "Failed to invoke tflite!
";
    }
  }
  gettimeofday(&stop_time, nullptr);
  LOG(INFO) << "invoked 
";
  LOG(INFO) << "inference time: "
            << (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000)
            << " ms 
";

  if (s->profiling) {
    profiler->StopProfiling();
    auto profile_events = profiler->GetProfileEvents();
    for (int i = 0; i < profile_events.size(); i++) {
      auto op_index = profile_events[i]->event_metadata;
      const auto node_and_registration =
          interpreter->node_and_registration(op_index);
      const TfLiteRegistration registration = node_and_registration->second;
      PrintProfilingInfo(profile_events[i], op_index, registration);
    }
  }

  int output = interpreter->outputs()[0];
  //LOG(INFO) << "RunInference interpreter->tensor(output)->type: " << interpreter->tensor(output)->type << "
";

  float temp = interpreter->typed_output_tensor<float>(0)[0];
  float direct = interpreter->typed_output_tensor<float>(1)[0];
  float power = interpreter->typed_output_tensor<float>(2)[0];

  //LOG(INFO) << "RunInference temp: " << temp << "
";
  //LOG(INFO) << "RunInference direct: " << direct << "
";
  //LOG(INFO) << "RunInference power: " << power << "
";

  LOG(INFO) << "temp:" << int(temp) << " C
";
  LOG(INFO) << "direct:" << int(direct) << " direction(0:head,1:body,2:leg)
";
  LOG(INFO) << "power:" << int(power) << " power(0:auto,9:large)
";
    
    ac.temp = temp;
    ac.direct = direct;
    ac.power = power;
  LOG(INFO) << "RunInference ac.temp: " << ac.temp << "
";
  LOG(INFO) << "RunInference ac.direct: " << ac.direct << "
";
  LOG(INFO) << "RunInference ac.power: " << ac.power << "
";

    return ac;
}



void display_usage() {
  LOG(INFO) << "ta_ac
"
            << "--accelerated, -a: [0|1], use Android NNAPI or not
"
            << "--count, -c: loop interpreter->Invoke() for certain times
"
            << "--input_mean, -b: input mean
"
            << "--input_std, -s: input standard deviation
"
            << "--data, -d: data_name.txt
"
            << "--labels, -l: labels for the model
"
            << "--tflite_model, -m: model_name.tflite
"
            << "--profiling, -p: [0|1], profiling or not
"
            << "--threads, -t: number of threads
"
            << "--verbose, -v: [0|1] print more information
"
            << "
";
}

Settings getopt(int argc, char** argv)
{
        Settings s;
      int c;
        while (1) {
    static struct option long_options[] = {
        {"accelerated", required_argument, nullptr, a},
        {"count", required_argument, nullptr, c},
        {"verbose", required_argument, nullptr, v},
        {"data", required_argument, nullptr, d},
        {"labels", required_argument, nullptr, l},
        {"tflite_model", required_argument, nullptr, m},
        {"profiling", required_argument, nullptr, p},
        {"threads", required_argument, nullptr, t},
        {"input_mean", required_argument, nullptr, b},
        {"input_std", required_argument, nullptr, s},
        {nullptr, 0, nullptr, 0}};

    /* getopt_long stores the option index here. */
    int option_index = 0;

    c = getopt_long(argc, argv, "a:b:c:d:f:l:m:p:s:t:v:", long_options, &option_index);

    /* Detect the end of the options. */
    if (c == -1) break;

    switch (c) {
      case a:
        s.accel = strtol(optarg, nullptr, 10);  // NOLINT(runtime/deprecated_fn)
        break;
      case b:
        s.input_mean = strtod(optarg, nullptr);
        break;
      case c:
        s.loop_count =
            strtol(optarg, nullptr, 10);  // NOLINT(runtime/deprecated_fn)
        break;
      case d:
        s.input_data_name = optarg;
        break;
      case l:
        s.labels_file_name = optarg;
        break;
      case m:
        s.model_name = optarg;
        break;
      case p:
        s.profiling =
            strtol(optarg, nullptr, 10);  // NOLINT(runtime/deprecated_fn)
        break;
      case s:
        s.input_std = strtod(optarg, nullptr);
        break;
      case t:
        s.number_of_threads = strtol(  // NOLINT(runtime/deprecated_fn)
            optarg, nullptr, 10);
        break;
      case v:
        s.verbose =
            strtol(optarg, nullptr, 10);  // NOLINT(runtime/deprecated_fn)
        break;
      case h:
      case ?:
        /* getopt_long already printed an error message. */
        display_usage();
        exit(-1);
      default:
        exit(-1);
    }
  }
    return s;
}

#if 0
int Main(int argc, char** argv) {
  Settings s;
    ac_settings ac;
    s = getopt(argc, argv);

//input data
  int data_width = 4;
  int data_height = 3;
  int data_channels = 1;

  std::vector<float> ac_in(data_height * data_width * data_channels);
// test code {2018,11,16,14.88,31.21549,121.30741,15.18,31.20742,121.44468,14.5,14.4,14});
  ac_in[0] = 2018;
  ac_in[1] = 11;
  ac_in[2] = 16;
  ac_in[3] = 14.88f;
  ac_in[4] = 31.21549f;
  ac_in[5] = 121.30741f;
  ac_in[6] = 15.18;
  ac_in[7] = 31.20742;
  ac_in[8] = 121.44468;
  ac_in[9] = 14.5f;
  ac_in[10] = 14.4f;
  ac_in[11] = 14;

//inference
  ac = RunInference(&s,ac_in,data_width,data_height,data_channels);

//output data
  LOG(INFO) << "Main ac.temp: " << ac.temp << "
";
  LOG(INFO) << "Main ac.direct: " << ac.direct << "
";
  LOG(INFO) << "Main ac.power: " << ac.power << "
";
  return 0;
}
#endif

}  // namespace ta_ac
}  // namespace tflite

#if 0
int main(int argc, char** argv) { 
  printf("-----------------------
");
  printf("-         ta_ac      --
");  
  printf("-     tflite   ok!   --
");
  printf("-----------------------
");    
  return tflite::ta_ac::Main(argc, argv);  
}

#endif
View Code

封装了两个接口:RunInference和getopt,并将推理所需的输入数据和输出数据打包和聚合在一起,可以让外部调用起来更方便。

去掉了命名空间内主函数和C主函数,将程序主入口放到arm_caller.cc里。

四、增加一个文件:arm_caller.cc

技术图片
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdarg>
#include <cstdio>
#include <cstdlib>
#include <iomanip>
#include <string>
#include <vector>
#include <fcntl.h>      // NOLINT(build/include_order)
#include <getopt.h>     // NOLINT(build/include_order)
#include <sys/time.h>   // NOLINT(build/include_order)
#include <sys/types.h>  // NOLINT(build/include_order)
#include <sys/uio.h>    // NOLINT(build/include_order)
#include <unistd.h>     // NOLINT(build/include_order)

#include "tensorflow/contrib/lite/examples/ta_ac/ta_ac.h"



int main(int argc, char** argv) {
  printf("-----------------------
");
  printf("-         ta_ac      --
");  
  printf("-     tflite   ok!   --
");
  printf("-----------------------
");
  tflite::ta_ac::Settings s;
    tflite::ta_ac::ac_settings ac;
    s = tflite::ta_ac::getopt(argc, argv);

//input data
  int data_width = 4;
  int data_height = 3;
  int data_channels = 1;

  std::vector<float> ac_in(data_height * data_width * data_channels);
// test code {2018,11,16,14.88,31.21549,121.30741,15.18,31.20742,121.44468,14.5,14.4,14});
  ac_in[0] = 2018;
  ac_in[1] = 11;
  ac_in[2] = 16;
  ac_in[3] = 14.88f;
  ac_in[4] = 31.21549f;
  ac_in[5] = 121.30741f;
  ac_in[6] = 15.18;
  ac_in[7] = 31.20742;
  ac_in[8] = 121.44468;
  ac_in[9] = 14.5f;
  ac_in[10] = 14.4f;
  ac_in[11] = 14;

//inference
  ac = tflite::ta_ac::RunInference(&s,ac_in,data_width,data_height,data_channels);

//output data
  printf("arm_caller ac.temp:%d
", int(ac.temp));
  printf("arm_caller ac.direct:%d
", ac.direct);
  printf("arm_caller ac.power:%d
", ac.power);

  return 0;
}
View Code

1、定义一些标准头文件

2、特别需要注意的是,需要定义ta_ac.h头文件:

#include "tensorflow/contrib/lite/examples/ta_ac/ta_ac.h",否则无法正常编译和调用。

3、输入接口模拟

//input data
  int data_width = 4;
  int data_height = 3;
  int data_channels = 1;
  std::vector<float> ac_in(data_height * data_width * data_channels);
  ac_in[0] = 2018;
  ac_in[1] = 11;
  ac_in[2] = 16;
  ac_in[3] = 14.88f;
  ac_in[4] = 31.21549f;
  ac_in[5] = 121.30741f;
  ac_in[6] = 15.18;
  ac_in[7] = 31.20742;
  ac_in[8] = 121.44468;
  ac_in[9] = 14.5f;
  ac_in[10] = 14.4f;
  ac_in[11] = 14;

4、参数初始化

tflite::ta_ac::Settings s;
tflite::ta_ac::ac_settings ac;
s = tflite::ta_ac::getopt(argc, argv);

5、调用推理接口

  ac = tflite::ta_ac::RunInference(&s,ac_in,data_width,data_height,data_channels);

6、输出推理结果

  printf("arm_caller ac.temp:%d ", int(ac.temp));
  printf("arm_caller ac.direct:%d ", ac.direct);
  printf("arm_caller ac.power:%d ", ac.power);

 

 

 上一篇:

  使用Tensorflow搭建回归预测模型之七:模型压缩

 

 

以上是关于使用Tensorflow搭建回归预测模型之八:模型与外部接口对接的主要内容,如果未能解决你的问题,请参考以下文章

吴裕雄 python 神经网络——TensorFlow实现回归模型训练预测MNIST手写数据集

DL之GRU(Tensorflow框架):基于茅台股票数据集利用GRU算法实现回归预测(保存模型.ckpt.index.ckpt.data文件)

TensorFlow训练Logistic回归

tensorflow学习笔记 | 02 - 线性回归问题Numpy实战

5.2 多元线性回归完成广告投放销售额预测——python实战

多输出回归模型始终为 Tensorflow 中的批次返回相同的值