模型推理模板

Posted 洪流之源

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了模型推理模板相关的知识,希望对你有一定的参考价值。

infer.h

#pragma once

#include <string>
#include <future>
#include <memory>

// 封装接口类
class Infer 

public:
    virtual std::shared_future<std::string> commit(const std::string& input) = 0;
;

std::shared_ptr<Infer> create_infer(const std::string& file);

infer.cpp

#include "infer.h"
#include <thread>
#include <vector>
#include <condition_variable>
#include <mutex>
#include <string>
#include <future>
#include <queue>
#include <functional>

// 封装接口类

struct Job

    std::shared_ptr<std::promise<std::string>> pro;
    std::string input;
;

class InferImpl : public Infer 

public:
    virtual ~InferImpl() 
    
        stop();
    

    void stop() 
    
        if (running_) 
        
            running_ = false;
            cv_.notify_one();
        

        if (worker_thread_.joinable())
            worker_thread_.join();
    

    bool startup(const std::string& file) 
    

        file_ = file;
        running_ = true; // 启动后,运行状态设置为true

        // 线程传递promise的目的,是获得线程是否初始化成功的状态
        // 而在线程内做初始化,好处是,初始化跟释放在同一个线程内
        // 代码可读性好,资源管理方便
        std::promise<bool> pro;
        worker_thread_ = std::thread(&InferImpl::worker, this, std::ref(pro));
        /*
            注意:这里thread 一构建好后,worker函数就开始执行了
            第一个参数是该线程要执行的worker函数,第二个参数是this指的是class InferImpl,第三个参数指的是传引用,因为我们在worker函数里要修改pro。
         */
        return pro.get_future().get();
    

    virtual std::shared_future<std::string> commit(const std::string& input) override 
    
        Job job;
        job.input = input;
        job.pro.reset(new std::promise<std::string>());
        std::shared_future<std::string> fut = job.pro->get_future();
        
            std::lock_guard<std::mutex> l(lock_);
            jobs_.emplace(std::move(job));
        
        cv_.notify_one();
        return fut;
    

    void worker(std::promise<bool>& pro) 
    
        // load model
        if (file_ != "trtfile") 
        
            // failed
            pro.set_value(false);
            printf("Load model failed: %s\\n", file_.c_str());
            return;
        

        // load success
        pro.set_value(true); // 这里的promise用来负责确认infer初始化成功了

        std::vector<Job> fetched_jobs;
        while (running_) 
        
            
                std::unique_lock<std::mutex> l(lock_);
                // 一直等着,cv_.wait(lock, predicate) 
                // 如果 running不在运行状态 或者说 jobs_有东西 而且接收到了notify one的信号
                cv_.wait(l, [&]() return not running_ || not jobs_.empty(); ); 

                // 如果停止运行,则直接结束循环
                if (not running_) break;

                int batch_size = 5;
                for (int i = 0; i < batch_size && not jobs_.empty(); ++i) 
                   // jobs_不为空的时候
                    fetched_jobs.emplace_back(std::move(jobs_.front())); // 就往里面fetched_jobs里塞东西
                    jobs_.pop();                                         // fetched_jobs塞进来一个,jobs_那边就要pop掉一个。(因为move)
                
            

            // 一次加载一批,并进行批处理
            // forward(fetched_jobs)
            for (auto& job : fetched_jobs) 
            
                job.pro->set_value(job.input + "---processed");
            
            fetched_jobs.clear();
        
        printf("Infer worker done.\\n");
    

private:
    std::atomic<bool> running_ false ;
    std::string file_;
    std::thread worker_thread_;
    std::queue<Job> jobs_;
    std::mutex lock_;
    std::condition_variable cv_;
;

std::shared_ptr<Infer> create_infer(const std::string& file)

    // 实例化一个推理器的实现类(inferImpl),以指针形式返回 
    std::shared_ptr<InferImpl> instance = std::make_shared<InferImpl>();

    // 推理器实现类实例(instance)启动。这里的file是engine file
    if (not instance->startup(file)) 
                        
        instance.reset(); // 如果启动不成功就reset
    

    return instance;

main.cpp

#include "infer.h"

int main()

    auto infer = create_infer("trtfile"); // 创建及初始化
    if (infer == nullptr) 
    
        printf("Infer is nullptr.\\n");
        return -1;
    

    // 将任务提交给推理器(推理器执行commit),同时推理器(infer)也等着获取(get)结果
    printf("commit msg = %s\\n", infer->commit("msg").get().c_str()); 

    return 0;

以上是关于模型推理模板的主要内容,如果未能解决你的问题,请参考以下文章

你或许也想拥有专属于自己的AI模型文件格式(推理部署篇)-

模型推理openvino 推理实践

模型推理谈谈为什么量化能加速推理

模型推理openvino 性能测试

推理实践丨如何使用MindStudio进行Pytorch模型离线推理

还在为模型加速推理发愁吗?不如看看这篇吧。手把手教你把pytorch模型转化为TensorRT,加速推理