一个简单的tensorRT mnist推理案例,模型采用代码构建
Posted python算法工程师
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了一个简单的tensorRT mnist推理案例,模型采用代码构建相关的知识,希望对你有一定的参考价值。
TensorRT是NVIDIA的一个深度神经网络推理引擎,可以对深度学习模型进行优化和部署。本程序中,使用了TensorRT来加载一个已经训练好的模型并进行推理。
TRTLogger是一个日志记录类,用于记录TensorRT的运行日志。
Matrix是一个矩阵结构体,用于存储模型权重和输入输出数据。Model是一个模型结构体,用于存储加载的模型。
print_image函数用于将图像的像素值打印出来,方便调试和查看。load_file函数用于从文件中加载数据,包括模型权重和输入图像数据。load_model函数用于加载模型权重,其中模型权重的文件名按照"[index].weight"的格式命名,index从0开始递增。模型权重的形状是预先定义好的,存储在weight_shapes数组中,其中weight_shapes[i][0]表示第i层权重的行数,weight_shapes[i][1]表示第i层权重的列数。
这些函数都是为了方便程序的编写和调试,可以根据具体的应用场景进行修改和扩展。
它包括将BMP格式的图像数据转换为适合输入神经网络的矩阵的函数,以及将神经网络的权重转换为适合与TensorRT一起使用的格式的函数。
do_trt_build_engine函数使用TensorRT API构建神经网络,然后将结果引擎序列化到文件中。
do_trt_inference函数从文件中加载序列化的引擎,然后使用引擎在一组输入图像上执行推理。对于每个输入图像,它将BMP数据转换为矩阵,将矩阵复制到GPU,使用引擎进行推理,然后将输出概率值复制回CPU以供显示。
它首先调用load_model函数加载训练好的模型,并打印出每个权重矩阵的大小。
接下来,它调用do_trt_build_engine函数将模型转换为TensorRT引擎,并将引擎保存到文件mnist.trtmodel中。
最后,它调用do_trt_inference函数对一组输入图像执行推理,并显示每个图像的预测结果和置信度。
在推理完成后,它打印出一条消息表示程序运行完成,并返回0表示程序正常退出。
// tensorRT include
#include <NvInfer.h>
#include <NvInferRuntime.h>
// cuda include
#include <cuda_runtime.h>
// system include
#include <stdio.h>
#include <string.h>
#include <math.h>
#include <vector>
#include <string>
#include <fstream>
#include <algorithm>
using namespace std;
#define SIMLOG(type, ...) \\
do \\
printf("[%s:%d]%s: ", __FILE__, __LINE__, type); \\
printf(__VA_ARGS__); \\
printf("\\n"); \\
while(0)
#define INFO(...) SIMLOG("info", __VA_ARGS__)
inline const char* severity_string(nvinfer1::ILogger::Severity t)
switch(t)
case nvinfer1::ILogger::Severity::kINTERNAL_ERROR: return "internal_error";
case nvinfer1::ILogger::Severity::kERROR: return "error";
case nvinfer1::ILogger::Severity::kWARNING: return "warning";
case nvinfer1::ILogger::Severity::kINFO: return "info";
case nvinfer1::ILogger::Severity::kVERBOSE: return "verbose";
default: return "unknow";
class TRTLogger : public nvinfer1::ILogger
public:
virtual void log(Severity severity, nvinfer1::AsciiChar const* msg) noexcept override
if(severity <= Severity::kINFO)
SIMLOG(severity_string(severity), "%s", msg);
;
struct Matrix
vector<float> data;
int rows = 0, cols = 0;
void resize(int rows, int cols)
this->rows = rows;
this->cols = cols;
this->data.resize(rows * cols * sizeof(float));
bool empty() constreturn data.empty();
int size() const return rows * cols;
float* ptr() constreturn (float*)this->data.data();
;
struct Model
vector<Matrix> weights;
;
void print_image(const vector<unsigned char>& a, int rows, int cols, const char* format = "%3d")
INFO("Matrix[%p], %d x %d", &a, rows, cols);
char fmt[20];
sprintf(fmt, "%s,", format);
for(int i = 0; i < rows; ++i)
printf("row[%02d]: ", i);
for(int j = 0; j < cols; ++j)
int index = (rows - i - 1) * cols + j;
printf(fmt, a.data()[index * 3 + 0]);
printf("\\n");
vector<unsigned char> load_file(const string& file)
ifstream in(file, ios::in | ios::binary);
if (!in.is_open())
return ;
in.seekg(0, ios::end);
size_t length = in.tellg();
std::vector<uint8_t> data;
if (length > 0)
in.seekg(0, ios::beg);
data.resize(length);
in.read((char*)&data[0], length);
in.close();
return data;
bool load_model(Model& model)
model.weights.resize(4);
const int weight_shapes[][2] =
1024, 784,
1024, 1,
10, 1024,
10, 1
;
for(int i = 0; i < model.weights.size(); ++i)
char weight_name[100];
sprintf(weight_name, "%d.weight", i);
auto data = load_file(weight_name);
if(data.empty())
INFO("Load %s failed.", weight_name);
return false;
auto& w = model.weights[i];
int rows = weight_shapes[i][0];
int cols = weight_shapes[i][1];
if(data.size() != rows * cols * sizeof(float))
INFO("Invalid weight file: %s", weight_name);
return false;
w.resize(rows, cols);
memcpy(w.ptr(), data.data(), data.size());
return true;
Matrix bmp_data_to_normalize_matrix(const vector<unsigned char>& data)
Matrix output;
const int std_w = 28;
const int std_h = 28;
if(data.size() != std_w * std_h * 3)
INFO("Invalid bmp file, must be %d x %d @ rgb 3 channels image", std_w, std_h);
return output;
output.resize(1, std_w * std_h);
const unsigned char* begin_ptr = data.data();
float* output_ptr = output.ptr();
for(int i = 0; i < std_h; ++i)
const unsigned char* image_row_ptr = begin_ptr + (std_h - i - 1) * std_w * 3;
float* output_row_ptr = output_ptr + i * std_w;
for(int j = 0; j < std_w; ++j)
// normalize
output_row_ptr[j] = (image_row_ptr[j * 3 + 0] / 255.0f - 0.1307f) / 0.3081f;;
return output;
nvinfer1::Weights model_weights_to_trt_weights(const Matrix& model_weights)
nvinfer1::Weights output;
output.type = nvinfer1::DataType::kFLOAT;
output.values = model_weights.ptr();
output.count = model_weights.size();
return output;
TRTLogger logger;
void do_trt_build_engine(const Model& model, const string& save_file)
/*
Network is:
image
|
linear (fully connected) input = 784, output = 1024, bias = True
|
relu
|
linear (fully connected) input = 1024, output = 10, bias = True
|
sigmoid
|
prob
*/
nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger);
nvinfer1::IBuilderConfig* config = builder->createBuilderConfig();
nvinfer1::INetworkDefinition* network = builder->createNetworkV2(1);
nvinfer1::ITensor* input = network->addInput("image", nvinfer1::DataType::kFLOAT, nvinfer1::Dims4(1, 784, 1, 1));
nvinfer1::Weights layer1_weight = model_weights_to_trt_weights(model.weights[0]);
nvinfer1::Weights layer1_bias = model_weights_to_trt_weights(model.weights[1]);
auto layer1 = network->addFullyConnected(*input, model.weights[0].rows, layer1_weight, layer1_bias);
auto relu1 = network->addActivation(*layer1->getOutput(0), nvinfer1::ActivationType::kRELU);
nvinfer1::Weights layer2_weight = model_weights_to_trt_weights(model.weights[2]);
nvinfer1::Weights layer2_bias = model_weights_to_trt_weights(model.weights[3]);
auto layer2 = network->addFullyConnected(*relu1->getOutput(0), model.weights[2].rows, layer2_weight, layer2_bias);
auto prob = network->addActivation(*layer2->getOutput(0), nvinfer1::ActivationType::kSIGMOID);
network->markOutput(*prob->getOutput(0));
config->setMaxWorkspaceSize(1 << 28);
builder->setMaxBatchSize(1);
nvinfer1::ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);
if(engine == nullptr)
INFO("Build engine failed.");
return;
nvinfer1::IHostMemory* model_data = engine->serialize();
ofstream outf(save_file, ios::binary | ios::out);
if(outf.is_open())
outf.write((const char*)model_data->data(), model_data->size());
outf.close();
else
INFO("Open %s failed", save_file.c_str());
model_data->destroy();
engine->destroy();
network->destroy();
config->destroy();
builder->destroy();
void do_trt_inference(const string& model_file)
auto engine_data = load_file(model_file);
if(engine_data.empty())
INFO("engine_data is empty");
return;
nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger);
nvinfer1::ICudaEngine* engine = runtime->deserializeCudaEngine(engine_data.data(), engine_data.size());
if(engine == nullptr)
INFO("Deserialize cuda engine failed.");
return;
nvinfer1::IExecutionContext* execution_context = engine->createExecutionContext();
cudaStream_t stream = nullptr;
cudaStreamCreate(&stream);
const char* image_list[] = "5.bmp", "6.bmp";
int num_image = sizeof(image_list) / sizeof(image_list[0]);
const int num_classes = 10;
for(int i = 0; i < num_image; ++i)
const int bmp_file_head_size = 54;
auto file_name = image_list[i];
auto image_data = load_file(file_name);
if(image_data.empty() || image_data.size() != bmp_file_head_size + 28*28*3)
INFO("Load image failed: %s", file_name);
continue;
image_data.erase(image_data.begin(), image_data.begin() + bmp_file_head_size);
auto image = bmp_data_to_normalize_matrix(image_data);
float* image_device_ptr = nullptr;
cudaMalloc(&image_device_ptr, image.size() * sizeof(float));
cudaMemcpyAsync(image_device_ptr, image.ptr(), image.size() * sizeof(float), cudaMemcpyHostToDevice, stream);
float* output_device_ptr = nullptr;
cudaMalloc(&output_device_ptr, num_classes * sizeof(float));
float* bindings[] =<模型部署|Arcface+TensorRT的部署,模型推理部署
Arcface基于resnet50的部署项目,可以直接生成engine文件,同时对于PRelu的Plugin有详细的中文注释,希望可以教会你写plugin的套路,后续也会加入其他plugin的书写注释。
步骤清晰简单,快速上手。
完整源码:
源代码截图:
源代码下载
关注公众号,回复1002,获取文件下载动态密码。
链接: https://pan.baidu.com/s/1DHGF-grBlpuRh82TjRV3uw 提取码: kj61
以上是关于一个简单的tensorRT mnist推理案例,模型采用代码构建的主要内容,如果未能解决你的问题,请参考以下文章
TensorRT&Sample&Python[fc_plugin_caffe_mnist]
模型部署|Arcface+TensorRT的部署,模型推理部署
模型部署|Arcface+TensorRT的部署,模型推理部署