diff --git a/deploy/cpp_infer/include/config.h b/deploy/cpp_infer/include/config.h index 27539ea7934dc192e86bca3ea6bfd7999ee229a3..3faeede1611e1048ff77f300a75b08b3a86c14e4 100644 --- a/deploy/cpp_infer/include/config.h +++ b/deploy/cpp_infer/include/config.h @@ -25,9 +25,9 @@ namespace PaddleOCR { -class Config { +class OCRConfig { public: - explicit Config(const std::string &config_file) { + explicit OCRConfig(const std::string &config_file) { config_map_ = LoadConfig(config_file); this->use_gpu = bool(stoi(config_map_["use_gpu"])); @@ -41,8 +41,6 @@ public: this->use_mkldnn = bool(stoi(config_map_["use_mkldnn"])); - this->use_zero_copy_run = bool(stoi(config_map_["use_zero_copy_run"])); - this->max_side_len = stoi(config_map_["max_side_len"]); this->det_db_thresh = stod(config_map_["det_db_thresh"]); @@ -76,8 +74,6 @@ public: bool use_mkldnn = false; - bool use_zero_copy_run = false; - int max_side_len = 960; double det_db_thresh = 0.3; diff --git a/deploy/cpp_infer/include/ocr_cls.h b/deploy/cpp_infer/include/ocr_cls.h index 38a37cff3c035eafe3617d83b2cc15ca47f30186..87772cc109b18beb6a31940311389e2f0596b031 100644 --- a/deploy/cpp_infer/include/ocr_cls.h +++ b/deploy/cpp_infer/include/ocr_cls.h @@ -30,6 +30,8 @@ #include #include +using namespace paddle_infer; + namespace PaddleOCR { class Classifier { @@ -37,14 +39,12 @@ public: explicit Classifier(const std::string &model_dir, const bool &use_gpu, const int &gpu_id, const int &gpu_mem, const int &cpu_math_library_num_threads, - const bool &use_mkldnn, const bool &use_zero_copy_run, - const double &cls_thresh) { + const bool &use_mkldnn, const double &cls_thresh) { this->use_gpu_ = use_gpu; this->gpu_id_ = gpu_id; this->gpu_mem_ = gpu_mem; this->cpu_math_library_num_threads_ = cpu_math_library_num_threads; this->use_mkldnn_ = use_mkldnn; - this->use_zero_copy_run_ = use_zero_copy_run; this->cls_thresh = cls_thresh; @@ -57,14 +57,13 @@ public: cv::Mat Run(cv::Mat &img); private: - std::shared_ptr predictor_; + std::shared_ptr predictor_; bool use_gpu_ = false; int gpu_id_ = 0; int gpu_mem_ = 4000; int cpu_math_library_num_threads_ = 4; bool use_mkldnn_ = false; - bool use_zero_copy_run_ = false; double cls_thresh = 0.5; std::vector mean_ = {0.5f, 0.5f, 0.5f}; diff --git a/deploy/cpp_infer/include/ocr_det.h b/deploy/cpp_infer/include/ocr_det.h index 0308d07f3bac67a275452500184e0959b16e8003..d50fd70af5ec04105e993e358e459f1940d36c7f 100644 --- a/deploy/cpp_infer/include/ocr_det.h +++ b/deploy/cpp_infer/include/ocr_det.h @@ -32,6 +32,8 @@ #include #include +using namespace paddle_infer; + namespace PaddleOCR { class DBDetector { @@ -39,8 +41,8 @@ public: explicit DBDetector(const std::string &model_dir, const bool &use_gpu, const int &gpu_id, const int &gpu_mem, const int &cpu_math_library_num_threads, - const bool &use_mkldnn, const bool &use_zero_copy_run, - const int &max_side_len, const double &det_db_thresh, + const bool &use_mkldnn, const int &max_side_len, + const double &det_db_thresh, const double &det_db_box_thresh, const double &det_db_unclip_ratio, const bool &visualize) { @@ -49,7 +51,6 @@ public: this->gpu_mem_ = gpu_mem; this->cpu_math_library_num_threads_ = cpu_math_library_num_threads; this->use_mkldnn_ = use_mkldnn; - this->use_zero_copy_run_ = use_zero_copy_run; this->max_side_len_ = max_side_len; @@ -69,14 +70,13 @@ public: void Run(cv::Mat &img, std::vector>> &boxes); private: - std::shared_ptr predictor_; + std::shared_ptr predictor_; bool use_gpu_ = false; int gpu_id_ = 0; int gpu_mem_ = 4000; int cpu_math_library_num_threads_ = 4; bool use_mkldnn_ = false; - bool use_zero_copy_run_ = false; int max_side_len_ = 960; diff --git a/deploy/cpp_infer/include/ocr_rec.h b/deploy/cpp_infer/include/ocr_rec.h index 89bcd82cb99a90ddd8e152a034769312d9791e7e..14b77b084a30ade71efe626430cb854d0bfbc1ce 100644 --- a/deploy/cpp_infer/include/ocr_rec.h +++ b/deploy/cpp_infer/include/ocr_rec.h @@ -32,6 +32,8 @@ #include #include +using namespace paddle_infer; + namespace PaddleOCR { class CRNNRecognizer { @@ -39,14 +41,12 @@ public: explicit CRNNRecognizer(const std::string &model_dir, const bool &use_gpu, const int &gpu_id, const int &gpu_mem, const int &cpu_math_library_num_threads, - const bool &use_mkldnn, const bool &use_zero_copy_run, - const string &label_path) { + const bool &use_mkldnn, const string &label_path) { this->use_gpu_ = use_gpu; this->gpu_id_ = gpu_id; this->gpu_mem_ = gpu_mem; this->cpu_math_library_num_threads_ = cpu_math_library_num_threads; this->use_mkldnn_ = use_mkldnn; - this->use_zero_copy_run_ = use_zero_copy_run; this->label_list_ = Utility::ReadDict(label_path); this->label_list_.insert(this->label_list_.begin(), @@ -63,14 +63,13 @@ public: Classifier *cls); private: - std::shared_ptr predictor_; + std::shared_ptr predictor_; bool use_gpu_ = false; int gpu_id_ = 0; int gpu_mem_ = 4000; int cpu_math_library_num_threads_ = 4; bool use_mkldnn_ = false; - bool use_zero_copy_run_ = false; std::vector label_list_; diff --git a/deploy/cpp_infer/src/config.cpp b/deploy/cpp_infer/src/config.cpp index 52dfa209b049c6d47285bcba40e41de846de610f..303c3c1259515ee8c67fa865bf485ae3338505d6 100644 --- a/deploy/cpp_infer/src/config.cpp +++ b/deploy/cpp_infer/src/config.cpp @@ -16,8 +16,8 @@ namespace PaddleOCR { -std::vector Config::split(const std::string &str, - const std::string &delim) { +std::vector OCRConfig::split(const std::string &str, + const std::string &delim) { std::vector res; if ("" == str) return res; @@ -38,7 +38,7 @@ std::vector Config::split(const std::string &str, } std::map -Config::LoadConfig(const std::string &config_path) { +OCRConfig::LoadConfig(const std::string &config_path) { auto config = Utility::ReadDict(config_path); std::map dict; @@ -53,7 +53,7 @@ Config::LoadConfig(const std::string &config_path) { return dict; } -void Config::PrintConfigInfo() { +void OCRConfig::PrintConfigInfo() { std::cout << "=======Paddle OCR inference config======" << std::endl; for (auto iter = config_map_.begin(); iter != config_map_.end(); iter++) { std::cout << iter->first << " : " << iter->second << std::endl; diff --git a/deploy/cpp_infer/src/main.cpp b/deploy/cpp_infer/src/main.cpp index 63da62c7d4e0e9592d62ac61ae1888dc35a71ec0..21890d45ce8c6b13e280c87bdfad8ca8e48f8523 100644 --- a/deploy/cpp_infer/src/main.cpp +++ b/deploy/cpp_infer/src/main.cpp @@ -42,7 +42,7 @@ int main(int argc, char **argv) { exit(1); } - Config config(argv[1]); + OCRConfig config(argv[1]); config.PrintConfigInfo(); @@ -50,37 +50,22 @@ int main(int argc, char **argv) { cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR); - DBDetector det( - config.det_model_dir, config.use_gpu, config.gpu_id, config.gpu_mem, - config.cpu_math_library_num_threads, config.use_mkldnn, - config.use_zero_copy_run, config.max_side_len, config.det_db_thresh, - config.det_db_box_thresh, config.det_db_unclip_ratio, config.visualize); + DBDetector det(config.det_model_dir, config.use_gpu, config.gpu_id, + config.gpu_mem, config.cpu_math_library_num_threads, + config.use_mkldnn, config.max_side_len, config.det_db_thresh, + config.det_db_box_thresh, config.det_db_unclip_ratio, + config.visualize); Classifier *cls = nullptr; if (config.use_angle_cls == true) { cls = new Classifier(config.cls_model_dir, config.use_gpu, config.gpu_id, config.gpu_mem, config.cpu_math_library_num_threads, - config.use_mkldnn, config.use_zero_copy_run, - config.cls_thresh); + config.use_mkldnn, config.cls_thresh); } CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id, config.gpu_mem, config.cpu_math_library_num_threads, - config.use_mkldnn, config.use_zero_copy_run, - config.char_list_file); - -#ifdef USE_MKL -#pragma omp parallel - for (auto i = 0; i < 10; i++) { - LOG_IF(WARNING, - config.cpu_math_library_num_threads != omp_get_num_threads()) - << "WARNING! MKL is running on " << omp_get_num_threads() - << " threads while cpu_math_library_num_threads is set to " - << config.cpu_math_library_num_threads - << ". Possible reason could be 1. You have set omp_set_num_threads() " - "somewhere; 2. MKL is not linked properly"; - } -#endif + config.use_mkldnn, config.char_list_file); auto start = std::chrono::system_clock::now(); std::vector>> boxes; diff --git a/deploy/cpp_infer/src/ocr_cls.cpp b/deploy/cpp_infer/src/ocr_cls.cpp index fed2023f9f111294a07a9c841f4843404bbd9af2..9757b482d4f407cefd8db5bd611000062f754645 100644 --- a/deploy/cpp_infer/src/ocr_cls.cpp +++ b/deploy/cpp_infer/src/ocr_cls.cpp @@ -35,26 +35,16 @@ cv::Mat Classifier::Run(cv::Mat &img) { this->permute_op_.Run(&resize_img, input.data()); // Inference. - if (this->use_zero_copy_run_) { - auto input_names = this->predictor_->GetInputNames(); - auto input_t = this->predictor_->GetInputTensor(input_names[0]); - input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); - input_t->copy_from_cpu(input.data()); - this->predictor_->ZeroCopyRun(); - } else { - paddle::PaddleTensor input_t; - input_t.shape = {1, 3, resize_img.rows, resize_img.cols}; - input_t.data = - paddle::PaddleBuf(input.data(), input.size() * sizeof(float)); - input_t.dtype = PaddleDType::FLOAT32; - std::vector outputs; - this->predictor_->Run({input_t}, &outputs, 1); - } + auto input_names = this->predictor_->GetInputNames(); + auto input_t = this->predictor_->GetInputHandle(input_names[0]); + input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); + input_t->CopyFromCpu(input.data()); + this->predictor_->Run(); std::vector softmax_out; std::vector label_out; auto output_names = this->predictor_->GetOutputNames(); - auto softmax_out_t = this->predictor_->GetOutputTensor(output_names[0]); + auto softmax_out_t = this->predictor_->GetOutputHandle(output_names[0]); auto softmax_shape_out = softmax_out_t->shape(); int softmax_out_num = @@ -63,7 +53,7 @@ cv::Mat Classifier::Run(cv::Mat &img) { softmax_out.resize(softmax_out_num); - softmax_out_t->copy_to_cpu(softmax_out.data()); + softmax_out_t->CopyToCpu(softmax_out.data()); float score = 0; int label = 0; @@ -95,7 +85,7 @@ void Classifier::LoadModel(const std::string &model_dir) { } // false for zero copy tensor - config.SwitchUseFeedFetchOps(!this->use_zero_copy_run_); + config.SwitchUseFeedFetchOps(false); // true for multiple input config.SwitchSpecifyInputNames(true); @@ -104,6 +94,6 @@ void Classifier::LoadModel(const std::string &model_dir) { config.EnableMemoryOptim(); config.DisableGlogInfo(); - this->predictor_ = CreatePaddlePredictor(config); + this->predictor_ = CreatePredictor(config); } } // namespace PaddleOCR diff --git a/deploy/cpp_infer/src/ocr_det.cpp b/deploy/cpp_infer/src/ocr_det.cpp index e253f9cc89810f4d1adfca5be5186220a873d1a2..c6c93991743b28609e880a9534d3228daf2c5bef 100644 --- a/deploy/cpp_infer/src/ocr_det.cpp +++ b/deploy/cpp_infer/src/ocr_det.cpp @@ -17,12 +17,17 @@ namespace PaddleOCR { void DBDetector::LoadModel(const std::string &model_dir) { - AnalysisConfig config; + // AnalysisConfig config; + paddle_infer::Config config; config.SetModel(model_dir + "/inference.pdmodel", model_dir + "/inference.pdiparams"); if (this->use_gpu_) { config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); + // config.EnableTensorRtEngine( + // 1 << 20, 1, 3, + // AnalysisConfig::Precision::kFloat32, + // false, false); } else { config.DisableGpu(); if (this->use_mkldnn_) { @@ -32,10 +37,8 @@ void DBDetector::LoadModel(const std::string &model_dir) { } config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_); } - - // false for zero copy tensor - // true for commom tensor - config.SwitchUseFeedFetchOps(!this->use_zero_copy_run_); + // use zero_copy_run as default + config.SwitchUseFeedFetchOps(false); // true for multiple input config.SwitchSpecifyInputNames(true); @@ -44,7 +47,7 @@ void DBDetector::LoadModel(const std::string &model_dir) { config.EnableMemoryOptim(); config.DisableGlogInfo(); - this->predictor_ = CreatePaddlePredictor(config); + this->predictor_ = CreatePredictor(config); } void DBDetector::Run(cv::Mat &img, @@ -64,31 +67,21 @@ void DBDetector::Run(cv::Mat &img, this->permute_op_.Run(&resize_img, input.data()); // Inference. - if (this->use_zero_copy_run_) { - auto input_names = this->predictor_->GetInputNames(); - auto input_t = this->predictor_->GetInputTensor(input_names[0]); - input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); - input_t->copy_from_cpu(input.data()); - this->predictor_->ZeroCopyRun(); - } else { - paddle::PaddleTensor input_t; - input_t.shape = {1, 3, resize_img.rows, resize_img.cols}; - input_t.data = - paddle::PaddleBuf(input.data(), input.size() * sizeof(float)); - input_t.dtype = PaddleDType::FLOAT32; - std::vector outputs; - this->predictor_->Run({input_t}, &outputs, 1); - } + auto input_names = this->predictor_->GetInputNames(); + auto input_t = this->predictor_->GetInputHandle(input_names[0]); + input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); + input_t->CopyFromCpu(input.data()); + this->predictor_->Run(); std::vector out_data; auto output_names = this->predictor_->GetOutputNames(); - auto output_t = this->predictor_->GetOutputTensor(output_names[0]); + auto output_t = this->predictor_->GetOutputHandle(output_names[0]); std::vector output_shape = output_t->shape(); int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); out_data.resize(out_num); - output_t->copy_to_cpu(out_data.data()); + output_t->CopyToCpu(out_data.data()); int n2 = output_shape[2]; int n3 = output_shape[3]; diff --git a/deploy/cpp_infer/src/ocr_rec.cpp b/deploy/cpp_infer/src/ocr_rec.cpp index d4deb5a17fc47427eb92cda02c270d268cfcafc7..e33695a74d72020f4397b84fcc07e9d9bf01486c 100644 --- a/deploy/cpp_infer/src/ocr_rec.cpp +++ b/deploy/cpp_infer/src/ocr_rec.cpp @@ -43,32 +43,22 @@ void CRNNRecognizer::Run(std::vector>> boxes, this->permute_op_.Run(&resize_img, input.data()); // Inference. - if (this->use_zero_copy_run_) { - auto input_names = this->predictor_->GetInputNames(); - auto input_t = this->predictor_->GetInputTensor(input_names[0]); - input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); - input_t->copy_from_cpu(input.data()); - this->predictor_->ZeroCopyRun(); - } else { - paddle::PaddleTensor input_t; - input_t.shape = {1, 3, resize_img.rows, resize_img.cols}; - input_t.data = - paddle::PaddleBuf(input.data(), input.size() * sizeof(float)); - input_t.dtype = PaddleDType::FLOAT32; - std::vector outputs; - this->predictor_->Run({input_t}, &outputs, 1); - } + auto input_names = this->predictor_->GetInputNames(); + auto input_t = this->predictor_->GetInputHandle(input_names[0]); + input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); + input_t->CopyFromCpu(input.data()); + this->predictor_->Run(); std::vector predict_batch; auto output_names = this->predictor_->GetOutputNames(); - auto output_t = this->predictor_->GetOutputTensor(output_names[0]); + auto output_t = this->predictor_->GetOutputHandle(output_names[0]); auto predict_shape = output_t->shape(); int out_num = std::accumulate(predict_shape.begin(), predict_shape.end(), 1, std::multiplies()); predict_batch.resize(out_num); - output_t->copy_to_cpu(predict_batch.data()); + output_t->CopyToCpu(predict_batch.data()); // ctc decode std::vector str_res; @@ -102,7 +92,8 @@ void CRNNRecognizer::Run(std::vector>> boxes, } void CRNNRecognizer::LoadModel(const std::string &model_dir) { - AnalysisConfig config; + // AnalysisConfig config; + paddle_infer::Config config; config.SetModel(model_dir + "/inference.pdmodel", model_dir + "/inference.pdiparams"); @@ -118,9 +109,7 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_); } - // false for zero copy tensor - // true for commom tensor - config.SwitchUseFeedFetchOps(!this->use_zero_copy_run_); + config.SwitchUseFeedFetchOps(false); // true for multiple input config.SwitchSpecifyInputNames(true); @@ -129,7 +118,7 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { config.EnableMemoryOptim(); config.DisableGlogInfo(); - this->predictor_ = CreatePaddlePredictor(config); + this->predictor_ = CreatePredictor(config); } cv::Mat CRNNRecognizer::GetRotateCropImage(const cv::Mat &srcimage, diff --git a/deploy/cpp_infer/tools/config.txt b/deploy/cpp_infer/tools/config.txt index f1ab0b1131ef5d55b098667612c019e0fc01c9dc..34f47ed82015b5c27a61a34d1de22f3251e0fd75 100644 --- a/deploy/cpp_infer/tools/config.txt +++ b/deploy/cpp_infer/tools/config.txt @@ -1,10 +1,9 @@ # model load config -use_gpu 0 +use_gpu 0 gpu_id 0 gpu_mem 4000 cpu_math_library_num_threads 10 use_mkldnn 0 -use_zero_copy_run 1 # det config max_side_len 960