diff --git a/deploy/cpp_infer/include/args.h b/deploy/cpp_infer/include/args.h index f7fac9c92c421ca85818b2d04097ce8e55ea117e..e6e76ef927c16f6afe381f64ea8dde4ac99185cf 100644 --- a/deploy/cpp_infer/include/args.h +++ b/deploy/cpp_infer/include/args.h @@ -49,6 +49,11 @@ DECLARE_int32(rec_batch_num); DECLARE_string(rec_char_dict_path); DECLARE_int32(rec_img_h); DECLARE_int32(rec_img_w); +// layout model related +DECLARE_string(layout_model_dir); +DECLARE_string(layout_dict_path); +DECLARE_double(layout_score_threshold); +DECLARE_double(layout_nms_threshold); // structure model related DECLARE_string(table_model_dir); DECLARE_int32(table_max_len); @@ -59,4 +64,5 @@ DECLARE_bool(merge_no_span_structure); DECLARE_bool(det); DECLARE_bool(rec); DECLARE_bool(cls); -DECLARE_bool(table); \ No newline at end of file +DECLARE_bool(table); +DECLARE_bool(layout); \ No newline at end of file diff --git a/deploy/cpp_infer/include/paddleocr.h b/deploy/cpp_infer/include/paddleocr.h index a2c60b14acceaa90a8d8e4a70ccc50f02f254eb6..c2af4b29e579088e9dc1ed817de11be8cc3ff4fe 100644 --- a/deploy/cpp_infer/include/paddleocr.h +++ b/deploy/cpp_infer/include/paddleocr.h @@ -43,21 +43,26 @@ class PPOCR { public: explicit PPOCR(); ~PPOCR(); - std::vector> - ocr(std::vector cv_all_img_names, bool det = true, - bool rec = true, bool cls = true); + + std::vector> ocr(std::vector img_list, + bool det = true, + bool rec = true, + bool cls = true); + std::vector ocr(cv::Mat img, bool det, bool rec, bool cls); + + void reset_timer(); + void benchmark_log(int img_num); protected: - void det(cv::Mat img, std::vector &ocr_results, - std::vector ×); + std::vector time_info_det = {0, 0, 0}; + std::vector time_info_rec = {0, 0, 0}; + std::vector time_info_cls = {0, 0, 0}; + + void det(cv::Mat img, std::vector &ocr_results); void rec(std::vector img_list, - std::vector &ocr_results, - std::vector ×); + std::vector &ocr_results); void cls(std::vector img_list, - std::vector &ocr_results, - std::vector ×); - void log(std::vector &det_times, std::vector &rec_times, - std::vector &cls_times, int img_num); + std::vector &ocr_results); private: DBDetector *detector_ = nullptr; diff --git a/deploy/cpp_infer/include/paddlestructure.h b/deploy/cpp_infer/include/paddlestructure.h index 6d2c8b7d203a05f531b8d038d885061c42897373..bee888a827ff9e01f97cedc06241dcf1ff0a2aea 100644 --- a/deploy/cpp_infer/include/paddlestructure.h +++ b/deploy/cpp_infer/include/paddlestructure.h @@ -31,6 +31,7 @@ #include #include +#include #include #include @@ -42,23 +43,31 @@ class PaddleStructure : public PPOCR { public: explicit PaddleStructure(); ~PaddleStructure(); - std::vector> - structure(std::vector cv_all_img_names, bool layout = false, - bool table = true); + + std::vector structure(cv::Mat img, + bool layout = false, + bool table = true, + bool ocr = false); + + void reset_timer(); + void benchmark_log(int img_num); private: - StructureTableRecognizer *recognizer_ = nullptr; + std::vector time_info_table = {0, 0, 0}; + std::vector time_info_layout = {0, 0, 0}; + + StructureTableRecognizer *table_model_ = nullptr; + StructureLayoutRecognizer *layout_model_ = nullptr; + + void layout(cv::Mat img, + std::vector &structure_result); + + void table(cv::Mat img, StructurePredictResult &structure_result); - void table(cv::Mat img, StructurePredictResult &structure_result, - std::vector &time_info_table, - std::vector &time_info_det, - std::vector &time_info_rec, - std::vector &time_info_cls); std::string rebuild_table(std::vector rec_html_tags, std::vector> rec_boxes, std::vector &ocr_result); - float iou(std::vector &box1, std::vector &box2); float dis(std::vector &box1, std::vector &box2); static bool comparison_dis(const std::vector &dis1, diff --git a/deploy/cpp_infer/include/postprocess_op.h b/deploy/cpp_infer/include/postprocess_op.h index f5db52a6097f0fb916fc96fd8c76095f2ed1a9fa..2bff298c3c726b321aa5a8c351599091c4b09f4c 100644 --- a/deploy/cpp_infer/include/postprocess_op.h +++ b/deploy/cpp_infer/include/postprocess_op.h @@ -92,7 +92,23 @@ private: class TablePostProcessor { public: - void init(std::string label_path, bool merge_no_span_structure = true); + void init(std::string label_path, bool merge_no_span_structure = true) { + this->label_list_ = Utility::ReadDict(label_path); + if (merge_no_span_structure) { + this->label_list_.push_back(""); + std::vector::iterator it; + for (it = this->label_list_.begin(); it != this->label_list_.end();) { + if (*it == "") { + it = this->label_list_.erase(it); + } else { + ++it; + } + } + } + // add_special_char + this->label_list_.insert(this->label_list_.begin(), this->beg); + this->label_list_.push_back(this->end); + } void Run(std::vector &loc_preds, std::vector &structure_probs, std::vector &rec_scores, std::vector &loc_preds_shape, std::vector &structure_probs_shape, @@ -106,4 +122,33 @@ private: std::string beg = "sos"; }; +class PicodetPostProcessor { +public: + void init(std::string label_path, const double score_threshold = 0.4, + const double nms_threshold = 0.5, + const std::vector &fpn_stride = {8, 16, 32, 64}) { + this->label_list_ = Utility::ReadDict(label_path); + this->score_threshold_ = score_threshold; + this->nms_threshold_ = nms_threshold; + this->num_class_ = label_list_.size(); + this->fpn_stride_ = fpn_stride; + } + void Run(std::vector &results, + std::vector> outs, std::vector ori_shape, + std::vector resize_shape, int eg_max); + std::vector fpn_stride_ = {8, 16, 32, 64}; + +private: + StructurePredictResult disPred2Bbox(std::vector bbox_pred, int label, + float score, int x, int y, int stride, + std::vector im_shape, int reg_max); + void nms(std::vector &input_boxes, + float nms_threshold); + + std::vector label_list_; + double score_threshold_ = 0.4; + double nms_threshold_ = 0.5; + int num_class_ = 5; +}; + } // namespace PaddleOCR diff --git a/deploy/cpp_infer/include/preprocess_op.h b/deploy/cpp_infer/include/preprocess_op.h index 078f19d5b808c81e88d7aa464d6bfaca7fe1b14e..46cda1ca0176027422c5caa1874dfaa40d3a1010 100644 --- a/deploy/cpp_infer/include/preprocess_op.h +++ b/deploy/cpp_infer/include/preprocess_op.h @@ -82,4 +82,10 @@ public: const int max_len = 488); }; +class Resize { +public: + virtual void Run(const cv::Mat &img, cv::Mat &resize_img, const int h, + const int w); +}; + } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/include/structure_layout.h b/deploy/cpp_infer/include/structure_layout.h new file mode 100644 index 0000000000000000000000000000000000000000..57f5f4810799df68230ce050ecb1399be619c25b --- /dev/null +++ b/deploy/cpp_infer/include/structure_layout.h @@ -0,0 +1,94 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include "paddle_api.h" +#include "paddle_inference_api.h" +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +using namespace paddle_infer; + +namespace PaddleOCR { + +class StructureLayoutRecognizer { +public: + explicit StructureLayoutRecognizer( + const std::string &model_dir, const bool &use_gpu, const int &gpu_id, + const int &gpu_mem, const int &cpu_math_library_num_threads, + const bool &use_mkldnn, const string &label_path, + const bool &use_tensorrt, const std::string &precision, + const double &layout_score_threshold, + const double &layout_nms_threshold) { + this->use_gpu_ = use_gpu; + this->gpu_id_ = gpu_id; + this->gpu_mem_ = gpu_mem; + this->cpu_math_library_num_threads_ = cpu_math_library_num_threads; + this->use_mkldnn_ = use_mkldnn; + this->use_tensorrt_ = use_tensorrt; + this->precision_ = precision; + + this->post_processor_.init(label_path, layout_score_threshold, + layout_nms_threshold); + LoadModel(model_dir); + } + + // Load Paddle inference model + void LoadModel(const std::string &model_dir); + + void Run(cv::Mat img, std::vector &result, + std::vector ×); + +private: + std::shared_ptr predictor_; + + bool use_gpu_ = false; + int gpu_id_ = 0; + int gpu_mem_ = 4000; + int cpu_math_library_num_threads_ = 4; + bool use_mkldnn_ = false; + + std::vector mean_ = {0.485f, 0.456f, 0.406f}; + std::vector scale_ = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f}; + bool is_scale_ = true; + + bool use_tensorrt_ = false; + std::string precision_ = "fp32"; + + // pre-process + Resize resize_op_; + Normalize normalize_op_; + Permute permute_op_; + + // post-process + PicodetPostProcessor post_processor_; + +}; // class StructureTableRecognizer + +} // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/include/utility.h b/deploy/cpp_infer/include/utility.h index 85b280fe25a46be70dba529891c3470a729dfbf1..ddf2eee1c497d55a3fb5335ca68575b22a2e2e8c 100644 --- a/deploy/cpp_infer/include/utility.h +++ b/deploy/cpp_infer/include/utility.h @@ -42,11 +42,13 @@ struct OCRPredictResult { struct StructurePredictResult { std::vector box; + std::vector box_float; std::vector> cell_box; std::string type; std::vector text_res; std::string html; float html_score = -1; + float confidence; }; class Utility { @@ -58,7 +60,7 @@ public: const std::string &save_path); static void VisualizeBboxes(const cv::Mat &srcimg, - const StructurePredictResult &structure_result, + StructurePredictResult &structure_result, const std::string &save_path); template @@ -89,6 +91,12 @@ public: static std::vector xyxyxyxy2xyxy(std::vector> &box); static std::vector xyxyxyxy2xyxy(std::vector &box); + static float fast_exp(float x); + static std::vector + activation_function_softmax(std::vector &src); + static float iou(std::vector &box1, std::vector &box2); + static float iou(std::vector &box1, std::vector &box2); + private: static bool comparison_box(const OCRPredictResult &result1, const OCRPredictResult &result2) { diff --git a/deploy/cpp_infer/src/args.cpp b/deploy/cpp_infer/src/args.cpp index 17e9c8b625baf53c2583a6d778aba552cdd19e97..28066f0b20061059f32e2658fa4ea70fd827acb7 100644 --- a/deploy/cpp_infer/src/args.cpp +++ b/deploy/cpp_infer/src/args.cpp @@ -51,6 +51,13 @@ DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt", DEFINE_int32(rec_img_h, 48, "rec image height"); DEFINE_int32(rec_img_w, 320, "rec image width"); +// layout model related +DEFINE_string(layout_model_dir, "", "Path of table layout inference model."); +DEFINE_string(layout_dict_path, + "../../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt", + "Path of dictionary."); +DEFINE_double(layout_score_threshold, 0.5, "Threshold of score."); +DEFINE_double(layout_nms_threshold, 0.5, "Threshold of nms."); // structure model related DEFINE_string(table_model_dir, "", "Path of table struture inference model."); DEFINE_int32(table_max_len, 488, "max len size of input image."); @@ -65,4 +72,5 @@ DEFINE_string(table_char_dict_path, DEFINE_bool(det, true, "Whether use det in forward."); DEFINE_bool(rec, true, "Whether use rec in forward."); DEFINE_bool(cls, false, "Whether use cls in forward."); -DEFINE_bool(table, false, "Whether use table structure in forward."); \ No newline at end of file +DEFINE_bool(table, false, "Whether use table structure in forward."); +DEFINE_bool(layout, false, "Whether use layout analysis in forward."); \ No newline at end of file diff --git a/deploy/cpp_infer/src/main.cpp b/deploy/cpp_infer/src/main.cpp index 34ffdc62674ef02b2d30c8e213a783495ceaff99..a639614b3eec4214a8dd9385b0c64db56d43607f 100644 --- a/deploy/cpp_infer/src/main.cpp +++ b/deploy/cpp_infer/src/main.cpp @@ -65,6 +65,14 @@ void check_params() { exit(1); } } + if (FLAGS_layout) { + if (FLAGS_layout_model_dir.empty() || FLAGS_image_dir.empty()) { + std::cout << "Usage[layout]: ./ppocr " + << "--layout_model_dir=/PATH/TO/LAYOUT_INFERENCE_MODEL/ " + << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl; + exit(1); + } + } if (FLAGS_precision != "fp32" && FLAGS_precision != "fp16" && FLAGS_precision != "int8") { cout << "precison should be 'fp32'(default), 'fp16' or 'int8'. " << endl; @@ -75,71 +83,94 @@ void check_params() { void ocr(std::vector &cv_all_img_names) { PPOCR ocr = PPOCR(); - std::vector> ocr_results = - ocr.ocr(cv_all_img_names, FLAGS_det, FLAGS_rec, FLAGS_cls); + if (FLAGS_benchmark) { + ocr.reset_timer(); + } + std::vector img_list; + std::vector img_names; for (int i = 0; i < cv_all_img_names.size(); ++i) { - if (FLAGS_benchmark) { - cout << cv_all_img_names[i] << '\t'; - if (FLAGS_rec && FLAGS_det) { - Utility::print_result(ocr_results[i]); - } else if (FLAGS_det) { - for (int n = 0; n < ocr_results[i].size(); n++) { - for (int m = 0; m < ocr_results[i][n].box.size(); m++) { - cout << ocr_results[i][n].box[m][0] << ' ' - << ocr_results[i][n].box[m][1] << ' '; - } - } - cout << endl; - } else { - Utility::print_result(ocr_results[i]); - } - } else { - cout << cv_all_img_names[i] << "\n"; - Utility::print_result(ocr_results[i]); - if (FLAGS_visualize && FLAGS_det) { - cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR); - if (!srcimg.data) { - std::cerr << "[ERROR] image read failed! image path: " - << cv_all_img_names[i] << endl; - exit(1); - } - std::string file_name = Utility::basename(cv_all_img_names[i]); + cv::Mat img = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR); + if (!img.data) { + std::cerr << "[ERROR] image read failed! image path: " + << cv_all_img_names[i] << endl; + continue; + } + img_list.push_back(img); + img_names.push_back(cv_all_img_names[i]); + } - Utility::VisualizeBboxes(srcimg, ocr_results[i], - FLAGS_output + "/" + file_name); - } - cout << "***************************" << endl; + std::vector> ocr_results = + ocr.ocr(img_list, FLAGS_det, FLAGS_rec, FLAGS_cls); + + for (int i = 0; i < img_names.size(); ++i) { + cout << "predict img: " << cv_all_img_names[i] << endl; + Utility::print_result(ocr_results[i]); + if (FLAGS_visualize && FLAGS_det) { + std::string file_name = Utility::basename(img_names[i]); + cv::Mat srcimg = img_list[i]; + Utility::VisualizeBboxes(srcimg, ocr_results[i], + FLAGS_output + "/" + file_name); } } + if (FLAGS_benchmark) { + ocr.benchmark_log(cv_all_img_names.size()); + } } void structure(std::vector &cv_all_img_names) { PaddleOCR::PaddleStructure engine = PaddleOCR::PaddleStructure(); - std::vector> structure_results = - engine.structure(cv_all_img_names, false, FLAGS_table); + + if (FLAGS_benchmark) { + engine.reset_timer(); + } + for (int i = 0; i < cv_all_img_names.size(); i++) { cout << "predict img: " << cv_all_img_names[i] << endl; - cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR); - for (int j = 0; j < structure_results[i].size(); j++) { - std::cout << j << "\ttype: " << structure_results[i][j].type + cv::Mat img = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR); + if (!img.data) { + std::cerr << "[ERROR] image read failed! image path: " + << cv_all_img_names[i] << endl; + continue; + } + + std::vector structure_results = engine.structure( + img, FLAGS_layout, FLAGS_table, FLAGS_det && FLAGS_rec); + + for (int j = 0; j < structure_results.size(); j++) { + std::cout << j << "\ttype: " << structure_results[j].type << ", region: ["; - std::cout << structure_results[i][j].box[0] << "," - << structure_results[i][j].box[1] << "," - << structure_results[i][j].box[2] << "," - << structure_results[i][j].box[3] << "], res: "; - if (structure_results[i][j].type == "table") { - std::cout << structure_results[i][j].html << std::endl; - std::string file_name = Utility::basename(cv_all_img_names[i]); - - Utility::VisualizeBboxes(srcimg, structure_results[i][j], - FLAGS_output + "/" + std::to_string(j) + "_" + - file_name); + std::cout << structure_results[j].box[0] << "," + << structure_results[j].box[1] << "," + << structure_results[j].box[2] << "," + << structure_results[j].box[3] << "], score: "; + std::cout << structure_results[j].confidence << ", res: "; + + if (structure_results[j].type == "table") { + std::cout << structure_results[j].html << std::endl; + if (structure_results[j].cell_box.size() > 0 && FLAGS_visualize) { + std::string file_name = Utility::basename(cv_all_img_names[i]); + + Utility::VisualizeBboxes(img, structure_results[j], + FLAGS_output + "/" + std::to_string(j) + + "_" + file_name); + } } else { - Utility::print_result(structure_results[i][j].text_res); + cout << "count of ocr result is : " + << structure_results[j].text_res.size() << endl; + if (structure_results[j].text_res.size() > 0) { + cout << "********** print ocr result " + << "**********" << endl; + Utility::print_result(structure_results[j].text_res); + cout << "********** end print ocr result " + << "**********" << endl; + } } } } + if (FLAGS_benchmark) { + engine.benchmark_log(cv_all_img_names.size()); + } } int main(int argc, char **argv) { @@ -157,6 +188,9 @@ int main(int argc, char **argv) { cv::glob(FLAGS_image_dir, cv_all_img_names); std::cout << "total images num: " << cv_all_img_names.size() << endl; + if (!Utility::PathExists(FLAGS_output)) { + Utility::CreateDir(FLAGS_output); + } if (FLAGS_type == "ocr") { ocr(cv_all_img_names); } else if (FLAGS_type == "structure") { diff --git a/deploy/cpp_infer/src/paddleocr.cpp b/deploy/cpp_infer/src/paddleocr.cpp index 1de4fc7e9af8bf63cf68ef42d2a508cdc4b5f9f3..7417e9660a05f5745a3bf633c9637815425b0d83 100644 --- a/deploy/cpp_infer/src/paddleocr.cpp +++ b/deploy/cpp_infer/src/paddleocr.cpp @@ -44,8 +44,71 @@ PPOCR::PPOCR() { } }; -void PPOCR::det(cv::Mat img, std::vector &ocr_results, - std::vector ×) { +std::vector> +PPOCR::ocr(std::vector img_list, bool det, bool rec, bool cls) { + std::vector> ocr_results; + + if (!det) { + std::vector ocr_result; + ocr_result.resize(img_list.size()); + if (cls && this->classifier_ != nullptr) { + this->cls(img_list, ocr_result); + for (int i = 0; i < img_list.size(); i++) { + if (ocr_result[i].cls_label % 2 == 1 && + ocr_result[i].cls_score > this->classifier_->cls_thresh) { + cv::rotate(img_list[i], img_list[i], 1); + } + } + } + if (rec) { + this->rec(img_list, ocr_result); + } + for (int i = 0; i < ocr_result.size(); ++i) { + std::vector ocr_result_tmp; + ocr_result_tmp.push_back(ocr_result[i]); + ocr_results.push_back(ocr_result_tmp); + } + } else { + for (int i = 0; i < img_list.size(); ++i) { + std::vector ocr_result = + this->ocr(img_list[i], true, rec, cls); + ocr_results.push_back(ocr_result); + } + } + return ocr_results; +} + +std::vector PPOCR::ocr(cv::Mat img, bool det, bool rec, + bool cls) { + + std::vector ocr_result; + // det + this->det(img, ocr_result); + // crop image + std::vector img_list; + for (int j = 0; j < ocr_result.size(); j++) { + cv::Mat crop_img; + crop_img = Utility::GetRotateCropImage(img, ocr_result[j].box); + img_list.push_back(crop_img); + } + // cls + if (cls && this->classifier_ != nullptr) { + this->cls(img_list, ocr_result); + for (int i = 0; i < img_list.size(); i++) { + if (ocr_result[i].cls_label % 2 == 1 && + ocr_result[i].cls_score > this->classifier_->cls_thresh) { + cv::rotate(img_list[i], img_list[i], 1); + } + } + } + // rec + if (rec) { + this->rec(img_list, ocr_result); + } + return ocr_result; +} + +void PPOCR::det(cv::Mat img, std::vector &ocr_results) { std::vector>> boxes; std::vector det_times; @@ -58,14 +121,13 @@ void PPOCR::det(cv::Mat img, std::vector &ocr_results, } // sort boex from top to bottom, from left to right Utility::sorted_boxes(ocr_results); - times[0] += det_times[0]; - times[1] += det_times[1]; - times[2] += det_times[2]; + this->time_info_det[0] += det_times[0]; + this->time_info_det[1] += det_times[1]; + this->time_info_det[2] += det_times[2]; } void PPOCR::rec(std::vector img_list, - std::vector &ocr_results, - std::vector ×) { + std::vector &ocr_results) { std::vector rec_texts(img_list.size(), ""); std::vector rec_text_scores(img_list.size(), 0); std::vector rec_times; @@ -75,14 +137,13 @@ void PPOCR::rec(std::vector img_list, ocr_results[i].text = rec_texts[i]; ocr_results[i].score = rec_text_scores[i]; } - times[0] += rec_times[0]; - times[1] += rec_times[1]; - times[2] += rec_times[2]; + this->time_info_rec[0] += rec_times[0]; + this->time_info_rec[1] += rec_times[1]; + this->time_info_rec[2] += rec_times[2]; } void PPOCR::cls(std::vector img_list, - std::vector &ocr_results, - std::vector ×) { + std::vector &ocr_results) { std::vector cls_labels(img_list.size(), 0); std::vector cls_scores(img_list.size(), 0); std::vector cls_times; @@ -92,125 +153,43 @@ void PPOCR::cls(std::vector img_list, ocr_results[i].cls_label = cls_labels[i]; ocr_results[i].cls_score = cls_scores[i]; } - times[0] += cls_times[0]; - times[1] += cls_times[1]; - times[2] += cls_times[2]; + this->time_info_cls[0] += cls_times[0]; + this->time_info_cls[1] += cls_times[1]; + this->time_info_cls[2] += cls_times[2]; } -std::vector> -PPOCR::ocr(std::vector cv_all_img_names, bool det, bool rec, - bool cls) { - std::vector time_info_det = {0, 0, 0}; - std::vector time_info_rec = {0, 0, 0}; - std::vector time_info_cls = {0, 0, 0}; - std::vector> ocr_results; - - if (!det) { - std::vector ocr_result; - // read image - std::vector img_list; - for (int i = 0; i < cv_all_img_names.size(); ++i) { - cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR); - if (!srcimg.data) { - std::cerr << "[ERROR] image read failed! image path: " - << cv_all_img_names[i] << endl; - exit(1); - } - img_list.push_back(srcimg); - OCRPredictResult res; - ocr_result.push_back(res); - } - if (cls && this->classifier_ != nullptr) { - this->cls(img_list, ocr_result, time_info_cls); - for (int i = 0; i < img_list.size(); i++) { - if (ocr_result[i].cls_label % 2 == 1 && - ocr_result[i].cls_score > this->classifier_->cls_thresh) { - cv::rotate(img_list[i], img_list[i], 1); - } - } - } - if (rec) { - this->rec(img_list, ocr_result, time_info_rec); - } - for (int i = 0; i < cv_all_img_names.size(); ++i) { - std::vector ocr_result_tmp; - ocr_result_tmp.push_back(ocr_result[i]); - ocr_results.push_back(ocr_result_tmp); - } - } else { - if (!Utility::PathExists(FLAGS_output) && FLAGS_det) { - Utility::CreateDir(FLAGS_output); - } - - for (int i = 0; i < cv_all_img_names.size(); ++i) { - std::vector ocr_result; - if (!FLAGS_benchmark) { - cout << "predict img: " << cv_all_img_names[i] << endl; - } - - cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR); - if (!srcimg.data) { - std::cerr << "[ERROR] image read failed! image path: " - << cv_all_img_names[i] << endl; - exit(1); - } - // det - this->det(srcimg, ocr_result, time_info_det); - // crop image - std::vector img_list; - for (int j = 0; j < ocr_result.size(); j++) { - cv::Mat crop_img; - crop_img = Utility::GetRotateCropImage(srcimg, ocr_result[j].box); - img_list.push_back(crop_img); - } - - // cls - if (cls && this->classifier_ != nullptr) { - this->cls(img_list, ocr_result, time_info_cls); - for (int i = 0; i < img_list.size(); i++) { - if (ocr_result[i].cls_label % 2 == 1 && - ocr_result[i].cls_score > this->classifier_->cls_thresh) { - cv::rotate(img_list[i], img_list[i], 1); - } - } - } - // rec - if (rec) { - this->rec(img_list, ocr_result, time_info_rec); - } - ocr_results.push_back(ocr_result); - } - } - if (FLAGS_benchmark) { - this->log(time_info_det, time_info_rec, time_info_cls, - cv_all_img_names.size()); - } - return ocr_results; -} // namespace PaddleOCR +void PPOCR::reset_timer() { + this->time_info_det = {0, 0, 0}; + this->time_info_rec = {0, 0, 0}; + this->time_info_cls = {0, 0, 0}; +} -void PPOCR::log(std::vector &det_times, std::vector &rec_times, - std::vector &cls_times, int img_num) { - if (det_times[0] + det_times[1] + det_times[2] > 0) { +void PPOCR::benchmark_log(int img_num) { + if (this->time_info_det[0] + this->time_info_det[1] + this->time_info_det[2] > + 0) { AutoLogger autolog_det("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt, FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic", - FLAGS_precision, det_times, img_num); + FLAGS_precision, this->time_info_det, img_num); autolog_det.report(); } - if (rec_times[0] + rec_times[1] + rec_times[2] > 0) { + if (this->time_info_rec[0] + this->time_info_rec[1] + this->time_info_rec[2] > + 0) { AutoLogger autolog_rec("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt, FLAGS_enable_mkldnn, FLAGS_cpu_threads, FLAGS_rec_batch_num, "dynamic", FLAGS_precision, - rec_times, img_num); + this->time_info_rec, img_num); autolog_rec.report(); } - if (cls_times[0] + cls_times[1] + cls_times[2] > 0) { + if (this->time_info_cls[0] + this->time_info_cls[1] + this->time_info_cls[2] > + 0) { AutoLogger autolog_cls("ocr_cls", FLAGS_use_gpu, FLAGS_use_tensorrt, FLAGS_enable_mkldnn, FLAGS_cpu_threads, FLAGS_cls_batch_num, "dynamic", FLAGS_precision, - cls_times, img_num); + this->time_info_cls, img_num); autolog_cls.report(); } } + PPOCR::~PPOCR() { if (this->detector_ != nullptr) { delete this->detector_; diff --git a/deploy/cpp_infer/src/paddlestructure.cpp b/deploy/cpp_infer/src/paddlestructure.cpp index ea69977a1e45b0f7c1235a647d7c56db4d3cbc74..73df39e1ae5f6ff0f65d8149d4c77bcc2fadb566 100644 --- a/deploy/cpp_infer/src/paddlestructure.cpp +++ b/deploy/cpp_infer/src/paddlestructure.cpp @@ -22,8 +22,15 @@ namespace PaddleOCR { PaddleStructure::PaddleStructure() { + if (FLAGS_layout) { + this->layout_model_ = new StructureLayoutRecognizer( + FLAGS_layout_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem, + FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_layout_dict_path, + FLAGS_use_tensorrt, FLAGS_precision, FLAGS_layout_score_threshold, + FLAGS_layout_nms_threshold); + } if (FLAGS_table) { - this->recognizer_ = new StructureTableRecognizer( + this->table_model_ = new StructureTableRecognizer( FLAGS_table_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_table_char_dict_path, FLAGS_use_tensorrt, FLAGS_precision, FLAGS_table_batch_num, @@ -31,68 +38,63 @@ PaddleStructure::PaddleStructure() { } }; -std::vector> -PaddleStructure::structure(std::vector cv_all_img_names, - bool layout, bool table) { - std::vector time_info_det = {0, 0, 0}; - std::vector time_info_rec = {0, 0, 0}; - std::vector time_info_cls = {0, 0, 0}; - std::vector time_info_table = {0, 0, 0}; +std::vector +PaddleStructure::structure(cv::Mat srcimg, bool layout, bool table, bool ocr) { + cv::Mat img; + srcimg.copyTo(img); - std::vector> structure_results; + std::vector structure_results; - if (!Utility::PathExists(FLAGS_output) && FLAGS_det) { - Utility::CreateDir(FLAGS_output); + if (layout) { + this->layout(img, structure_results); + } else { + StructurePredictResult res; + res.type = "table"; + res.box = std::vector(4, 0); + res.box[2] = img.cols; + res.box[3] = img.rows; + structure_results.push_back(res); } - for (int i = 0; i < cv_all_img_names.size(); ++i) { - std::vector structure_result; - cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR); - if (!srcimg.data) { - std::cerr << "[ERROR] image read failed! image path: " - << cv_all_img_names[i] << endl; - exit(1); - } - if (layout) { - } else { - StructurePredictResult res; - res.type = "table"; - res.box = std::vector(4, 0); - res.box[2] = srcimg.cols; - res.box[3] = srcimg.rows; - structure_result.push_back(res); - } - cv::Mat roi_img; - for (int i = 0; i < structure_result.size(); i++) { - // crop image - roi_img = Utility::crop_image(srcimg, structure_result[i].box); - if (structure_result[i].type == "table") { - this->table(roi_img, structure_result[i], time_info_table, - time_info_det, time_info_rec, time_info_cls); - } + cv::Mat roi_img; + for (int i = 0; i < structure_results.size(); i++) { + // crop image + roi_img = Utility::crop_image(img, structure_results[i].box); + if (structure_results[i].type == "table" && table) { + this->table(roi_img, structure_results[i]); + } else if (ocr) { + structure_results[i].text_res = this->ocr(roi_img, true, true, false); } - structure_results.push_back(structure_result); } + return structure_results; }; +void PaddleStructure::layout( + cv::Mat img, std::vector &structure_result) { + std::vector layout_times; + this->layout_model_->Run(img, structure_result, layout_times); + + this->time_info_layout[0] += layout_times[0]; + this->time_info_layout[1] += layout_times[1]; + this->time_info_layout[2] += layout_times[2]; +} + void PaddleStructure::table(cv::Mat img, - StructurePredictResult &structure_result, - std::vector &time_info_table, - std::vector &time_info_det, - std::vector &time_info_rec, - std::vector &time_info_cls) { + StructurePredictResult &structure_result) { // predict structure std::vector> structure_html_tags; std::vector structure_scores(1, 0); std::vector>> structure_boxes; - std::vector structure_imes; + std::vector structure_times; std::vector img_list; img_list.push_back(img); - this->recognizer_->Run(img_list, structure_html_tags, structure_scores, - structure_boxes, structure_imes); - time_info_table[0] += structure_imes[0]; - time_info_table[1] += structure_imes[1]; - time_info_table[2] += structure_imes[2]; + + this->table_model_->Run(img_list, structure_html_tags, structure_scores, + structure_boxes, structure_times); + + this->time_info_table[0] += structure_times[0]; + this->time_info_table[1] += structure_times[1]; + this->time_info_table[2] += structure_times[2]; std::vector ocr_result; std::string html; @@ -100,7 +102,7 @@ void PaddleStructure::table(cv::Mat img, for (int i = 0; i < img_list.size(); i++) { // det - this->det(img_list[i], ocr_result, time_info_det); + this->det(img_list[i], ocr_result); // crop image std::vector rec_img_list; std::vector ocr_box; @@ -115,7 +117,7 @@ void PaddleStructure::table(cv::Mat img, rec_img_list.push_back(crop_img); } // rec - this->rec(rec_img_list, ocr_result, time_info_rec); + this->rec(rec_img_list, ocr_result); // rebuild table html = this->rebuild_table(structure_html_tags[i], structure_boxes[i], ocr_result); @@ -150,7 +152,7 @@ PaddleStructure::rebuild_table(std::vector structure_html_tags, structure_box = structure_boxes[j]; } dis_list[j][0] = this->dis(ocr_box, structure_box); - dis_list[j][1] = 1 - this->iou(ocr_box, structure_box); + dis_list[j][1] = 1 - Utility::iou(ocr_box, structure_box); dis_list[j][2] = j; } // find min dis idx @@ -216,28 +218,6 @@ PaddleStructure::rebuild_table(std::vector structure_html_tags, return html_str; } -float PaddleStructure::iou(std::vector &box1, std::vector &box2) { - int area1 = max(0, box1[2] - box1[0]) * max(0, box1[3] - box1[1]); - int area2 = max(0, box2[2] - box2[0]) * max(0, box2[3] - box2[1]); - - // computing the sum_area - int sum_area = area1 + area2; - - // find the each point of intersect rectangle - int x1 = max(box1[0], box2[0]); - int y1 = max(box1[1], box2[1]); - int x2 = min(box1[2], box2[2]); - int y2 = min(box1[3], box2[3]); - - // judge if there is an intersect - if (y1 >= y2 || x1 >= x2) { - return 0.0; - } else { - int intersect = (x2 - x1) * (y2 - y1); - return intersect / (sum_area - intersect + 0.00000001); - } -} - float PaddleStructure::dis(std::vector &box1, std::vector &box2) { int x1_1 = box1[0]; int y1_1 = box1[1]; @@ -256,9 +236,61 @@ float PaddleStructure::dis(std::vector &box1, std::vector &box2) { return dis + min(dis_2, dis_3); } +void PaddleStructure::reset_timer() { + this->time_info_det = {0, 0, 0}; + this->time_info_rec = {0, 0, 0}; + this->time_info_cls = {0, 0, 0}; + this->time_info_table = {0, 0, 0}; + this->time_info_layout = {0, 0, 0}; +} + +void PaddleStructure::benchmark_log(int img_num) { + if (this->time_info_det[0] + this->time_info_det[1] + this->time_info_det[2] > + 0) { + AutoLogger autolog_det("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt, + FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic", + FLAGS_precision, this->time_info_det, img_num); + autolog_det.report(); + } + if (this->time_info_rec[0] + this->time_info_rec[1] + this->time_info_rec[2] > + 0) { + AutoLogger autolog_rec("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt, + FLAGS_enable_mkldnn, FLAGS_cpu_threads, + FLAGS_rec_batch_num, "dynamic", FLAGS_precision, + this->time_info_rec, img_num); + autolog_rec.report(); + } + if (this->time_info_cls[0] + this->time_info_cls[1] + this->time_info_cls[2] > + 0) { + AutoLogger autolog_cls("ocr_cls", FLAGS_use_gpu, FLAGS_use_tensorrt, + FLAGS_enable_mkldnn, FLAGS_cpu_threads, + FLAGS_cls_batch_num, "dynamic", FLAGS_precision, + this->time_info_cls, img_num); + autolog_cls.report(); + } + if (this->time_info_table[0] + this->time_info_table[1] + + this->time_info_table[2] > + 0) { + AutoLogger autolog_table("table", FLAGS_use_gpu, FLAGS_use_tensorrt, + FLAGS_enable_mkldnn, FLAGS_cpu_threads, + FLAGS_cls_batch_num, "dynamic", FLAGS_precision, + this->time_info_table, img_num); + autolog_table.report(); + } + if (this->time_info_layout[0] + this->time_info_layout[1] + + this->time_info_layout[2] > + 0) { + AutoLogger autolog_layout("layout", FLAGS_use_gpu, FLAGS_use_tensorrt, + FLAGS_enable_mkldnn, FLAGS_cpu_threads, + FLAGS_cls_batch_num, "dynamic", FLAGS_precision, + this->time_info_layout, img_num); + autolog_layout.report(); + } +} + PaddleStructure::~PaddleStructure() { - if (this->recognizer_ != nullptr) { - delete this->recognizer_; + if (this->table_model_ != nullptr) { + delete this->table_model_; } }; diff --git a/deploy/cpp_infer/src/postprocess_op.cpp b/deploy/cpp_infer/src/postprocess_op.cpp index 4b0c693c80467bceb75da2b3fef6e816b0690979..0a4da675b6afef9158b53cf566c80f50345adfee 100644 --- a/deploy/cpp_infer/src/postprocess_op.cpp +++ b/deploy/cpp_infer/src/postprocess_op.cpp @@ -352,25 +352,6 @@ std::vector>> DBPostProcessor::FilterTagDetRes( return root_points; } -void TablePostProcessor::init(std::string label_path, - bool merge_no_span_structure) { - this->label_list_ = Utility::ReadDict(label_path); - if (merge_no_span_structure) { - this->label_list_.push_back(""); - std::vector::iterator it; - for (it = this->label_list_.begin(); it != this->label_list_.end();) { - if (*it == "") { - it = this->label_list_.erase(it); - } else { - ++it; - } - } - } - // add_special_char - this->label_list_.insert(this->label_list_.begin(), this->beg); - this->label_list_.push_back(this->end); -} - void TablePostProcessor::Run( std::vector &loc_preds, std::vector &structure_probs, std::vector &rec_scores, std::vector &loc_preds_shape, @@ -440,4 +421,129 @@ void TablePostProcessor::Run( } } +void PicodetPostProcessor::Run(std::vector &results, + std::vector> outs, + std::vector ori_shape, + std::vector resize_shape, int reg_max) { + int in_h = resize_shape[0]; + int in_w = resize_shape[1]; + float scale_factor_h = resize_shape[0] / float(ori_shape[0]); + float scale_factor_w = resize_shape[1] / float(ori_shape[1]); + + std::vector> bbox_results; + bbox_results.resize(this->num_class_); + for (int i = 0; i < this->fpn_stride_.size(); ++i) { + int feature_h = std::ceil((float)in_h / this->fpn_stride_[i]); + int feature_w = std::ceil((float)in_w / this->fpn_stride_[i]); + for (int idx = 0; idx < feature_h * feature_w; idx++) { + // score and label + float score = 0; + int cur_label = 0; + for (int label = 0; label < this->num_class_; label++) { + if (outs[i][idx * this->num_class_ + label] > score) { + score = outs[i][idx * this->num_class_ + label]; + cur_label = label; + } + } + // bbox + if (score > this->score_threshold_) { + int row = idx / feature_w; + int col = idx % feature_w; + std::vector bbox_pred( + outs[i + this->fpn_stride_.size()].begin() + idx * 4 * reg_max, + outs[i + this->fpn_stride_.size()].begin() + + (idx + 1) * 4 * reg_max); + bbox_results[cur_label].push_back( + this->disPred2Bbox(bbox_pred, cur_label, score, col, row, + this->fpn_stride_[i], resize_shape, reg_max)); + } + } + } + for (int i = 0; i < bbox_results.size(); i++) { + bool flag = bbox_results[i].size() <= 0; + } + for (int i = 0; i < bbox_results.size(); i++) { + bool flag = bbox_results[i].size() <= 0; + if (bbox_results[i].size() <= 0) { + continue; + } + this->nms(bbox_results[i], this->nms_threshold_); + for (auto box : bbox_results[i]) { + box.box_float[0] = box.box_float[0] / scale_factor_w; + box.box_float[2] = box.box_float[2] / scale_factor_w; + box.box_float[1] = box.box_float[1] / scale_factor_h; + box.box_float[3] = box.box_float[3] / scale_factor_h; + box.box = {(int)box.box_float[0], (int)box.box_float[1], + (int)box.box_float[2], (int)box.box_float[3]}; + results.push_back(box); + } + } +} + +StructurePredictResult +PicodetPostProcessor::disPred2Bbox(std::vector bbox_pred, int label, + float score, int x, int y, int stride, + std::vector im_shape, int reg_max) { + float ct_x = (x + 0.5) * stride; + float ct_y = (y + 0.5) * stride; + std::vector dis_pred; + dis_pred.resize(4); + for (int i = 0; i < 4; i++) { + float dis = 0; + std::vector bbox_pred_i(bbox_pred.begin() + i * reg_max, + bbox_pred.begin() + (i + 1) * reg_max); + std::vector dis_after_sm = + Utility::activation_function_softmax(bbox_pred_i); + for (int j = 0; j < reg_max; j++) { + dis += j * dis_after_sm[j]; + } + dis *= stride; + dis_pred[i] = dis; + } + + float xmin_float = (std::max)(ct_x - dis_pred[0], .0f); + float ymin_float = (std::max)(ct_y - dis_pred[1], .0f); + float xmax_float = (std::min)(ct_x + dis_pred[2], (float)im_shape[1]); + float ymax_float = (std::min)(ct_y + dis_pred[3], (float)im_shape[0]); + + StructurePredictResult result_item; + result_item.box_float = {xmin_float, ymin_float, xmax_float, ymax_float}; + result_item.type = this->label_list_[label]; + result_item.confidence = score; + + return result_item; +} + +void PicodetPostProcessor::nms(std::vector &input_boxes, + float nms_threshold) { + std::sort(input_boxes.begin(), input_boxes.end(), + [](StructurePredictResult a, StructurePredictResult b) { + return a.confidence > b.confidence; + }); + std::vector picked(input_boxes.size(), 1); + + for (int i = 0; i < input_boxes.size(); ++i) { + if (picked[i] == 0) { + continue; + } + for (int j = i + 1; j < input_boxes.size(); ++j) { + if (picked[j] == 0) { + continue; + } + float iou = + Utility::iou(input_boxes[i].box_float, input_boxes[j].box_float); + if (iou > nms_threshold) { + picked[j] = 0; + } + } + } + std::vector input_boxes_nms; + for (int i = 0; i < input_boxes.size(); ++i) { + if (picked[i] == 1) { + input_boxes_nms.push_back(input_boxes[i]); + } + } + input_boxes = input_boxes_nms; +} + } // namespace PaddleOCR diff --git a/deploy/cpp_infer/src/preprocess_op.cpp b/deploy/cpp_infer/src/preprocess_op.cpp index ac185e22d68955ef440e22c327b835dbce6c4e1b..eb448a7e8226571eba66e4d28283000bc5fbfefb 100644 --- a/deploy/cpp_infer/src/preprocess_op.cpp +++ b/deploy/cpp_infer/src/preprocess_op.cpp @@ -175,4 +175,9 @@ void TablePadImg::Run(const cv::Mat &img, cv::Mat &resize_img, cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0)); } +void Resize::Run(const cv::Mat &img, cv::Mat &resize_img, const int h, + const int w) { + cv::resize(img, resize_img, cv::Size(w, h)); +} + } // namespace PaddleOCR diff --git a/deploy/cpp_infer/src/structure_layout.cpp b/deploy/cpp_infer/src/structure_layout.cpp new file mode 100644 index 0000000000000000000000000000000000000000..898ead7aef0a7394db144557527f07ed4823e950 --- /dev/null +++ b/deploy/cpp_infer/src/structure_layout.cpp @@ -0,0 +1,144 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +namespace PaddleOCR { + +void StructureLayoutRecognizer::Run(cv::Mat img, + std::vector &result, + std::vector ×) { + std::chrono::duration preprocess_diff = + std::chrono::steady_clock::now() - std::chrono::steady_clock::now(); + std::chrono::duration inference_diff = + std::chrono::steady_clock::now() - std::chrono::steady_clock::now(); + std::chrono::duration postprocess_diff = + std::chrono::steady_clock::now() - std::chrono::steady_clock::now(); + + // preprocess + auto preprocess_start = std::chrono::steady_clock::now(); + + cv::Mat srcimg; + img.copyTo(srcimg); + cv::Mat resize_img; + this->resize_op_.Run(srcimg, resize_img, 800, 608); + this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, + this->is_scale_); + + std::vector input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f); + this->permute_op_.Run(&resize_img, input.data()); + auto preprocess_end = std::chrono::steady_clock::now(); + preprocess_diff += preprocess_end - preprocess_start; + + // inference. + auto input_names = this->predictor_->GetInputNames(); + auto input_t = this->predictor_->GetInputHandle(input_names[0]); + input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); + auto inference_start = std::chrono::steady_clock::now(); + input_t->CopyFromCpu(input.data()); + + this->predictor_->Run(); + + // Get output tensor + std::vector> out_tensor_list; + std::vector> output_shape_list; + auto output_names = this->predictor_->GetOutputNames(); + for (int j = 0; j < output_names.size(); j++) { + auto output_tensor = this->predictor_->GetOutputHandle(output_names[j]); + std::vector output_shape = output_tensor->shape(); + int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, + std::multiplies()); + output_shape_list.push_back(output_shape); + + std::vector out_data; + out_data.resize(out_num); + output_tensor->CopyToCpu(out_data.data()); + out_tensor_list.push_back(out_data); + } + auto inference_end = std::chrono::steady_clock::now(); + inference_diff += inference_end - inference_start; + + // postprocess + auto postprocess_start = std::chrono::steady_clock::now(); + + std::vector bbox_num; + int reg_max = 0; + for (int i = 0; i < out_tensor_list.size(); i++) { + if (i == this->post_processor_.fpn_stride_.size()) { + reg_max = output_shape_list[i][2] / 4; + break; + } + } + std::vector ori_shape = {srcimg.rows, srcimg.cols}; + std::vector resize_shape = {resize_img.rows, resize_img.cols}; + this->post_processor_.Run(result, out_tensor_list, ori_shape, resize_shape, + reg_max); + bbox_num.push_back(result.size()); + + auto postprocess_end = std::chrono::steady_clock::now(); + postprocess_diff += postprocess_end - postprocess_start; + times.push_back(double(preprocess_diff.count() * 1000)); + times.push_back(double(inference_diff.count() * 1000)); + times.push_back(double(postprocess_diff.count() * 1000)); +} + +void StructureLayoutRecognizer::LoadModel(const std::string &model_dir) { + AnalysisConfig config; + if (Utility::PathExists(model_dir + "/inference.pdmodel") && + Utility::PathExists(model_dir + "/inference.pdiparams")) { + config.SetModel(model_dir + "/inference.pdmodel", + model_dir + "/inference.pdiparams"); + } else if (Utility::PathExists(model_dir + "/model.pdmodel") && + Utility::PathExists(model_dir + "/model.pdiparams")) { + config.SetModel(model_dir + "/model.pdmodel", + model_dir + "/model.pdiparams"); + } else { + std::cerr << "[ERROR] not find model.pdiparams or inference.pdiparams in " + << model_dir << endl; + exit(1); + } + + if (this->use_gpu_) { + config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); + if (this->use_tensorrt_) { + auto precision = paddle_infer::Config::Precision::kFloat32; + if (this->precision_ == "fp16") { + precision = paddle_infer::Config::Precision::kHalf; + } + if (this->precision_ == "int8") { + precision = paddle_infer::Config::Precision::kInt8; + } + config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false); + } + } else { + config.DisableGpu(); + if (this->use_mkldnn_) { + config.EnableMKLDNN(); + } + config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_); + } + + // false for zero copy tensor + config.SwitchUseFeedFetchOps(false); + // true for multiple input + config.SwitchSpecifyInputNames(true); + + config.SwitchIrOptim(true); + + config.EnableMemoryOptim(); + config.DisableGlogInfo(); + + this->predictor_ = CreatePredictor(config); +} +} // namespace PaddleOCR diff --git a/deploy/cpp_infer/src/utility.cpp b/deploy/cpp_infer/src/utility.cpp index 0e6ba17fc3bab5b5e005f8b5e41640899bee39d0..6b41a523afa18bcea833461309d1505abb8dc04c 100644 --- a/deploy/cpp_infer/src/utility.cpp +++ b/deploy/cpp_infer/src/utility.cpp @@ -66,10 +66,11 @@ void Utility::VisualizeBboxes(const cv::Mat &srcimg, } void Utility::VisualizeBboxes(const cv::Mat &srcimg, - const StructurePredictResult &structure_result, + StructurePredictResult &structure_result, const std::string &save_path) { cv::Mat img_vis; srcimg.copyTo(img_vis); + img_vis = crop_image(img_vis, structure_result.box); for (int n = 0; n < structure_result.cell_box.size(); n++) { if (structure_result.cell_box[n].size() == 8) { cv::Point rook_points[4]; @@ -280,17 +281,17 @@ void Utility::print_result(const std::vector &ocr_result) { } } -cv::Mat Utility::crop_image(cv::Mat &img, std::vector &area) { +cv::Mat Utility::crop_image(cv::Mat &img, std::vector &box) { cv::Mat crop_im; - int crop_x1 = std::max(0, area[0]); - int crop_y1 = std::max(0, area[1]); - int crop_x2 = std::min(img.cols - 1, area[2] - 1); - int crop_y2 = std::min(img.rows - 1, area[3] - 1); + int crop_x1 = std::max(0, box[0]); + int crop_y1 = std::max(0, box[1]); + int crop_x2 = std::min(img.cols - 1, box[2] - 1); + int crop_y2 = std::min(img.rows - 1, box[3] - 1); - crop_im = cv::Mat::zeros(area[3] - area[1], area[2] - area[0], 16); + crop_im = cv::Mat::zeros(box[3] - box[1], box[2] - box[0], 16); cv::Mat crop_im_window = - crop_im(cv::Range(crop_y1 - area[1], crop_y2 + 1 - area[1]), - cv::Range(crop_x1 - area[0], crop_x2 + 1 - area[0])); + crop_im(cv::Range(crop_y1 - box[1], crop_y2 + 1 - box[1]), + cv::Range(crop_x1 - box[0], crop_x2 + 1 - box[0])); cv::Mat roi_img = img(cv::Range(crop_y1, crop_y2 + 1), cv::Range(crop_x1, crop_x2 + 1)); crop_im_window += roi_img; @@ -341,4 +342,78 @@ std::vector Utility::xyxyxyxy2xyxy(std::vector &box) { return box1; } +float Utility::fast_exp(float x) { + union { + uint32_t i; + float f; + } v{}; + v.i = (1 << 23) * (1.4426950409 * x + 126.93490512f); + return v.f; +} + +std::vector +Utility::activation_function_softmax(std::vector &src) { + int length = src.size(); + std::vector dst; + dst.resize(length); + const float alpha = float(*std::max_element(&src[0], &src[0 + length])); + float denominator{0}; + + for (int i = 0; i < length; ++i) { + dst[i] = fast_exp(src[i] - alpha); + denominator += dst[i]; + } + + for (int i = 0; i < length; ++i) { + dst[i] /= denominator; + } + return dst; +} + +float Utility::iou(std::vector &box1, std::vector &box2) { + int area1 = std::max(0, box1[2] - box1[0]) * std::max(0, box1[3] - box1[1]); + int area2 = std::max(0, box2[2] - box2[0]) * std::max(0, box2[3] - box2[1]); + + // computing the sum_area + int sum_area = area1 + area2; + + // find the each point of intersect rectangle + int x1 = std::max(box1[0], box2[0]); + int y1 = std::max(box1[1], box2[1]); + int x2 = std::min(box1[2], box2[2]); + int y2 = std::min(box1[3], box2[3]); + + // judge if there is an intersect + if (y1 >= y2 || x1 >= x2) { + return 0.0; + } else { + int intersect = (x2 - x1) * (y2 - y1); + return intersect / (sum_area - intersect + 0.00000001); + } +} + +float Utility::iou(std::vector &box1, std::vector &box2) { + float area1 = std::max((float)0.0, box1[2] - box1[0]) * + std::max((float)0.0, box1[3] - box1[1]); + float area2 = std::max((float)0.0, box2[2] - box2[0]) * + std::max((float)0.0, box2[3] - box2[1]); + + // computing the sum_area + float sum_area = area1 + area2; + + // find the each point of intersect rectangle + float x1 = std::max(box1[0], box2[0]); + float y1 = std::max(box1[1], box2[1]); + float x2 = std::min(box1[2], box2[2]); + float y2 = std::min(box1[3], box2[3]); + + // judge if there is an intersect + if (y1 >= y2 || x1 >= x2) { + return 0.0; + } else { + float intersect = (x2 - x1) * (y2 - y1); + return intersect / (sum_area - intersect + 0.00000001); + } +} + } // namespace PaddleOCR \ No newline at end of file