diff --git a/deploy/cpp_infer/include/args.h b/deploy/cpp_infer/include/args.h index 473ff25d981f8409f60a43940aaaec376375adf5..e0dd8bbcd1044fd695c90805bc770de5b47e51cf 100644 --- a/deploy/cpp_infer/include/args.h +++ b/deploy/cpp_infer/include/args.h @@ -30,7 +30,8 @@ DECLARE_string(image_dir); DECLARE_string(type); // detection related DECLARE_string(det_model_dir); -DECLARE_int32(max_side_len); +DECLARE_string(limit_type); +DECLARE_int32(limit_side_len); DECLARE_double(det_db_thresh); DECLARE_double(det_db_box_thresh); DECLARE_double(det_db_unclip_ratio); @@ -48,7 +49,13 @@ DECLARE_int32(rec_batch_num); DECLARE_string(rec_char_dict_path); DECLARE_int32(rec_img_h); DECLARE_int32(rec_img_w); +// structure model related +DECLARE_string(table_model_dir); +DECLARE_int32(table_max_len); +DECLARE_int32(table_batch_num); +DECLARE_string(table_char_dict_path); // forward related DECLARE_bool(det); DECLARE_bool(rec); DECLARE_bool(cls); +DECLARE_bool(table); \ No newline at end of file diff --git a/deploy/cpp_infer/include/ocr_det.h b/deploy/cpp_infer/include/ocr_det.h index 7efd4d8f0f4ccb705fc34695bb9843e0b6af5a9b..d1421b103b28b44e15a7df53a63fd893ca60e529 100644 --- a/deploy/cpp_infer/include/ocr_det.h +++ b/deploy/cpp_infer/include/ocr_det.h @@ -41,8 +41,8 @@ public: explicit DBDetector(const std::string &model_dir, const bool &use_gpu, const int &gpu_id, const int &gpu_mem, const int &cpu_math_library_num_threads, - const bool &use_mkldnn, const int &max_side_len, - const double &det_db_thresh, + const bool &use_mkldnn, const string &limit_type, + const int &limit_side_len, const double &det_db_thresh, const double &det_db_box_thresh, const double &det_db_unclip_ratio, const std::string &det_db_score_mode, @@ -54,7 +54,8 @@ public: this->cpu_math_library_num_threads_ = cpu_math_library_num_threads; this->use_mkldnn_ = use_mkldnn; - this->max_side_len_ = max_side_len; + this->limit_type_ = limit_type; + this->limit_side_len_ = limit_side_len; this->det_db_thresh_ = det_db_thresh; this->det_db_box_thresh_ = det_db_box_thresh; @@ -84,7 +85,8 @@ private: int cpu_math_library_num_threads_ = 4; bool use_mkldnn_ = false; - int max_side_len_ = 960; + string limit_type_ = "max"; + int limit_side_len_ = 960; double det_db_thresh_ = 0.3; double det_db_box_thresh_ = 0.5; @@ -106,7 +108,7 @@ private: Permute permute_op_; // post-process - PostProcessor post_processor_; + DBPostProcessor post_processor_; }; } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/include/paddleocr.h b/deploy/cpp_infer/include/paddleocr.h index 6db9d86cb152bfcc708a87c6a98be59d88a5d8db..a2c60b14acceaa90a8d8e4a70ccc50f02f254eb6 100644 --- a/deploy/cpp_infer/include/paddleocr.h +++ b/deploy/cpp_infer/include/paddleocr.h @@ -47,11 +47,7 @@ public: ocr(std::vector cv_all_img_names, bool det = true, bool rec = true, bool cls = true); -private: - DBDetector *detector_ = nullptr; - Classifier *classifier_ = nullptr; - CRNNRecognizer *recognizer_ = nullptr; - +protected: void det(cv::Mat img, std::vector &ocr_results, std::vector ×); void rec(std::vector img_list, @@ -62,6 +58,11 @@ private: std::vector ×); void log(std::vector &det_times, std::vector &rec_times, std::vector &cls_times, int img_num); + +private: + DBDetector *detector_ = nullptr; + Classifier *classifier_ = nullptr; + CRNNRecognizer *recognizer_ = nullptr; }; } // namespace PaddleOCR diff --git a/deploy/cpp_infer/include/paddlestructure.h b/deploy/cpp_infer/include/paddlestructure.h new file mode 100644 index 0000000000000000000000000000000000000000..b30ac045b2a6552b69442b2e8b29673efc820e31 --- /dev/null +++ b/deploy/cpp_infer/include/paddlestructure.h @@ -0,0 +1,79 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include "paddle_api.h" +#include "paddle_inference_api.h" +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +using namespace paddle_infer; + +namespace PaddleOCR { + +class PaddleStructure : public PPOCR { +public: + explicit PaddleStructure(); + ~PaddleStructure(); + std::vector> + structure(std::vector cv_all_img_names, bool layout = false, + bool table = true); + +private: + StructureTableRecognizer *recognizer_ = nullptr; + + void table(cv::Mat img, StructurePredictResult &structure_result, + std::vector &time_info_table, + std::vector &time_info_det, + std::vector &time_info_rec, + std::vector &time_info_cls); + std::string + rebuild_table(std::vector rec_html_tags, + std::vector>> rec_boxes, + std::vector &ocr_result); + + float iou(std::vector> &box1, + std::vector> &box2); + float dis(std::vector> &box1, + std::vector> &box2); + + static bool comparison_dis(const std::vector &dis1, + const std::vector &dis2) { + if (dis1[1] < dis2[1]) { + return true; + } else if (dis1[1] == dis2[1]) { + return dis1[0] < dis2[0]; + } else { + return false; + } + } +}; + +} // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/include/postprocess_op.h b/deploy/cpp_infer/include/postprocess_op.h index 4a98b151bdcc53e2ab3fbda1dca55dd9746bd86c..77b3f8b660bda29815245b31ab8cac479b24498f 100644 --- a/deploy/cpp_infer/include/postprocess_op.h +++ b/deploy/cpp_infer/include/postprocess_op.h @@ -34,7 +34,7 @@ using namespace std; namespace PaddleOCR { -class PostProcessor { +class DBPostProcessor { public: void GetContourArea(const std::vector> &box, float unclip_ratio, float &distance); @@ -90,4 +90,21 @@ private: } }; +class TablePostProcessor { +public: + void init(std::string label_path); + void + Run(std::vector &loc_preds, std::vector &structure_probs, + std::vector &rec_scores, std::vector &loc_preds_shape, + std::vector &structure_probs_shape, + std::vector> &rec_html_tag_batch, + std::vector>>> &rec_boxes_batch, + std::vector &width_list, std::vector &height_list); + +private: + std::vector label_list_; + std::string end = "eos"; + std::string beg = "sos"; +}; + } // namespace PaddleOCR diff --git a/deploy/cpp_infer/include/preprocess_op.h b/deploy/cpp_infer/include/preprocess_op.h index 31217de301573e078f8e11ef88657f369ede9b31..078f19d5b808c81e88d7aa464d6bfaca7fe1b14e 100644 --- a/deploy/cpp_infer/include/preprocess_op.h +++ b/deploy/cpp_infer/include/preprocess_op.h @@ -48,11 +48,12 @@ class PermuteBatch { public: virtual void Run(const std::vector imgs, float *data); }; - + class ResizeImgType0 { public: - virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len, - float &ratio_h, float &ratio_w, bool use_tensorrt); + virtual void Run(const cv::Mat &img, cv::Mat &resize_img, string limit_type, + int limit_side_len, float &ratio_h, float &ratio_w, + bool use_tensorrt); }; class CrnnResizeImg { @@ -69,4 +70,16 @@ public: const std::vector &rec_image_shape = {3, 48, 192}); }; +class TableResizeImg { +public: + virtual void Run(const cv::Mat &img, cv::Mat &resize_img, + const int max_len = 488); +}; + +class TablePadImg { +public: + virtual void Run(const cv::Mat &img, cv::Mat &resize_img, + const int max_len = 488); +}; + } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/include/structure_table.h b/deploy/cpp_infer/include/structure_table.h new file mode 100644 index 0000000000000000000000000000000000000000..7449c6cd0e158425bccb75740191dd0b6d6ecc9b --- /dev/null +++ b/deploy/cpp_infer/include/structure_table.h @@ -0,0 +1,100 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include "paddle_api.h" +#include "paddle_inference_api.h" +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +using namespace paddle_infer; + +namespace PaddleOCR { + +class StructureTableRecognizer { +public: + explicit StructureTableRecognizer( + const std::string &model_dir, const bool &use_gpu, const int &gpu_id, + const int &gpu_mem, const int &cpu_math_library_num_threads, + const bool &use_mkldnn, const string &label_path, + const bool &use_tensorrt, const std::string &precision, + const int &table_batch_num, const int &table_max_len) { + this->use_gpu_ = use_gpu; + this->gpu_id_ = gpu_id; + this->gpu_mem_ = gpu_mem; + this->cpu_math_library_num_threads_ = cpu_math_library_num_threads; + this->use_mkldnn_ = use_mkldnn; + this->use_tensorrt_ = use_tensorrt; + this->precision_ = precision; + this->table_batch_num_ = table_batch_num; + this->table_max_len_ = table_max_len; + + this->post_processor_.init(label_path); + LoadModel(model_dir); + } + + // Load Paddle inference model + void LoadModel(const std::string &model_dir); + + void Run(std::vector img_list, + std::vector> &rec_html_tags, + std::vector &rec_scores, + std::vector>>> &rec_boxes, + std::vector ×); + +private: + std::shared_ptr predictor_; + + bool use_gpu_ = false; + int gpu_id_ = 0; + int gpu_mem_ = 4000; + int cpu_math_library_num_threads_ = 4; + bool use_mkldnn_ = false; + int table_max_len_ = 488; + + std::vector mean_ = {0.485f, 0.456f, 0.406f}; + std::vector scale_ = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f}; + bool is_scale_ = true; + + bool use_tensorrt_ = false; + std::string precision_ = "fp32"; + int table_batch_num_ = 1; + + // pre-process + TableResizeImg resize_op_; + Normalize normalize_op_; + PermuteBatch permute_op_; + TablePadImg pad_op_; + + // post-process + TablePostProcessor post_processor_; + +}; // class StructureTableRecognizer + +} // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/include/utility.h b/deploy/cpp_infer/include/utility.h index eb18c0624492e9b47de156d60611d637d8dca6c3..520804f64529303b5ecec27dc5f0895f1fff5c72 100644 --- a/deploy/cpp_infer/include/utility.h +++ b/deploy/cpp_infer/include/utility.h @@ -40,6 +40,14 @@ struct OCRPredictResult { int cls_label = -1; }; +struct StructurePredictResult { + std::vector box; + std::string type; + std::vector text_res; + std::string html; + float html_score = -1; +}; + class Utility { public: static std::vector ReadDict(const std::string &path); @@ -68,6 +76,22 @@ public: static void CreateDir(const std::string &path); static void print_result(const std::vector &ocr_result); + + static cv::Mat crop_image(cv::Mat &img, std::vector &area); + + static void sorted_boxes(std::vector &ocr_result); + +private: + static bool comparison_box(const OCRPredictResult &result1, + const OCRPredictResult &result2) { + if (result1.box[0][1] < result2.box[0][1]) { + return true; + } else if (result1.box[0][1] == result2.box[0][1]) { + return result1.box[0][0] < result2.box[0][0]; + } else { + return false; + } + } }; } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/src/args.cpp b/deploy/cpp_infer/src/args.cpp index 93d0f5ea5fd07bdc3eb44537bc1c0d4e131736d3..df1b9e32a3aacc309d6485114f9b267001f79920 100644 --- a/deploy/cpp_infer/src/args.cpp +++ b/deploy/cpp_infer/src/args.cpp @@ -30,7 +30,8 @@ DEFINE_string( "Perform ocr or structure, the value is selected in ['ocr','structure']."); // detection related DEFINE_string(det_model_dir, "", "Path of det inference model."); -DEFINE_int32(max_side_len, 960, "max_side_len of input image."); +DEFINE_string(limit_type, "max", "limit_type of input image."); +DEFINE_int32(limit_side_len, 960, "limit_side_len of input image."); DEFINE_double(det_db_thresh, 0.3, "Threshold of det_db_thresh."); DEFINE_double(det_db_box_thresh, 0.6, "Threshold of det_db_box_thresh."); DEFINE_double(det_db_unclip_ratio, 1.5, "Threshold of det_db_unclip_ratio."); @@ -50,7 +51,16 @@ DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt", DEFINE_int32(rec_img_h, 48, "rec image height"); DEFINE_int32(rec_img_w, 320, "rec image width"); +// structure model related +DEFINE_string(table_model_dir, "", "Path of table struture inference model."); +DEFINE_int32(table_max_len, 488, "max len size of input image."); +DEFINE_int32(table_batch_num, 1, "table_batch_num."); +DEFINE_string(table_char_dict_path, + "../../ppocr/utils/dict/table_structure_dict.txt", + "Path of dictionary."); + // ocr forward related DEFINE_bool(det, true, "Whether use det in forward."); DEFINE_bool(rec, true, "Whether use rec in forward."); -DEFINE_bool(cls, false, "Whether use cls in forward."); \ No newline at end of file +DEFINE_bool(cls, false, "Whether use cls in forward."); +DEFINE_bool(table, false, "Whether use table structure in forward."); \ No newline at end of file diff --git a/deploy/cpp_infer/src/main.cpp b/deploy/cpp_infer/src/main.cpp index c4b5b97ea8b2ebf77dd9a3e2af69a1a1ca19ed2a..aa35eca3f138bee69b270cae6b976f3aa9874b33 100644 --- a/deploy/cpp_infer/src/main.cpp +++ b/deploy/cpp_infer/src/main.cpp @@ -19,6 +19,7 @@ #include #include +#include using namespace PaddleOCR; @@ -32,6 +33,12 @@ void check_params() { } } if (FLAGS_rec) { + std::cout + << "In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320'," + "if you are using recognition model with PP-OCRv2 or an older " + "version, " + "please set --rec_image_shape='3,32,320" + << std::endl; if (FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) { std::cout << "Usage[rec]: ./ppocr " "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ " @@ -47,6 +54,17 @@ void check_params() { exit(1); } } + if (FLAGS_table) { + if (FLAGS_table_model_dir.empty() || FLAGS_det_model_dir.empty() || + FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) { + std::cout << "Usage[table]: ./ppocr " + << "--det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ " + << "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ " + << "--table_model_dir=/PATH/TO/TABLE_INFERENCE_MODEL/ " + << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl; + exit(1); + } + } if (FLAGS_precision != "fp32" && FLAGS_precision != "fp16" && FLAGS_precision != "int8") { cout << "precison should be 'fp32'(default), 'fp16' or 'int8'. " << endl; @@ -54,21 +72,7 @@ void check_params() { } } -int main(int argc, char **argv) { - // Parsing command-line - google::ParseCommandLineFlags(&argc, &argv, true); - check_params(); - - if (!Utility::PathExists(FLAGS_image_dir)) { - std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir - << endl; - exit(1); - } - - std::vector cv_all_img_names; - cv::glob(FLAGS_image_dir, cv_all_img_names); - std::cout << "total images num: " << cv_all_img_names.size() << endl; - +void ocr(std::vector &cv_all_img_names) { PPOCR ocr = PPOCR(); std::vector> ocr_results = @@ -109,3 +113,49 @@ int main(int argc, char **argv) { } } } + +void structure(std::vector &cv_all_img_names) { + PaddleOCR::PaddleStructure engine = PaddleOCR::PaddleStructure(); + std::vector> structure_results = + engine.structure(cv_all_img_names, false, FLAGS_table); + for (int i = 0; i < cv_all_img_names.size(); i++) { + cout << cv_all_img_names[i] << "\n"; + for (int j = 0; j < structure_results[i].size(); j++) { + std::cout << j << "\ttype: " << structure_results[i][j].type + << ", region: ["; + std::cout << structure_results[i][j].box[0] << "," + << structure_results[i][j].box[1] << "," + << structure_results[i][j].box[2] << "," + << structure_results[i][j].box[3] << "], res: "; + if (structure_results[i][j].type == "Table") { + std::cout << structure_results[i][j].html << std::endl; + } else { + Utility::print_result(structure_results[i][j].text_res); + } + } + } +} + +int main(int argc, char **argv) { + // Parsing command-line + google::ParseCommandLineFlags(&argc, &argv, true); + check_params(); + + if (!Utility::PathExists(FLAGS_image_dir)) { + std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir + << endl; + exit(1); + } + + std::vector cv_all_img_names; + cv::glob(FLAGS_image_dir, cv_all_img_names); + std::cout << "total images num: " << cv_all_img_names.size() << endl; + + if (FLAGS_type == "ocr") { + ocr(cv_all_img_names); + } else if (FLAGS_type == "structure") { + structure(cv_all_img_names); + } else { + std::cout << "only value in ['ocr','structure'] is supported" << endl; + } +} diff --git a/deploy/cpp_infer/src/ocr_det.cpp b/deploy/cpp_infer/src/ocr_det.cpp index 550997e71937d23a7448e8ff1c4ffad579d2931c..56de195186a0d4d6c8b2482eb57c106347485928 100644 --- a/deploy/cpp_infer/src/ocr_det.cpp +++ b/deploy/cpp_infer/src/ocr_det.cpp @@ -47,7 +47,7 @@ void DBDetector::LoadModel(const std::string &model_dir) { {"elementwise_add_7", {1, 56, 2, 2}}, {"nearest_interp_v2_0.tmp_0", {1, 256, 2, 2}}}; std::map> max_input_shape = { - {"x", {1, 3, this->max_side_len_, this->max_side_len_}}, + {"x", {1, 3, 1536, 1536}}, {"conv2d_92.tmp_0", {1, 120, 400, 400}}, {"conv2d_91.tmp_0", {1, 24, 200, 200}}, {"conv2d_59.tmp_0", {1, 96, 400, 400}}, @@ -109,7 +109,8 @@ void DBDetector::Run(cv::Mat &img, img.copyTo(srcimg); auto preprocess_start = std::chrono::steady_clock::now(); - this->resize_op_.Run(img, resize_img, this->max_side_len_, ratio_h, ratio_w, + this->resize_op_.Run(img, resize_img, this->limit_type_, + this->limit_side_len_, ratio_h, ratio_w, this->use_tensorrt_); this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, diff --git a/deploy/cpp_infer/src/paddleocr.cpp b/deploy/cpp_infer/src/paddleocr.cpp index cd620a9206cad8ec2b1cd5924c714a8a1fa989b6..1de4fc7e9af8bf63cf68ef42d2a508cdc4b5f9f3 100644 --- a/deploy/cpp_infer/src/paddleocr.cpp +++ b/deploy/cpp_infer/src/paddleocr.cpp @@ -23,10 +23,10 @@ PPOCR::PPOCR() { if (FLAGS_det) { this->detector_ = new DBDetector( FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem, - FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_max_side_len, - FLAGS_det_db_thresh, FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio, - FLAGS_det_db_score_mode, FLAGS_use_dilation, FLAGS_use_tensorrt, - FLAGS_precision); + FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_limit_type, + FLAGS_limit_side_len, FLAGS_det_db_thresh, FLAGS_det_db_box_thresh, + FLAGS_det_db_unclip_ratio, FLAGS_det_db_score_mode, FLAGS_use_dilation, + FLAGS_use_tensorrt, FLAGS_precision); } if (FLAGS_cls && FLAGS_use_angle_cls) { @@ -56,7 +56,8 @@ void PPOCR::det(cv::Mat img, std::vector &ocr_results, res.box = boxes[i]; ocr_results.push_back(res); } - + // sort boex from top to bottom, from left to right + Utility::sorted_boxes(ocr_results); times[0] += det_times[0]; times[1] += det_times[1]; times[2] += det_times[2]; diff --git a/deploy/cpp_infer/src/paddlestructure.cpp b/deploy/cpp_infer/src/paddlestructure.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dbaa84fe8454cbe33f04a3b4328c6e9bf7c54943 --- /dev/null +++ b/deploy/cpp_infer/src/paddlestructure.cpp @@ -0,0 +1,272 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "auto_log/autolog.h" +#include +#include + +namespace PaddleOCR { + +PaddleStructure::PaddleStructure() { + if (FLAGS_table) { + this->recognizer_ = new StructureTableRecognizer( + FLAGS_table_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem, + FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_table_char_dict_path, + FLAGS_use_tensorrt, FLAGS_precision, FLAGS_table_batch_num, + FLAGS_table_max_len); + } +}; + +std::vector> +PaddleStructure::structure(std::vector cv_all_img_names, + bool layout, bool table) { + std::vector time_info_det = {0, 0, 0}; + std::vector time_info_rec = {0, 0, 0}; + std::vector time_info_cls = {0, 0, 0}; + std::vector time_info_table = {0, 0, 0}; + + std::vector> structure_results; + + if (!Utility::PathExists(FLAGS_output) && FLAGS_det) { + mkdir(FLAGS_output.c_str(), 0777); + } + for (int i = 0; i < cv_all_img_names.size(); ++i) { + std::vector structure_result; + cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR); + if (!srcimg.data) { + std::cerr << "[ERROR] image read failed! image path: " + << cv_all_img_names[i] << endl; + exit(1); + } + if (layout) { + } else { + StructurePredictResult res; + res.type = "Table"; + res.box = std::vector(4, 0); + res.box[2] = srcimg.cols; + res.box[3] = srcimg.rows; + structure_result.push_back(res); + } + cv::Mat roi_img; + for (int i = 0; i < structure_result.size(); i++) { + // crop image + roi_img = Utility::crop_image(srcimg, structure_result[i].box); + if (structure_result[i].type == "Table") { + this->table(roi_img, structure_result[i], time_info_table, + time_info_det, time_info_rec, time_info_cls); + } + } + structure_results.push_back(structure_result); + } + return structure_results; +}; + +void PaddleStructure::table(cv::Mat img, + StructurePredictResult &structure_result, + std::vector &time_info_table, + std::vector &time_info_det, + std::vector &time_info_rec, + std::vector &time_info_cls) { + // predict structure + std::vector> structure_html_tags; + std::vector structure_scores(1, 0); + std::vector>>> structure_boxes; + std::vector structure_imes; + std::vector img_list; + img_list.push_back(img); + this->recognizer_->Run(img_list, structure_html_tags, structure_scores, + structure_boxes, structure_imes); + time_info_table[0] += structure_imes[0]; + time_info_table[1] += structure_imes[1]; + time_info_table[2] += structure_imes[2]; + + std::vector ocr_result; + std::string html; + int expand_pixel = 3; + + for (int i = 0; i < img_list.size(); i++) { + // det + this->det(img_list[i], ocr_result, time_info_det); + // crop image + std::vector rec_img_list; + for (int j = 0; j < ocr_result.size(); j++) { + int x_collect[4] = {ocr_result[j].box[0][0], ocr_result[j].box[1][0], + ocr_result[j].box[2][0], ocr_result[j].box[3][0]}; + int y_collect[4] = {ocr_result[j].box[0][1], ocr_result[j].box[1][1], + ocr_result[j].box[2][1], ocr_result[j].box[3][1]}; + int left = int(*std::min_element(x_collect, x_collect + 4)); + int right = int(*std::max_element(x_collect, x_collect + 4)); + int top = int(*std::min_element(y_collect, y_collect + 4)); + int bottom = int(*std::max_element(y_collect, y_collect + 4)); + std::vector box{max(0, left - expand_pixel), + max(0, top - expand_pixel), + min(img_list[i].cols, right + expand_pixel), + min(img_list[i].rows, bottom + expand_pixel)}; + cv::Mat crop_img = Utility::crop_image(img_list[i], box); + rec_img_list.push_back(crop_img); + } + // rec + this->rec(rec_img_list, ocr_result, time_info_rec); + // rebuild table + html = this->rebuild_table(structure_html_tags[i], structure_boxes[i], + ocr_result); + structure_result.html = html; + structure_result.html_score = structure_scores[i]; + } +}; + +std::string PaddleStructure::rebuild_table( + std::vector structure_html_tags, + std::vector>> structure_boxes, + std::vector &ocr_result) { + // match text in same cell + std::vector> matched(structure_boxes.size(), + std::vector()); + + for (int i = 0; i < ocr_result.size(); i++) { + std::vector> dis_list(structure_boxes.size(), + std::vector(3, 100000.0)); + for (int j = 0; j < structure_boxes.size(); j++) { + int x_collect[4] = {ocr_result[i].box[0][0], ocr_result[i].box[1][0], + ocr_result[i].box[2][0], ocr_result[i].box[3][0]}; + int y_collect[4] = {ocr_result[i].box[0][1], ocr_result[i].box[1][1], + ocr_result[i].box[2][1], ocr_result[i].box[3][1]}; + int left = int(*std::min_element(x_collect, x_collect + 4)); + int right = int(*std::max_element(x_collect, x_collect + 4)); + int top = int(*std::min_element(y_collect, y_collect + 4)); + int bottom = int(*std::max_element(y_collect, y_collect + 4)); + std::vector> box(2, std::vector(2, 0)); + box[0][0] = left - 1; + box[0][1] = top - 1; + box[1][0] = right + 1; + box[1][1] = bottom + 1; + + dis_list[j][0] = this->dis(box, structure_boxes[j]); + dis_list[j][1] = 1 - this->iou(box, structure_boxes[j]); + dis_list[j][2] = j; + } + // find min dis idx + std::sort(dis_list.begin(), dis_list.end(), + PaddleStructure::comparison_dis); + matched[dis_list[0][2]].push_back(ocr_result[i].text); + } + // get pred html + std::string html_str = ""; + int td_tag_idx = 0; + for (int i = 0; i < structure_html_tags.size(); i++) { + if (structure_html_tags[i].find("") != std::string::npos) { + if (structure_html_tags[i].find("") != std::string::npos) { + html_str += ""; + } + if (matched[td_tag_idx].size() > 0) { + bool b_with = false; + if (matched[td_tag_idx][0].find("") != std::string::npos && + matched[td_tag_idx].size() > 1) { + b_with = true; + html_str += ""; + } + for (int j = 0; j < matched[td_tag_idx].size(); j++) { + std::string content = matched[td_tag_idx][j]; + if (matched[td_tag_idx].size() > 1) { + // remove blank, and + if (content.length() > 0 && content.at(0) == ' ') { + content = content.substr(0); + } + if (content.length() > 2 && content.substr(0, 3) == "") { + content = content.substr(3); + } + if (content.length() > 4 && + content.substr(content.length() - 4) == "") { + content = content.substr(0, content.length() - 4); + } + if (content.empty()) { + continue; + } + // add blank + if (j != matched[td_tag_idx].size() - 1 && + content.at(content.length() - 1) != ' ') { + content += ' '; + } + } + html_str += content; + } + if (b_with) { + html_str += ""; + } + } + if (structure_html_tags[i].find("") != std::string::npos) { + html_str += ""; + } else { + html_str += structure_html_tags[i]; + } + td_tag_idx += 1; + } else { + html_str += structure_html_tags[i]; + } + } + return html_str; +} + +float PaddleStructure::iou(std::vector> &box1, + std::vector> &box2) { + int area1 = max(0, box1[1][0] - box1[0][0]) * max(0, box1[1][1] - box1[0][1]); + int area2 = max(0, box2[1][0] - box2[0][0]) * max(0, box2[1][1] - box2[0][1]); + + // computing the sum_area + int sum_area = area1 + area2; + + // find the each point of intersect rectangle + int x1 = max(box1[0][0], box2[0][0]); + int y1 = max(box1[0][1], box2[0][1]); + int x2 = min(box1[1][0], box2[1][0]); + int y2 = min(box1[1][1], box2[1][1]); + + // judge if there is an intersect + if (y1 >= y2 || x1 >= x2) { + return 0.0; + } else { + int intersect = (x2 - x1) * (y2 - y1); + return intersect / (sum_area - intersect + 0.00000001); + } +} + +float PaddleStructure::dis(std::vector> &box1, + std::vector> &box2) { + int x1_1 = box1[0][0]; + int y1_1 = box1[0][1]; + int x2_1 = box1[1][0]; + int y2_1 = box1[1][1]; + + int x1_2 = box2[0][0]; + int y1_2 = box2[0][1]; + int x2_2 = box2[1][0]; + int y2_2 = box2[1][1]; + + float dis = + abs(x1_2 - x1_1) + abs(y1_2 - y1_1) + abs(x2_2 - x2_1) + abs(y2_2 - y2_1); + float dis_2 = abs(x1_2 - x1_1) + abs(y1_2 - y1_1); + float dis_3 = abs(x2_2 - x2_1) + abs(y2_2 - y2_1); + return dis + min(dis_2, dis_3); +} + +PaddleStructure::~PaddleStructure() { + if (this->recognizer_ != nullptr) { + delete this->recognizer_; + } +}; + +} // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/src/postprocess_op.cpp b/deploy/cpp_infer/src/postprocess_op.cpp index 5374fb1a4eba68d8055a52ec91d97c290832aa9d..8d7af6474c996067a532e3d3eefb86ea5d6d3e3b 100644 --- a/deploy/cpp_infer/src/postprocess_op.cpp +++ b/deploy/cpp_infer/src/postprocess_op.cpp @@ -17,8 +17,8 @@ namespace PaddleOCR { -void PostProcessor::GetContourArea(const std::vector> &box, - float unclip_ratio, float &distance) { +void DBPostProcessor::GetContourArea(const std::vector> &box, + float unclip_ratio, float &distance) { int pts_num = 4; float area = 0.0f; float dist = 0.0f; @@ -35,8 +35,8 @@ void PostProcessor::GetContourArea(const std::vector> &box, distance = area * unclip_ratio / dist; } -cv::RotatedRect PostProcessor::UnClip(std::vector> box, - const float &unclip_ratio) { +cv::RotatedRect DBPostProcessor::UnClip(std::vector> box, + const float &unclip_ratio) { float distance = 1.0; GetContourArea(box, unclip_ratio, distance); @@ -67,7 +67,7 @@ cv::RotatedRect PostProcessor::UnClip(std::vector> box, return res; } -float **PostProcessor::Mat2Vec(cv::Mat mat) { +float **DBPostProcessor::Mat2Vec(cv::Mat mat) { auto **array = new float *[mat.rows]; for (int i = 0; i < mat.rows; ++i) array[i] = new float[mat.cols]; @@ -81,7 +81,7 @@ float **PostProcessor::Mat2Vec(cv::Mat mat) { } std::vector> -PostProcessor::OrderPointsClockwise(std::vector> pts) { +DBPostProcessor::OrderPointsClockwise(std::vector> pts) { std::vector> box = pts; std::sort(box.begin(), box.end(), XsortInt); @@ -99,7 +99,7 @@ PostProcessor::OrderPointsClockwise(std::vector> pts) { return rect; } -std::vector> PostProcessor::Mat2Vector(cv::Mat mat) { +std::vector> DBPostProcessor::Mat2Vector(cv::Mat mat) { std::vector> img_vec; std::vector tmp; @@ -113,20 +113,20 @@ std::vector> PostProcessor::Mat2Vector(cv::Mat mat) { return img_vec; } -bool PostProcessor::XsortFp32(std::vector a, std::vector b) { +bool DBPostProcessor::XsortFp32(std::vector a, std::vector b) { if (a[0] != b[0]) return a[0] < b[0]; return false; } -bool PostProcessor::XsortInt(std::vector a, std::vector b) { +bool DBPostProcessor::XsortInt(std::vector a, std::vector b) { if (a[0] != b[0]) return a[0] < b[0]; return false; } -std::vector> PostProcessor::GetMiniBoxes(cv::RotatedRect box, - float &ssid) { +std::vector> +DBPostProcessor::GetMiniBoxes(cv::RotatedRect box, float &ssid) { ssid = std::max(box.size.width, box.size.height); cv::Mat points; @@ -160,8 +160,8 @@ std::vector> PostProcessor::GetMiniBoxes(cv::RotatedRect box, return array; } -float PostProcessor::PolygonScoreAcc(std::vector contour, - cv::Mat pred) { +float DBPostProcessor::PolygonScoreAcc(std::vector contour, + cv::Mat pred) { int width = pred.cols; int height = pred.rows; std::vector box_x; @@ -206,8 +206,8 @@ float PostProcessor::PolygonScoreAcc(std::vector contour, return score; } -float PostProcessor::BoxScoreFast(std::vector> box_array, - cv::Mat pred) { +float DBPostProcessor::BoxScoreFast(std::vector> box_array, + cv::Mat pred) { auto array = box_array; int width = pred.cols; int height = pred.rows; @@ -244,7 +244,7 @@ float PostProcessor::BoxScoreFast(std::vector> box_array, return score; } -std::vector>> PostProcessor::BoxesFromBitmap( +std::vector>> DBPostProcessor::BoxesFromBitmap( const cv::Mat pred, const cv::Mat bitmap, const float &box_thresh, const float &det_db_unclip_ratio, const std::string &det_db_score_mode) { const int min_size = 3; @@ -321,9 +321,9 @@ std::vector>> PostProcessor::BoxesFromBitmap( return boxes; } -std::vector>> -PostProcessor::FilterTagDetRes(std::vector>> boxes, - float ratio_h, float ratio_w, cv::Mat srcimg) { +std::vector>> DBPostProcessor::FilterTagDetRes( + std::vector>> boxes, float ratio_h, + float ratio_w, cv::Mat srcimg) { int oriimg_h = srcimg.rows; int oriimg_w = srcimg.cols; @@ -352,4 +352,77 @@ PostProcessor::FilterTagDetRes(std::vector>> boxes, return root_points; } +void TablePostProcessor::init(std::string label_path) { + this->label_list_ = Utility::ReadDict(label_path); + this->label_list_.insert(this->label_list_.begin(), this->beg); + this->label_list_.push_back(this->end); +} + +void TablePostProcessor::Run( + std::vector &loc_preds, std::vector &structure_probs, + std::vector &rec_scores, std::vector &loc_preds_shape, + std::vector &structure_probs_shape, + std::vector> &rec_html_tag_batch, + std::vector>>> &rec_boxes_batch, + std::vector &width_list, std::vector &height_list) { + for (int batch_idx = 0; batch_idx < structure_probs_shape[0]; batch_idx++) { + // image tags and boxs + std::vector rec_html_tags; + std::vector>> rec_boxes; + + float score = 0.f; + int count = 0; + float char_score = 0.f; + int char_idx = 0; + + // step + for (int step_idx = 0; step_idx < structure_probs_shape[1]; step_idx++) { + std::string html_tag; + std::vector> rec_box; + // html tag + int step_start_idx = (batch_idx * structure_probs_shape[1] + step_idx) * + structure_probs_shape[2]; + char_idx = int(Utility::argmax( + &structure_probs[step_start_idx], + &structure_probs[step_start_idx + structure_probs_shape[2]])); + char_score = float(*std::max_element( + &structure_probs[step_start_idx], + &structure_probs[step_start_idx + structure_probs_shape[2]])); + html_tag = this->label_list_[char_idx]; + + if (step_idx > 0 && html_tag == this->end) { + break; + } + if (html_tag == this->beg) { + continue; + } + count += 1; + score += char_score; + rec_html_tags.push_back(html_tag); + // box + if (html_tag == "" || html_tag == " point(2, 0); + step_start_idx = (batch_idx * structure_probs_shape[1] + step_idx) * + loc_preds_shape[2] + + point_idx; + point[0] = int(loc_preds[step_start_idx] * width_list[batch_idx]); + point[1] = + int(loc_preds[step_start_idx + 1] * height_list[batch_idx]); + rec_box.push_back(point); + } + rec_boxes.push_back(rec_box); + } + } + score /= count; + if (isnan(score) || rec_boxes.size() == 0 || rec_html_tags.size() == 0) { + score = -1; + } + rec_scores.push_back(score); + rec_boxes_batch.push_back(rec_boxes); + rec_html_tag_batch.push_back(rec_html_tags); + } +} + } // namespace PaddleOCR diff --git a/deploy/cpp_infer/src/preprocess_op.cpp b/deploy/cpp_infer/src/preprocess_op.cpp index fff49ba2c2cd0e68f0c1d93e5877ab6276bdc337..ac185e22d68955ef440e22c327b835dbce6c4e1b 100644 --- a/deploy/cpp_infer/src/preprocess_op.cpp +++ b/deploy/cpp_infer/src/preprocess_op.cpp @@ -69,18 +69,28 @@ void Normalize::Run(cv::Mat *im, const std::vector &mean, } void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, - int max_size_len, float &ratio_h, float &ratio_w, - bool use_tensorrt) { + string limit_type, int limit_side_len, float &ratio_h, + float &ratio_w, bool use_tensorrt) { int w = img.cols; int h = img.rows; - float ratio = 1.f; - int max_wh = w >= h ? w : h; - if (max_wh > max_size_len) { - if (h > w) { - ratio = float(max_size_len) / float(h); - } else { - ratio = float(max_size_len) / float(w); + if (limit_type == "min") { + int min_wh = min(h, w); + if (min_wh < limit_side_len) { + if (h < w) { + ratio = float(limit_side_len) / float(h); + } else { + ratio = float(limit_side_len) / float(w); + } + } + } else { + int max_wh = max(h, w); + if (max_wh > limit_side_len) { + if (h > w) { + ratio = float(limit_side_len) / float(h); + } else { + ratio = float(limit_side_len) / float(w); + } } } @@ -143,4 +153,26 @@ void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, } } +void TableResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, + const int max_len) { + int w = img.cols; + int h = img.rows; + + int max_wh = w >= h ? w : h; + float ratio = w >= h ? float(max_len) / float(w) : float(max_len) / float(h); + + int resize_h = int(float(h) * ratio); + int resize_w = int(float(w) * ratio); + + cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); +} + +void TablePadImg::Run(const cv::Mat &img, cv::Mat &resize_img, + const int max_len) { + int w = img.cols; + int h = img.rows; + cv::copyMakeBorder(img, resize_img, 0, max_len - h, 0, max_len - w, + cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0)); +} + } // namespace PaddleOCR diff --git a/deploy/cpp_infer/src/structure_table.cpp b/deploy/cpp_infer/src/structure_table.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bbc32580e49d6ed7b29e3f0931eab0b0969b02b9 --- /dev/null +++ b/deploy/cpp_infer/src/structure_table.cpp @@ -0,0 +1,158 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +namespace PaddleOCR { + +void StructureTableRecognizer::Run( + std::vector img_list, + std::vector> &structure_html_tags, + std::vector &structure_scores, + std::vector>>> &structure_boxes, + std::vector ×) { + std::chrono::duration preprocess_diff = + std::chrono::steady_clock::now() - std::chrono::steady_clock::now(); + std::chrono::duration inference_diff = + std::chrono::steady_clock::now() - std::chrono::steady_clock::now(); + std::chrono::duration postprocess_diff = + std::chrono::steady_clock::now() - std::chrono::steady_clock::now(); + + int img_num = img_list.size(); + for (int beg_img_no = 0; beg_img_no < img_num; + beg_img_no += this->table_batch_num_) { + // preprocess + auto preprocess_start = std::chrono::steady_clock::now(); + int end_img_no = min(img_num, beg_img_no + this->table_batch_num_); + int batch_num = end_img_no - beg_img_no; + std::vector norm_img_batch; + std::vector width_list; + std::vector height_list; + for (int ino = beg_img_no; ino < end_img_no; ino++) { + cv::Mat srcimg; + img_list[ino].copyTo(srcimg); + cv::Mat resize_img; + cv::Mat pad_img; + this->resize_op_.Run(srcimg, resize_img, this->table_max_len_); + this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, + this->is_scale_); + this->pad_op_.Run(resize_img, pad_img, this->table_max_len_); + norm_img_batch.push_back(pad_img); + width_list.push_back(srcimg.cols); + height_list.push_back(srcimg.rows); + } + + std::vector input( + batch_num * 3 * this->table_max_len_ * this->table_max_len_, 0.0f); + this->permute_op_.Run(norm_img_batch, input.data()); + auto preprocess_end = std::chrono::steady_clock::now(); + preprocess_diff += preprocess_end - preprocess_start; + // inference. + auto input_names = this->predictor_->GetInputNames(); + auto input_t = this->predictor_->GetInputHandle(input_names[0]); + input_t->Reshape( + {batch_num, 3, this->table_max_len_, this->table_max_len_}); + auto inference_start = std::chrono::steady_clock::now(); + input_t->CopyFromCpu(input.data()); + this->predictor_->Run(); + auto output_names = this->predictor_->GetOutputNames(); + auto output_tensor0 = this->predictor_->GetOutputHandle(output_names[0]); + auto output_tensor1 = this->predictor_->GetOutputHandle(output_names[1]); + std::vector predict_shape0 = output_tensor0->shape(); + std::vector predict_shape1 = output_tensor1->shape(); + + int out_num0 = std::accumulate(predict_shape0.begin(), predict_shape0.end(), + 1, std::multiplies()); + int out_num1 = std::accumulate(predict_shape1.begin(), predict_shape1.end(), + 1, std::multiplies()); + std::vector loc_preds; + std::vector structure_probs; + loc_preds.resize(out_num0); + structure_probs.resize(out_num1); + + output_tensor0->CopyToCpu(loc_preds.data()); + output_tensor1->CopyToCpu(structure_probs.data()); + auto inference_end = std::chrono::steady_clock::now(); + inference_diff += inference_end - inference_start; + // postprocess + auto postprocess_start = std::chrono::steady_clock::now(); + std::vector> structure_html_tag_batch; + std::vector structure_score_batch; + std::vector>>> + structure_boxes_batch; + this->post_processor_.Run(loc_preds, structure_probs, structure_score_batch, + predict_shape0, predict_shape1, + structure_html_tag_batch, structure_boxes_batch, + width_list, height_list); + for (int m = 0; m < predict_shape0[0]; m++) { + + structure_html_tag_batch[m].insert(structure_html_tag_batch[m].begin(), + ""); + structure_html_tag_batch[m].insert(structure_html_tag_batch[m].begin(), + ""); + structure_html_tag_batch[m].insert(structure_html_tag_batch[m].begin(), + ""); + structure_html_tag_batch[m].push_back("
"); + structure_html_tag_batch[m].push_back(""); + structure_html_tag_batch[m].push_back(""); + structure_html_tags.push_back(structure_html_tag_batch[m]); + structure_scores.push_back(structure_score_batch[m]); + structure_boxes.push_back(structure_boxes_batch[m]); + } + auto postprocess_end = std::chrono::steady_clock::now(); + postprocess_diff += postprocess_end - postprocess_start; + times.push_back(double(preprocess_diff.count() * 1000)); + times.push_back(double(inference_diff.count() * 1000)); + times.push_back(double(postprocess_diff.count() * 1000)); + } +} + +void StructureTableRecognizer::LoadModel(const std::string &model_dir) { + AnalysisConfig config; + config.SetModel(model_dir + "/inference.pdmodel", + model_dir + "/inference.pdiparams"); + + if (this->use_gpu_) { + config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); + if (this->use_tensorrt_) { + auto precision = paddle_infer::Config::Precision::kFloat32; + if (this->precision_ == "fp16") { + precision = paddle_infer::Config::Precision::kHalf; + } + if (this->precision_ == "int8") { + precision = paddle_infer::Config::Precision::kInt8; + } + config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false); + } + } else { + config.DisableGpu(); + if (this->use_mkldnn_) { + config.EnableMKLDNN(); + } + config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_); + } + + // false for zero copy tensor + config.SwitchUseFeedFetchOps(false); + // true for multiple input + config.SwitchSpecifyInputNames(true); + + config.SwitchIrOptim(true); + + config.EnableMemoryOptim(); + config.DisableGlogInfo(); + + this->predictor_ = CreatePredictor(config); +} +} // namespace PaddleOCR diff --git a/deploy/cpp_infer/src/utility.cpp b/deploy/cpp_infer/src/utility.cpp index 45b8104626cfc3d128e14ece8ba6763f0986cfe4..4bfc1d091d6124b10c79032beb702ba8727210fc 100644 --- a/deploy/cpp_infer/src/utility.cpp +++ b/deploy/cpp_infer/src/utility.cpp @@ -248,4 +248,33 @@ void Utility::print_result(const std::vector &ocr_result) { std::cout << std::endl; } } + +cv::Mat Utility::crop_image(cv::Mat &img, std::vector &area) { + cv::Mat crop_im; + int crop_x1 = std::max(0, area[0]); + int crop_y1 = std::max(0, area[1]); + int crop_x2 = std::min(img.cols - 1, area[2] - 1); + int crop_y2 = std::min(img.rows - 1, area[3] - 1); + + crop_im = cv::Mat::zeros(area[3] - area[1], area[2] - area[0], 16); + cv::Mat crop_im_window = + crop_im(cv::Range(crop_y1 - area[1], crop_y2 + 1 - area[1]), + cv::Range(crop_x1 - area[0], crop_x2 + 1 - area[0])); + cv::Mat roi_img = + img(cv::Range(crop_y1, crop_y2 + 1), cv::Range(crop_x1, crop_x2 + 1)); + crop_im_window += roi_img; + return crop_im; +} + +void Utility::sorted_boxes(std::vector &ocr_result) { + std::sort(ocr_result.begin(), ocr_result.end(), Utility::comparison_box); + + for (int i = 0; i < ocr_result.size() - 1; i++) { + if (abs(ocr_result[i + 1].box[0][1] - ocr_result[i].box[0][1]) < 10 && + (ocr_result[i + 1].box[0][0] < ocr_result[i].box[0][0])) { + std::swap(ocr_result[i], ocr_result[i + 1]); + } + } +} + } // namespace PaddleOCR \ No newline at end of file diff --git a/ppstructure/table/matcher.py b/ppstructure/table/matcher.py index 6884ea3c4884bb9118b5d75e0fc1f7280a73efd8..9c5bd2630f78527ade4fd1309f22d1731fe838a2 100755 --- a/ppstructure/table/matcher.py +++ b/ppstructure/table/matcher.py @@ -62,8 +62,8 @@ class TableMatch: def __call__(self, structure_res, dt_boxes, rec_res): pred_structures, pred_bboxes = structure_res if self.filter_ocr_result: - dt_boxes, rec_res = self.filter_ocr_result(pred_bboxes, dt_boxes, - rec_res) + dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes, + rec_res) matched_index = self.match_result(dt_boxes, pred_bboxes) if self.use_master: pred_html, pred = self.get_pred_html_master(pred_structures, @@ -179,7 +179,7 @@ class TableMatch: html = deal_bb(html) return html, end_html - def filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res): + def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res): y1 = pred_bboxes[:, 1::2].min() new_dt_boxes = [] new_rec_res = [] diff --git a/ppstructure/table/predict_table.py b/ppstructure/table/predict_table.py index 57f0fec01768bb2a3aef3f64a8d65640c774ebdb..b0c7ef589ffcfe4c3c3ec1fd43550813c9d27dcc 100644 --- a/ppstructure/table/predict_table.py +++ b/ppstructure/table/predict_table.py @@ -70,7 +70,7 @@ class TableSystem(object): if args.table_algorithm in ['TableMaster']: self.match = TableMasterMatcher() else: - self.match = TableMatch() + self.match = TableMatch(filter_ocr_result=True) self.benchmark = args.benchmark self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(