From a4d73c7b1b19f1e8906d9530b8ca5f696e89a93d Mon Sep 17 00:00:00 2001 From: jack <136876878@qq.com> Date: Mon, 6 Jul 2020 18:00:24 +0800 Subject: [PATCH] add ir optimization args --- deploy/cpp/demo/classifier.cpp | 4 +++- deploy/cpp/demo/detector.cpp | 4 +++- deploy/cpp/demo/segmenter.cpp | 4 +++- deploy/cpp/include/paddlex/paddlex.h | 9 ++++++--- deploy/cpp/src/paddlex.cpp | 5 ++++- 5 files changed, 19 insertions(+), 7 deletions(-) diff --git a/deploy/cpp/demo/classifier.cpp b/deploy/cpp/demo/classifier.cpp index 8b78d7e..6fd354d 100644 --- a/deploy/cpp/demo/classifier.cpp +++ b/deploy/cpp/demo/classifier.cpp @@ -37,6 +37,7 @@ DEFINE_int32(batch_size, 1, "Batch size of infering"); DEFINE_int32(thread_num, omp_get_num_procs(), "Number of preprocessing threads"); +DEFINE_bool(use_ir_optim, true, "use ir optimization"); int main(int argc, char** argv) { // Parsing command-line @@ -57,7 +58,8 @@ int main(int argc, char** argv) { FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, - FLAGS_key); + FLAGS_key, + FLAGS_use_ir_optim); // 进行预测 double total_running_time_s = 0.0; diff --git a/deploy/cpp/demo/detector.cpp b/deploy/cpp/demo/detector.cpp index 5b4e3a2..54f93d2 100644 --- a/deploy/cpp/demo/detector.cpp +++ b/deploy/cpp/demo/detector.cpp @@ -43,6 +43,7 @@ DEFINE_double(threshold, DEFINE_int32(thread_num, omp_get_num_procs(), "Number of preprocessing threads"); +DEFINE_bool(use_ir_optim, true, "use ir optimization"); int main(int argc, char** argv) { // 解析命令行参数 @@ -62,7 +63,8 @@ int main(int argc, char** argv) { FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, - FLAGS_key); + FLAGS_key, + FLAGS_use_ir_optim); double total_running_time_s = 0.0; double total_imread_time_s = 0.0; diff --git a/deploy/cpp/demo/segmenter.cpp b/deploy/cpp/demo/segmenter.cpp index 7dd48e5..1ddbb75 100644 --- a/deploy/cpp/demo/segmenter.cpp +++ b/deploy/cpp/demo/segmenter.cpp @@ -39,6 +39,7 @@ DEFINE_int32(batch_size, 1, "Batch size of infering"); DEFINE_int32(thread_num, omp_get_num_procs(), "Number of preprocessing threads"); +DEFINE_bool(use_ir_optim, true, "use ir optimization"); int main(int argc, char** argv) { // 解析命令行参数 @@ -59,7 +60,8 @@ int main(int argc, char** argv) { FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id, - FLAGS_key); + FLAGS_key, + FLAGS_use_ir_optim); double total_running_time_s = 0.0; double total_imread_time_s = 0.0; diff --git a/deploy/cpp/include/paddlex/paddlex.h b/deploy/cpp/include/paddlex/paddlex.h index af4d889..e0d0569 100644 --- a/deploy/cpp/include/paddlex/paddlex.h +++ b/deploy/cpp/include/paddlex/paddlex.h @@ -72,20 +72,23 @@ class Model { * @param use_trt: use Tensor RT or not when infering * @param gpu_id: the id of gpu when infering with using gpu * @param key: the key of encryption when using encrypted model + * @param use_ir_optim: use ir optimization when infering * */ void Init(const std::string& model_dir, bool use_gpu = false, bool use_trt = false, int gpu_id = 0, - std::string key = "") { - create_predictor(model_dir, use_gpu, use_trt, gpu_id, key); + std::string key = "", + bool use_ir_optim = true) { + create_predictor(model_dir, use_gpu, use_trt, gpu_id, key, use_ir_optim); } void create_predictor(const std::string& model_dir, bool use_gpu = false, bool use_trt = false, int gpu_id = 0, - std::string key = ""); + std::string key = "", + bool use_ir_optim = true); /* * @brief diff --git a/deploy/cpp/src/paddlex.cpp b/deploy/cpp/src/paddlex.cpp index bedd83b..cf1dfc9 100644 --- a/deploy/cpp/src/paddlex.cpp +++ b/deploy/cpp/src/paddlex.cpp @@ -22,7 +22,8 @@ void Model::create_predictor(const std::string& model_dir, bool use_gpu, bool use_trt, int gpu_id, - std::string key) { + std::string key, + bool use_ir_optim) { paddle::AnalysisConfig config; std::string model_file = model_dir + OS_PATH_SEP + "__model__"; std::string params_file = model_dir + OS_PATH_SEP + "__params__"; @@ -63,6 +64,8 @@ void Model::create_predictor(const std::string& model_dir, } config.SwitchUseFeedFetchOps(false); config.SwitchSpecifyInputNames(true); + // 开启图优化 + config.SwitchIrOptim(use_ir_optim); // 开启内存优化 config.EnableMemoryOptim(); if (use_trt) { -- GitLab