From 483486af5891db8c4c6d4581ee18c7d272042f05 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Mon, 13 Jul 2020 08:59:21 +0000 Subject: [PATCH] add config --- deploy/cpp_infer/CMakeLists.txt | 6 +- deploy/cpp_infer/include/config.h | 91 +++++++++++++++++++++++ deploy/cpp_infer/include/ocr_det.h | 34 +++++++-- deploy/cpp_infer/include/ocr_rec.h | 40 +++++----- deploy/cpp_infer/include/postprocess_op.h | 44 +++-------- deploy/cpp_infer/include/preprocess_op.h | 3 +- deploy/cpp_infer/include/utility.h | 41 ++++++++++ deploy/cpp_infer/src/config.cpp | 64 ++++++++++++++++ deploy/cpp_infer/src/main.cpp | 23 ++++-- deploy/cpp_infer/src/ocr_det.cpp | 54 +++++++------- deploy/cpp_infer/src/ocr_rec.cpp | 60 ++++++--------- deploy/cpp_infer/src/postprocess_op.cpp | 27 ++++--- deploy/cpp_infer/src/preprocess_op.cpp | 2 +- deploy/cpp_infer/src/utility.cpp | 39 ++++++++++ deploy/cpp_infer/tools/config.txt | 17 +++++ deploy/cpp_infer/tools/run.sh | 2 +- 16 files changed, 393 insertions(+), 154 deletions(-) create mode 100644 deploy/cpp_infer/include/config.h create mode 100644 deploy/cpp_infer/include/utility.h create mode 100644 deploy/cpp_infer/src/config.cpp create mode 100644 deploy/cpp_infer/src/utility.cpp create mode 100644 deploy/cpp_infer/tools/config.txt diff --git a/deploy/cpp_infer/CMakeLists.txt b/deploy/cpp_infer/CMakeLists.txt index 8f06068c..1415e2cb 100644 --- a/deploy/cpp_infer/CMakeLists.txt +++ b/deploy/cpp_infer/CMakeLists.txt @@ -57,7 +57,8 @@ link_directories("${PADDLE_LIB}/third_party/install/xxhash/lib") link_directories("${PADDLE_LIB}/paddle/lib") -add_executable(${DEMO_NAME} src/main.cpp src/ocr_det.cpp src/ocr_rec.cpp src/preprocess_op.cpp src/clipper.cpp src/postprocess_op.cpp ) +AUX_SOURCE_DIRECTORY(./src SRCS) +add_executable(${DEMO_NAME} ${SRCS}) if(WITH_MKL) include_directories("${PADDLE_LIB}/third_party/install/mklml/include") @@ -81,9 +82,6 @@ else() ${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_SHARED_LIBRARY_SUFFIX}) endif() -# user ze -# set(EXTERNAL_LIB "-lrt -ldl -lpthread -lm -lopencv_world") -# gry set(EXTERNAL_LIB "-lrt -ldl -lpthread -lm") set(DEPS ${DEPS} diff --git a/deploy/cpp_infer/include/config.h b/deploy/cpp_infer/include/config.h new file mode 100644 index 00000000..37c5078a --- /dev/null +++ b/deploy/cpp_infer/include/config.h @@ -0,0 +1,91 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "include/utility.h" + +namespace PaddleOCR { + +class Config { +public: + explicit Config(const std::string &config_file) { + config_map_ = LoadConfig(config_file); + + this->use_gpu = bool(stoi(config_map_["use_gpu"])); + + this->gpu_id = stoi(config_map_["gpu_id"]); + + this->gpu_mem = stoi(config_map_["gpu_mem"]); + + this->cpu_math_library_num_threads = + stoi(config_map_["cpu_math_library_num_threads"]); + + this->max_side_len = stoi(config_map_["max_side_len"]); + + this->det_db_thresh = stod(config_map_["det_db_thresh"]); + + this->det_db_box_thresh = stod(config_map_["det_db_box_thresh"]); + + this->det_db_box_thresh = stod(config_map_["det_db_box_thresh"]); + + this->det_model_dir.assign(config_map_["det_model_dir"]); + + this->rec_model_dir.assign(config_map_["rec_model_dir"]); + + this->char_list_file.assign(config_map_["char_list_file"]); + } + + bool use_gpu = false; + + int gpu_id = 0; + + int gpu_mem = 4000; + + int cpu_math_library_num_threads = 1; + + int max_side_len = 960; + + double det_db_thresh = 0.3; + + double det_db_box_thresh = 0.5; + + double det_db_unclip_ratio = 2.0; + + std::string det_model_dir; + + std::string rec_model_dir; + + std::string char_list_file; + + void PrintConfigInfo(); + +private: + // Load configuration + std::map LoadConfig(const std::string &config_file); + + std::vector split(const std::string &str, + const std::string &delim); + + std::map config_map_; +}; + +} // namespace PaddleOCR diff --git a/deploy/cpp_infer/include/ocr_det.h b/deploy/cpp_infer/include/ocr_det.h index efcaa25c..3208d346 100644 --- a/deploy/cpp_infer/include/ocr_det.h +++ b/deploy/cpp_infer/include/ocr_det.h @@ -36,16 +36,29 @@ namespace PaddleOCR { class DBDetector { public: - explicit DBDetector(const std::string &model_dir, bool use_gpu = false, - const int gpu_id = 0, const int max_side_len = 960) { - LoadModel(model_dir, use_gpu); + explicit DBDetector(const std::string &model_dir, const bool &use_gpu = false, + const int &gpu_id = 0, const int &gpu_mem = 4000, + const int &cpu_math_library_num_threads = 4, + const int &max_side_len = 960, + const double &det_db_thresh = 0.3, + const double &det_db_box_thresh = 0.5, + const double &det_db_unclip_ratio = 2.0) { + LoadModel(model_dir); + + this->use_gpu_ = use_gpu; + this->gpu_id_ = gpu_id; + this->gpu_mem_ = gpu_mem; + this->cpu_math_library_num_threads_ = cpu_math_library_num_threads; + this->max_side_len_ = max_side_len; + + this->det_db_thresh_ = det_db_thresh; + this->det_db_box_thresh_ = det_db_box_thresh; + this->det_db_unclip_ratio_ = det_db_unclip_ratio; } // Load Paddle inference model - void LoadModel(const std::string &model_dir, bool use_gpu, - const int min_subgraph_size = 3, const int batch_size = 1, - const int gpu_id = 0); + void LoadModel(const std::string &model_dir); // Run predictor void Run(cv::Mat &img, std::vector>> &boxes); @@ -53,8 +66,17 @@ public: private: std::shared_ptr predictor_; + bool use_gpu_ = false; + int gpu_id_ = 0; + int gpu_mem_ = 4000; + int cpu_math_library_num_threads_ = 4; + int max_side_len_ = 960; + double det_db_thresh_ = 0.3; + double det_db_box_thresh_ = 0.5; + double det_db_unclip_ratio_ = 2.0; + std::vector mean_ = {0.485f, 0.456f, 0.406f}; std::vector scale_ = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f}; bool is_scale_ = true; diff --git a/deploy/cpp_infer/include/ocr_rec.h b/deploy/cpp_infer/include/ocr_rec.h index 50220136..dfb02f63 100644 --- a/deploy/cpp_infer/include/ocr_rec.h +++ b/deploy/cpp_infer/include/ocr_rec.h @@ -29,29 +29,40 @@ #include #include +#include namespace PaddleOCR { class CRNNRecognizer { public: - explicit CRNNRecognizer(const std::string &model_dir, - const string label_path = "./tools/ppocr_keys_v1.txt", - bool use_gpu = false, const int gpu_id = 0) { - LoadModel(model_dir, use_gpu); - - this->label_list_ = ReadDict(label_path); + explicit CRNNRecognizer( + const std::string &model_dir, const bool &use_gpu = false, + const int &gpu_id = 0, const int &gpu_mem = 4000, + const int &cpu_math_library_num_threads = 4, + const string &label_path = "./tools/ppocr_keys_v1.txt") { + LoadModel(model_dir); + + this->use_gpu_ = use_gpu; + this->gpu_id_ = gpu_id; + this->gpu_mem_ = gpu_mem; + this->cpu_math_library_num_threads_ = cpu_math_library_num_threads; + + this->label_list_ = Utility::ReadDict(label_path); } // Load Paddle inference model - void LoadModel(const std::string &model_dir, bool use_gpu, - const int gpu_id = 0, const int min_subgraph_size = 3, - const int batch_size = 1); + void LoadModel(const std::string &model_dir); void Run(std::vector>> boxes, cv::Mat &img); private: std::shared_ptr predictor_; + bool use_gpu_ = false; + int gpu_id_ = 0; + int gpu_mem_ = 4000; + int cpu_math_library_num_threads_ = 4; + std::vector label_list_; std::vector mean_ = {0.5f, 0.5f, 0.5f}; @@ -66,15 +77,8 @@ private: // post-process PostProcessor post_processor_; - cv::Mat get_rotate_crop_image(const cv::Mat &srcimage, - std::vector> box); - - std::vector ReadDict(const std::string &path); - - template - inline size_t argmax(ForwardIterator first, ForwardIterator last) { - return std::distance(first, std::max_element(first, last)); - } + cv::Mat GetRotateCropImage(const cv::Mat &srcimage, + std::vector> box); }; // class CrnnRecognizer diff --git a/deploy/cpp_infer/include/postprocess_op.h b/deploy/cpp_infer/include/postprocess_op.h index d851e29e..de6a81b2 100644 --- a/deploy/cpp_infer/include/postprocess_op.h +++ b/deploy/cpp_infer/include/postprocess_op.h @@ -28,36 +28,17 @@ #include #include "include/clipper.h" +#include "include/utility.h" using namespace std; namespace PaddleOCR { -inline std::vector ReadDict(std::string path) { - std::ifstream in(path); - std::string filename; - std::string line; - std::vector m_vec; - if (in) { - while (getline(in, line)) { - m_vec.push_back(line); - } - } else { - std::cout << "no such file" << std::endl; - } - return m_vec; -} - -template -inline size_t Argmax(ForwardIterator first, ForwardIterator last) { - return std::distance(first, std::max_element(first, last)); -} - class PostProcessor { public: void GetContourArea(float **box, float unclip_ratio, float &distance); - cv::RotatedRect unclip(float **box); + cv::RotatedRect UnClip(float **box, const float &unclip_ratio); float **Mat2Vec(cv::Mat mat); @@ -67,23 +48,17 @@ public: std::vector> order_points_clockwise(std::vector> pts); - float **get_mini_boxes(cv::RotatedRect box, float &ssid); + float **GetMiniBoxes(cv::RotatedRect box, float &ssid); - float box_score_fast(float **box_array, cv::Mat pred); + float BoxScoreFast(float **box_array, cv::Mat pred); std::vector>> - boxes_from_bitmap(const cv::Mat pred, const cv::Mat bitmap); + BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap, + const float &box_thresh, const float &det_db_unclip_ratio); std::vector>> - filter_tag_det_res(std::vector>> boxes, - float ratio_h, float ratio_w, cv::Mat srcimg); - - template - inline size_t argmax(ForwardIterator first, ForwardIterator last) { - return std::distance(first, std::max_element(first, last)); - } - - // CRNN + FilterTagDetRes(std::vector>> boxes, + float ratio_h, float ratio_w, cv::Mat srcimg); private: void quickSort(float **s, int l, int r); @@ -99,6 +74,7 @@ private: return min; return x; } + inline float clampf(float x, float min, float max) { if (x > max) return max; @@ -108,4 +84,4 @@ private: } }; -} // namespace PaddleOCR \ No newline at end of file +} // namespace PaddleOCR diff --git a/deploy/cpp_infer/include/preprocess_op.h b/deploy/cpp_infer/include/preprocess_op.h index 61f80449..309d7fd4 100644 --- a/deploy/cpp_infer/include/preprocess_op.h +++ b/deploy/cpp_infer/include/preprocess_op.h @@ -44,7 +44,6 @@ public: virtual void Run(const cv::Mat *im, float *data); }; -// RGB -> CHW class ResizeImgType0 { public: virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len, @@ -54,7 +53,7 @@ public: class CrnnResizeImg { public: virtual void Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio, - const std::vector rec_image_shape = {3, 32, 320}); + const std::vector &rec_image_shape = {3, 32, 320}); }; } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/include/utility.h b/deploy/cpp_infer/include/utility.h new file mode 100644 index 00000000..ebcc3e84 --- /dev/null +++ b/deploy/cpp_infer/include/utility.h @@ -0,0 +1,41 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace PaddleOCR { + +class Utility { +public: + static std::vector ReadDict(const std::string &path); + + template + inline static size_t argmax(ForwardIterator first, ForwardIterator last) { + return std::distance(first, std::max_element(first, last)); + } +}; + +} // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/src/config.cpp b/deploy/cpp_infer/src/config.cpp new file mode 100644 index 00000000..228c874d --- /dev/null +++ b/deploy/cpp_infer/src/config.cpp @@ -0,0 +1,64 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +namespace PaddleOCR { + +std::vector Config::split(const std::string &str, + const std::string &delim) { + std::vector res; + if ("" == str) + return res; + char *strs = new char[str.length() + 1]; + std::strcpy(strs, str.c_str()); + + char *d = new char[delim.length() + 1]; + std::strcpy(d, delim.c_str()); + + char *p = std::strtok(strs, d); + while (p) { + std::string s = p; + res.push_back(s); + p = std::strtok(NULL, d); + } + + return res; +} + +std::map +Config::LoadConfig(const std::string &config_path) { + auto config = Utility::ReadDict(config_path); + + std::map dict; + for (int i = 0; i < config.size(); i++) { + // pass for empty line or comment + if (config[i].size() <= 1 or config[i][0] == '#') { + continue; + } + std::vector res = split(config[i], " "); + dict[res[0]] = res[1]; + } + return dict; +} + +void Config::PrintConfigInfo() { + std::cout << "=======Paddle OCR inference config======" << std::endl; + for (auto iter = config_map_.begin(); iter != config_map_.end(); iter++) { + std::cout << iter->first << " : " << iter->second << std::endl; + } + std::cout << "=======End of Paddle OCR inference config======" << std::endl; +} + +} // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/src/main.cpp b/deploy/cpp_infer/src/main.cpp index 823f3046..b63bdafc 100644 --- a/deploy/cpp_infer/src/main.cpp +++ b/deploy/cpp_infer/src/main.cpp @@ -25,6 +25,7 @@ #include #include +#include #include #include @@ -33,21 +34,29 @@ using namespace cv; using namespace PaddleOCR; int main(int argc, char **argv) { - if (argc < 4) { + if (argc < 3) { std::cerr << "[ERROR] usage: " << argv[0] - << " det_model_file rec_model_file image_path\n"; + << " configure_filepath image_path\n"; exit(1); } - std::string det_model_file = argv[1]; - std::string rec_model_file = argv[2]; - std::string img_path = argv[3]; + + Config config(argv[1]); + + config.PrintConfigInfo(); + + std::string img_path(argv[2]); auto start = std::chrono::system_clock::now(); cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR); - DBDetector det(det_model_file); - CRNNRecognizer rec(rec_model_file); + DBDetector det(config.det_model_dir, config.use_gpu, config.gpu_id, + config.gpu_mem, config.cpu_math_library_num_threads, + config.max_side_len, config.det_db_thresh, + config.det_db_box_thresh, config.det_db_unclip_ratio); + CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id, + config.gpu_mem, config.cpu_math_library_num_threads, + config.char_list_file); std::vector>> boxes; det.Run(srcimg, boxes); diff --git a/deploy/cpp_infer/src/ocr_det.cpp b/deploy/cpp_infer/src/ocr_det.cpp index 746c94b2..a449e1b3 100644 --- a/deploy/cpp_infer/src/ocr_det.cpp +++ b/deploy/cpp_infer/src/ocr_det.cpp @@ -31,29 +31,28 @@ namespace PaddleOCR { -void DBDetector::LoadModel(const std::string &model_dir, bool use_gpu, - const int gpu_id, const int min_subgraph_size, - const int batch_size) { +void DBDetector::LoadModel(const std::string &model_dir) { AnalysisConfig config; config.SetModel(model_dir + "/model", model_dir + "/params"); - // for cpu - config.DisableGpu(); - config.EnableMKLDNN(); // 开启MKLDNN加速 - config.SetCpuMathLibraryNumThreads(10); + if (this->use_gpu_) { + config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); + } else { + config.DisableGpu(); + config.EnableMKLDNN(); // 开启MKLDNN加速 + config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_); + } - // 使用ZeroCopyTensor,此处必须设置为false + // false for zero copy tensor config.SwitchUseFeedFetchOps(false); - // 若输入为多个,此处必须设置为true + // true for multiple input config.SwitchSpecifyInputNames(true); - // config.SwitchIrDebug(true); // - // 可视化调试选项,若开启,则会在每个图优化过程后生成dot文件 - // config.SwitchIrOptim(false);// 默认为true。如果设置为false,关闭所有优化 - config.EnableMemoryOptim(); // 开启内存/显存复用 + + config.SwitchIrOptim(true); + + config.EnableMemoryOptim(); this->predictor_ = CreatePaddlePredictor(config); - // predictor_ = std::move(CreatePaddlePredictor(config)); // PaddleDetection - // usage } void DBDetector::Run(cv::Mat &img, @@ -69,13 +68,13 @@ void DBDetector::Run(cv::Mat &img, this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, this->is_scale_); - float *input = new float[1 * 3 * resize_img.rows * resize_img.cols]; - this->permute_op_.Run(&resize_img, input); + std::vector input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f); + this->permute_op_.Run(&resize_img, input.data()); auto input_names = this->predictor_->GetInputNames(); auto input_t = this->predictor_->GetInputTensor(input_names[0]); input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); - input_t->copy_from_cpu(input); + input_t->copy_from_cpu(input.data()); this->predictor_->ZeroCopyRun(); @@ -93,25 +92,26 @@ void DBDetector::Run(cv::Mat &img, int n3 = output_shape[3]; int n = n2 * n3; - float *pred = new float[n]; - unsigned char *cbuf = new unsigned char[n]; + std::vector pred(n, 0.0); + std::vector cbuf(n, ' '); for (int i = 0; i < n; i++) { pred[i] = float(out_data[i]); cbuf[i] = (unsigned char)((out_data[i]) * 255); } - cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char *)cbuf); - cv::Mat pred_map(n2, n3, CV_32F, (float *)pred); + cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char *)cbuf.data()); + cv::Mat pred_map(n2, n3, CV_32F, (float *)pred.data()); - const double threshold = 0.3 * 255; + const double threshold = this->det_db_thresh_ * 255; const double maxvalue = 255; cv::Mat bit_map; cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY); - boxes = post_processor_.boxes_from_bitmap(pred_map, bit_map); + boxes = post_processor_.BoxesFromBitmap( + pred_map, bit_map, this->det_db_box_thresh_, this->det_db_unclip_ratio_); - boxes = post_processor_.filter_tag_det_res(boxes, ratio_h, ratio_w, srcimg); + boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg); //// visualization cv::Point rook_points[boxes.size()][4]; @@ -133,10 +133,6 @@ void DBDetector::Run(cv::Mat &img, std::cout << "The detection visualized image saved in ./det_res.png" << std::endl; - - delete[] input; - delete[] pred; - delete[] cbuf; } } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/src/ocr_rec.cpp b/deploy/cpp_infer/src/ocr_rec.cpp index b50020e7..6173c7ce 100644 --- a/deploy/cpp_infer/src/ocr_rec.cpp +++ b/deploy/cpp_infer/src/ocr_rec.cpp @@ -41,7 +41,7 @@ void CRNNRecognizer::Run(std::vector>> boxes, std::cout << "The predicted text is :" << std::endl; int index = 0; for (int i = boxes.size() - 1; i >= 0; i--) { - crop_img = get_rotate_crop_image(srcimg, boxes[i]); + crop_img = GetRotateCropImage(srcimg, boxes[i]); float wh_ratio = float(crop_img.cols) / float(crop_img.rows); @@ -50,14 +50,14 @@ void CRNNRecognizer::Run(std::vector>> boxes, this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, this->is_scale_); - float *input = new float[1 * 3 * resize_img.rows * resize_img.cols]; + std::vector input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f); - this->permute_op_.Run(&resize_img, input); + this->permute_op_.Run(&resize_img, input.data()); auto input_names = this->predictor_->GetInputNames(); auto input_t = this->predictor_->GetInputTensor(input_names[0]); input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); - input_t->copy_from_cpu(input); + input_t->copy_from_cpu(input.data()); this->predictor_->ZeroCopyRun(); @@ -104,7 +104,8 @@ void CRNNRecognizer::Run(std::vector>> boxes, float max_value = 0.0f; for (int n = predict_lod[0][0]; n < predict_lod[0][1] - 1; n++) { - argmax_idx = int(argmax(&predict_batch[n * predict_shape[1]], + argmax_idx = + int(Utility::argmax(&predict_batch[n * predict_shape[1]], &predict_batch[(n + 1) * predict_shape[1]])); max_value = float(*std::max_element(&predict_batch[n * predict_shape[1]], @@ -116,37 +117,35 @@ void CRNNRecognizer::Run(std::vector>> boxes, } score /= count; std::cout << "\tscore: " << score << std::endl; - - delete[] input; } } -void CRNNRecognizer::LoadModel(const std::string &model_dir, bool use_gpu, - const int gpu_id, const int min_subgraph_size, - const int batch_size) { +void CRNNRecognizer::LoadModel(const std::string &model_dir) { AnalysisConfig config; config.SetModel(model_dir + "/model", model_dir + "/params"); - // for cpu - config.DisableGpu(); - config.EnableMKLDNN(); // 开启MKLDNN加速 - config.SetCpuMathLibraryNumThreads(10); + if (this->use_gpu_) { + config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); + } else { + config.DisableGpu(); + config.EnableMKLDNN(); // 开启MKLDNN加速 + config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_); + } - // 使用ZeroCopyTensor,此处必须设置为false + // false for zero copy tensor config.SwitchUseFeedFetchOps(false); - // 若输入为多个,此处必须设置为true + // true for multiple input config.SwitchSpecifyInputNames(true); - // config.SwitchIrDebug(true); // - // 可视化调试选项,若开启,则会在每个图优化过程后生成dot文件 - // config.SwitchIrOptim(false);// 默认为true。如果设置为false,关闭所有优化 - config.EnableMemoryOptim(); // 开启内存/显存复用 + + config.SwitchIrOptim(true); + + config.EnableMemoryOptim(); this->predictor_ = CreatePaddlePredictor(config); } -cv::Mat -CRNNRecognizer::get_rotate_crop_image(const cv::Mat &srcimage, - std::vector> box) { +cv::Mat CRNNRecognizer::GetRotateCropImage(const cv::Mat &srcimage, + std::vector> box) { cv::Mat image; srcimage.copyTo(image); std::vector> points = box; @@ -200,19 +199,4 @@ CRNNRecognizer::get_rotate_crop_image(const cv::Mat &srcimage, } } -std::vector CRNNRecognizer::ReadDict(const std::string &path) { - std::ifstream in(path); - std::string filename; - std::string line; - std::vector m_vec; - if (in) { - while (getline(in, line)) { - m_vec.push_back(line); - } - } else { - std::cout << "no such file" << std::endl; - } - return m_vec; -} - } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/src/postprocess_op.cpp b/deploy/cpp_infer/src/postprocess_op.cpp index 69036db1..78c37a03 100644 --- a/deploy/cpp_infer/src/postprocess_op.cpp +++ b/deploy/cpp_infer/src/postprocess_op.cpp @@ -34,8 +34,7 @@ void PostProcessor::GetContourArea(float **box, float unclip_ratio, distance = area * unclip_ratio / dist; } -cv::RotatedRect PostProcessor::unclip(float **box) { - float unclip_ratio = 2.0; +cv::RotatedRect PostProcessor::UnClip(float **box, const float &unclip_ratio) { float distance = 1.0; GetContourArea(box, unclip_ratio, distance); @@ -136,7 +135,7 @@ PostProcessor::order_points_clockwise(std::vector> pts) { return rect; } -float **PostProcessor::get_mini_boxes(cv::RotatedRect box, float &ssid) { +float **PostProcessor::GetMiniBoxes(cv::RotatedRect box, float &ssid) { ssid = box.size.width >= box.size.height ? box.size.height : box.size.width; cv::Mat points; @@ -169,7 +168,7 @@ float **PostProcessor::get_mini_boxes(cv::RotatedRect box, float &ssid) { return array; } -float PostProcessor::box_score_fast(float **box_array, cv::Mat pred) { +float PostProcessor::BoxScoreFast(float **box_array, cv::Mat pred) { auto array = box_array; int width = pred.cols; int height = pred.rows; @@ -207,10 +206,11 @@ float PostProcessor::box_score_fast(float **box_array, cv::Mat pred) { } std::vector>> -PostProcessor::boxes_from_bitmap(const cv::Mat pred, const cv::Mat bitmap) { +PostProcessor::BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap, + const float &box_thresh, + const float &det_db_unclip_ratio) { const int min_size = 3; const int max_candidates = 1000; - const float box_thresh = 0.5; int width = bitmap.cols; int height = bitmap.rows; @@ -229,7 +229,7 @@ PostProcessor::boxes_from_bitmap(const cv::Mat pred, const cv::Mat bitmap) { for (int _i = 0; _i < num_contours; _i++) { float ssid; cv::RotatedRect box = cv::minAreaRect(contours[_i]); - auto array = get_mini_boxes(box, ssid); + auto array = GetMiniBoxes(box, ssid); auto box_for_unclip = array; // end get_mini_box @@ -239,17 +239,16 @@ PostProcessor::boxes_from_bitmap(const cv::Mat pred, const cv::Mat bitmap) { } float score; - score = box_score_fast(array, pred); - // end box_score_fast + score = BoxScoreFast(array, pred); if (score < box_thresh) continue; // start for unclip - cv::RotatedRect points = unclip(box_for_unclip); + cv::RotatedRect points = UnClip(box_for_unclip, det_db_unclip_ratio); // end for unclip cv::RotatedRect clipbox = points; - auto cliparray = get_mini_boxes(clipbox, ssid); + auto cliparray = GetMiniBoxes(clipbox, ssid); if (ssid < min_size + 2) continue; @@ -273,9 +272,9 @@ PostProcessor::boxes_from_bitmap(const cv::Mat pred, const cv::Mat bitmap) { return boxes; } -std::vector>> PostProcessor::filter_tag_det_res( - std::vector>> boxes, float ratio_h, - float ratio_w, cv::Mat srcimg) { +std::vector>> +PostProcessor::FilterTagDetRes(std::vector>> boxes, + float ratio_h, float ratio_w, cv::Mat srcimg) { int oriimg_h = srcimg.rows; int oriimg_w = srcimg.cols; diff --git a/deploy/cpp_infer/src/preprocess_op.cpp b/deploy/cpp_infer/src/preprocess_op.cpp index 5fee3c4d..0078063e 100644 --- a/deploy/cpp_infer/src/preprocess_op.cpp +++ b/deploy/cpp_infer/src/preprocess_op.cpp @@ -97,7 +97,7 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, } void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio, - const std::vector rec_image_shape) { + const std::vector &rec_image_shape) { int imgC, imgH, imgW; imgC = rec_image_shape[0]; imgH = rec_image_shape[1]; diff --git a/deploy/cpp_infer/src/utility.cpp b/deploy/cpp_infer/src/utility.cpp new file mode 100644 index 00000000..8ca5ac67 --- /dev/null +++ b/deploy/cpp_infer/src/utility.cpp @@ -0,0 +1,39 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include + +namespace PaddleOCR { + +std::vector Utility::ReadDict(const std::string &path) { + std::ifstream in(path); + std::string line; + std::vector m_vec; + if (in) { + while (getline(in, line)) { + m_vec.push_back(line); + } + } else { + std::cout << "no such label file: " << path << ", exit the program..." + << std::endl; + exit(1); + } + return m_vec; +} + +} // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/tools/config.txt b/deploy/cpp_infer/tools/config.txt new file mode 100644 index 00000000..1ce1568e --- /dev/null +++ b/deploy/cpp_infer/tools/config.txt @@ -0,0 +1,17 @@ +# model load config +use_gpu 0 +gpu_id 0 +gpu_mem 4000 +cpu_math_library_num_threads 1 + +# det config +max_side_len 960 +det_db_thresh 0.3 +det_db_box_thresh 0.5 +det_db_unclip_ratio 2.0 +det_model_dir ./inference/det_db + +# rec config +rec_model_dir ./inference/rec_crnn +char_list_file ./tools/ppocr_keys_v1.txt +img_path ../../doc/imgs/11.jpg diff --git a/deploy/cpp_infer/tools/run.sh b/deploy/cpp_infer/tools/run.sh index dd2441c1..052b363b 100755 --- a/deploy/cpp_infer/tools/run.sh +++ b/deploy/cpp_infer/tools/run.sh @@ -1,2 +1,2 @@ -./build/ocr_system ./inference/det_db/ ./inference/rec_crnn/ ../../doc/imgs/12.jpg +./build/ocr_system ./tools/config.txt ../../doc/imgs/6.jpg -- GitLab