diff --git a/deploy/cpp_infer/include/ocr_cls.h b/deploy/cpp_infer/include/ocr_cls.h index a43c80053498843ec0152c96d209057017fff352..742e1f8bb0392859ea4bc3a6a4b4410f6b375826 100644 --- a/deploy/cpp_infer/include/ocr_cls.h +++ b/deploy/cpp_infer/include/ocr_cls.h @@ -42,7 +42,7 @@ public: const int &gpu_id, const int &gpu_mem, const int &cpu_math_library_num_threads, const bool &use_mkldnn, const double &cls_thresh, - const bool &use_tensorrt, const bool &use_fp16) { + const bool &use_tensorrt, const std::string &precision) { this->use_gpu_ = use_gpu; this->gpu_id_ = gpu_id; this->gpu_mem_ = gpu_mem; @@ -51,7 +51,7 @@ public: this->cls_thresh = cls_thresh; this->use_tensorrt_ = use_tensorrt; - this->use_fp16_ = use_fp16; + this->precision_ = precision; LoadModel(model_dir); } @@ -75,7 +75,7 @@ private: std::vector scale_ = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; bool is_scale_ = true; bool use_tensorrt_ = false; - bool use_fp16_ = false; + std::string precision_ = "fp32"; // pre-process ClsResizeImg resize_op_; Normalize normalize_op_; diff --git a/deploy/cpp_infer/include/ocr_det.h b/deploy/cpp_infer/include/ocr_det.h index 18318c9c4e37136db62c1338db1b58f82859f037..e5a31ed8e5ab6397c4fa67388252e2baef8b9dd7 100644 --- a/deploy/cpp_infer/include/ocr_det.h +++ b/deploy/cpp_infer/include/ocr_det.h @@ -46,7 +46,7 @@ public: const double &det_db_box_thresh, const double &det_db_unclip_ratio, const bool &use_polygon_score, const bool &visualize, - const bool &use_tensorrt, const bool &use_fp16) { + const bool &use_tensorrt, const std::string &precision) { this->use_gpu_ = use_gpu; this->gpu_id_ = gpu_id; this->gpu_mem_ = gpu_mem; @@ -62,7 +62,7 @@ public: this->visualize_ = visualize; this->use_tensorrt_ = use_tensorrt; - this->use_fp16_ = use_fp16; + this->precision_ = precision; LoadModel(model_dir); } @@ -71,7 +71,7 @@ public: void LoadModel(const std::string &model_dir); // Run predictor - void Run(cv::Mat &img, std::vector>> &boxes); + void Run(cv::Mat &img, std::vector>> &boxes, std::vector *times); private: std::shared_ptr predictor_; @@ -91,7 +91,7 @@ private: bool visualize_ = true; bool use_tensorrt_ = false; - bool use_fp16_ = false; + std::string precision_ = "fp32"; std::vector mean_ = {0.485f, 0.456f, 0.406f}; std::vector scale_ = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f}; diff --git a/deploy/cpp_infer/include/ocr_rec.h b/deploy/cpp_infer/include/ocr_rec.h index 25f55ae26a29cc4f93f152cc072bd444aedf6bf2..d585112b051daff7c03060836a4c065ba6e3949c 100644 --- a/deploy/cpp_infer/include/ocr_rec.h +++ b/deploy/cpp_infer/include/ocr_rec.h @@ -44,14 +44,14 @@ public: const int &gpu_id, const int &gpu_mem, const int &cpu_math_library_num_threads, const bool &use_mkldnn, const string &label_path, - const bool &use_tensorrt, const bool &use_fp16) { + const bool &use_tensorrt, const std::string &precision) { this->use_gpu_ = use_gpu; this->gpu_id_ = gpu_id; this->gpu_mem_ = gpu_mem; this->cpu_math_library_num_threads_ = cpu_math_library_num_threads; this->use_mkldnn_ = use_mkldnn; this->use_tensorrt_ = use_tensorrt; - this->use_fp16_ = use_fp16; + this->precision_ = precision; this->label_list_ = Utility::ReadDict(label_path); this->label_list_.insert(this->label_list_.begin(), @@ -64,7 +64,7 @@ public: // Load Paddle inference model void LoadModel(const std::string &model_dir); - void Run(cv::Mat &img); + void Run(cv::Mat &img, std::vector *times); private: std::shared_ptr predictor_; @@ -81,7 +81,7 @@ private: std::vector scale_ = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; bool is_scale_ = true; bool use_tensorrt_ = false; - bool use_fp16_ = false; + std::string precision_ = "fp32"; // pre-process CrnnResizeImg resize_op_; Normalize normalize_op_; @@ -90,9 +90,6 @@ private: // post-process PostProcessor post_processor_; - cv::Mat GetRotateCropImage(const cv::Mat &srcimage, - std::vector> box); - }; // class CrnnRecognizer } // namespace PaddleOCR diff --git a/deploy/cpp_infer/include/utility.h b/deploy/cpp_infer/include/utility.h index 6e8173e007279319657250b376de022240bc6f62..678187d3fabfb1c91584226950155b3c47b5f93f 100644 --- a/deploy/cpp_infer/include/utility.h +++ b/deploy/cpp_infer/include/utility.h @@ -47,6 +47,9 @@ public: static void GetAllFiles(const char *dir_name, std::vector &all_inputs); + + static cv::Mat GetRotateCropImage(const cv::Mat &srcimage, + std::vector> box); }; } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/src/main.cpp b/deploy/cpp_infer/src/main.cpp index 830f032d4649c44ed33527c383dd332b494c47c3..5e5c851517d5efaa75f54b7a156563a4a42880d5 100644 --- a/deploy/cpp_infer/src/main.cpp +++ b/deploy/cpp_infer/src/main.cpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -41,7 +42,9 @@ DEFINE_int32(gpu_mem, 4000, "GPU id when infering with GPU."); DEFINE_int32(cpu_math_library_num_threads, 10, "Num of threads with CPU."); DEFINE_bool(use_mkldnn, false, "Whether use mkldnn with CPU."); DEFINE_bool(use_tensorrt, false, "Whether use tensorrt."); -DEFINE_bool(use_fp16, false, "Whether use fp16 when use tensorrt."); +DEFINE_string(precision, "fp32", "Precision be one of fp32/fp16/int8"); +DEFINE_bool(benchmark, true, "Whether use benchmark."); +DEFINE_string(save_log_path, "./log_output/", "Save benchmark log path."); // detection related DEFINE_string(image_dir, "", "Dir of input image."); DEFINE_string(det_model_dir, "", "Path of det inference model."); @@ -65,6 +68,34 @@ using namespace cv; using namespace PaddleOCR; +void PrintBenchmarkLog(std::string model_name, + int batch_size, + std::string input_shape, + std::vector time_info, + int img_num){ + LOG(INFO) << "----------------------- Config info -----------------------"; + LOG(INFO) << "runtime_device: " << (FLAGS_use_gpu ? "gpu" : "cpu"); + LOG(INFO) << "ir_optim: " << "True"; + LOG(INFO) << "enable_memory_optim: " << "True"; + LOG(INFO) << "enable_tensorrt: " << FLAGS_use_tensorrt; + LOG(INFO) << "enable_mkldnn: " << (FLAGS_use_mkldnn ? "True" : "False"); + LOG(INFO) << "cpu_math_library_num_threads: " << FLAGS_cpu_math_library_num_threads; + LOG(INFO) << "----------------------- Data info -----------------------"; + LOG(INFO) << "batch_size: " << batch_size; + LOG(INFO) << "input_shape: " << input_shape; + LOG(INFO) << "data_num: " << img_num; + LOG(INFO) << "----------------------- Model info -----------------------"; + LOG(INFO) << "model_name: " << model_name; + LOG(INFO) << "precision: " << FLAGS_precision; + LOG(INFO) << "----------------------- Perf info ------------------------"; + LOG(INFO) << "Total time spent(ms): " + << std::accumulate(time_info.begin(), time_info.end(), 0); + LOG(INFO) << "preprocess_time(ms): " << time_info[0] / img_num + << ", inference_time(ms): " << time_info[1] / img_num + << ", postprocess_time(ms): " << time_info[2] / img_num; +} + + static bool PathExists(const std::string& path){ #ifdef _WIN32 struct _stat buffer; @@ -76,88 +107,15 @@ static bool PathExists(const std::string& path){ } -cv::Mat GetRotateCropImage(const cv::Mat &srcimage, - std::vector> box) { - cv::Mat image; - srcimage.copyTo(image); - std::vector> points = box; - - int x_collect[4] = {box[0][0], box[1][0], box[2][0], box[3][0]}; - int y_collect[4] = {box[0][1], box[1][1], box[2][1], box[3][1]}; - int left = int(*std::min_element(x_collect, x_collect + 4)); - int right = int(*std::max_element(x_collect, x_collect + 4)); - int top = int(*std::min_element(y_collect, y_collect + 4)); - int bottom = int(*std::max_element(y_collect, y_collect + 4)); - - cv::Mat img_crop; - image(cv::Rect(left, top, right - left, bottom - top)).copyTo(img_crop); - - for (int i = 0; i < points.size(); i++) { - points[i][0] -= left; - points[i][1] -= top; - } - - int img_crop_width = int(sqrt(pow(points[0][0] - points[1][0], 2) + - pow(points[0][1] - points[1][1], 2))); - int img_crop_height = int(sqrt(pow(points[0][0] - points[3][0], 2) + - pow(points[0][1] - points[3][1], 2))); - - cv::Point2f pts_std[4]; - pts_std[0] = cv::Point2f(0., 0.); - pts_std[1] = cv::Point2f(img_crop_width, 0.); - pts_std[2] = cv::Point2f(img_crop_width, img_crop_height); - pts_std[3] = cv::Point2f(0.f, img_crop_height); - - cv::Point2f pointsf[4]; - pointsf[0] = cv::Point2f(points[0][0], points[0][1]); - pointsf[1] = cv::Point2f(points[1][0], points[1][1]); - pointsf[2] = cv::Point2f(points[2][0], points[2][1]); - pointsf[3] = cv::Point2f(points[3][0], points[3][1]); - - cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std); - - cv::Mat dst_img; - cv::warpPerspective(img_crop, dst_img, M, - cv::Size(img_crop_width, img_crop_height), - cv::BORDER_REPLICATE); - - if (float(dst_img.rows) >= float(dst_img.cols) * 1.5) { - cv::Mat srcCopy = cv::Mat(dst_img.rows, dst_img.cols, dst_img.depth()); - cv::transpose(dst_img, srcCopy); - cv::flip(srcCopy, srcCopy, 0); - return srcCopy; - } else { - return dst_img; - } -} - - -int main_det(int argc, char **argv) { - // Parsing command-line - google::ParseCommandLineFlags(&argc, &argv, true); - if (FLAGS_det_model_dir.empty() || FLAGS_image_dir.empty()) { - std::cout << "Usage[det]: ./ppocr --det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ " - << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl; - exit(1); - } - if (!PathExists(FLAGS_image_dir)) { - std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir << endl; - exit(1); - } - - std::vector cv_all_img_names; - cv::glob(FLAGS_image_dir, cv_all_img_names); - std::cout << "total images num: " << cv_all_img_names.size() << endl; - +int main_det(std::vector cv_all_img_names) { + std::vector time_info = {0, 0, 0}; DBDetector det(FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem, FLAGS_cpu_math_library_num_threads, FLAGS_use_mkldnn, FLAGS_max_side_len, FLAGS_det_db_thresh, FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio, FLAGS_use_polygon_score, FLAGS_visualize, - FLAGS_use_tensorrt, FLAGS_use_fp16); - - auto start = std::chrono::system_clock::now(); - + FLAGS_use_tensorrt, FLAGS_precision); + for (int i = 0; i < cv_all_img_names.size(); ++i) { LOG(INFO) << "The predict img: " << cv_all_img_names[i]; @@ -167,46 +125,28 @@ int main_det(int argc, char **argv) { exit(1); } std::vector>> boxes; + std::vector det_times; - det.Run(srcimg, boxes); - - auto end = std::chrono::system_clock::now(); - auto duration = - std::chrono::duration_cast(end - start); - std::cout << "Cost " - << double(duration.count()) * - std::chrono::microseconds::period::num / - std::chrono::microseconds::period::den - << "s" << std::endl; + det.Run(srcimg, boxes, &det_times); + + time_info[0] += det_times[0]; + time_info[1] += det_times[1]; + time_info[2] += det_times[2]; } + if (FLAGS_benchmark) { + PrintBenchmarkLog("det", 1, "dynamic", time_info, cv_all_img_names.size()); + } return 0; } -int main_rec(int argc, char **argv) { - // Parsing command-line - google::ParseCommandLineFlags(&argc, &argv, true); - if (FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) { - std::cout << "Usage[rec]: ./ppocr --rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ " - << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl; - exit(1); - } - if (!PathExists(FLAGS_image_dir)) { - std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir << endl; - exit(1); - } - - std::vector cv_all_img_names; - cv::glob(FLAGS_image_dir, cv_all_img_names); - std::cout << "total images num: " << cv_all_img_names.size() << endl; - +int main_rec(std::vector cv_all_img_names) { + std::vector time_info = {0, 0, 0}; CRNNRecognizer rec(FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem, FLAGS_cpu_math_library_num_threads, FLAGS_use_mkldnn, FLAGS_char_list_file, - FLAGS_use_tensorrt, FLAGS_use_fp16); - - auto start = std::chrono::system_clock::now(); + FLAGS_use_tensorrt, FLAGS_precision); for (int i = 0; i < cv_all_img_names.size(); ++i) { LOG(INFO) << "The predict img: " << cv_all_img_names[i]; @@ -217,65 +157,42 @@ int main_rec(int argc, char **argv) { exit(1); } - rec.Run(srcimg); + std::vector rec_times; + rec.Run(srcimg, &rec_times); - auto end = std::chrono::system_clock::now(); - auto duration = - std::chrono::duration_cast(end - start); - std::cout << "Cost " - << double(duration.count()) * - std::chrono::microseconds::period::num / - std::chrono::microseconds::period::den - << "s" << std::endl; + time_info[0] += rec_times[0]; + time_info[1] += rec_times[1]; + time_info[2] += rec_times[2]; + } + + if (FLAGS_benchmark) { + PrintBenchmarkLog("rec", 1, "dynamic", time_info, cv_all_img_names.size()); } return 0; } -int main_system(int argc, char **argv) { - // Parsing command-line - google::ParseCommandLineFlags(&argc, &argv, true); - if ((FLAGS_det_model_dir.empty() || FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) || - (FLAGS_use_angle_cls && FLAGS_cls_model_dir.empty())) { - std::cout << "Usage[system without angle cls]: ./ppocr --det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ " - << "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ " - << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl; - std::cout << "Usage[system with angle cls]: ./ppocr --det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ " - << "--use_angle_cls=true " - << "--cls_model_dir=/PATH/TO/CLS_INFERENCE_MODEL/ " - << "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ " - << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl; - exit(1); - } - if (!PathExists(FLAGS_image_dir)) { - std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir << endl; - exit(1); - } - - std::vector cv_all_img_names; - cv::glob(FLAGS_image_dir, cv_all_img_names); - std::cout << "total images num: " << cv_all_img_names.size() << endl; - +int main_system(std::vector cv_all_img_names) { DBDetector det(FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem, FLAGS_cpu_math_library_num_threads, FLAGS_use_mkldnn, FLAGS_max_side_len, FLAGS_det_db_thresh, FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio, FLAGS_use_polygon_score, FLAGS_visualize, - FLAGS_use_tensorrt, FLAGS_use_fp16); + FLAGS_use_tensorrt, FLAGS_precision); Classifier *cls = nullptr; if (FLAGS_use_angle_cls) { cls = new Classifier(FLAGS_cls_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem, FLAGS_cpu_math_library_num_threads, FLAGS_use_mkldnn, FLAGS_cls_thresh, - FLAGS_use_tensorrt, FLAGS_use_fp16); + FLAGS_use_tensorrt, FLAGS_precision); } CRNNRecognizer rec(FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem, FLAGS_cpu_math_library_num_threads, FLAGS_use_mkldnn, FLAGS_char_list_file, - FLAGS_use_tensorrt, FLAGS_use_fp16); + FLAGS_use_tensorrt, FLAGS_precision); auto start = std::chrono::system_clock::now(); @@ -288,17 +205,19 @@ int main_system(int argc, char **argv) { exit(1); } std::vector>> boxes; - - det.Run(srcimg, boxes); + std::vector det_times; + std::vector rec_times; + + det.Run(srcimg, boxes, &det_times); cv::Mat crop_img; for (int j = 0; j < boxes.size(); j++) { - crop_img = GetRotateCropImage(srcimg, boxes[j]); + crop_img = Utility::GetRotateCropImage(srcimg, boxes[j]); if (cls != nullptr) { crop_img = cls->Run(crop_img); } - rec.Run(crop_img); + rec.Run(crop_img, &rec_times); } auto end = std::chrono::system_clock::now(); @@ -315,22 +234,70 @@ int main_system(int argc, char **argv) { } +void check_params(char* mode) { + if (strcmp(mode, "det")==0) { + if (FLAGS_det_model_dir.empty() || FLAGS_image_dir.empty()) { + std::cout << "Usage[det]: ./ppocr --det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ " + << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl; + exit(1); + } + } + if (strcmp(mode, "rec")==0) { + if (FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) { + std::cout << "Usage[rec]: ./ppocr --rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ " + << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl; + exit(1); + } + } + if (strcmp(mode, "system")==0) { + if ((FLAGS_det_model_dir.empty() || FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) || + (FLAGS_use_angle_cls && FLAGS_cls_model_dir.empty())) { + std::cout << "Usage[system without angle cls]: ./ppocr --det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ " + << "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ " + << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl; + std::cout << "Usage[system with angle cls]: ./ppocr --det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ " + << "--use_angle_cls=true " + << "--cls_model_dir=/PATH/TO/CLS_INFERENCE_MODEL/ " + << "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ " + << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl; + exit(1); + } + } + if (FLAGS_precision != "fp32" && FLAGS_precision != "fp16" && FLAGS_precision != "int8") { + cout << "precison should be 'fp32'(default), 'fp16' or 'int8'. " << endl; + exit(1); + } +} + + int main(int argc, char **argv) { - if (strcmp(argv[1], "det")!=0 && strcmp(argv[1], "rec")!=0 && strcmp(argv[1], "system")!=0) { - std::cout << "Please choose one mode of [det, rec, system] !" << std::endl; - return -1; - } - std::cout << "mode: " << argv[1] << endl; - - if (strcmp(argv[1], "det")==0) { - return main_det(argc, argv); - } - if (strcmp(argv[1], "rec")==0) { - return main_rec(argc, argv); - } - if (strcmp(argv[1], "system")==0) { - return main_system(argc, argv); - } + if (argc<=1 || (strcmp(argv[1], "det")!=0 && strcmp(argv[1], "rec")!=0 && strcmp(argv[1], "system")!=0)) { + std::cout << "Please choose one mode of [det, rec, system] !" << std::endl; + return -1; + } + std::cout << "mode: " << argv[1] << endl; + + // Parsing command-line + google::ParseCommandLineFlags(&argc, &argv, true); + check_params(argv[1]); + + if (!PathExists(FLAGS_image_dir)) { + std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir << endl; + exit(1); + } -// return 0; + std::vector cv_all_img_names; + cv::glob(FLAGS_image_dir, cv_all_img_names); + std::cout << "total images num: " << cv_all_img_names.size() << endl; + + if (strcmp(argv[1], "det")==0) { + return main_det(cv_all_img_names); + } + if (strcmp(argv[1], "rec")==0) { + return main_rec(cv_all_img_names); + } + if (strcmp(argv[1], "system")==0) { + return main_system(cv_all_img_names); + } + } diff --git a/deploy/cpp_infer/src/ocr_cls.cpp b/deploy/cpp_infer/src/ocr_cls.cpp index 9199e082e5df42b0c9c42e668d2df37acf4521c4..3b04b6f8248bb17b9e315ae8b777530840015394 100644 --- a/deploy/cpp_infer/src/ocr_cls.cpp +++ b/deploy/cpp_infer/src/ocr_cls.cpp @@ -77,10 +77,16 @@ void Classifier::LoadModel(const std::string &model_dir) { if (this->use_gpu_) { config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); if (this->use_tensorrt_) { + auto precision = paddle_infer::Config::Precision::kFloat32; + if (this->precision_ == "fp16") { + precision = paddle_infer::Config::Precision::kHalf; + } + if (this->precision_ == "int8") { + precision = paddle_infer::Config::Precision::kInt8; + } config.EnableTensorRtEngine( 1 << 20, 10, 3, - this->use_fp16_ ? paddle_infer::Config::Precision::kHalf - : paddle_infer::Config::Precision::kFloat32, + precision, false, false); } } else { diff --git a/deploy/cpp_infer/src/ocr_det.cpp b/deploy/cpp_infer/src/ocr_det.cpp index 58dc4dce8117f81b17e3c88ea02404d474ea9248..a69f5ca1bd3ee7665f8b2f5610c67dd6feb7eb54 100644 --- a/deploy/cpp_infer/src/ocr_det.cpp +++ b/deploy/cpp_infer/src/ocr_det.cpp @@ -26,10 +26,16 @@ void DBDetector::LoadModel(const std::string &model_dir) { if (this->use_gpu_) { config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); if (this->use_tensorrt_) { + auto precision = paddle_infer::Config::Precision::kFloat32; + if (this->precision_ == "fp16") { + precision = paddle_infer::Config::Precision::kHalf; + } + if (this->precision_ == "int8") { + precision = paddle_infer::Config::Precision::kInt8; + } config.EnableTensorRtEngine( 1 << 20, 10, 3, - this->use_fp16_ ? paddle_infer::Config::Precision::kHalf - : paddle_infer::Config::Precision::kFloat32, + precision, false, false); std::map> min_input_shape = { {"x", {1, 3, 50, 50}}, @@ -91,13 +97,16 @@ void DBDetector::LoadModel(const std::string &model_dir) { } void DBDetector::Run(cv::Mat &img, - std::vector>> &boxes) { + std::vector>> &boxes, + std::vector *times) { float ratio_h{}; float ratio_w{}; cv::Mat srcimg; cv::Mat resize_img; img.copyTo(srcimg); + + auto preprocess_start = std::chrono::steady_clock::now(); this->resize_op_.Run(img, resize_img, this->max_side_len_, ratio_h, ratio_w, this->use_tensorrt_); @@ -106,14 +115,17 @@ void DBDetector::Run(cv::Mat &img, std::vector input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f); this->permute_op_.Run(&resize_img, input.data()); - + auto preprocess_end = std::chrono::steady_clock::now(); + // Inference. auto input_names = this->predictor_->GetInputNames(); auto input_t = this->predictor_->GetInputHandle(input_names[0]); input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); + auto inference_start = std::chrono::steady_clock::now(); input_t->CopyFromCpu(input.data()); + this->predictor_->Run(); - + std::vector out_data; auto output_names = this->predictor_->GetOutputNames(); auto output_t = this->predictor_->GetOutputHandle(output_names[0]); @@ -123,7 +135,9 @@ void DBDetector::Run(cv::Mat &img, out_data.resize(out_num); output_t->CopyToCpu(out_data.data()); - + auto inference_end = std::chrono::steady_clock::now(); + + auto postprocess_start = std::chrono::steady_clock::now(); int n2 = output_shape[2]; int n3 = output_shape[3]; int n = n2 * n3; @@ -151,7 +165,15 @@ void DBDetector::Run(cv::Mat &img, this->det_db_unclip_ratio_, this->use_polygon_score_); boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg); + auto postprocess_end = std::chrono::steady_clock::now(); std::cout << "Detected boxes num: " << boxes.size() << endl; + + std::chrono::duration preprocess_diff = preprocess_end - preprocess_start; + times->push_back(double(preprocess_diff.count() * 1000)); + std::chrono::duration inference_diff = inference_end - inference_start; + times->push_back(double(inference_diff.count() * 1000)); + std::chrono::duration postprocess_diff = postprocess_end - postprocess_start; + times->push_back(double(postprocess_diff.count() * 1000)); //// visualization if (this->visualize_) { diff --git a/deploy/cpp_infer/src/ocr_rec.cpp b/deploy/cpp_infer/src/ocr_rec.cpp index c4a784f82c789f3ebdc826ccb1d37631c8204368..b64dcea5ae2a68485296c02cdb7689c60ea504f8 100644 --- a/deploy/cpp_infer/src/ocr_rec.cpp +++ b/deploy/cpp_infer/src/ocr_rec.cpp @@ -16,13 +16,13 @@ namespace PaddleOCR { -void CRNNRecognizer::Run(cv::Mat &img) { +void CRNNRecognizer::Run(cv::Mat &img, std::vector *times) { cv::Mat srcimg; img.copyTo(srcimg); cv::Mat resize_img; float wh_ratio = float(srcimg.cols) / float(srcimg.rows); - + auto preprocess_start = std::chrono::steady_clock::now(); this->resize_op_.Run(srcimg, resize_img, wh_ratio, this->use_tensorrt_); this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, @@ -31,11 +31,13 @@ void CRNNRecognizer::Run(cv::Mat &img) { std::vector input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f); this->permute_op_.Run(&resize_img, input.data()); + auto preprocess_end = std::chrono::steady_clock::now(); // Inference. auto input_names = this->predictor_->GetInputNames(); auto input_t = this->predictor_->GetInputHandle(input_names[0]); input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); + auto inference_start = std::chrono::steady_clock::now(); input_t->CopyFromCpu(input.data()); this->predictor_->Run(); @@ -49,8 +51,10 @@ void CRNNRecognizer::Run(cv::Mat &img) { predict_batch.resize(out_num); output_t->CopyToCpu(predict_batch.data()); + auto inference_end = std::chrono::steady_clock::now(); // ctc decode + auto postprocess_start = std::chrono::steady_clock::now(); std::vector str_res; int argmax_idx; int last_index = 0; @@ -73,11 +77,19 @@ void CRNNRecognizer::Run(cv::Mat &img) { } last_index = argmax_idx; } + auto postprocess_end = std::chrono::steady_clock::now(); score /= count; for (int i = 0; i < str_res.size(); i++) { std::cout << str_res[i]; } std::cout << "\tscore: " << score << std::endl; + + std::chrono::duration preprocess_diff = preprocess_end - preprocess_start; + times->push_back(double(preprocess_diff.count() * 1000)); + std::chrono::duration inference_diff = inference_end - inference_start; + times->push_back(double(inference_diff.count() * 1000)); + std::chrono::duration postprocess_diff = postprocess_end - postprocess_start; + times->push_back(double(postprocess_diff.count() * 1000)); } void CRNNRecognizer::LoadModel(const std::string &model_dir) { @@ -89,10 +101,16 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { if (this->use_gpu_) { config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); if (this->use_tensorrt_) { + auto precision = paddle_infer::Config::Precision::kFloat32; + if (this->precision_ == "fp16") { + precision = paddle_infer::Config::Precision::kHalf; + } + if (this->precision_ == "int8") { + precision = paddle_infer::Config::Precision::kInt8; + } config.EnableTensorRtEngine( 1 << 20, 10, 3, - this->use_fp16_ ? paddle_infer::Config::Precision::kHalf - : paddle_infer::Config::Precision::kFloat32, + precision, false, false); std::map> min_input_shape = { {"x", {1, 3, 32, 10}}}; @@ -126,59 +144,4 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { this->predictor_ = CreatePredictor(config); } -cv::Mat CRNNRecognizer::GetRotateCropImage(const cv::Mat &srcimage, - std::vector> box) { - cv::Mat image; - srcimage.copyTo(image); - std::vector> points = box; - - int x_collect[4] = {box[0][0], box[1][0], box[2][0], box[3][0]}; - int y_collect[4] = {box[0][1], box[1][1], box[2][1], box[3][1]}; - int left = int(*std::min_element(x_collect, x_collect + 4)); - int right = int(*std::max_element(x_collect, x_collect + 4)); - int top = int(*std::min_element(y_collect, y_collect + 4)); - int bottom = int(*std::max_element(y_collect, y_collect + 4)); - - cv::Mat img_crop; - image(cv::Rect(left, top, right - left, bottom - top)).copyTo(img_crop); - - for (int i = 0; i < points.size(); i++) { - points[i][0] -= left; - points[i][1] -= top; - } - - int img_crop_width = int(sqrt(pow(points[0][0] - points[1][0], 2) + - pow(points[0][1] - points[1][1], 2))); - int img_crop_height = int(sqrt(pow(points[0][0] - points[3][0], 2) + - pow(points[0][1] - points[3][1], 2))); - - cv::Point2f pts_std[4]; - pts_std[0] = cv::Point2f(0., 0.); - pts_std[1] = cv::Point2f(img_crop_width, 0.); - pts_std[2] = cv::Point2f(img_crop_width, img_crop_height); - pts_std[3] = cv::Point2f(0.f, img_crop_height); - - cv::Point2f pointsf[4]; - pointsf[0] = cv::Point2f(points[0][0], points[0][1]); - pointsf[1] = cv::Point2f(points[1][0], points[1][1]); - pointsf[2] = cv::Point2f(points[2][0], points[2][1]); - pointsf[3] = cv::Point2f(points[3][0], points[3][1]); - - cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std); - - cv::Mat dst_img; - cv::warpPerspective(img_crop, dst_img, M, - cv::Size(img_crop_width, img_crop_height), - cv::BORDER_REPLICATE); - - if (float(dst_img.rows) >= float(dst_img.cols) * 1.5) { - cv::Mat srcCopy = cv::Mat(dst_img.rows, dst_img.cols, dst_img.depth()); - cv::transpose(dst_img, srcCopy); - cv::flip(srcCopy, srcCopy, 0); - return srcCopy; - } else { - return dst_img; - } -} - } // namespace PaddleOCR diff --git a/deploy/cpp_infer/src/utility.cpp b/deploy/cpp_infer/src/utility.cpp index 2cd84f7e8dbdd8144b5337f55b3f3a62ed43d5b3..dba445b747ff3f3c0d2db91061650c369977c4dd 100644 --- a/deploy/cpp_infer/src/utility.cpp +++ b/deploy/cpp_infer/src/utility.cpp @@ -92,4 +92,59 @@ void Utility::GetAllFiles(const char *dir_name, } } +cv::Mat Utility::GetRotateCropImage(const cv::Mat &srcimage, + std::vector> box) { + cv::Mat image; + srcimage.copyTo(image); + std::vector> points = box; + + int x_collect[4] = {box[0][0], box[1][0], box[2][0], box[3][0]}; + int y_collect[4] = {box[0][1], box[1][1], box[2][1], box[3][1]}; + int left = int(*std::min_element(x_collect, x_collect + 4)); + int right = int(*std::max_element(x_collect, x_collect + 4)); + int top = int(*std::min_element(y_collect, y_collect + 4)); + int bottom = int(*std::max_element(y_collect, y_collect + 4)); + + cv::Mat img_crop; + image(cv::Rect(left, top, right - left, bottom - top)).copyTo(img_crop); + + for (int i = 0; i < points.size(); i++) { + points[i][0] -= left; + points[i][1] -= top; + } + + int img_crop_width = int(sqrt(pow(points[0][0] - points[1][0], 2) + + pow(points[0][1] - points[1][1], 2))); + int img_crop_height = int(sqrt(pow(points[0][0] - points[3][0], 2) + + pow(points[0][1] - points[3][1], 2))); + + cv::Point2f pts_std[4]; + pts_std[0] = cv::Point2f(0., 0.); + pts_std[1] = cv::Point2f(img_crop_width, 0.); + pts_std[2] = cv::Point2f(img_crop_width, img_crop_height); + pts_std[3] = cv::Point2f(0.f, img_crop_height); + + cv::Point2f pointsf[4]; + pointsf[0] = cv::Point2f(points[0][0], points[0][1]); + pointsf[1] = cv::Point2f(points[1][0], points[1][1]); + pointsf[2] = cv::Point2f(points[2][0], points[2][1]); + pointsf[3] = cv::Point2f(points[3][0], points[3][1]); + + cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std); + + cv::Mat dst_img; + cv::warpPerspective(img_crop, dst_img, M, + cv::Size(img_crop_width, img_crop_height), + cv::BORDER_REPLICATE); + + if (float(dst_img.rows) >= float(dst_img.cols) * 1.5) { + cv::Mat srcCopy = cv::Mat(dst_img.rows, dst_img.cols, dst_img.depth()); + cv::transpose(dst_img, srcCopy); + cv::flip(srcCopy, srcCopy, 0); + return srcCopy; + } else { + return dst_img; + } +} + } // namespace PaddleOCR \ No newline at end of file