模型推理模板
Posted 洪流之源
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了模型推理模板相关的知识,希望对你有一定的参考价值。
infer.h
#pragma once
#include <string>
#include <future>
#include <memory>
// 封装接口类
class Infer
public:
virtual std::shared_future<std::string> commit(const std::string& input) = 0;
;
std::shared_ptr<Infer> create_infer(const std::string& file);
infer.cpp
#include "infer.h"
#include <thread>
#include <vector>
#include <condition_variable>
#include <mutex>
#include <string>
#include <future>
#include <queue>
#include <functional>
// 封装接口类
struct Job
std::shared_ptr<std::promise<std::string>> pro;
std::string input;
;
class InferImpl : public Infer
public:
virtual ~InferImpl()
stop();
void stop()
if (running_)
running_ = false;
cv_.notify_one();
if (worker_thread_.joinable())
worker_thread_.join();
bool startup(const std::string& file)
file_ = file;
running_ = true; // 启动后,运行状态设置为true
// 线程传递promise的目的,是获得线程是否初始化成功的状态
// 而在线程内做初始化,好处是,初始化跟释放在同一个线程内
// 代码可读性好,资源管理方便
std::promise<bool> pro;
worker_thread_ = std::thread(&InferImpl::worker, this, std::ref(pro));
/*
注意:这里thread 一构建好后,worker函数就开始执行了
第一个参数是该线程要执行的worker函数,第二个参数是this指的是class InferImpl,第三个参数指的是传引用,因为我们在worker函数里要修改pro。
*/
return pro.get_future().get();
virtual std::shared_future<std::string> commit(const std::string& input) override
Job job;
job.input = input;
job.pro.reset(new std::promise<std::string>());
std::shared_future<std::string> fut = job.pro->get_future();
std::lock_guard<std::mutex> l(lock_);
jobs_.emplace(std::move(job));
cv_.notify_one();
return fut;
void worker(std::promise<bool>& pro)
// load model
if (file_ != "trtfile")
// failed
pro.set_value(false);
printf("Load model failed: %s\\n", file_.c_str());
return;
// load success
pro.set_value(true); // 这里的promise用来负责确认infer初始化成功了
std::vector<Job> fetched_jobs;
while (running_)
std::unique_lock<std::mutex> l(lock_);
// 一直等着,cv_.wait(lock, predicate)
// 如果 running不在运行状态 或者说 jobs_有东西 而且接收到了notify one的信号
cv_.wait(l, [&]() return not running_ || not jobs_.empty(); );
// 如果停止运行,则直接结束循环
if (not running_) break;
int batch_size = 5;
for (int i = 0; i < batch_size && not jobs_.empty(); ++i)
// jobs_不为空的时候
fetched_jobs.emplace_back(std::move(jobs_.front())); // 就往里面fetched_jobs里塞东西
jobs_.pop(); // fetched_jobs塞进来一个,jobs_那边就要pop掉一个。(因为move)
// 一次加载一批,并进行批处理
// forward(fetched_jobs)
for (auto& job : fetched_jobs)
job.pro->set_value(job.input + "---processed");
fetched_jobs.clear();
printf("Infer worker done.\\n");
private:
std::atomic<bool> running_ false ;
std::string file_;
std::thread worker_thread_;
std::queue<Job> jobs_;
std::mutex lock_;
std::condition_variable cv_;
;
std::shared_ptr<Infer> create_infer(const std::string& file)
// 实例化一个推理器的实现类(inferImpl),以指针形式返回
std::shared_ptr<InferImpl> instance = std::make_shared<InferImpl>();
// 推理器实现类实例(instance)启动。这里的file是engine file
if (not instance->startup(file))
instance.reset(); // 如果启动不成功就reset
return instance;
main.cpp
#include "infer.h"
int main()
auto infer = create_infer("trtfile"); // 创建及初始化
if (infer == nullptr)
printf("Infer is nullptr.\\n");
return -1;
// 将任务提交给推理器(推理器执行commit),同时推理器(infer)也等着获取(get)结果
printf("commit msg = %s\\n", infer->commit("msg").get().c_str());
return 0;
以上是关于模型推理模板的主要内容,如果未能解决你的问题,请参考以下文章