diff --git a/.gitignore b/.gitignore index 3300be325f1f6c8b2b58301fc87a4f9d241afb84..3a05fb74687f2b12790f2f73fc96cf8a6abb2bd3 100644 --- a/.gitignore +++ b/.gitignore @@ -31,4 +31,4 @@ paddleocr.egg-info/ /deploy/android_demo/app/.cxx/ /deploy/android_demo/app/cache/ test_tipc/web/models/ -test_tipc/web/node_modules/ +test_tipc/web/node_modules/ \ No newline at end of file diff --git a/configs/table/SLANet.yml b/configs/table/SLANet.yml new file mode 100644 index 0000000000000000000000000000000000000000..384c95852e815f9780328f63cbbd52fa0ef3deb4 --- /dev/null +++ b/configs/table/SLANet.yml @@ -0,0 +1,143 @@ +Global: + use_gpu: true + epoch_num: 100 + log_smooth_window: 20 + print_batch_step: 20 + save_model_dir: ./output/SLANet + save_epoch_step: 400 + # evaluation is run every 1000 iterations after the 0th iteration + eval_batch_step: [0, 1000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: ./output/SLANet/infer + use_visualdl: False + infer_img: doc/table/table.jpg + # for data or label process + character_dict_path: ppocr/utils/dict/table_structure_dict.txt + character_type: en + max_text_length: &max_text_length 500 + box_format: &box_format 'xyxy' # 'xywh', 'xyxy', 'xyxyxyxy' + infer_mode: False + use_sync_bn: True + save_res_path: 'output/infer' + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + clip_norm: 5.0 + lr: + name: Piecewise + learning_rate: 0.001 + decay_epochs : [40, 50] + values : [0.001, 0.0001, 0.00005] + regularizer: + name: 'L2' + factor: 0.00000 + +Architecture: + model_type: table + algorithm: SLANet + Backbone: + name: PPLCNet + scale: 1.0 + pretrained: true + use_ssld: true + Neck: + name: CSPPAN + out_channels: 96 + Head: + name: SLAHead + hidden_size: 256 + max_text_length: *max_text_length + loc_reg_num: &loc_reg_num 4 + +Loss: + name: SLALoss + structure_weight: 1.0 + loc_weight: 2.0 + loc_loss: smooth_l1 + +PostProcess: + name: TableLabelDecode + merge_no_span_structure: &merge_no_span_structure True + +Metric: + name: TableMetric + main_indicator: acc + compute_bbox_metric: False + loc_reg_num: *loc_reg_num + box_format: *box_format + +Train: + dataset: + name: PubTabDataSet + data_dir: train_data/table/pubtabnet/train/ + label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - TableLabelEncode: + learn_empty_box: False + merge_no_span_structure: *merge_no_span_structure + replace_empty_cell_token: False + loc_reg_num: *loc_reg_num + max_text_length: *max_text_length + - TableBoxEncode: + in_box_format: *box_format + out_box_format: *box_format + - ResizeTableImage: + max_len: 488 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - PaddingTableImage: + size: [488, 488] + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ] + loader: + shuffle: True + batch_size_per_card: 48 + drop_last: True + num_workers: 1 + +Eval: + dataset: + name: PubTabDataSet + data_dir: train_data/table/pubtabnet/val/ + label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - TableLabelEncode: + learn_empty_box: False + merge_no_span_structure: *merge_no_span_structure + replace_empty_cell_token: False + loc_reg_num: *loc_reg_num + max_text_length: *max_text_length + - TableBoxEncode: + in_box_format: *box_format + out_box_format: *box_format + - ResizeTableImage: + max_len: 488 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - PaddingTableImage: + size: [488, 488] + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 48 + num_workers: 1 diff --git a/configs/table/table_master.yml b/configs/table/table_master.yml index b8daf3630755e61322665b6fc5f830e4a45875b8..df437f7c95523c5fe12f7166d011b4ad8473628b 100755 --- a/configs/table/table_master.yml +++ b/configs/table/table_master.yml @@ -8,16 +8,15 @@ Global: eval_batch_step: [0, 6259] cal_metric_during_train: true pretrained_model: null - checkpoints: + checkpoints: save_inference_dir: output/table_master/infer use_visualdl: false infer_img: ppstructure/docs/table/table.jpg save_res_path: ./output/table_master character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt infer_mode: false - max_text_length: 500 - process_total_num: 0 - process_cut_num: 0 + max_text_length: &max_text_length 500 + box_format: &box_format 'xywh' # 'xywh', 'xyxy', 'xyxyxyxy' Optimizer: @@ -52,7 +51,8 @@ Architecture: headers: 8 dropout: 0 d_ff: 2024 - max_text_length: 500 + max_text_length: *max_text_length + loc_reg_num: &loc_reg_num 4 Loss: name: TableMasterLoss @@ -61,11 +61,13 @@ Loss: PostProcess: name: TableMasterLabelDecode box_shape: pad + merge_no_span_structure: &merge_no_span_structure True Metric: name: TableMetric main_indicator: acc compute_bbox_metric: False + box_format: *box_format Train: dataset: @@ -78,15 +80,18 @@ Train: channel_first: False - TableMasterLabelEncode: learn_empty_box: False - merge_no_span_structure: True + merge_no_span_structure: *merge_no_span_structure replace_empty_cell_token: True + loc_reg_num: *loc_reg_num + max_text_length: *max_text_length - ResizeTableImage: max_len: 480 resize_bboxes: True - PaddingTableImage: size: [480, 480] - TableBoxEncode: - use_xywh: True + in_box_format: *box_format + out_box_format: *box_format - NormalizeImage: scale: 1./255. mean: [0.5, 0.5, 0.5] @@ -112,15 +117,18 @@ Eval: channel_first: False - TableMasterLabelEncode: learn_empty_box: False - merge_no_span_structure: True + merge_no_span_structure: *merge_no_span_structure replace_empty_cell_token: True + loc_reg_num: *loc_reg_num + max_text_length: *max_text_length - ResizeTableImage: max_len: 480 resize_bboxes: True - PaddingTableImage: size: [480, 480] - TableBoxEncode: - use_xywh: True + in_box_format: *box_format + out_box_format: *box_format - NormalizeImage: scale: 1./255. mean: [0.5, 0.5, 0.5] diff --git a/configs/table/table_mv3.yml b/configs/table/table_mv3.yml index 66c1c83e124d4e94e1f4036a494dfd80c840f229..16c1457442237fc9711b9c3f6dc47625f242956c 100755 --- a/configs/table/table_mv3.yml +++ b/configs/table/table_mv3.yml @@ -17,10 +17,9 @@ Global: # for data or label process character_dict_path: ppocr/utils/dict/table_structure_dict.txt character_type: en - max_text_length: 800 + max_text_length: &max_text_length 800 + box_format: &box_format 'xyxy' # 'xywh', 'xyxy', 'xyxyxyxy' infer_mode: False - process_total_num: 0 - process_cut_num: 0 Optimizer: name: Adam @@ -44,7 +43,8 @@ Architecture: name: TableAttentionHead hidden_size: 256 loc_type: 2 - max_text_length: 800 + max_text_length: *max_text_length + loc_reg_num: &loc_reg_num 4 Loss: name: TableAttentionLoss @@ -72,6 +72,8 @@ Train: learn_empty_box: False merge_no_span_structure: False replace_empty_cell_token: False + loc_reg_num: *loc_reg_num + max_text_length: *max_text_length - TableBoxEncode: - ResizeTableImage: max_len: 488 @@ -94,8 +96,8 @@ Train: Eval: dataset: name: PubTabDataSet - data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/val/ - label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/val_500.jsonl] + data_dir: train_data/table/pubtabnet/val/ + label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl] transforms: - DecodeImage: # load image img_mode: BGR @@ -104,6 +106,8 @@ Eval: learn_empty_box: False merge_no_span_structure: False replace_empty_cell_token: False + loc_reg_num: *loc_reg_num + max_text_length: *max_text_length - TableBoxEncode: - ResizeTableImage: max_len: 488 diff --git a/deploy/cpp_infer/include/args.h b/deploy/cpp_infer/include/args.h index 473ff25d981f8409f60a43940aaaec376375adf5..e0dd8bbcd1044fd695c90805bc770de5b47e51cf 100644 --- a/deploy/cpp_infer/include/args.h +++ b/deploy/cpp_infer/include/args.h @@ -30,7 +30,8 @@ DECLARE_string(image_dir); DECLARE_string(type); // detection related DECLARE_string(det_model_dir); -DECLARE_int32(max_side_len); +DECLARE_string(limit_type); +DECLARE_int32(limit_side_len); DECLARE_double(det_db_thresh); DECLARE_double(det_db_box_thresh); DECLARE_double(det_db_unclip_ratio); @@ -48,7 +49,13 @@ DECLARE_int32(rec_batch_num); DECLARE_string(rec_char_dict_path); DECLARE_int32(rec_img_h); DECLARE_int32(rec_img_w); +// structure model related +DECLARE_string(table_model_dir); +DECLARE_int32(table_max_len); +DECLARE_int32(table_batch_num); +DECLARE_string(table_char_dict_path); // forward related DECLARE_bool(det); DECLARE_bool(rec); DECLARE_bool(cls); +DECLARE_bool(table); \ No newline at end of file diff --git a/deploy/cpp_infer/include/ocr_det.h b/deploy/cpp_infer/include/ocr_det.h index 7efd4d8f0f4ccb705fc34695bb9843e0b6af5a9b..d1421b103b28b44e15a7df53a63fd893ca60e529 100644 --- a/deploy/cpp_infer/include/ocr_det.h +++ b/deploy/cpp_infer/include/ocr_det.h @@ -41,8 +41,8 @@ public: explicit DBDetector(const std::string &model_dir, const bool &use_gpu, const int &gpu_id, const int &gpu_mem, const int &cpu_math_library_num_threads, - const bool &use_mkldnn, const int &max_side_len, - const double &det_db_thresh, + const bool &use_mkldnn, const string &limit_type, + const int &limit_side_len, const double &det_db_thresh, const double &det_db_box_thresh, const double &det_db_unclip_ratio, const std::string &det_db_score_mode, @@ -54,7 +54,8 @@ public: this->cpu_math_library_num_threads_ = cpu_math_library_num_threads; this->use_mkldnn_ = use_mkldnn; - this->max_side_len_ = max_side_len; + this->limit_type_ = limit_type; + this->limit_side_len_ = limit_side_len; this->det_db_thresh_ = det_db_thresh; this->det_db_box_thresh_ = det_db_box_thresh; @@ -84,7 +85,8 @@ private: int cpu_math_library_num_threads_ = 4; bool use_mkldnn_ = false; - int max_side_len_ = 960; + string limit_type_ = "max"; + int limit_side_len_ = 960; double det_db_thresh_ = 0.3; double det_db_box_thresh_ = 0.5; @@ -106,7 +108,7 @@ private: Permute permute_op_; // post-process - PostProcessor post_processor_; + DBPostProcessor post_processor_; }; } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/include/paddleocr.h b/deploy/cpp_infer/include/paddleocr.h index 6db9d86cb152bfcc708a87c6a98be59d88a5d8db..a2c60b14acceaa90a8d8e4a70ccc50f02f254eb6 100644 --- a/deploy/cpp_infer/include/paddleocr.h +++ b/deploy/cpp_infer/include/paddleocr.h @@ -47,11 +47,7 @@ public: ocr(std::vector cv_all_img_names, bool det = true, bool rec = true, bool cls = true); -private: - DBDetector *detector_ = nullptr; - Classifier *classifier_ = nullptr; - CRNNRecognizer *recognizer_ = nullptr; - +protected: void det(cv::Mat img, std::vector &ocr_results, std::vector ×); void rec(std::vector img_list, @@ -62,6 +58,11 @@ private: std::vector ×); void log(std::vector &det_times, std::vector &rec_times, std::vector &cls_times, int img_num); + +private: + DBDetector *detector_ = nullptr; + Classifier *classifier_ = nullptr; + CRNNRecognizer *recognizer_ = nullptr; }; } // namespace PaddleOCR diff --git a/deploy/cpp_infer/include/paddlestructure.h b/deploy/cpp_infer/include/paddlestructure.h new file mode 100644 index 0000000000000000000000000000000000000000..b30ac045b2a6552b69442b2e8b29673efc820e31 --- /dev/null +++ b/deploy/cpp_infer/include/paddlestructure.h @@ -0,0 +1,79 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include "paddle_api.h" +#include "paddle_inference_api.h" +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +using namespace paddle_infer; + +namespace PaddleOCR { + +class PaddleStructure : public PPOCR { +public: + explicit PaddleStructure(); + ~PaddleStructure(); + std::vector> + structure(std::vector cv_all_img_names, bool layout = false, + bool table = true); + +private: + StructureTableRecognizer *recognizer_ = nullptr; + + void table(cv::Mat img, StructurePredictResult &structure_result, + std::vector &time_info_table, + std::vector &time_info_det, + std::vector &time_info_rec, + std::vector &time_info_cls); + std::string + rebuild_table(std::vector rec_html_tags, + std::vector>> rec_boxes, + std::vector &ocr_result); + + float iou(std::vector> &box1, + std::vector> &box2); + float dis(std::vector> &box1, + std::vector> &box2); + + static bool comparison_dis(const std::vector &dis1, + const std::vector &dis2) { + if (dis1[1] < dis2[1]) { + return true; + } else if (dis1[1] == dis2[1]) { + return dis1[0] < dis2[0]; + } else { + return false; + } + } +}; + +} // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/include/postprocess_op.h b/deploy/cpp_infer/include/postprocess_op.h index 4a98b151bdcc53e2ab3fbda1dca55dd9746bd86c..77b3f8b660bda29815245b31ab8cac479b24498f 100644 --- a/deploy/cpp_infer/include/postprocess_op.h +++ b/deploy/cpp_infer/include/postprocess_op.h @@ -34,7 +34,7 @@ using namespace std; namespace PaddleOCR { -class PostProcessor { +class DBPostProcessor { public: void GetContourArea(const std::vector> &box, float unclip_ratio, float &distance); @@ -90,4 +90,21 @@ private: } }; +class TablePostProcessor { +public: + void init(std::string label_path); + void + Run(std::vector &loc_preds, std::vector &structure_probs, + std::vector &rec_scores, std::vector &loc_preds_shape, + std::vector &structure_probs_shape, + std::vector> &rec_html_tag_batch, + std::vector>>> &rec_boxes_batch, + std::vector &width_list, std::vector &height_list); + +private: + std::vector label_list_; + std::string end = "eos"; + std::string beg = "sos"; +}; + } // namespace PaddleOCR diff --git a/deploy/cpp_infer/include/preprocess_op.h b/deploy/cpp_infer/include/preprocess_op.h index 31217de301573e078f8e11ef88657f369ede9b31..078f19d5b808c81e88d7aa464d6bfaca7fe1b14e 100644 --- a/deploy/cpp_infer/include/preprocess_op.h +++ b/deploy/cpp_infer/include/preprocess_op.h @@ -48,11 +48,12 @@ class PermuteBatch { public: virtual void Run(const std::vector imgs, float *data); }; - + class ResizeImgType0 { public: - virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len, - float &ratio_h, float &ratio_w, bool use_tensorrt); + virtual void Run(const cv::Mat &img, cv::Mat &resize_img, string limit_type, + int limit_side_len, float &ratio_h, float &ratio_w, + bool use_tensorrt); }; class CrnnResizeImg { @@ -69,4 +70,16 @@ public: const std::vector &rec_image_shape = {3, 48, 192}); }; +class TableResizeImg { +public: + virtual void Run(const cv::Mat &img, cv::Mat &resize_img, + const int max_len = 488); +}; + +class TablePadImg { +public: + virtual void Run(const cv::Mat &img, cv::Mat &resize_img, + const int max_len = 488); +}; + } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/include/structure_table.h b/deploy/cpp_infer/include/structure_table.h new file mode 100644 index 0000000000000000000000000000000000000000..7449c6cd0e158425bccb75740191dd0b6d6ecc9b --- /dev/null +++ b/deploy/cpp_infer/include/structure_table.h @@ -0,0 +1,100 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include "paddle_api.h" +#include "paddle_inference_api.h" +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +using namespace paddle_infer; + +namespace PaddleOCR { + +class StructureTableRecognizer { +public: + explicit StructureTableRecognizer( + const std::string &model_dir, const bool &use_gpu, const int &gpu_id, + const int &gpu_mem, const int &cpu_math_library_num_threads, + const bool &use_mkldnn, const string &label_path, + const bool &use_tensorrt, const std::string &precision, + const int &table_batch_num, const int &table_max_len) { + this->use_gpu_ = use_gpu; + this->gpu_id_ = gpu_id; + this->gpu_mem_ = gpu_mem; + this->cpu_math_library_num_threads_ = cpu_math_library_num_threads; + this->use_mkldnn_ = use_mkldnn; + this->use_tensorrt_ = use_tensorrt; + this->precision_ = precision; + this->table_batch_num_ = table_batch_num; + this->table_max_len_ = table_max_len; + + this->post_processor_.init(label_path); + LoadModel(model_dir); + } + + // Load Paddle inference model + void LoadModel(const std::string &model_dir); + + void Run(std::vector img_list, + std::vector> &rec_html_tags, + std::vector &rec_scores, + std::vector>>> &rec_boxes, + std::vector ×); + +private: + std::shared_ptr predictor_; + + bool use_gpu_ = false; + int gpu_id_ = 0; + int gpu_mem_ = 4000; + int cpu_math_library_num_threads_ = 4; + bool use_mkldnn_ = false; + int table_max_len_ = 488; + + std::vector mean_ = {0.485f, 0.456f, 0.406f}; + std::vector scale_ = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f}; + bool is_scale_ = true; + + bool use_tensorrt_ = false; + std::string precision_ = "fp32"; + int table_batch_num_ = 1; + + // pre-process + TableResizeImg resize_op_; + Normalize normalize_op_; + PermuteBatch permute_op_; + TablePadImg pad_op_; + + // post-process + TablePostProcessor post_processor_; + +}; // class StructureTableRecognizer + +} // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/include/utility.h b/deploy/cpp_infer/include/utility.h index eb18c0624492e9b47de156d60611d637d8dca6c3..520804f64529303b5ecec27dc5f0895f1fff5c72 100644 --- a/deploy/cpp_infer/include/utility.h +++ b/deploy/cpp_infer/include/utility.h @@ -40,6 +40,14 @@ struct OCRPredictResult { int cls_label = -1; }; +struct StructurePredictResult { + std::vector box; + std::string type; + std::vector text_res; + std::string html; + float html_score = -1; +}; + class Utility { public: static std::vector ReadDict(const std::string &path); @@ -68,6 +76,22 @@ public: static void CreateDir(const std::string &path); static void print_result(const std::vector &ocr_result); + + static cv::Mat crop_image(cv::Mat &img, std::vector &area); + + static void sorted_boxes(std::vector &ocr_result); + +private: + static bool comparison_box(const OCRPredictResult &result1, + const OCRPredictResult &result2) { + if (result1.box[0][1] < result2.box[0][1]) { + return true; + } else if (result1.box[0][1] == result2.box[0][1]) { + return result1.box[0][0] < result2.box[0][0]; + } else { + return false; + } + } }; } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/readme.md b/deploy/cpp_infer/readme.md index a87db7e6596bc2528bfb4a93c3170ebf0482ccad..2afdf79521223c4f473ded8d4f930546fb762c46 100644 --- a/deploy/cpp_infer/readme.md +++ b/deploy/cpp_infer/readme.md @@ -171,6 +171,9 @@ inference/ |-- cls | |--inference.pdiparams | |--inference.pdmodel +|-- table +| |--inference.pdiparams +| |--inference.pdmodel ``` @@ -275,6 +278,17 @@ Specifically, --cls=true \ ``` + +##### 7. table +```shell +./build/ppocr --det_model_dir=inference/det_db \ + --rec_model_dir=inference/rec_rcnn \ + --table_model_dir=inference/table \ + --image_dir=../../ppstructure/docs/table/table.jpg \ + --type=structure \ + --table=true +``` + More parameters are as follows, - Common parameters @@ -293,9 +307,9 @@ More parameters are as follows, |parameter|data type|default|meaning| | :---: | :---: | :---: | :---: | -|det|bool|true|前向是否执行文字检测| -|rec|bool|true|前向是否执行文字识别| -|cls|bool|false|前向是否执行文字方向分类| +|det|bool|true|Whether to perform text detection in the forward direction| +|rec|bool|true|Whether to perform text recognition in the forward direction| +|cls|bool|false|Whether to perform text direction classification in the forward direction| - Detection related parameters @@ -329,6 +343,15 @@ More parameters are as follows, |rec_img_h|int|48|image height of recognition| |rec_img_w|int|320|image width of recognition| +- Table recognition related parameters + +|parameter|data type|default|meaning| +| :---: | :---: | :---: | :---: | +|table_model_dir|string|-|Address of table recognition inference model| +|table_char_dict_path|string|../../ppocr/utils/dict/table_structure_dict.txt|dictionary file| +|table_max_len|int|488|The size of the long side of the input image of the table recognition model, the final input image size of the network is(table_max_len,table_max_len)| + + * Multi-language inference is also supported in PaddleOCR, you can refer to [recognition tutorial](../../doc/doc_en/recognition_en.md) for more supported languages and models in PaddleOCR. Specifically, if you want to infer using multi-language models, you just need to modify values of `rec_char_dict_path` and `rec_model_dir`. @@ -344,6 +367,12 @@ predict img: ../../doc/imgs/12.jpg The detection visualized image saved in ./output//12.jpg ``` +- table + +```bash +predict img: ../../ppstructure/docs/table/table.jpg +0 type: table, region: [0,0,371,293], res:
MethodsRPFFPS
SegLink [26]70.086.077.08.9
PixelLink [4]73.283.077.8-
TextSnake [18]73.983.278.31.1
TextField [37]75.987.481.35.2
MSR[38]76.787.481.7-
FTSN [3]77.187.682.0-
LSE[30]81.784.282.9-
CRAFT [2]78.288.282.98.6
MCN [16]798883-
ATRR[35]82.185.283.6-
PAN [34]83.884.484.130.2
DB[12]79.291.584.932.0
DRRG [41]82.3088.0585.08-
Ours (SynText)80.6885.4082.9712.68
Ours (MLT-17)84.5486.6285.5712.31
+``` ## 3. FAQ diff --git a/deploy/cpp_infer/readme_ch.md b/deploy/cpp_infer/readme_ch.md index 8c334851c0d44acd393c6daa79edf25dc9e6fa24..d94c95c8c5a5bf02d5b9fb4f16edd4da8ebe1e3f 100644 --- a/deploy/cpp_infer/readme_ch.md +++ b/deploy/cpp_infer/readme_ch.md @@ -181,6 +181,9 @@ inference/ |-- cls | |--inference.pdiparams | |--inference.pdmodel +|-- table +| |--inference.pdiparams +| |--inference.pdmodel ``` @@ -285,6 +288,16 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir --cls=true \ ``` +##### 7. 表格识别 +```shell +./build/ppocr --det_model_dir=inference/det_db \ + --rec_model_dir=inference/rec_rcnn \ + --table_model_dir=inference/table \ + --image_dir=../../ppstructure/docs/table/table.jpg \ + --type=structure \ + --table=true +``` + 更多支持的可调节参数解释如下: - 通用参数 @@ -328,21 +341,32 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir |cls_thresh|float|0.9|方向分类器的得分阈值| |cls_batch_num|int|1|方向分类器batchsize| -- 识别模型相关 +- 文字识别模型相关 |参数名称|类型|默认参数|意义| | :---: | :---: | :---: | :---: | -|rec_model_dir|string|-|识别模型inference model地址| +|rec_model_dir|string|-|文字识别模型inference model地址| |rec_char_dict_path|string|../../ppocr/utils/ppocr_keys_v1.txt|字典文件| -|rec_batch_num|int|6|识别模型batchsize| -|rec_img_h|int|48|识别模型输入图像高度| -|rec_img_w|int|320|识别模型输入图像宽度| +|rec_batch_num|int|6|文字识别模型batchsize| +|rec_img_h|int|48|文字识别模型输入图像高度| +|rec_img_w|int|320|文字识别模型输入图像宽度| + + +- 表格识别模型相关 + +|参数名称|类型|默认参数|意义| +| :---: | :---: | :---: | :---: | +|table_model_dir|string|-|表格识别模型inference model地址| +|table_char_dict_path|string|../../ppocr/utils/dict/table_structure_dict.txt|字典文件| +|table_max_len|int|488|表格识别模型输入图像长边大小,最终网络输入图像大小为(table_max_len,table_max_len)| * PaddleOCR也支持多语言的预测,更多支持的语言和模型可以参考[识别文档](../../doc/doc_ch/recognition.md)中的多语言字典与模型部分,如果希望进行多语言预测,只需将修改`rec_char_dict_path`(字典文件路径)以及`rec_model_dir`(inference模型路径)字段即可。 最终屏幕上会输出检测结果如下。 +- ocr + ```bash predict img: ../../doc/imgs/12.jpg ../../doc/imgs/12.jpg @@ -353,6 +377,13 @@ predict img: ../../doc/imgs/12.jpg The detection visualized image saved in ./output//12.jpg ``` +- table + +```bash +predict img: ../../ppstructure/docs/table/table.jpg +0 type: table, region: [0,0,371,293], res:
MethodsRPFFPS
SegLink [26]70.086.077.08.9
PixelLink [4]73.283.077.8-
TextSnake [18]73.983.278.31.1
TextField [37]75.987.481.35.2
MSR[38]76.787.481.7-
FTSN [3]77.187.682.0-
LSE[30]81.784.282.9-
CRAFT [2]78.288.282.98.6
MCN [16]798883-
ATRR[35]82.185.283.6-
PAN [34]83.884.484.130.2
DB[12]79.291.584.932.0
DRRG [41]82.3088.0585.08-
Ours (SynText)80.6885.4082.9712.68
Ours (MLT-17)84.5486.6285.5712.31
+``` + ## 3. FAQ diff --git a/deploy/cpp_infer/src/args.cpp b/deploy/cpp_infer/src/args.cpp index 93d0f5ea5fd07bdc3eb44537bc1c0d4e131736d3..df1b9e32a3aacc309d6485114f9b267001f79920 100644 --- a/deploy/cpp_infer/src/args.cpp +++ b/deploy/cpp_infer/src/args.cpp @@ -30,7 +30,8 @@ DEFINE_string( "Perform ocr or structure, the value is selected in ['ocr','structure']."); // detection related DEFINE_string(det_model_dir, "", "Path of det inference model."); -DEFINE_int32(max_side_len, 960, "max_side_len of input image."); +DEFINE_string(limit_type, "max", "limit_type of input image."); +DEFINE_int32(limit_side_len, 960, "limit_side_len of input image."); DEFINE_double(det_db_thresh, 0.3, "Threshold of det_db_thresh."); DEFINE_double(det_db_box_thresh, 0.6, "Threshold of det_db_box_thresh."); DEFINE_double(det_db_unclip_ratio, 1.5, "Threshold of det_db_unclip_ratio."); @@ -50,7 +51,16 @@ DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt", DEFINE_int32(rec_img_h, 48, "rec image height"); DEFINE_int32(rec_img_w, 320, "rec image width"); +// structure model related +DEFINE_string(table_model_dir, "", "Path of table struture inference model."); +DEFINE_int32(table_max_len, 488, "max len size of input image."); +DEFINE_int32(table_batch_num, 1, "table_batch_num."); +DEFINE_string(table_char_dict_path, + "../../ppocr/utils/dict/table_structure_dict.txt", + "Path of dictionary."); + // ocr forward related DEFINE_bool(det, true, "Whether use det in forward."); DEFINE_bool(rec, true, "Whether use rec in forward."); -DEFINE_bool(cls, false, "Whether use cls in forward."); \ No newline at end of file +DEFINE_bool(cls, false, "Whether use cls in forward."); +DEFINE_bool(table, false, "Whether use table structure in forward."); \ No newline at end of file diff --git a/deploy/cpp_infer/src/main.cpp b/deploy/cpp_infer/src/main.cpp index c4b5b97ea8b2ebf77dd9a3e2af69a1a1ca19ed2a..66412a7b283f84107e117cfd59fb7d7aabff651c 100644 --- a/deploy/cpp_infer/src/main.cpp +++ b/deploy/cpp_infer/src/main.cpp @@ -19,6 +19,7 @@ #include #include +#include using namespace PaddleOCR; @@ -32,6 +33,12 @@ void check_params() { } } if (FLAGS_rec) { + std::cout + << "In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320'," + "if you are using recognition model with PP-OCRv2 or an older " + "version, " + "please set --rec_image_shape='3,32,320" + << std::endl; if (FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) { std::cout << "Usage[rec]: ./ppocr " "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ " @@ -47,6 +54,17 @@ void check_params() { exit(1); } } + if (FLAGS_table) { + if (FLAGS_table_model_dir.empty() || FLAGS_det_model_dir.empty() || + FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) { + std::cout << "Usage[table]: ./ppocr " + << "--det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ " + << "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ " + << "--table_model_dir=/PATH/TO/TABLE_INFERENCE_MODEL/ " + << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl; + exit(1); + } + } if (FLAGS_precision != "fp32" && FLAGS_precision != "fp16" && FLAGS_precision != "int8") { cout << "precison should be 'fp32'(default), 'fp16' or 'int8'. " << endl; @@ -54,21 +72,7 @@ void check_params() { } } -int main(int argc, char **argv) { - // Parsing command-line - google::ParseCommandLineFlags(&argc, &argv, true); - check_params(); - - if (!Utility::PathExists(FLAGS_image_dir)) { - std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir - << endl; - exit(1); - } - - std::vector cv_all_img_names; - cv::glob(FLAGS_image_dir, cv_all_img_names); - std::cout << "total images num: " << cv_all_img_names.size() << endl; - +void ocr(std::vector &cv_all_img_names) { PPOCR ocr = PPOCR(); std::vector> ocr_results = @@ -109,3 +113,49 @@ int main(int argc, char **argv) { } } } + +void structure(std::vector &cv_all_img_names) { + PaddleOCR::PaddleStructure engine = PaddleOCR::PaddleStructure(); + std::vector> structure_results = + engine.structure(cv_all_img_names, false, FLAGS_table); + for (int i = 0; i < cv_all_img_names.size(); i++) { + cout << "predict img: " << cv_all_img_names[i] << endl; + for (int j = 0; j < structure_results[i].size(); j++) { + std::cout << j << "\ttype: " << structure_results[i][j].type + << ", region: ["; + std::cout << structure_results[i][j].box[0] << "," + << structure_results[i][j].box[1] << "," + << structure_results[i][j].box[2] << "," + << structure_results[i][j].box[3] << "], res: "; + if (structure_results[i][j].type == "table") { + std::cout << structure_results[i][j].html << std::endl; + } else { + Utility::print_result(structure_results[i][j].text_res); + } + } + } +} + +int main(int argc, char **argv) { + // Parsing command-line + google::ParseCommandLineFlags(&argc, &argv, true); + check_params(); + + if (!Utility::PathExists(FLAGS_image_dir)) { + std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir + << endl; + exit(1); + } + + std::vector cv_all_img_names; + cv::glob(FLAGS_image_dir, cv_all_img_names); + std::cout << "total images num: " << cv_all_img_names.size() << endl; + + if (FLAGS_type == "ocr") { + ocr(cv_all_img_names); + } else if (FLAGS_type == "structure") { + structure(cv_all_img_names); + } else { + std::cout << "only value in ['ocr','structure'] is supported" << endl; + } +} diff --git a/deploy/cpp_infer/src/ocr_det.cpp b/deploy/cpp_infer/src/ocr_det.cpp index 550997e71937d23a7448e8ff1c4ffad579d2931c..56de195186a0d4d6c8b2482eb57c106347485928 100644 --- a/deploy/cpp_infer/src/ocr_det.cpp +++ b/deploy/cpp_infer/src/ocr_det.cpp @@ -47,7 +47,7 @@ void DBDetector::LoadModel(const std::string &model_dir) { {"elementwise_add_7", {1, 56, 2, 2}}, {"nearest_interp_v2_0.tmp_0", {1, 256, 2, 2}}}; std::map> max_input_shape = { - {"x", {1, 3, this->max_side_len_, this->max_side_len_}}, + {"x", {1, 3, 1536, 1536}}, {"conv2d_92.tmp_0", {1, 120, 400, 400}}, {"conv2d_91.tmp_0", {1, 24, 200, 200}}, {"conv2d_59.tmp_0", {1, 96, 400, 400}}, @@ -109,7 +109,8 @@ void DBDetector::Run(cv::Mat &img, img.copyTo(srcimg); auto preprocess_start = std::chrono::steady_clock::now(); - this->resize_op_.Run(img, resize_img, this->max_side_len_, ratio_h, ratio_w, + this->resize_op_.Run(img, resize_img, this->limit_type_, + this->limit_side_len_, ratio_h, ratio_w, this->use_tensorrt_); this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, diff --git a/deploy/cpp_infer/src/paddleocr.cpp b/deploy/cpp_infer/src/paddleocr.cpp index cd620a9206cad8ec2b1cd5924c714a8a1fa989b6..1de4fc7e9af8bf63cf68ef42d2a508cdc4b5f9f3 100644 --- a/deploy/cpp_infer/src/paddleocr.cpp +++ b/deploy/cpp_infer/src/paddleocr.cpp @@ -23,10 +23,10 @@ PPOCR::PPOCR() { if (FLAGS_det) { this->detector_ = new DBDetector( FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem, - FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_max_side_len, - FLAGS_det_db_thresh, FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio, - FLAGS_det_db_score_mode, FLAGS_use_dilation, FLAGS_use_tensorrt, - FLAGS_precision); + FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_limit_type, + FLAGS_limit_side_len, FLAGS_det_db_thresh, FLAGS_det_db_box_thresh, + FLAGS_det_db_unclip_ratio, FLAGS_det_db_score_mode, FLAGS_use_dilation, + FLAGS_use_tensorrt, FLAGS_precision); } if (FLAGS_cls && FLAGS_use_angle_cls) { @@ -56,7 +56,8 @@ void PPOCR::det(cv::Mat img, std::vector &ocr_results, res.box = boxes[i]; ocr_results.push_back(res); } - + // sort boex from top to bottom, from left to right + Utility::sorted_boxes(ocr_results); times[0] += det_times[0]; times[1] += det_times[1]; times[2] += det_times[2]; diff --git a/deploy/cpp_infer/src/paddlestructure.cpp b/deploy/cpp_infer/src/paddlestructure.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1ca85a96bbcf09472ce5916375a24a9441a2da53 --- /dev/null +++ b/deploy/cpp_infer/src/paddlestructure.cpp @@ -0,0 +1,272 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "auto_log/autolog.h" +#include +#include + +namespace PaddleOCR { + +PaddleStructure::PaddleStructure() { + if (FLAGS_table) { + this->recognizer_ = new StructureTableRecognizer( + FLAGS_table_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem, + FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_table_char_dict_path, + FLAGS_use_tensorrt, FLAGS_precision, FLAGS_table_batch_num, + FLAGS_table_max_len); + } +}; + +std::vector> +PaddleStructure::structure(std::vector cv_all_img_names, + bool layout, bool table) { + std::vector time_info_det = {0, 0, 0}; + std::vector time_info_rec = {0, 0, 0}; + std::vector time_info_cls = {0, 0, 0}; + std::vector time_info_table = {0, 0, 0}; + + std::vector> structure_results; + + if (!Utility::PathExists(FLAGS_output) && FLAGS_det) { + mkdir(FLAGS_output.c_str(), 0777); + } + for (int i = 0; i < cv_all_img_names.size(); ++i) { + std::vector structure_result; + cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR); + if (!srcimg.data) { + std::cerr << "[ERROR] image read failed! image path: " + << cv_all_img_names[i] << endl; + exit(1); + } + if (layout) { + } else { + StructurePredictResult res; + res.type = "table"; + res.box = std::vector(4, 0); + res.box[2] = srcimg.cols; + res.box[3] = srcimg.rows; + structure_result.push_back(res); + } + cv::Mat roi_img; + for (int i = 0; i < structure_result.size(); i++) { + // crop image + roi_img = Utility::crop_image(srcimg, structure_result[i].box); + if (structure_result[i].type == "table") { + this->table(roi_img, structure_result[i], time_info_table, + time_info_det, time_info_rec, time_info_cls); + } + } + structure_results.push_back(structure_result); + } + return structure_results; +}; + +void PaddleStructure::table(cv::Mat img, + StructurePredictResult &structure_result, + std::vector &time_info_table, + std::vector &time_info_det, + std::vector &time_info_rec, + std::vector &time_info_cls) { + // predict structure + std::vector> structure_html_tags; + std::vector structure_scores(1, 0); + std::vector>>> structure_boxes; + std::vector structure_imes; + std::vector img_list; + img_list.push_back(img); + this->recognizer_->Run(img_list, structure_html_tags, structure_scores, + structure_boxes, structure_imes); + time_info_table[0] += structure_imes[0]; + time_info_table[1] += structure_imes[1]; + time_info_table[2] += structure_imes[2]; + + std::vector ocr_result; + std::string html; + int expand_pixel = 3; + + for (int i = 0; i < img_list.size(); i++) { + // det + this->det(img_list[i], ocr_result, time_info_det); + // crop image + std::vector rec_img_list; + for (int j = 0; j < ocr_result.size(); j++) { + int x_collect[4] = {ocr_result[j].box[0][0], ocr_result[j].box[1][0], + ocr_result[j].box[2][0], ocr_result[j].box[3][0]}; + int y_collect[4] = {ocr_result[j].box[0][1], ocr_result[j].box[1][1], + ocr_result[j].box[2][1], ocr_result[j].box[3][1]}; + int left = int(*std::min_element(x_collect, x_collect + 4)); + int right = int(*std::max_element(x_collect, x_collect + 4)); + int top = int(*std::min_element(y_collect, y_collect + 4)); + int bottom = int(*std::max_element(y_collect, y_collect + 4)); + std::vector box{max(0, left - expand_pixel), + max(0, top - expand_pixel), + min(img_list[i].cols, right + expand_pixel), + min(img_list[i].rows, bottom + expand_pixel)}; + cv::Mat crop_img = Utility::crop_image(img_list[i], box); + rec_img_list.push_back(crop_img); + } + // rec + this->rec(rec_img_list, ocr_result, time_info_rec); + // rebuild table + html = this->rebuild_table(structure_html_tags[i], structure_boxes[i], + ocr_result); + structure_result.html = html; + structure_result.html_score = structure_scores[i]; + } +}; + +std::string PaddleStructure::rebuild_table( + std::vector structure_html_tags, + std::vector>> structure_boxes, + std::vector &ocr_result) { + // match text in same cell + std::vector> matched(structure_boxes.size(), + std::vector()); + + for (int i = 0; i < ocr_result.size(); i++) { + std::vector> dis_list(structure_boxes.size(), + std::vector(3, 100000.0)); + for (int j = 0; j < structure_boxes.size(); j++) { + int x_collect[4] = {ocr_result[i].box[0][0], ocr_result[i].box[1][0], + ocr_result[i].box[2][0], ocr_result[i].box[3][0]}; + int y_collect[4] = {ocr_result[i].box[0][1], ocr_result[i].box[1][1], + ocr_result[i].box[2][1], ocr_result[i].box[3][1]}; + int left = int(*std::min_element(x_collect, x_collect + 4)); + int right = int(*std::max_element(x_collect, x_collect + 4)); + int top = int(*std::min_element(y_collect, y_collect + 4)); + int bottom = int(*std::max_element(y_collect, y_collect + 4)); + std::vector> box(2, std::vector(2, 0)); + box[0][0] = left - 1; + box[0][1] = top - 1; + box[1][0] = right + 1; + box[1][1] = bottom + 1; + + dis_list[j][0] = this->dis(box, structure_boxes[j]); + dis_list[j][1] = 1 - this->iou(box, structure_boxes[j]); + dis_list[j][2] = j; + } + // find min dis idx + std::sort(dis_list.begin(), dis_list.end(), + PaddleStructure::comparison_dis); + matched[dis_list[0][2]].push_back(ocr_result[i].text); + } + // get pred html + std::string html_str = ""; + int td_tag_idx = 0; + for (int i = 0; i < structure_html_tags.size(); i++) { + if (structure_html_tags[i].find("") != std::string::npos) { + if (structure_html_tags[i].find("") != std::string::npos) { + html_str += ""; + } + if (matched[td_tag_idx].size() > 0) { + bool b_with = false; + if (matched[td_tag_idx][0].find("") != std::string::npos && + matched[td_tag_idx].size() > 1) { + b_with = true; + html_str += ""; + } + for (int j = 0; j < matched[td_tag_idx].size(); j++) { + std::string content = matched[td_tag_idx][j]; + if (matched[td_tag_idx].size() > 1) { + // remove blank, and + if (content.length() > 0 && content.at(0) == ' ') { + content = content.substr(0); + } + if (content.length() > 2 && content.substr(0, 3) == "") { + content = content.substr(3); + } + if (content.length() > 4 && + content.substr(content.length() - 4) == "") { + content = content.substr(0, content.length() - 4); + } + if (content.empty()) { + continue; + } + // add blank + if (j != matched[td_tag_idx].size() - 1 && + content.at(content.length() - 1) != ' ') { + content += ' '; + } + } + html_str += content; + } + if (b_with) { + html_str += ""; + } + } + if (structure_html_tags[i].find("") != std::string::npos) { + html_str += ""; + } else { + html_str += structure_html_tags[i]; + } + td_tag_idx += 1; + } else { + html_str += structure_html_tags[i]; + } + } + return html_str; +} + +float PaddleStructure::iou(std::vector> &box1, + std::vector> &box2) { + int area1 = max(0, box1[1][0] - box1[0][0]) * max(0, box1[1][1] - box1[0][1]); + int area2 = max(0, box2[1][0] - box2[0][0]) * max(0, box2[1][1] - box2[0][1]); + + // computing the sum_area + int sum_area = area1 + area2; + + // find the each point of intersect rectangle + int x1 = max(box1[0][0], box2[0][0]); + int y1 = max(box1[0][1], box2[0][1]); + int x2 = min(box1[1][0], box2[1][0]); + int y2 = min(box1[1][1], box2[1][1]); + + // judge if there is an intersect + if (y1 >= y2 || x1 >= x2) { + return 0.0; + } else { + int intersect = (x2 - x1) * (y2 - y1); + return intersect / (sum_area - intersect + 0.00000001); + } +} + +float PaddleStructure::dis(std::vector> &box1, + std::vector> &box2) { + int x1_1 = box1[0][0]; + int y1_1 = box1[0][1]; + int x2_1 = box1[1][0]; + int y2_1 = box1[1][1]; + + int x1_2 = box2[0][0]; + int y1_2 = box2[0][1]; + int x2_2 = box2[1][0]; + int y2_2 = box2[1][1]; + + float dis = + abs(x1_2 - x1_1) + abs(y1_2 - y1_1) + abs(x2_2 - x2_1) + abs(y2_2 - y2_1); + float dis_2 = abs(x1_2 - x1_1) + abs(y1_2 - y1_1); + float dis_3 = abs(x2_2 - x2_1) + abs(y2_2 - y2_1); + return dis + min(dis_2, dis_3); +} + +PaddleStructure::~PaddleStructure() { + if (this->recognizer_ != nullptr) { + delete this->recognizer_; + } +}; + +} // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/src/postprocess_op.cpp b/deploy/cpp_infer/src/postprocess_op.cpp index 5374fb1a4eba68d8055a52ec91d97c290832aa9d..551f98a1668124f83ef615f0a41b081508898d6e 100644 --- a/deploy/cpp_infer/src/postprocess_op.cpp +++ b/deploy/cpp_infer/src/postprocess_op.cpp @@ -17,8 +17,8 @@ namespace PaddleOCR { -void PostProcessor::GetContourArea(const std::vector> &box, - float unclip_ratio, float &distance) { +void DBPostProcessor::GetContourArea(const std::vector> &box, + float unclip_ratio, float &distance) { int pts_num = 4; float area = 0.0f; float dist = 0.0f; @@ -35,8 +35,8 @@ void PostProcessor::GetContourArea(const std::vector> &box, distance = area * unclip_ratio / dist; } -cv::RotatedRect PostProcessor::UnClip(std::vector> box, - const float &unclip_ratio) { +cv::RotatedRect DBPostProcessor::UnClip(std::vector> box, + const float &unclip_ratio) { float distance = 1.0; GetContourArea(box, unclip_ratio, distance); @@ -67,7 +67,7 @@ cv::RotatedRect PostProcessor::UnClip(std::vector> box, return res; } -float **PostProcessor::Mat2Vec(cv::Mat mat) { +float **DBPostProcessor::Mat2Vec(cv::Mat mat) { auto **array = new float *[mat.rows]; for (int i = 0; i < mat.rows; ++i) array[i] = new float[mat.cols]; @@ -81,7 +81,7 @@ float **PostProcessor::Mat2Vec(cv::Mat mat) { } std::vector> -PostProcessor::OrderPointsClockwise(std::vector> pts) { +DBPostProcessor::OrderPointsClockwise(std::vector> pts) { std::vector> box = pts; std::sort(box.begin(), box.end(), XsortInt); @@ -99,7 +99,7 @@ PostProcessor::OrderPointsClockwise(std::vector> pts) { return rect; } -std::vector> PostProcessor::Mat2Vector(cv::Mat mat) { +std::vector> DBPostProcessor::Mat2Vector(cv::Mat mat) { std::vector> img_vec; std::vector tmp; @@ -113,20 +113,20 @@ std::vector> PostProcessor::Mat2Vector(cv::Mat mat) { return img_vec; } -bool PostProcessor::XsortFp32(std::vector a, std::vector b) { +bool DBPostProcessor::XsortFp32(std::vector a, std::vector b) { if (a[0] != b[0]) return a[0] < b[0]; return false; } -bool PostProcessor::XsortInt(std::vector a, std::vector b) { +bool DBPostProcessor::XsortInt(std::vector a, std::vector b) { if (a[0] != b[0]) return a[0] < b[0]; return false; } -std::vector> PostProcessor::GetMiniBoxes(cv::RotatedRect box, - float &ssid) { +std::vector> +DBPostProcessor::GetMiniBoxes(cv::RotatedRect box, float &ssid) { ssid = std::max(box.size.width, box.size.height); cv::Mat points; @@ -160,8 +160,8 @@ std::vector> PostProcessor::GetMiniBoxes(cv::RotatedRect box, return array; } -float PostProcessor::PolygonScoreAcc(std::vector contour, - cv::Mat pred) { +float DBPostProcessor::PolygonScoreAcc(std::vector contour, + cv::Mat pred) { int width = pred.cols; int height = pred.rows; std::vector box_x; @@ -206,8 +206,8 @@ float PostProcessor::PolygonScoreAcc(std::vector contour, return score; } -float PostProcessor::BoxScoreFast(std::vector> box_array, - cv::Mat pred) { +float DBPostProcessor::BoxScoreFast(std::vector> box_array, + cv::Mat pred) { auto array = box_array; int width = pred.cols; int height = pred.rows; @@ -244,7 +244,7 @@ float PostProcessor::BoxScoreFast(std::vector> box_array, return score; } -std::vector>> PostProcessor::BoxesFromBitmap( +std::vector>> DBPostProcessor::BoxesFromBitmap( const cv::Mat pred, const cv::Mat bitmap, const float &box_thresh, const float &det_db_unclip_ratio, const std::string &det_db_score_mode) { const int min_size = 3; @@ -321,9 +321,9 @@ std::vector>> PostProcessor::BoxesFromBitmap( return boxes; } -std::vector>> -PostProcessor::FilterTagDetRes(std::vector>> boxes, - float ratio_h, float ratio_w, cv::Mat srcimg) { +std::vector>> DBPostProcessor::FilterTagDetRes( + std::vector>> boxes, float ratio_h, + float ratio_w, cv::Mat srcimg) { int oriimg_h = srcimg.rows; int oriimg_w = srcimg.cols; @@ -352,4 +352,77 @@ PostProcessor::FilterTagDetRes(std::vector>> boxes, return root_points; } +void TablePostProcessor::init(std::string label_path) { + this->label_list_ = Utility::ReadDict(label_path); + this->label_list_.insert(this->label_list_.begin(), this->beg); + this->label_list_.push_back(this->end); +} + +void TablePostProcessor::Run( + std::vector &loc_preds, std::vector &structure_probs, + std::vector &rec_scores, std::vector &loc_preds_shape, + std::vector &structure_probs_shape, + std::vector> &rec_html_tag_batch, + std::vector>>> &rec_boxes_batch, + std::vector &width_list, std::vector &height_list) { + for (int batch_idx = 0; batch_idx < structure_probs_shape[0]; batch_idx++) { + // image tags and boxs + std::vector rec_html_tags; + std::vector>> rec_boxes; + + float score = 0.f; + int count = 0; + float char_score = 0.f; + int char_idx = 0; + + // step + for (int step_idx = 0; step_idx < structure_probs_shape[1]; step_idx++) { + std::string html_tag; + std::vector> rec_box; + // html tag + int step_start_idx = (batch_idx * structure_probs_shape[1] + step_idx) * + structure_probs_shape[2]; + char_idx = int(Utility::argmax( + &structure_probs[step_start_idx], + &structure_probs[step_start_idx + structure_probs_shape[2]])); + char_score = float(*std::max_element( + &structure_probs[step_start_idx], + &structure_probs[step_start_idx + structure_probs_shape[2]])); + html_tag = this->label_list_[char_idx]; + + if (step_idx > 0 && html_tag == this->end) { + break; + } + if (html_tag == this->beg) { + continue; + } + count += 1; + score += char_score; + rec_html_tags.push_back(html_tag); + // box + if (html_tag == "" || html_tag == "") { + for (int point_idx = 0; point_idx < loc_preds_shape[2]; + point_idx += 2) { + std::vector point(2, 0); + step_start_idx = (batch_idx * structure_probs_shape[1] + step_idx) * + loc_preds_shape[2] + + point_idx; + point[0] = int(loc_preds[step_start_idx] * width_list[batch_idx]); + point[1] = + int(loc_preds[step_start_idx + 1] * height_list[batch_idx]); + rec_box.push_back(point); + } + rec_boxes.push_back(rec_box); + } + } + score /= count; + if (isnan(score) || rec_boxes.size() == 0) { + score = -1; + } + rec_scores.push_back(score); + rec_boxes_batch.push_back(rec_boxes); + rec_html_tag_batch.push_back(rec_html_tags); + } +} + } // namespace PaddleOCR diff --git a/deploy/cpp_infer/src/preprocess_op.cpp b/deploy/cpp_infer/src/preprocess_op.cpp index fff49ba2c2cd0e68f0c1d93e5877ab6276bdc337..ac185e22d68955ef440e22c327b835dbce6c4e1b 100644 --- a/deploy/cpp_infer/src/preprocess_op.cpp +++ b/deploy/cpp_infer/src/preprocess_op.cpp @@ -69,18 +69,28 @@ void Normalize::Run(cv::Mat *im, const std::vector &mean, } void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, - int max_size_len, float &ratio_h, float &ratio_w, - bool use_tensorrt) { + string limit_type, int limit_side_len, float &ratio_h, + float &ratio_w, bool use_tensorrt) { int w = img.cols; int h = img.rows; - float ratio = 1.f; - int max_wh = w >= h ? w : h; - if (max_wh > max_size_len) { - if (h > w) { - ratio = float(max_size_len) / float(h); - } else { - ratio = float(max_size_len) / float(w); + if (limit_type == "min") { + int min_wh = min(h, w); + if (min_wh < limit_side_len) { + if (h < w) { + ratio = float(limit_side_len) / float(h); + } else { + ratio = float(limit_side_len) / float(w); + } + } + } else { + int max_wh = max(h, w); + if (max_wh > limit_side_len) { + if (h > w) { + ratio = float(limit_side_len) / float(h); + } else { + ratio = float(limit_side_len) / float(w); + } } } @@ -143,4 +153,26 @@ void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, } } +void TableResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, + const int max_len) { + int w = img.cols; + int h = img.rows; + + int max_wh = w >= h ? w : h; + float ratio = w >= h ? float(max_len) / float(w) : float(max_len) / float(h); + + int resize_h = int(float(h) * ratio); + int resize_w = int(float(w) * ratio); + + cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); +} + +void TablePadImg::Run(const cv::Mat &img, cv::Mat &resize_img, + const int max_len) { + int w = img.cols; + int h = img.rows; + cv::copyMakeBorder(img, resize_img, 0, max_len - h, 0, max_len - w, + cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0)); +} + } // namespace PaddleOCR diff --git a/deploy/cpp_infer/src/structure_table.cpp b/deploy/cpp_infer/src/structure_table.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bbc32580e49d6ed7b29e3f0931eab0b0969b02b9 --- /dev/null +++ b/deploy/cpp_infer/src/structure_table.cpp @@ -0,0 +1,158 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +namespace PaddleOCR { + +void StructureTableRecognizer::Run( + std::vector img_list, + std::vector> &structure_html_tags, + std::vector &structure_scores, + std::vector>>> &structure_boxes, + std::vector ×) { + std::chrono::duration preprocess_diff = + std::chrono::steady_clock::now() - std::chrono::steady_clock::now(); + std::chrono::duration inference_diff = + std::chrono::steady_clock::now() - std::chrono::steady_clock::now(); + std::chrono::duration postprocess_diff = + std::chrono::steady_clock::now() - std::chrono::steady_clock::now(); + + int img_num = img_list.size(); + for (int beg_img_no = 0; beg_img_no < img_num; + beg_img_no += this->table_batch_num_) { + // preprocess + auto preprocess_start = std::chrono::steady_clock::now(); + int end_img_no = min(img_num, beg_img_no + this->table_batch_num_); + int batch_num = end_img_no - beg_img_no; + std::vector norm_img_batch; + std::vector width_list; + std::vector height_list; + for (int ino = beg_img_no; ino < end_img_no; ino++) { + cv::Mat srcimg; + img_list[ino].copyTo(srcimg); + cv::Mat resize_img; + cv::Mat pad_img; + this->resize_op_.Run(srcimg, resize_img, this->table_max_len_); + this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, + this->is_scale_); + this->pad_op_.Run(resize_img, pad_img, this->table_max_len_); + norm_img_batch.push_back(pad_img); + width_list.push_back(srcimg.cols); + height_list.push_back(srcimg.rows); + } + + std::vector input( + batch_num * 3 * this->table_max_len_ * this->table_max_len_, 0.0f); + this->permute_op_.Run(norm_img_batch, input.data()); + auto preprocess_end = std::chrono::steady_clock::now(); + preprocess_diff += preprocess_end - preprocess_start; + // inference. + auto input_names = this->predictor_->GetInputNames(); + auto input_t = this->predictor_->GetInputHandle(input_names[0]); + input_t->Reshape( + {batch_num, 3, this->table_max_len_, this->table_max_len_}); + auto inference_start = std::chrono::steady_clock::now(); + input_t->CopyFromCpu(input.data()); + this->predictor_->Run(); + auto output_names = this->predictor_->GetOutputNames(); + auto output_tensor0 = this->predictor_->GetOutputHandle(output_names[0]); + auto output_tensor1 = this->predictor_->GetOutputHandle(output_names[1]); + std::vector predict_shape0 = output_tensor0->shape(); + std::vector predict_shape1 = output_tensor1->shape(); + + int out_num0 = std::accumulate(predict_shape0.begin(), predict_shape0.end(), + 1, std::multiplies()); + int out_num1 = std::accumulate(predict_shape1.begin(), predict_shape1.end(), + 1, std::multiplies()); + std::vector loc_preds; + std::vector structure_probs; + loc_preds.resize(out_num0); + structure_probs.resize(out_num1); + + output_tensor0->CopyToCpu(loc_preds.data()); + output_tensor1->CopyToCpu(structure_probs.data()); + auto inference_end = std::chrono::steady_clock::now(); + inference_diff += inference_end - inference_start; + // postprocess + auto postprocess_start = std::chrono::steady_clock::now(); + std::vector> structure_html_tag_batch; + std::vector structure_score_batch; + std::vector>>> + structure_boxes_batch; + this->post_processor_.Run(loc_preds, structure_probs, structure_score_batch, + predict_shape0, predict_shape1, + structure_html_tag_batch, structure_boxes_batch, + width_list, height_list); + for (int m = 0; m < predict_shape0[0]; m++) { + + structure_html_tag_batch[m].insert(structure_html_tag_batch[m].begin(), + ""); + structure_html_tag_batch[m].insert(structure_html_tag_batch[m].begin(), + ""); + structure_html_tag_batch[m].insert(structure_html_tag_batch[m].begin(), + ""); + structure_html_tag_batch[m].push_back("
"); + structure_html_tag_batch[m].push_back(""); + structure_html_tag_batch[m].push_back(""); + structure_html_tags.push_back(structure_html_tag_batch[m]); + structure_scores.push_back(structure_score_batch[m]); + structure_boxes.push_back(structure_boxes_batch[m]); + } + auto postprocess_end = std::chrono::steady_clock::now(); + postprocess_diff += postprocess_end - postprocess_start; + times.push_back(double(preprocess_diff.count() * 1000)); + times.push_back(double(inference_diff.count() * 1000)); + times.push_back(double(postprocess_diff.count() * 1000)); + } +} + +void StructureTableRecognizer::LoadModel(const std::string &model_dir) { + AnalysisConfig config; + config.SetModel(model_dir + "/inference.pdmodel", + model_dir + "/inference.pdiparams"); + + if (this->use_gpu_) { + config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); + if (this->use_tensorrt_) { + auto precision = paddle_infer::Config::Precision::kFloat32; + if (this->precision_ == "fp16") { + precision = paddle_infer::Config::Precision::kHalf; + } + if (this->precision_ == "int8") { + precision = paddle_infer::Config::Precision::kInt8; + } + config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false); + } + } else { + config.DisableGpu(); + if (this->use_mkldnn_) { + config.EnableMKLDNN(); + } + config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_); + } + + // false for zero copy tensor + config.SwitchUseFeedFetchOps(false); + // true for multiple input + config.SwitchSpecifyInputNames(true); + + config.SwitchIrOptim(true); + + config.EnableMemoryOptim(); + config.DisableGlogInfo(); + + this->predictor_ = CreatePredictor(config); +} +} // namespace PaddleOCR diff --git a/deploy/cpp_infer/src/utility.cpp b/deploy/cpp_infer/src/utility.cpp index 45b8104626cfc3d128e14ece8ba6763f0986cfe4..4bfc1d091d6124b10c79032beb702ba8727210fc 100644 --- a/deploy/cpp_infer/src/utility.cpp +++ b/deploy/cpp_infer/src/utility.cpp @@ -248,4 +248,33 @@ void Utility::print_result(const std::vector &ocr_result) { std::cout << std::endl; } } + +cv::Mat Utility::crop_image(cv::Mat &img, std::vector &area) { + cv::Mat crop_im; + int crop_x1 = std::max(0, area[0]); + int crop_y1 = std::max(0, area[1]); + int crop_x2 = std::min(img.cols - 1, area[2] - 1); + int crop_y2 = std::min(img.rows - 1, area[3] - 1); + + crop_im = cv::Mat::zeros(area[3] - area[1], area[2] - area[0], 16); + cv::Mat crop_im_window = + crop_im(cv::Range(crop_y1 - area[1], crop_y2 + 1 - area[1]), + cv::Range(crop_x1 - area[0], crop_x2 + 1 - area[0])); + cv::Mat roi_img = + img(cv::Range(crop_y1, crop_y2 + 1), cv::Range(crop_x1, crop_x2 + 1)); + crop_im_window += roi_img; + return crop_im; +} + +void Utility::sorted_boxes(std::vector &ocr_result) { + std::sort(ocr_result.begin(), ocr_result.end(), Utility::comparison_box); + + for (int i = 0; i < ocr_result.size() - 1; i++) { + if (abs(ocr_result[i + 1].box[0][1] - ocr_result[i].box[0][1]) < 10 && + (ocr_result[i + 1].box[0][0] < ocr_result[i].box[0][0])) { + std::swap(ocr_result[i], ocr_result[i + 1]); + } + } +} + } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/hubserving/ocr_system/module.py b/deploy/hubserving/ocr_system/module.py index 71a19c6b7049ec1d779377e7c84cbfe7d2820991..dff3abb48010946a9817b832383f1c95b7053970 100644 --- a/deploy/hubserving/ocr_system/module.py +++ b/deploy/hubserving/ocr_system/module.py @@ -118,7 +118,7 @@ class OCRSystem(hub.Module): all_results.append([]) continue starttime = time.time() - dt_boxes, rec_res = self.text_sys(img) + dt_boxes, rec_res, _ = self.text_sys(img) elapse = time.time() - starttime logger.info("Predict time: {}".format(elapse)) diff --git a/deploy/hubserving/readme.md b/deploy/hubserving/readme.md index 183a25912c2c62371e6db6af2fde5c792fbcbecb..8144c2e7cefaed6f64763e414101445b2d80b81a 100755 --- a/deploy/hubserving/readme.md +++ b/deploy/hubserving/readme.md @@ -59,7 +59,8 @@ pip3 install paddlehub==2.1.0 --upgrade -i https://mirror.baidu.com/pypi/simple 检测模型:./inference/ch_PP-OCRv3_det_infer/ 识别模型:./inference/ch_PP-OCRv3_rec_infer/ 方向分类器:./inference/ch_ppocr_mobile_v2.0_cls_infer/ -表格结构识别模型:./inference/en_ppocr_mobile_v2.0_table_structure_infer/ +版面分析模型:./inference/layout_infer/ +表格结构识别模型:./inference/ch_ppstructure_mobile_v2.0_SLANet_infer/ ``` **模型路径可在`params.py`中查看和修改。** 更多模型可以从PaddleOCR提供的模型库[PP-OCR](../../doc/doc_ch/models_list.md)和[PP-Structure](../../ppstructure/docs/models_list.md)下载,也可以替换成自己训练转换好的模型。 @@ -172,7 +173,7 @@ hub serving start -c deploy/hubserving/ocr_system/config.json ## 3. 发送预测请求 配置好服务端,可使用以下命令发送预测请求,获取预测结果: -```python tools/test_hubserving.py server_url image_path``` +```python tools/test_hubserving.py --server_url=server_url --image_dir=image_path``` 需要给脚本传递2个参数: - **server_url**:服务地址,格式为 diff --git a/deploy/hubserving/readme_en.md b/deploy/hubserving/readme_en.md index 27eccbb5e9c465f20b3725f04aa1652e6829fa3c..06eaaebacb51744844473c0ffe8b189dc545492c 100755 --- a/deploy/hubserving/readme_en.md +++ b/deploy/hubserving/readme_en.md @@ -61,7 +61,8 @@ Before installing the service module, you need to prepare the inference model an text detection model: ./inference/ch_PP-OCRv3_det_infer/ text recognition model: ./inference/ch_PP-OCRv3_rec_infer/ text angle classifier: ./inference/ch_ppocr_mobile_v2.0_cls_infer/ -tanle recognition: ./inference/en_ppocr_mobile_v2.0_table_structure_infer/ +layout parse model: ./inference/layout_infer/ +tanle recognition: ./inference/ch_ppstructure_mobile_v2.0_SLANet_infer/ ``` **The model path can be found and modified in `params.py`.** More models provided by PaddleOCR can be obtained from the [model library](../../doc/doc_en/models_list_en.md). You can also use models trained by yourself. @@ -177,7 +178,7 @@ hub serving start -c deploy/hubserving/ocr_system/config.json ## 3. Send prediction requests After the service starts, you can use the following command to send a prediction request to obtain the prediction result: ```shell -python tools/test_hubserving.py server_url image_path +python tools/test_hubserving.py --server_url=server_url --image_dir=image_path ``` Two parameters need to be passed to the script: diff --git a/deploy/hubserving/structure_system/module.py b/deploy/hubserving/structure_system/module.py index 92846edc6698d0d75224a2b2a844c572fcb17a56..61c93bb146ab11998bc7ed3350cb2686b73e3d3b 100644 --- a/deploy/hubserving/structure_system/module.py +++ b/deploy/hubserving/structure_system/module.py @@ -119,7 +119,7 @@ class StructureSystem(hub.Module): all_results.append([]) continue starttime = time.time() - res = self.table_sys(img) + res, _ = self.table_sys(img) elapse = time.time() - starttime logger.info("Predict time: {}".format(elapse)) @@ -144,6 +144,6 @@ class StructureSystem(hub.Module): if __name__ == '__main__': structure_system = StructureSystem() structure_system._initialize() - image_path = ['./doc/table/1.png'] + image_path = ['./ppstructure/docs/table/1.png'] res = structure_system.predict(paths=image_path) print(res) diff --git a/deploy/hubserving/structure_system/params.py b/deploy/hubserving/structure_system/params.py index 3cc6a2794f80bcd68e254b82e45a05eb17811f65..fe691fbc2d172cc1ad32115abd5a4ee850d8ab2e 100755 --- a/deploy/hubserving/structure_system/params.py +++ b/deploy/hubserving/structure_system/params.py @@ -23,8 +23,10 @@ def read_params(): cfg = table_read_params() # params for layout parser model - cfg.layout_path_model = 'lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config' - cfg.layout_label_map = None + cfg.layout_model_dir = '' + cfg.layout_dict_path = './ppocr/utils/dict/layout_publaynet_dict.txt' + cfg.layout_score_threshold = 0.5 + cfg.layout_nms_threshold = 0.5 cfg.mode = 'structure' cfg.output = './output' diff --git a/deploy/hubserving/structure_table/module.py b/deploy/hubserving/structure_table/module.py index 00393daa037368191201a5afed4aa29a3920c268..b4432b2d7b8764bc0327e7b12fe7887530e825c4 100644 --- a/deploy/hubserving/structure_table/module.py +++ b/deploy/hubserving/structure_table/module.py @@ -118,11 +118,11 @@ class TableSystem(hub.Module): all_results.append([]) continue starttime = time.time() - pred_html = self.table_sys(img) + res, _ = self.table_sys(img) elapse = time.time() - starttime logger.info("Predict time: {}".format(elapse)) - all_results.append({'html': pred_html}) + all_results.append({'html': res['html']}) return all_results @serving @@ -138,6 +138,6 @@ class TableSystem(hub.Module): if __name__ == '__main__': table_system = TableSystem() table_system._initialize() - image_path = ['./doc/table/table.jpg'] + image_path = ['./ppstructure/docs/table/table.jpg'] res = table_system.predict(paths=image_path) print(res) diff --git a/doc/doc_ch/algorithm_rec_sar.md b/doc/doc_ch/algorithm_rec_sar.md index b8304313994754480a89d708e39149d67f828c0d..cfb1de25390bda8c6ba4be1db9101269873e8b5b 100644 --- a/doc/doc_ch/algorithm_rec_sar.md +++ b/doc/doc_ch/algorithm_rec_sar.md @@ -79,7 +79,7 @@ python3 tools/export_model.py -c configs/rec/rec_r31_sar.yml -o Global.pretraine SAR文本识别模型推理,可以执行如下命令: ``` -python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_sar/" --rec_image_shape="3, 48, 48, 160" --rec_char_type="ch" --rec_algorithm="SAR" --rec_char_dict_path="ppocr/utils/dict90.txt" --max_text_length=30 --use_space_char=False +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_sar/" --rec_image_shape="3, 48, 48, 160" --rec_algorithm="SAR" --rec_char_dict_path="ppocr/utils/dict90.txt" --max_text_length=30 --use_space_char=False ``` diff --git a/doc/doc_ch/dataset/table_datasets.md b/doc/doc_ch/dataset/table_datasets.md index ae902b23ccf985d522386b7454c7f76a74917502..58f4cf470542ff7ef20f518efb8b6942a3caa2f0 100644 --- a/doc/doc_ch/dataset/table_datasets.md +++ b/doc/doc_ch/dataset/table_datasets.md @@ -3,6 +3,7 @@ - [数据集汇总](#数据集汇总) - [1. PubTabNet数据集](#1-pubtabnet数据集) - [2. 好未来表格识别竞赛数据集](#2-好未来表格识别竞赛数据集) +- [3. 好未来表格识别竞赛数据集](#2-WTW中文场景表格数据集) 这里整理了常用表格识别数据集,持续更新中,欢迎各位小伙伴贡献数据集~ @@ -12,6 +13,7 @@ |---|---|---| | PubTabNet |https://github.com/ibm-aur-nlp/PubTabNet| jsonl格式,可直接用[pubtab_dataset.py](../../../ppocr/data/pubtab_dataset.py)加载 | | 好未来表格识别竞赛数据集 |https://ai.100tal.com/dataset| jsonl格式,可直接用[pubtab_dataset.py](../../../ppocr/data/pubtab_dataset.py)加载 | +| WTW中文场景表格数据集 |https://github.com/wangwen-whu/WTW-Dataset| 需要进行转换后才能用[pubtab_dataset.py](../../../ppocr/data/pubtab_dataset.py)加载 | ## 1. PubTabNet数据集 - **数据简介**:PubTabNet数据集的训练集合中包含50万张图像,验证集合中包含0.9万张图像。部分图像可视化如下所示。 @@ -31,3 +33,12 @@ + +## 3. WTW中文场景表格数据集 +- **数据简介**:WTW中文场景表格数据集包含表格检测和表格数据两部分数据,数据集中同时包含扫描和拍照两张场景的图像。 + +https://github.com/wangwen-whu/WTW-Dataset/blob/main/demo/20210816_210413.gif + +
+ +
diff --git a/doc/doc_ch/table_recognition.md b/doc/doc_ch/table_recognition.md new file mode 100644 index 0000000000000000000000000000000000000000..e076149441eca410a25578fac8214862dfea1020 --- /dev/null +++ b/doc/doc_ch/table_recognition.md @@ -0,0 +1,343 @@ +# 表格识别 + +本文提供了PaddleOCR表格识别模型的全流程指南,包括数据准备、模型训练、调优、评估、预测,各个阶段的详细说明: + +- [1. 数据准备](#1-数据准备) + - [1.1. 数据集格式](#11-数据集格式) + - [1.2. 数据下载](#12-数据下载) + - [1.3. 数据集生成](#13-数据集生成) +- [2. 开始训练](#2-开始训练) + - [2.1. 启动训练](#21-启动训练) + - [2.2. 断点训练](#22-断点训练) + - [2.3. 更换Backbone 训练](#23-更换backbone-训练) + - [2.4. 混合精度训练](#24-混合精度训练) + - [2.5. 分布式训练](#25-分布式训练) + - [2.6. 其他训练环境](#26-其他训练环境) + - [2.7. 模型微调](#27-模型微调) +- [3. 模型评估与预测](#3-模型评估与预测) + - [3.1. 指标评估](#31-指标评估) + - [3.2. 测试表格结构识别效果](#32-测试表格结构识别效果) +- [4. 模型导出与预测](#4-模型导出与预测) + - [4.1 模型导出](#41-模型导出) + - [4.2 模型预测](#42-模型预测) +- [5. FAQ](#5-faq) + +# 1. 数据准备 + +## 1.1. 数据集格式 + +PaddleOCR 表格识别模型数据集格式如下: +```txt +img_label # 每张图片标注经过json.dumps()之后的字符串 +... +img_label +``` + +每一行的json格式为: +```txt +{ + 'filename': PMC5755158_010_01.png, # 图像名 + 'split': ’train‘, # 图像属于训练集还是验证集 + 'imgid': 0, # 图像的index + 'html': { + 'structure': {'tokens': ['', '', '', ...]}, # 表格的HTML字符串 + 'cell': [ + { + 'tokens': ['P', 'a', 'd', 'd', 'l', 'e', 'P', 'a', 'd', 'd', 'l', 'e'], # 表格中的单个文本 + 'bbox': [x0, y0, x1, y1] # 表格中的单个文本的坐标 + } + ] + } +} +``` + +训练数据的默认存储路径是 `PaddleOCR/train_data`,如果您的磁盘上已有数据集,只需创建软链接至数据集目录: + +``` +# linux and mac os +ln -sf /train_data/dataset +# windows +mklink /d /train_data/dataset +``` + +## 1.2. 数据下载 + +公开数据集下载可参考 [table_datasets](dataset/table_datasets.md)。 + +## 1.3. 数据集生成 + +使用[TableGeneration](https://github.com/WenmuZhou/TableGeneration)可进行扫描表格图像的生成。 + +TableGeneration是一个开源表格数据集生成工具,其通过浏览器渲染的方式对html字符串进行渲染后获得表格图像。部分样张如下: + +|类型|样例| +|---|---| +|简单表格|![](https://raw.githubusercontent.com/WenmuZhou/TableGeneration/main/imgs/simple.jpg)| +|彩色表格|![](https://raw.githubusercontent.com/WenmuZhou/TableGeneration/main/imgs/color.jpg)| + +# 2. 开始训练 + +PaddleOCR提供了训练脚本、评估脚本和预测脚本,本节将以 [SLANet](../../configs/table/SLANet.yml) 模型训练PubTabNet英文数据集为例: + +## 2.1. 启动训练 + +*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false* + +``` +# GPU训练 支持单卡,多卡训练 +# 训练日志会自动保存为 "{save_model_dir}" 下的train.log + +#单卡训练(训练周期长,不建议) +python3 tools/train.py -c configs/table/SLANet.yml + +#多卡训练,通过--gpus参数指定卡号 +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/table/SLANet.yml +``` + +正常启动训练后,会看到以下log输出: + +``` +[2022/08/16 03:07:33] ppocr INFO: epoch: [1/400], global_step: 20, lr: 0.000100, acc: 0.000000, loss: 3.915012, structure_loss: 3.229450, loc_loss: 0.670590, avg_reader_cost: 2.63382 s, avg_batch_cost: 6.32390 s, avg_samples: 48.0, ips: 7.59025 samples/s, eta: 9 days, 2:29:27 +[2022/08/16 03:08:41] ppocr INFO: epoch: [1/400], global_step: 40, lr: 0.000100, acc: 0.000000, loss: 1.750859, structure_loss: 1.082116, loc_loss: 0.652822, avg_reader_cost: 0.02533 s, avg_batch_cost: 3.37251 s, avg_samples: 48.0, ips: 14.23271 samples/s, eta: 6 days, 23:28:43 +[2022/08/16 03:09:46] ppocr INFO: epoch: [1/400], global_step: 60, lr: 0.000100, acc: 0.000000, loss: 1.395154, structure_loss: 0.776803, loc_loss: 0.625030, avg_reader_cost: 0.02550 s, avg_batch_cost: 3.26261 s, avg_samples: 48.0, ips: 14.71214 samples/s, eta: 6 days, 5:11:48 +``` + +log 中自动打印如下信息: + +| 字段 | 含义 | +| :----: | :------: | +| epoch | 当前迭代轮次 | +| global_step | 当前迭代次数 | +| lr | 当前学习率 | +| acc | 当前batch的准确率 | +| loss | 当前损失函数 | +| structure_loss | 表格结构损失值 | +| loc_loss | 单元格坐标损失值 | +| avg_reader_cost | 当前 batch 数据处理耗时 | +| avg_batch_cost | 当前 batch 总耗时 | +| avg_samples | 当前 batch 内的样本数 | +| ips | 每秒处理图片的数量 | + + +PaddleOCR支持训练和评估交替进行, 可以在 `configs/table/SLANet.yml` 中修改 `eval_batch_step` 设置评估频率,默认每1000个iter评估一次。评估过程中默认将最佳acc模型,保存为 `output/SLANet/best_accuracy` 。 + +如果验证集很大,测试将会比较耗时,建议减少评估次数,或训练完再进行评估。 + +**提示:** 可通过 -c 参数选择 `configs/table/` 路径下的多种模型配置进行训练,PaddleOCR支持的表格识别算法可以参考[前沿算法列表](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/algorithm_overview.md#3-%E8%A1%A8%E6%A0%BC%E8%AF%86%E5%88%AB%E7%AE%97%E6%B3%95): + +**注意,预测/评估时的配置文件请务必与训练一致。** + +## 2.2. 断点训练 + +如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径: +```shell +python3 tools/train.py -c configs/table/SLANet.yml -o Global.checkpoints=./your/trained/model +``` + +**注意**:`Global.checkpoints`的优先级高于`Global.pretrained_model`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrained_model`指定的模型。 + +## 2.3. 更换Backbone 训练 + +PaddleOCR将网络划分为四部分,分别在[ppocr/modeling](../../ppocr/modeling)下。 进入网络的数据将按照顺序(transforms->backbones->necks->heads)依次通过这四个部分。 + +```bash +├── architectures # 网络的组网代码 +├── transforms # 网络的图像变换模块 +├── backbones # 网络的特征提取模块 +├── necks # 网络的特征增强模块 +└── heads # 网络的输出模块 +``` +如果要更换的Backbone 在PaddleOCR中有对应实现,直接修改配置yml文件中`Backbone`部分的参数即可。 + +如果要使用新的Backbone,更换backbones的例子如下: + +1. 在 [ppocr/modeling/backbones](../../ppocr/modeling/backbones) 文件夹下新建文件,如my_backbone.py。 +2. 在 my_backbone.py 文件内添加相关代码,示例代码如下: + +```python +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class MyBackbone(nn.Layer): + def __init__(self, *args, **kwargs): + super(MyBackbone, self).__init__() + # your init code + self.conv = nn.xxxx + + def forward(self, inputs): + # your network forward + y = self.conv(inputs) + return y +``` + +3. 在 [ppocr/modeling/backbones/\__init\__.py](../../ppocr/modeling/backbones/__init__.py)文件内导入添加的`MyBackbone`模块,然后修改配置文件中Backbone进行配置即可使用,格式如下: + +```yaml +Backbone: +name: MyBackbone +args1: args1 +``` + +**注意**:如果要更换网络的其他模块,可以参考[文档](./add_new_algorithm.md)。 + +## 2.4. 混合精度训练 + +如果您想进一步加快训练速度,可以使用[自动混合精度训练](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/01_paddle2.0_introduction/basic_concept/amp_cn.html), 以单机单卡为例,命令如下: + +```shell +python3 tools/train.py -c configs/table/SLANet.yml \ + -o Global.pretrained_model=./pretrain_models/SLANet/best_accuracy \ + Global.use_amp=True Global.scale_loss=1024.0 Global.use_dynamic_loss_scaling=True + ``` + +## 2.5. 分布式训练 + +多机多卡训练时,通过 `--ips` 参数设置使用的机器IP地址,通过 `--gpus` 参数设置使用的GPU ID: + +```bash +python3 -m paddle.distributed.launch --ips="xx.xx.xx.xx,xx.xx.xx.xx" --gpus '0,1,2,3' tools/train.py -c configs/table/SLANet.yml \ + -o Global.pretrained_model=./pretrain_models/SLANet/best_accuracy +``` + +**注意:** (1)采用多机多卡训练时,需要替换上面命令中的ips值为您机器的地址,机器之间需要能够相互ping通;(2)训练时需要在多个机器上分别启动命令。查看机器ip地址的命令为`ifconfig`;(3)更多关于分布式训练的性能优势等信息,请参考:[分布式训练教程](./distributed_training.md)。 + + +## 2.6. 其他训练环境 + +- Windows GPU/CPU +在Windows平台上与Linux平台略有不同: +Windows平台只支持`单卡`的训练与预测,指定GPU进行训练`set CUDA_VISIBLE_DEVICES=0` +在Windows平台,DataLoader只支持单进程模式,因此需要设置 `num_workers` 为0; + +- macOS +不支持GPU模式,需要在配置文件中设置`use_gpu`为False,其余训练评估预测命令与Linux GPU完全相同。 + +- Linux DCU +DCU设备上运行需要设置环境变量 `export HIP_VISIBLE_DEVICES=0,1,2,3`,其余训练评估预测命令与Linux GPU完全相同。 + +## 2.7. 模型微调 + +实际使用过程中,建议加载官方提供的预训练模型,在自己的数据集中进行微调,关于模型的微调方法,请参考:[模型微调教程](./finetune.md)。 + + +# 3. 模型评估与预测 + +## 3.1. 指标评估 + +训练中模型参数默认保存在`Global.save_model_dir`目录下。在评估指标时,需要设置`Global.checkpoints`指向保存的参数文件。评估数据集可以通过 `configs/table/SLANet.yml` 修改Eval中的 `label_file_list` 设置。 + + +``` +# GPU 评估, Global.checkpoints 为待测权重 +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/table/SLANet.yml -o Global.checkpoints={path/to/weights}/best_accuracy +``` + +运行完成后,会输出模型的acc指标,如对英文表格识别模型进行评估,会见到如下输出。 +```bash +[2022/08/16 07:59:55] ppocr INFO: acc:0.7622245132160782 +[2022/08/16 07:59:55] ppocr INFO: fps:30.991640622573044 +``` + +## 3.2. 测试表格结构识别效果 + +使用 PaddleOCR 训练好的模型,可以通过以下脚本进行快速预测。 + +默认预测图片存储在 `infer_img` 里,通过 `-o Global.checkpoints` 加载训练好的参数文件: + +根据配置文件中设置的 `save_model_dir` 和 `save_epoch_step` 字段,会有以下几种参数被保存下来: + +``` +output/SLANet/ +├── best_accuracy.pdopt +├── best_accuracy.pdparams +├── best_accuracy.states +├── config.yml +├── latest.pdopt +├── latest.pdparams +├── latest.states +└── train.log +``` +其中 best_accuracy.* 是评估集上的最优模型;latest.* 是最后一个epoch的模型。 + +``` +# 预测表格图像 +python3 tools/infer_table.py -c configs/table/SLANet.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=ppstructure/docs/table/table.jpg +``` + +预测图片: + +![](../../ppstructure/docs/table/table.jpg) + +得到输入图像的预测结果: + +``` +['', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '
', '', ''],[[320.0562438964844, 197.83375549316406, 350.0928955078125, 214.4309539794922], ... , [318.959228515625, 271.0166931152344, 353.7394104003906, 286.4538269042969]] +``` + +单元格坐标可视化结果为 + +![](../../ppstructure/docs/imgs/slanet_result.jpg) + +# 4. 模型导出与预测 + +## 4.1 模型导出 + +inference 模型(`paddle.jit.save`保存的模型) +一般是模型训练,把模型结构和模型参数保存在文件中的固化模型,多用于预测部署场景。 +训练过程中保存的模型是checkpoints模型,保存的只有模型的参数,多用于恢复训练等。 +与checkpoints模型相比,inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合于实际系统集成。 + +表格识别模型转inference模型与文字检测识别的方式相同,如下: + +``` +# -c 后面设置训练算法的yml配置文件 +# -o 配置可选参数 +# Global.pretrained_model 参数设置待转换的训练模型地址,不用添加文件后缀 .pdmodel,.pdopt或.pdparams。 +# Global.save_inference_dir参数设置转换的模型将保存的地址。 + +python3 tools/export_model.py -c configs/table/SLANet.yml -o Global.pretrained_model=./pretrain_models/SLANet/best_accuracy Global.save_inference_dir=./inference/SLANet/ +``` + +转换成功后,在目录下有三个文件: + +``` +inference/SLANet/ + ├── inference.pdiparams # inference模型的参数文件 + ├── inference.pdiparams.info # inference模型的参数信息,可忽略 + └── inference.pdmodel # inference模型的program文件 +``` + +## 4.2 模型预测 + +模型导出后,使用如下命令即可完成inference模型的预测 + +```python +python3.7 table/predict_structure.py \ + --table_model_dir={path/to/inference model} \ + --table_char_dict_path=../ppocr/utils/dict/table_structure_dict_ch.txt \ + --image_dir=docs/table/table.jpg \ + --output=../output/table +``` + +预测图片: + +![](../../ppstructure/docs/table/table.jpg) + +得到输入图像的预测结果: + +``` +['', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '
', '', ''],[[320.0562438964844, 197.83375549316406, 350.0928955078125, 214.4309539794922], ... , [318.959228515625, 271.0166931152344, 353.7394104003906, 286.4538269042969]] +``` + +单元格坐标可视化结果为 + +![](../../ppstructure/docs/imgs/slanet_result.jpg) + + +# 5. FAQ + +Q1: 训练模型转inference 模型之后预测效果不一致? + +**A**:此类问题出现较多,问题多是trained model预测时候的预处理、后处理参数和inference model预测的时候的预处理、后处理参数不一致导致的。可以对比训练使用的配置文件中的预处理、后处理和预测时是否存在差异。 diff --git a/doc/doc_en/algorithm_rec_sar_en.md b/doc/doc_en/algorithm_rec_sar_en.md index 24b87c10c3b2839909392bf3de0e0c850112fcdc..5c8319da3bc63dce55b0d5eae749ed4500b9d2f6 100644 --- a/doc/doc_en/algorithm_rec_sar_en.md +++ b/doc/doc_en/algorithm_rec_sar_en.md @@ -79,7 +79,7 @@ python3 tools/export_model.py -c configs/rec/rec_r31_sar.yml -o Global.pretraine For SAR text recognition model inference, the following commands can be executed: ``` -python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_sar/" --rec_image_shape="3, 48, 48, 160" --rec_char_type="ch" --rec_algorithm="SAR" --rec_char_dict_path="ppocr/utils/dict90.txt" --max_text_length=30 --use_space_char=False +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_sar/" --rec_image_shape="3, 48, 48, 160" --rec_algorithm="SAR" --rec_char_dict_path="ppocr/utils/dict90.txt" --max_text_length=30 --use_space_char=False ``` diff --git a/doc/doc_en/dataset/table_datasets_en.md b/doc/doc_en/dataset/table_datasets_en.md index e30147909812a153f311add50f0bef5d1d1e0e32..70ca8309798994c6225ab0c10d4689da2387962b 100644 --- a/doc/doc_en/dataset/table_datasets_en.md +++ b/doc/doc_en/dataset/table_datasets_en.md @@ -3,6 +3,7 @@ - [Dataset Summary](#dataset-summary) - [1. PubTabNet](#1-pubtabnet) - [2. TAL Table Recognition Competition Dataset](#2-tal-table-recognition-competition-dataset) +- [3. WTW Chinese scene table dataset](#3-wtw-chinese-scene-table-dataset) Here are the commonly used table recognition datasets, which are being updated continuously. Welcome to contribute datasets~ @@ -12,6 +13,7 @@ Here are the commonly used table recognition datasets, which are being updated c |---|---|---| | PubTabNet |https://github.com/ibm-aur-nlp/PubTabNet| jsonl format, which can be loaded directly with [pubtab_dataset.py](../../../ppocr/data/pubtab_dataset.py) | | TAL Table Recognition Competition Dataset |https://ai.100tal.com/dataset| jsonl format, which can be loaded directly with [pubtab_dataset.py](../../../ppocr/data/pubtab_dataset.py) | +| WTW Chinese scene table dataset |https://github.com/wangwen-whu/WTW-Dataset| Conversion is required to load with [pubtab_dataset.py](../../../ppocr/data/pubtab_dataset.py)| ## 1. PubTabNet - **Data Introduction**:The training set of the PubTabNet dataset contains 500,000 images and the validation set contains 9000 images. Part of the image visualization is shown below. @@ -30,3 +32,11 @@ Here are the commonly used table recognition datasets, which are being updated c + +## 3. WTW Chinese scene table dataset +- **Data Introduction**:The WTW Chinese scene table dataset consists of two parts: table detection and table data. The dataset contains images of two scenes, scanned and photographed. +https://github.com/wangwen-whu/WTW-Dataset/blob/main/demo/20210816_210413.gif + +
+ +
diff --git a/doc/doc_en/quickstart_en.md b/doc/doc_en/quickstart_en.md index c678dc47625f4289a93621144bf5577b059d52b3..9e1de839ff0ed8291f1822186f43cb24c9f9ebce 100644 --- a/doc/doc_en/quickstart_en.md +++ b/doc/doc_en/quickstart_en.md @@ -3,14 +3,14 @@ **Note:** This tutorial mainly introduces the usage of PP-OCR series models, please refer to [PP-Structure Quick Start](../../ppstructure/docs/quickstart_en.md) for the quick use of document analysis related functions. - [1. Installation](#1-installation) - - [1.1 Install PaddlePaddle](#11-install-paddlepaddle) - - [1.2 Install PaddleOCR Whl Package](#12-install-paddleocr-whl-package) + - [1.1 Install PaddlePaddle](#11-install-paddlepaddle) + - [1.2 Install PaddleOCR Whl Package](#12-install-paddleocr-whl-package) - [2. Easy-to-Use](#2-easy-to-use) - - [2.1 Use by Command Line](#21-use-by-command-line) - - [2.1.1 Chinese and English Model](#211-chinese-and-english-model) - - [2.1.2 Multi-language Model](#212-multi-language-model) - - [2.2 Use by Code](#22-use-by-code) - - [2.2.1 Chinese & English Model and Multilingual Model](#221-chinese--english-model-and-multilingual-model) + - [2.1 Use by Command Line](#21-use-by-command-line) + - [2.1.1 Chinese and English Model](#211-chinese-and-english-model) + - [2.1.2 Multi-language Model](#212-multi-language-model) + - [2.2 Use by Code](#22-use-by-code) + - [2.2.1 Chinese & English Model and Multilingual Model](#221-chinese--english-model-and-multilingual-model) - [3. Summary](#3-summary) @@ -51,12 +51,6 @@ pip install "paddleocr>=2.0.1" # Recommend to use version 2.0.1+ Reference: [Solve shapely installation on windows](https://stackoverflow.com/questions/44398265/install-shapely-oserror-winerror-126-the-specified-module-could-not-be-found) -- **For layout analysis users**, run the following command to install **Layout-Parser** - - ```bash - pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl - ``` - ## 2. Easy-to-Use diff --git a/doc/doc_en/table_recognition_en.md b/doc/doc_en/table_recognition_en.md new file mode 100644 index 0000000000000000000000000000000000000000..aacf9ca673a5ce281cf7ae49bfead02b2c73db09 --- /dev/null +++ b/doc/doc_en/table_recognition_en.md @@ -0,0 +1,354 @@ +# Table Recognition + +This article provides a full-process guide for the PaddleOCR table recognition model, including data preparation, model training, tuning, evaluation, prediction, and detailed descriptions of each stage: + +- [1. Data Preparation](#1-data-preparation) + - [1.1. DataSet Format](#11-dataset-format) + - [1.2. Data Download](#12-data-download) + - [1.3. Dataset Generation](#13-dataset-generation) +- [2. Training](#2-training) + - [2.1. Start Training](#21-start-training) + - [2.2. Resume Training](#22-resume-training) + - [2.3. Training with New Backbone](#23-training-with-new-backbone) + - [2.4. Mixed Precision Training](#24-mixed-precision-training) + - [2.5. Distributed Training](#25-distributed-training) + - [2.6. Training on other platform(Windows/macOS/Linux DCU)](#26-training-on-other-platformwindowsmacoslinux-dcu) + - [2.7. Fine-tuning](#27-fine-tuning) +- [3. Evaluation and Test](#3-evaluation-and-test) + - [3.1. Evaluation](#31-evaluation) + - [3.2. Test table structure recognition effect](#32-test-table-structure-recognition-effect) +- [4. Model export and prediction](#4-model-export-and-prediction) + - [4.1 Model export](#41-model-export) + - [4.2 Prediction](#42-prediction) +- [5. FAQ](#5-faq) + +# 1. Data Preparation + +## 1.1. DataSet Format + +The format of the PaddleOCR table recognition model dataset is as follows: +```txt +img_label # Each image is marked with a string after json.dumps() +... +img_label +``` + +The json format of each line is: +```txt +{ + 'filename': PMC5755158_010_01.png,# image name + 'split': ’train‘, # whether the image belongs to the training set or the validation set + 'imgid': 0,# index of image + 'html': { + 'structure': {'tokens': ['', '', '', ...]}, # HTML string of the table + 'cell': [ + { + 'tokens': ['P', 'a', 'd', 'd', 'l', 'e', 'P', 'a', 'd', 'd', 'l', 'e'], # text in cell + 'bbox': [x0, y0, x1, y1] # bbox of cell + } + ] + } +} +``` + +The default storage path for training data is `PaddleOCR/train_data`, if you already have a dataset on disk, just create a soft link to the dataset directory: + +``` +# linux and mac os +ln -sf /train_data/dataset +# windows +mklink /d /train_data/dataset +``` + +## 1.2. Data Download + +Download the public dataset reference [table_datasets](dataset/table_datasets_en.md)。 + +## 1.3. Dataset Generation + +Use [TableGeneration](https://github.com/WenmuZhou/TableGeneration) to generate scanned table images. + +TableGeneration is an open source table dataset generation tool, which renders html strings through browser rendering to obtain table images. + +Some samples are as follows: + +|Type|Sample| +|---|---| +|Simple Table|![](https://raw.githubusercontent.com/WenmuZhou/TableGeneration/main/imgs/simple.jpg)| +|Simple Color Table|![](https://raw.githubusercontent.com/WenmuZhou/TableGeneration/main/imgs/color.jpg)| + +# 2. Training + +PaddleOCR provides training scripts, evaluation scripts, and prediction scripts. In this section, the [SLANet](../../configs/table/SLANet.yml) model will be used as an example: + +## 2.1. Start Training + +*If you are installing the cpu version, please modify the `use_gpu` field in the configuration file to false* + +``` +# GPU training Support single card and multi-card training +# The training log will be automatically saved as train.log under "{save_model_dir}" + +# specify the single card training(Long training time, not recommended) +python3 tools/train.py -c configs/table/SLANet.yml + +# specify the card number through --gpus +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/table/SLANet.yml +``` + +After starting training normally, you will see the following log output: + +``` +[2022/08/16 03:07:33] ppocr INFO: epoch: [1/400], global_step: 20, lr: 0.000100, acc: 0.000000, loss: 3.915012, structure_loss: 3.229450, loc_loss: 0.670590, avg_reader_cost: 2.63382 s, avg_batch_cost: 6.32390 s, avg_samples: 48.0, ips: 7.59025 samples/s, eta: 9 days, 2:29:27 +[2022/08/16 03:08:41] ppocr INFO: epoch: [1/400], global_step: 40, lr: 0.000100, acc: 0.000000, loss: 1.750859, structure_loss: 1.082116, loc_loss: 0.652822, avg_reader_cost: 0.02533 s, avg_batch_cost: 3.37251 s, avg_samples: 48.0, ips: 14.23271 samples/s, eta: 6 days, 23:28:43 +[2022/08/16 03:09:46] ppocr INFO: epoch: [1/400], global_step: 60, lr: 0.000100, acc: 0.000000, loss: 1.395154, structure_loss: 0.776803, loc_loss: 0.625030, avg_reader_cost: 0.02550 s, avg_batch_cost: 3.26261 s, avg_samples: 48.0, ips: 14.71214 samples/s, eta: 6 days, 5:11:48 +``` + +The following information is automatically printed in the log: + +| Field | Meaning | +| :----: | :------: | +| epoch | current iteration round | +| global_step | current iteration count | +| lr | current learning rate | +| acc | The accuracy of the current batch | +| loss | current loss function | +| structure_loss | Table Structure Loss Values | +| loc_loss | Cell Coordinate Loss Value | +| avg_reader_cost | Current batch data processing time | +| avg_batch_cost | The total time spent in the current batch | +| avg_samples | The number of samples in the current batch | +| ips | Number of images processed per second | + + +PaddleOCR supports alternating training and evaluation. You can modify `eval_batch_step` in `configs/table/SLANet.yml` to set the evaluation frequency. By default, it is evaluated once every 1000 iters. During the evaluation process, the best acc model is saved as `output/SLANet/best_accuracy` by default. + +If the validation set is large, the test will be time-consuming. It is recommended to reduce the number of evaluations, or perform evaluation after training. + +**Tips:** You can use the -c parameter to select various model configurations under the `configs/table/` path for training. For the table recognition algorithms supported by PaddleOCR, please refer to [Table Algorithms List](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_en/algorithm_overview_en.md#3): + +**Note that the configuration file for prediction/evaluation must be the same as training. ** + +## 2.2. Resume Training + +If the training program is interrupted, if you want to load the interrupted model to resume training, you can specify the path of the model to be loaded by specifying Global.checkpoints: + +```shell +python3 tools/train.py -c configs/table/SLANet.yml -o Global.checkpoints=./your/trained/model +``` +**Note**: The priority of `Global.checkpoints` is higher than that of `Global.pretrained_model`, that is, when two parameters are specified at the same time, the model specified by `Global.checkpoints` will be loaded first. If `Global.checkpoints` The specified model path is incorrect, and the model specified by `Global.pretrained_model` will be loaded. + +## 2.3. Training with New Backbone + +The network part completes the construction of the network, and PaddleOCR divides the network into four parts, which are under [ppocr/modeling](../../ppocr/modeling). The data entering the network will pass through these four parts in sequence(transforms->backbones-> +necks->heads). + +```bash +├── architectures # Code for building network +├── transforms # Image Transformation Module +├── backbones # Feature extraction module +├── necks # Feature enhancement module +└── heads # Output module +``` + +If the Backbone to be replaced has a corresponding implementation in PaddleOCR, you can directly modify the parameters in the `Backbone` part of the configuration yml file. + +However, if you want to use a new Backbone, an example of replacing the backbones is as follows: + +1. Create a new file under the [ppocr/modeling/backbones](../../ppocr/modeling/backbones) folder, such as my_backbone.py. +2. Add code in the my_backbone.py file, the sample code is as follows: + +```python +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class MyBackbone(nn.Layer): + def __init__(self, *args, **kwargs): + super(MyBackbone, self).__init__() + # your init code + self.conv = nn.xxxx + + def forward(self, inputs): + # your network forward + y = self.conv(inputs) + return y +``` + +3. Import the added module in the [ppocr/modeling/backbones/\__init\__.py](../../ppocr/modeling/backbones/__init__.py) file. + +After adding the four-part modules of the network, you only need to configure them in the configuration file to use, such as: + +```yaml + Backbone: + name: MyBackbone + args1: args1 +``` + +**NOTE**: More details about replace Backbone and other mudule can be found in [doc](add_new_algorithm_en.md). + +## 2.4. Mixed Precision Training + +If you want to speed up your training further, you can use [Auto Mixed Precision Training](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/01_paddle2.0_introduction/basic_concept/amp_cn.html), taking a single machine and a single gpu as an example, the commands are as follows: + +```shell +python3 tools/train.py -c configs/table/SLANet.yml \ + -o Global.pretrained_model=./pretrain_models/SLANet/best_accuracy \ + Global.use_amp=True Global.scale_loss=1024.0 Global.use_dynamic_loss_scaling=True + ``` + +## 2.5. Distributed Training + +During multi-machine multi-gpu training, use the `--ips` parameter to set the used machine IP address, and the `--gpus` parameter to set the used GPU ID: + +```bash +python3 -m paddle.distributed.launch --ips="xx.xx.xx.xx,xx.xx.xx.xx" --gpus '0,1,2,3' tools/train.py -c configs/table/SLANet.yml \ + -o Global.pretrained_model=./pretrain_models/SLANet/best_accuracy +``` + + +**Note:** (1) When using multi-machine and multi-gpu training, you need to replace the ips value in the above command with the address of your machine, and the machines need to be able to ping each other. (2) Training needs to be launched separately on multiple machines. The command to view the ip address of the machine is `ifconfig`. (3) For more details about the distributed training speedup ratio, please refer to [Distributed Training Tutorial](./distributed_training_en.md). + +## 2.6. Training on other platform(Windows/macOS/Linux DCU) + +- Windows GPU/CPU +The Windows platform is slightly different from the Linux platform: +Windows platform only supports `single gpu` training and inference, specify GPU for training `set CUDA_VISIBLE_DEVICES=0` +On the Windows platform, DataLoader only supports single-process mode, so you need to set `num_workers` to 0; + +- macOS +GPU mode is not supported, you need to set `use_gpu` to False in the configuration file, and the rest of the training evaluation prediction commands are exactly the same as Linux GPU. + +- Linux DCU +Running on a DCU device requires setting the environment variable `export HIP_VISIBLE_DEVICES=0,1,2,3`, and the rest of the training and evaluation prediction commands are exactly the same as the Linux GPU. + + +## 2.7. Fine-tuning + +In the actual use process, it is recommended to load the officially provided pre-training model and fine-tune it in your own data set. For the fine-tuning method of the table recognition model, please refer to: [Model fine-tuning tutorial](./finetune.md). + + +# 3. Evaluation and Test + +## 3.1. Evaluation + +The model parameters during training are saved in the `Global.save_model_dir` directory by default. When evaluating metrics, you need to set `Global.checkpoints` to point to the saved parameter file. Evaluation datasets can be modified via the `label_file_list` setting in Eval via `configs/table/SLANet.yml`. + +``` +# GPU evaluation, Global.checkpoints is the weight to be tested +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/table/SLANet.yml -o Global.checkpoints={path/to/weights}/best_accuracy +``` + +After the operation is completed, the acc indicator of the model will be output. If you evaluate the English table recognition model, you will see the following output. + +```bash +[2022/08/16 07:59:55] ppocr INFO: acc:0.7622245132160782 +[2022/08/16 07:59:55] ppocr INFO: fps:30.991640622573044 +``` + +## 3.2. Test table structure recognition effect + +Using the model trained by PaddleOCR, you can quickly get prediction through the following script. + +The default prediction picture is stored in `infer_img`, and the trained weight is specified via `-o Global.checkpoints`: + + +According to the `save_model_dir` and `save_epoch_step` fields set in the configuration file, the following parameters will be saved: + + +``` +output/SLANet/ +├── best_accuracy.pdopt +├── best_accuracy.pdparams +├── best_accuracy.states +├── config.yml +├── latest.pdopt +├── latest.pdparams +├── latest.states +└── train.log +``` +Among them, best_accuracy.* is the best model on the evaluation set; latest.* is the model of the last epoch. + +``` +# Predict table image +python3 tools/infer_table.py -c configs/table/SLANet.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=ppstructure/docs/table/table.jpg +``` + +Input image: + +![](../../ppstructure/docs/table/table.jpg) + +Get the prediction result of the input image: + +``` +['', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '
', '', ''],[[320.0562438964844, 197.83375549316406, 350.0928955078125, 214.4309539794922], ... , [318.959228515625, 271.0166931152344, 353.7394104003906, 286.4538269042969]] +``` + +The cell coordinates are visualized as + +![](../../ppstructure/docs/imgs/slanet_result.jpg) + +# 4. Model export and prediction + +## 4.1 Model export + +inference model (model saved by `paddle.jit.save`) +Generally, it is model training, a solidified model that saves the model structure and model parameters in a file, and is mostly used to predict deployment scenarios. +The model saved during the training process is the checkpoints model, and only the parameters of the model are saved, which are mostly used to resume training. +Compared with the checkpoints model, the inference model will additionally save the structural information of the model. It has superior performance in predicting deployment and accelerating reasoning, and is flexible and convenient, and is suitable for actual system integration. + +The way to convert the form recognition model to the inference model is the same as the text detection and recognition, as follows: + +``` +# -c Set the training algorithm yml configuration file +# -o Set optional parameters +# Global.pretrained_model parameter Set the training model address to be converted without adding the file suffix .pdmodel, .pdopt or .pdparams. +# Global.save_inference_dir Set the address where the converted model will be saved. + +python3 tools/export_model.py -c configs/table/SLANet.yml -o Global.pretrained_model=./pretrain_models/SLANet/best_accuracy Global.save_inference_dir=./inference/SLANet/ +``` + +After the conversion is successful, there are three files in the model save directory: + + +``` +inference/SLANet/ + ├── inference.pdiparams # The parameter file of inference model + ├── inference.pdiparams.info # The parameter information of inference model, which can be ignored + └── inference.pdmodel # The program file of model +``` + +## 4.2 Prediction + +After the model is exported, use the following command to complete the prediction of the inference model + +```python +python3.7 table/predict_structure.py \ + --table_model_dir={path/to/inference model} \ + --table_char_dict_path=../ppocr/utils/dict/table_structure_dict_ch.txt \ + --image_dir=docs/table/table.jpg \ + --output=../output/table +``` + +Input image: + +![](../../ppstructure/docs/table/table.jpg) + +Get the prediction result of the input image: + +``` +['', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '
', '', ''],[[320.0562438964844, 197.83375549316406, 350.0928955078125, 214.4309539794922], ... , [318.959228515625, 271.0166931152344, 353.7394104003906, 286.4538269042969]] +``` + +The cell coordinates are visualized as + +![](../../ppstructure/docs/imgs/slanet_result.jpg) + + + +# 5. FAQ + +Q1: After the training model is transferred to the inference model, the prediction effect is inconsistent? + +**A**: There are many such problems, and the problems are mostly caused by inconsistent preprocessing and postprocessing parameters when the trained model predicts and the preprocessing and postprocessing parameters when the inference model predicts. You can compare whether there are differences in preprocessing, postprocessing, and prediction in the configuration files used for training. diff --git a/paddleocr.py b/paddleocr.py index 470dc60da3b15195bcd401aff5e50be5a2cfd13e..8e34c4fbc331f798618fc5f33bc00963a577e25a 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -47,14 +47,14 @@ __all__ = [ ] SUPPORT_DET_MODEL = ['DB'] -VERSION = '2.5.0.3' +VERSION = '2.6' SUPPORT_REC_MODEL = ['CRNN', 'SVTR_LCNet'] BASE_DIR = os.path.expanduser("~/.paddleocr/") DEFAULT_OCR_MODEL_VERSION = 'PP-OCRv3' SUPPORT_OCR_MODEL_VERSION = ['PP-OCR', 'PP-OCRv2', 'PP-OCRv3'] -DEFAULT_STRUCTURE_MODEL_VERSION = 'PP-STRUCTURE' -SUPPORT_STRUCTURE_MODEL_VERSION = ['PP-STRUCTURE'] +DEFAULT_STRUCTURE_MODEL_VERSION = 'PP-Structurev2' +SUPPORT_STRUCTURE_MODEL_VERSION = ['PP-Structure', 'PP-Structurev2'] MODEL_URLS = { 'OCR': { 'PP-OCRv3': { @@ -263,7 +263,7 @@ MODEL_URLS = { } }, 'STRUCTURE': { - 'PP-STRUCTURE': { + 'PP-Structure': { 'table': { 'en': { 'url': @@ -271,6 +271,27 @@ MODEL_URLS = { 'dict_path': 'ppocr/utils/dict/table_structure_dict.txt' } } + }, + 'PP-Structurev2': { + 'table': { + 'en': { + 'url': + 'https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_infer.tar', + 'dict_path': 'ppocr/utils/dict/table_structure_dict.txt' + }, + 'ch': { + 'url': + 'https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar', + 'dict_path': 'ppocr/utils/dict/table_structure_dict_ch.txt' + } + }, + 'layout': { + 'ch': { + 'url': + 'https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_layout_infer.tar', + 'dict_path': 'ppocr/utils/dict/layout_publaynet_dict.txt' + } + } } } } @@ -298,12 +319,15 @@ def parse_args(mMain=True): "--structure_version", type=str, choices=SUPPORT_STRUCTURE_MODEL_VERSION, - default='PP-STRUCTURE', + default='PP-Structurev2', help='Model version, the current model support list is as follows:' - ' 1. STRUCTURE Support en table structure model.') + ' 1. PP-Structure Support en table structure model.' + ' 2. PP-Structurev2 Support ch and en table structure model.') for action in parser._actions: - if action.dest in ['rec_char_dict_path', 'table_char_dict_path']: + if action.dest in [ + 'rec_char_dict_path', 'table_char_dict_path', 'layout_dict_path' + ]: action.default = None if mMain: return parser.parse_args() @@ -477,7 +501,7 @@ class PaddleOCR(predict_system.TextSystem): if isinstance(img, np.ndarray) and len(img.shape) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) if det and rec: - dt_boxes, rec_res = self.__call__(img, cls) + dt_boxes, rec_res, _ = self.__call__(img, cls) return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] elif det and not rec: dt_boxes, elapse = self.text_detector(img) @@ -506,6 +530,12 @@ class PPStructure(StructureSystem): if not params.show_log: logger.setLevel(logging.INFO) lang, det_lang = parse_lang(params.lang) + if lang == 'ch': + table_lang = 'ch' + else: + table_lang = 'en' + if params.structure_version == 'PP-Structure': + params.merge_no_span_structure = False # init model dir det_model_config = get_model_config('OCR', params.ocr_version, 'det', @@ -520,14 +550,20 @@ class PPStructure(StructureSystem): params.rec_model_dir, os.path.join(BASE_DIR, 'whl', 'rec', lang), rec_model_config['url']) table_model_config = get_model_config( - 'STRUCTURE', params.structure_version, 'table', 'en') + 'STRUCTURE', params.structure_version, 'table', table_lang) params.table_model_dir, table_url = confirm_model_dir_url( params.table_model_dir, os.path.join(BASE_DIR, 'whl', 'table'), table_model_config['url']) + layout_model_config = get_model_config( + 'STRUCTURE', params.structure_version, 'layout', 'ch') + params.layout_model_dir, layout_url = confirm_model_dir_url( + params.layout_model_dir, + os.path.join(BASE_DIR, 'whl', 'layout'), layout_model_config['url']) # download model maybe_download(params.det_model_dir, det_url) maybe_download(params.rec_model_dir, rec_url) maybe_download(params.table_model_dir, table_url) + maybe_download(params.layout_model_dir, layout_url) if params.rec_char_dict_path is None: params.rec_char_dict_path = str( @@ -535,7 +571,9 @@ class PPStructure(StructureSystem): if params.table_char_dict_path is None: params.table_char_dict_path = str( Path(__file__).parent / table_model_config['dict_path']) - + if params.layout_dict_path is None: + params.layout_dict_path = str( + Path(__file__).parent / layout_model_config['dict_path']) logger.debug(params) super().__init__(params) @@ -557,7 +595,7 @@ class PPStructure(StructureSystem): if isinstance(img, np.ndarray) and len(img.shape) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - res = super().__call__(img, return_ocr_result_in_table) + res, _ = super().__call__(img, return_ocr_result_in_table) return res diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 68e5f719be30f2947d6a67a5cb90d1ba0e357309..59cb9b8a253cf04244ebf83511ab412174487a53 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -575,7 +575,7 @@ class TableLabelEncode(AttnLabelEncode): replace_empty_cell_token=False, merge_no_span_structure=False, learn_empty_box=False, - point_num=2, + loc_reg_num=4, **kwargs): self.max_text_len = max_text_length self.lower = False @@ -590,6 +590,12 @@ class TableLabelEncode(AttnLabelEncode): line = line.decode('utf-8').strip("\n").strip("\r\n") dict_character.append(line) + if self.merge_no_span_structure: + if "" not in dict_character: + dict_character.append("") + if "" in dict_character: + dict_character.remove("") + dict_character = self.add_special_char(dict_character) self.dict = {} for i, char in enumerate(dict_character): @@ -597,7 +603,7 @@ class TableLabelEncode(AttnLabelEncode): self.idx2char = {v: k for k, v in self.dict.items()} self.character = dict_character - self.point_num = point_num + self.loc_reg_num = loc_reg_num self.pad_idx = self.dict[self.beg_str] self.start_idx = self.dict[self.beg_str] self.end_idx = self.dict[self.end_str] @@ -653,7 +659,7 @@ class TableLabelEncode(AttnLabelEncode): # encode box bboxes = np.zeros( - (self._max_text_len, self.point_num * 2), dtype=np.float32) + (self._max_text_len, self.loc_reg_num), dtype=np.float32) bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32) bbox_idx = 0 @@ -718,11 +724,11 @@ class TableMasterLabelEncode(TableLabelEncode): replace_empty_cell_token=False, merge_no_span_structure=False, learn_empty_box=False, - point_num=2, + loc_reg_num=4, **kwargs): super(TableMasterLabelEncode, self).__init__( max_text_length, character_dict_path, replace_empty_cell_token, - merge_no_span_structure, learn_empty_box, point_num, **kwargs) + merge_no_span_structure, learn_empty_box, loc_reg_num, **kwargs) self.pad_idx = self.dict[self.pad_str] self.unknown_idx = self.dict[self.unknown_str] @@ -743,27 +749,35 @@ class TableMasterLabelEncode(TableLabelEncode): class TableBoxEncode(object): - def __init__(self, use_xywh=False, **kwargs): - self.use_xywh = use_xywh + def __init__(self, in_box_format='xyxy', out_box_format='xyxy', **kwargs): + assert out_box_format in ['xywh', 'xyxy', 'xyxyxyxy'] + self.in_box_format = in_box_format + self.out_box_format = out_box_format def __call__(self, data): img_height, img_width = data['image'].shape[:2] bboxes = data['bboxes'] - if self.use_xywh and bboxes.shape[1] == 4: - bboxes = self.xyxy2xywh(bboxes) + if self.in_box_format != self.out_box_format: + if self.out_box_format == 'xywh': + if self.in_box_format == 'xyxyxyxy': + bboxes = self.xyxyxyxy2xywh(bboxes) + elif self.in_box_format == 'xyxy': + bboxes = self.xyxy2xywh(bboxes) + bboxes[:, 0::2] /= img_width bboxes[:, 1::2] /= img_height data['bboxes'] = bboxes return data + def xyxyxyxy2xywh(self, boxes): + new_bboxes = np.zeros([len(bboxes), 4]) + new_bboxes[:, 0] = bboxes[:, 0::2].min() # x1 + new_bboxes[:, 1] = bboxes[:, 1::2].min() # y1 + new_bboxes[:, 2] = bboxes[:, 0::2].max() - new_bboxes[:, 0] # w + new_bboxes[:, 3] = bboxes[:, 1::2].max() - new_bboxes[:, 1] # h + return new_bboxes + def xyxy2xywh(self, bboxes): - """ - Convert coord (x1,y1,x2,y2) to (x,y,w,h). - where (x1,y1) is top-left, (x2,y2) is bottom-right. - (x,y) is bbox center and (w,h) is width and height. - :param bboxes: (x1, y1, x2, y2) - :return: - """ new_bboxes = np.empty_like(bboxes) new_bboxes[:, 0] = (bboxes[:, 0] + bboxes[:, 2]) / 2 # x center new_bboxes[:, 1] = (bboxes[:, 1] + bboxes[:, 3]) / 2 # y center diff --git a/ppocr/data/imaug/table_ops.py b/ppocr/data/imaug/table_ops.py index 8d139190ab4b22c553036ddc8e31cfbc7ec3423d..c2c2fb2be6c80fdeb637717af2bbe122e1be999c 100644 --- a/ppocr/data/imaug/table_ops.py +++ b/ppocr/data/imaug/table_ops.py @@ -206,7 +206,7 @@ class ResizeTableImage(object): data['bboxes'] = data['bboxes'] * ratio data['image'] = resize_img data['src_img'] = img - data['shape'] = np.array([resize_h, resize_w, ratio, ratio]) + data['shape'] = np.array([height, width, ratio, ratio]) data['max_len'] = self.max_len return data diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 1851fc84e4ee7ec69f9c5261446fea50bec493a0..1a11778945c9d7b5f5519cd55473e8bf7790db2c 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -52,7 +52,7 @@ from .basic_loss import DistanceLoss from .combined_loss import CombinedLoss # table loss -from .table_att_loss import TableAttentionLoss +from .table_att_loss import TableAttentionLoss, SLALoss from .table_master_loss import TableMasterLoss # vqa token loss from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss @@ -67,7 +67,8 @@ def build_loss(config): 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', - 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss' + 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss', + 'SLALoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/table_att_loss.py b/ppocr/losses/table_att_loss.py index 3496c9072553d839017eaa017fe47dfb66fb9d3b..f1771847b46b99d8cf2a3ae69e7e990ee02f26a5 100644 --- a/ppocr/losses/table_att_loss.py +++ b/ppocr/losses/table_att_loss.py @@ -22,65 +22,11 @@ from paddle.nn import functional as F class TableAttentionLoss(nn.Layer): - def __init__(self, - structure_weight, - loc_weight, - use_giou=False, - giou_weight=1.0, - **kwargs): + def __init__(self, structure_weight, loc_weight, **kwargs): super(TableAttentionLoss, self).__init__() self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none') self.structure_weight = structure_weight self.loc_weight = loc_weight - self.use_giou = use_giou - self.giou_weight = giou_weight - - def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'): - ''' - :param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,] - :param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,] - :return: loss - ''' - ix1 = paddle.maximum(preds[:, 0], bbox[:, 0]) - iy1 = paddle.maximum(preds[:, 1], bbox[:, 1]) - ix2 = paddle.minimum(preds[:, 2], bbox[:, 2]) - iy2 = paddle.minimum(preds[:, 3], bbox[:, 3]) - - iw = paddle.clip(ix2 - ix1 + 1e-3, 0., 1e10) - ih = paddle.clip(iy2 - iy1 + 1e-3, 0., 1e10) - - # overlap - inters = iw * ih - - # union - uni = (preds[:, 2] - preds[:, 0] + 1e-3) * ( - preds[:, 3] - preds[:, 1] + 1e-3) + (bbox[:, 2] - bbox[:, 0] + 1e-3 - ) * (bbox[:, 3] - bbox[:, 1] + - 1e-3) - inters + eps - - # ious - ious = inters / uni - - ex1 = paddle.minimum(preds[:, 0], bbox[:, 0]) - ey1 = paddle.minimum(preds[:, 1], bbox[:, 1]) - ex2 = paddle.maximum(preds[:, 2], bbox[:, 2]) - ey2 = paddle.maximum(preds[:, 3], bbox[:, 3]) - ew = paddle.clip(ex2 - ex1 + 1e-3, 0., 1e10) - eh = paddle.clip(ey2 - ey1 + 1e-3, 0., 1e10) - - # enclose erea - enclose = ew * eh + eps - giou = ious - (enclose - uni) / enclose - - loss = 1 - giou - - if reduction == 'mean': - loss = paddle.mean(loss) - elif reduction == 'sum': - loss = paddle.sum(loss) - else: - raise NotImplementedError - return loss def forward(self, predicts, batch): structure_probs = predicts['structure_probs'] @@ -100,20 +46,48 @@ class TableAttentionLoss(nn.Layer): loc_targets_mask = loc_targets_mask[:, 1:, :] loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight - if self.use_giou: - loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, - loc_targets) * self.giou_weight - total_loss = structure_loss + loc_loss + loc_loss_giou - return { - 'loss': total_loss, - "structure_loss": structure_loss, - "loc_loss": loc_loss, - "loc_loss_giou": loc_loss_giou - } - else: - total_loss = structure_loss + loc_loss - return { - 'loss': total_loss, - "structure_loss": structure_loss, - "loc_loss": loc_loss - } + + total_loss = structure_loss + loc_loss + return { + 'loss': total_loss, + "structure_loss": structure_loss, + "loc_loss": loc_loss + } + + +class SLALoss(nn.Layer): + def __init__(self, structure_weight, loc_weight, loc_loss='mse', **kwargs): + super(SLALoss, self).__init__() + self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='mean') + self.structure_weight = structure_weight + self.loc_weight = loc_weight + self.loc_loss = loc_loss + self.eps = 1e-12 + + def forward(self, predicts, batch): + structure_probs = predicts['structure_probs'] + structure_targets = batch[1].astype("int64") + structure_targets = structure_targets[:, 1:] + + structure_loss = self.loss_func(structure_probs, structure_targets) + + structure_loss = paddle.mean(structure_loss) * self.structure_weight + + loc_preds = predicts['loc_preds'] + loc_targets = batch[2].astype("float32") + loc_targets_mask = batch[3].astype("float32") + loc_targets = loc_targets[:, 1:, :] + loc_targets_mask = loc_targets_mask[:, 1:, :] + + loc_loss = F.smooth_l1_loss( + loc_preds * loc_targets_mask, + loc_targets * loc_targets_mask, + reduction='sum') * self.loc_weight + + loc_loss = loc_loss / (loc_targets_mask.sum() + self.eps) + total_loss = structure_loss + loc_loss + return { + 'loss': total_loss, + "structure_loss": structure_loss, + "loc_loss": loc_loss + } diff --git a/ppocr/metrics/table_metric.py b/ppocr/metrics/table_metric.py index fd2631e442b8d111c64d5cf4b34ea9063d8c60dd..c0b247efa672caacb9a9a09a8ef0da58e47367e4 100644 --- a/ppocr/metrics/table_metric.py +++ b/ppocr/metrics/table_metric.py @@ -16,9 +16,14 @@ from ppocr.metrics.det_metric import DetMetric class TableStructureMetric(object): - def __init__(self, main_indicator='acc', eps=1e-6, **kwargs): + def __init__(self, + main_indicator='acc', + eps=1e-6, + del_thead_tbody=False, + **kwargs): self.main_indicator = main_indicator self.eps = eps + self.del_thead_tbody = del_thead_tbody self.reset() def __call__(self, pred_label, batch=None, *args, **kwargs): @@ -31,6 +36,13 @@ class TableStructureMetric(object): gt_structure_batch_list): pred_str = ''.join(pred) target_str = ''.join(target) + if self.del_thead_tbody: + pred_str = pred_str.replace('', '').replace( + '', '').replace('', '').replace('', + '') + target_str = target_str.replace('', '').replace( + '', '').replace('', '').replace('', + '') if pred_str == target_str: correct_num += 1 all_num += 1 @@ -59,7 +71,8 @@ class TableMetric(object): def __init__(self, main_indicator='acc', compute_bbox_metric=False, - point_num=2, + box_format='xyxy', + del_thead_tbody=False, **kwargs): """ @@ -67,10 +80,11 @@ class TableMetric(object): @param main_matric: main_matric for save best_model @param kwargs: """ - self.structure_metric = TableStructureMetric() + self.structure_metric = TableStructureMetric( + del_thead_tbody=del_thead_tbody) self.bbox_metric = DetMetric() if compute_bbox_metric else None self.main_indicator = main_indicator - self.point_num = point_num + self.box_format = box_format self.reset() def __call__(self, pred_label, batch=None, *args, **kwargs): @@ -129,10 +143,14 @@ class TableMetric(object): self.bbox_metric.reset() def format_box(self, box): - if self.point_num == 2: + if self.box_format == 'xyxy': x1, y1, x2, y2 = box box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]] - elif self.point_num == 4: + elif self.box_format == 'xywh': + x, y, w, h = box + x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2 + box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]] + elif self.box_format == 'xyxyxyxy': x1, y1, x2, y2, x3, y3, x4, y4 = box box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]] return box diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index d4f5b15f56d34a9f6a6501058179a643ac7e8318..f5d54150bc325521698c43662895287e5640fb3d 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -21,7 +21,10 @@ def build_backbone(config, model_type): from .det_resnet import ResNet from .det_resnet_vd import ResNet_vd from .det_resnet_vd_sast import ResNet_SAST - support_dict = ["MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST"] + from .det_pp_lcnet import PPLCNet + support_dict = [ + "MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet" + ] if model_type == "table": from .table_master_resnet import TableResNetExtra support_dict.append('TableResNetExtra') diff --git a/ppocr/modeling/backbones/det_pp_lcnet.py b/ppocr/modeling/backbones/det_pp_lcnet.py new file mode 100644 index 0000000000000000000000000000000000000000..3f719e92bc67452b482e5b2053ee1a09540ffc0e --- /dev/null +++ b/ppocr/modeling/backbones/det_pp_lcnet.py @@ -0,0 +1,271 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import, division, print_function + +import os +import paddle +import paddle.nn as nn +from paddle import ParamAttr +from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Dropout, Linear +from paddle.regularizer import L2Decay +from paddle.nn.initializer import KaimingNormal +from paddle.utils.download import get_path_from_url + +MODEL_URLS = { + "PPLCNet_x0.25": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_25_pretrained.pdparams", + "PPLCNet_x0.35": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_35_pretrained.pdparams", + "PPLCNet_x0.5": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_5_pretrained.pdparams", + "PPLCNet_x0.75": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_75_pretrained.pdparams", + "PPLCNet_x1.0": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x1_0_pretrained.pdparams", + "PPLCNet_x1.5": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x1_5_pretrained.pdparams", + "PPLCNet_x2.0": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_0_pretrained.pdparams", + "PPLCNet_x2.5": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_5_pretrained.pdparams" +} + +MODEL_STAGES_PATTERN = { + "PPLCNet": ["blocks2", "blocks3", "blocks4", "blocks5", "blocks6"] +} + +__all__ = list(MODEL_URLS.keys()) + +# Each element(list) represents a depthwise block, which is composed of k, in_c, out_c, s, use_se. +# k: kernel_size +# in_c: input channel number in depthwise block +# out_c: output channel number in depthwise block +# s: stride in depthwise block +# use_se: whether to use SE block + +NET_CONFIG = { + "blocks2": + # k, in_c, out_c, s, use_se + [[3, 16, 32, 1, False]], + "blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]], + "blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]], + "blocks5": + [[3, 128, 256, 2, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False], + [5, 256, 256, 1, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False]], + "blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]] +} + + +def make_divisible(v, divisor=8, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNLayer(nn.Layer): + def __init__(self, + num_channels, + filter_size, + num_filters, + stride, + num_groups=1): + super().__init__() + + self.conv = Conv2D( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=num_groups, + weight_attr=ParamAttr(initializer=KaimingNormal()), + bias_attr=False) + + self.bn = BatchNorm( + num_filters, + param_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + self.hardswish = nn.Hardswish() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.hardswish(x) + return x + + +class DepthwiseSeparable(nn.Layer): + def __init__(self, + num_channels, + num_filters, + stride, + dw_size=3, + use_se=False): + super().__init__() + self.use_se = use_se + self.dw_conv = ConvBNLayer( + num_channels=num_channels, + num_filters=num_channels, + filter_size=dw_size, + stride=stride, + num_groups=num_channels) + if use_se: + self.se = SEModule(num_channels) + self.pw_conv = ConvBNLayer( + num_channels=num_channels, + filter_size=1, + num_filters=num_filters, + stride=1) + + def forward(self, x): + x = self.dw_conv(x) + if self.use_se: + x = self.se(x) + x = self.pw_conv(x) + return x + + +class SEModule(nn.Layer): + def __init__(self, channel, reduction=4): + super().__init__() + self.avg_pool = AdaptiveAvgPool2D(1) + self.conv1 = Conv2D( + in_channels=channel, + out_channels=channel // reduction, + kernel_size=1, + stride=1, + padding=0) + self.relu = nn.ReLU() + self.conv2 = Conv2D( + in_channels=channel // reduction, + out_channels=channel, + kernel_size=1, + stride=1, + padding=0) + self.hardsigmoid = nn.Hardsigmoid() + + def forward(self, x): + identity = x + x = self.avg_pool(x) + x = self.conv1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.hardsigmoid(x) + x = paddle.multiply(x=identity, y=x) + return x + + +class PPLCNet(nn.Layer): + def __init__(self, + in_channels=3, + scale=1.0, + pretrained=False, + use_ssld=False): + super().__init__() + self.out_channels = [ + int(NET_CONFIG["blocks3"][-1][2] * scale), + int(NET_CONFIG["blocks4"][-1][2] * scale), + int(NET_CONFIG["blocks5"][-1][2] * scale), + int(NET_CONFIG["blocks6"][-1][2] * scale) + ] + self.scale = scale + + self.conv1 = ConvBNLayer( + num_channels=in_channels, + filter_size=3, + num_filters=make_divisible(16 * scale), + stride=2) + + self.blocks2 = nn.Sequential(* [ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks2"]) + ]) + + self.blocks3 = nn.Sequential(* [ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks3"]) + ]) + + self.blocks4 = nn.Sequential(* [ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks4"]) + ]) + + self.blocks5 = nn.Sequential(* [ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks5"]) + ]) + + self.blocks6 = nn.Sequential(* [ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks6"]) + ]) + + if pretrained: + self._load_pretrained( + MODEL_URLS['PPLCNet_x{}'.format(scale)], use_ssld=use_ssld) + + def forward(self, x): + outs = [] + x = self.conv1(x) + x = self.blocks2(x) + x = self.blocks3(x) + outs.append(x) + x = self.blocks4(x) + outs.append(x) + x = self.blocks5(x) + outs.append(x) + x = self.blocks6(x) + outs.append(x) + return outs + + def _load_pretrained(self, pretrained_url, use_ssld=False): + if use_ssld: + pretrained_url = pretrained_url.replace("_pretrained", + "_ssld_pretrained") + print(pretrained_url) + local_weight_path = get_path_from_url( + pretrained_url, os.path.expanduser("~/.paddleclas/weights")) + param_state_dict = paddle.load(local_weight_path) + self.set_dict(param_state_dict) + return diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 190622329f8c7419e37d56629e6e76a31e166684..0feda6c6e062fa314d97b8949d8545ed3305c22e 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -44,7 +44,7 @@ def build_head(config): #kie head from .kie_sdmgr_head import SDMGRHead - from .table_att_head import TableAttentionHead + from .table_att_head import TableAttentionHead, SLAHead from .table_master_head import TableMasterHead support_dict = [ @@ -52,7 +52,7 @@ def build_head(config): 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead', - 'VLHead', 'RobustScannerHead' + 'VLHead', 'SLAHead', 'RobustScannerHead' ] #table head diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py index 4f39d6253d8d596fecdc4736666a6d3106601a82..00b434105bd9fe1f0d928c5f026dc5804b33fe23 100644 --- a/ppocr/modeling/heads/table_att_head.py +++ b/ppocr/modeling/heads/table_att_head.py @@ -18,12 +18,26 @@ from __future__ import print_function import paddle import paddle.nn as nn +from paddle import ParamAttr import paddle.nn.functional as F import numpy as np from .rec_att_head import AttentionGRUCell +def get_para_bias_attr(l2_decay, k): + if l2_decay > 0: + regularizer = paddle.regularizer.L2Decay(l2_decay) + stdv = 1.0 / math.sqrt(k * 1.0) + initializer = nn.initializer.Uniform(-stdv, stdv) + else: + regularizer = None + initializer = None + weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer) + bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer) + return [weight_attr, bias_attr] + + class TableAttentionHead(nn.Layer): def __init__(self, in_channels, @@ -32,7 +46,7 @@ class TableAttentionHead(nn.Layer): in_max_len=488, max_text_length=800, out_channels=30, - point_num=2, + loc_reg_num=4, **kwargs): super(TableAttentionHead, self).__init__() self.input_size = in_channels[-1] @@ -56,7 +70,7 @@ class TableAttentionHead(nn.Layer): else: self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1) self.loc_generator = nn.Linear(self.input_size + hidden_size, - point_num * 2) + loc_reg_num) def _char_to_onehot(self, input_char, onehot_dim): input_ont_hot = F.one_hot(input_char, onehot_dim) @@ -129,3 +143,121 @@ class TableAttentionHead(nn.Layer): loc_preds = self.loc_generator(loc_concat) loc_preds = F.sigmoid(loc_preds) return {'structure_probs': structure_probs, 'loc_preds': loc_preds} + + +class SLAHead(nn.Layer): + def __init__(self, + in_channels, + hidden_size, + out_channels=30, + max_text_length=500, + loc_reg_num=4, + fc_decay=0.0, + **kwargs): + """ + @param in_channels: input shape + @param hidden_size: hidden_size for RNN and Embedding + @param out_channels: num_classes to rec + @param max_text_length: max text pred + """ + super().__init__() + in_channels = in_channels[-1] + self.hidden_size = hidden_size + self.max_text_length = max_text_length + self.emb = self._char_to_onehot + self.num_embeddings = out_channels + + # structure + self.structure_attention_cell = AttentionGRUCell( + in_channels, hidden_size, self.num_embeddings) + weight_attr, bias_attr = get_para_bias_attr( + l2_decay=fc_decay, k=hidden_size) + weight_attr1_1, bias_attr1_1 = get_para_bias_attr( + l2_decay=fc_decay, k=hidden_size) + weight_attr1_2, bias_attr1_2 = get_para_bias_attr( + l2_decay=fc_decay, k=hidden_size) + self.structure_generator = nn.Sequential( + nn.Linear( + self.hidden_size, + self.hidden_size, + weight_attr=weight_attr1_2, + bias_attr=bias_attr1_2), + nn.Linear( + hidden_size, + out_channels, + weight_attr=weight_attr, + bias_attr=bias_attr)) + # loc + weight_attr1, bias_attr1 = get_para_bias_attr( + l2_decay=fc_decay, k=self.hidden_size) + weight_attr2, bias_attr2 = get_para_bias_attr( + l2_decay=fc_decay, k=self.hidden_size) + self.loc_generator = nn.Sequential( + nn.Linear( + self.hidden_size, + self.hidden_size, + weight_attr=weight_attr1, + bias_attr=bias_attr1), + nn.Linear( + self.hidden_size, + loc_reg_num, + weight_attr=weight_attr2, + bias_attr=bias_attr2), + nn.Sigmoid()) + + def forward(self, inputs, targets=None): + fea = inputs[-1] + batch_size = fea.shape[0] + # reshape + fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], -1]) + fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) + + hidden = paddle.zeros((batch_size, self.hidden_size)) + structure_preds = [] + loc_preds = [] + if self.training and targets is not None: + structure = targets[0] + for i in range(self.max_text_length + 1): + hidden, structure_step, loc_step = self._decode(structure[:, i], + fea, hidden) + structure_preds.append(structure_step) + loc_preds.append(loc_step) + else: + pre_chars = paddle.zeros(shape=[batch_size], dtype="int32") + max_text_length = paddle.to_tensor(self.max_text_length) + # for export + loc_step, structure_step = None, None + for i in range(max_text_length + 1): + hidden, structure_step, loc_step = self._decode(pre_chars, fea, + hidden) + pre_chars = structure_step.argmax(axis=1, dtype="int32") + structure_preds.append(structure_step) + loc_preds.append(loc_step) + structure_preds = paddle.stack(structure_preds, axis=1) + loc_preds = paddle.stack(loc_preds, axis=1) + if not self.training: + structure_preds = F.softmax(structure_preds) + return {'structure_probs': structure_preds, 'loc_preds': loc_preds} + + def _decode(self, pre_chars, features, hidden): + """ + Predict table label and coordinates for each step + @param pre_chars: Table label in previous step + @param features: + @param hidden: hidden status in previous step + @return: + """ + emb_feature = self.emb(pre_chars) + # output shape is b * self.hidden_size + (output, hidden), alpha = self.structure_attention_cell( + hidden, features, emb_feature) + + # structure + structure_step = self.structure_generator(output) + # loc + loc_step = self.loc_generator(output) + return hidden, structure_step, loc_step + + def _char_to_onehot(self, input_char): + input_ont_hot = F.one_hot(input_char, self.num_embeddings) + return input_ont_hot diff --git a/ppocr/modeling/heads/table_master_head.py b/ppocr/modeling/heads/table_master_head.py index fddbcc63fcd6d5380f9fdd96f9ca85756d666442..486f9cbea13c15b0f3a6d608789163f18f678914 100644 --- a/ppocr/modeling/heads/table_master_head.py +++ b/ppocr/modeling/heads/table_master_head.py @@ -37,7 +37,7 @@ class TableMasterHead(nn.Layer): d_ff=2048, dropout=0, max_text_length=500, - point_num=2, + loc_reg_num=4, **kwargs): super(TableMasterHead, self).__init__() hidden_size = in_channels[-1] @@ -50,7 +50,7 @@ class TableMasterHead(nn.Layer): self.cls_fc = nn.Linear(hidden_size, out_channels) self.bbox_fc = nn.Sequential( # nn.Linear(hidden_size, hidden_size), - nn.Linear(hidden_size, point_num * 2), + nn.Linear(hidden_size, loc_reg_num), nn.Sigmoid()) self.norm = nn.LayerNorm(hidden_size) self.embedding = Embeddings(d_model=hidden_size, vocab=out_channels) @@ -59,7 +59,7 @@ class TableMasterHead(nn.Layer): self.SOS = out_channels - 3 self.PAD = out_channels - 1 self.out_channels = out_channels - self.point_num = point_num + self.loc_reg_num = loc_reg_num self.max_text_length = max_text_length def make_mask(self, tgt): @@ -105,7 +105,7 @@ class TableMasterHead(nn.Layer): output = paddle.zeros( [input.shape[0], self.max_text_length + 1, self.out_channels]) bbox_output = paddle.zeros( - [input.shape[0], self.max_text_length + 1, self.point_num * 2]) + [input.shape[0], self.max_text_length + 1, self.loc_reg_num]) max_text_length = paddle.to_tensor(self.max_text_length) for i in range(max_text_length + 1): target_mask = self.make_mask(input) diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py index e10b082d11be69b1865f0093b6fec442b255f03a..e3ae2d6ef27821f592645a4ba945d3feeaa8cf8a 100644 --- a/ppocr/modeling/necks/__init__.py +++ b/ppocr/modeling/necks/__init__.py @@ -25,9 +25,10 @@ def build_neck(config): from .fpn import FPN from .fce_fpn import FCEFPN from .pren_fpn import PRENFPN + from .csp_pan import CSPPAN support_dict = [ 'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN', - 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN' + 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN' ] module_name = config.pop('name') diff --git a/ppocr/modeling/necks/csp_pan.py b/ppocr/modeling/necks/csp_pan.py new file mode 100755 index 0000000000000000000000000000000000000000..f4f8547f7d80d25edfe66824aa4f104341ae29ef --- /dev/null +++ b/ppocr/modeling/necks/csp_pan.py @@ -0,0 +1,324 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The code is based on: +# https://github.com/PaddlePaddle/PaddleDetection/blob/release%2F2.3/ppdet/modeling/necks/csp_pan.py + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr + +__all__ = ['CSPPAN'] + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channel=96, + out_channel=96, + kernel_size=3, + stride=1, + groups=1, + act='leaky_relu'): + super(ConvBNLayer, self).__init__() + initializer = nn.initializer.KaimingUniform() + self.act = act + assert self.act in ['leaky_relu', "hard_swish"] + self.conv = nn.Conv2D( + in_channels=in_channel, + out_channels=out_channel, + kernel_size=kernel_size, + groups=groups, + padding=(kernel_size - 1) // 2, + stride=stride, + weight_attr=ParamAttr(initializer=initializer), + bias_attr=False) + self.bn = nn.BatchNorm2D(out_channel) + + def forward(self, x): + x = self.bn(self.conv(x)) + if self.act == "leaky_relu": + x = F.leaky_relu(x) + elif self.act == "hard_swish": + x = F.hardswish(x) + return x + + +class DPModule(nn.Layer): + """ + Depth-wise and point-wise module. + Args: + in_channel (int): The input channels of this Module. + out_channel (int): The output channels of this Module. + kernel_size (int): The conv2d kernel size of this Module. + stride (int): The conv2d's stride of this Module. + act (str): The activation function of this Module, + Now support `leaky_relu` and `hard_swish`. + """ + + def __init__(self, + in_channel=96, + out_channel=96, + kernel_size=3, + stride=1, + act='leaky_relu'): + super(DPModule, self).__init__() + initializer = nn.initializer.KaimingUniform() + self.act = act + self.dwconv = nn.Conv2D( + in_channels=in_channel, + out_channels=out_channel, + kernel_size=kernel_size, + groups=out_channel, + padding=(kernel_size - 1) // 2, + stride=stride, + weight_attr=ParamAttr(initializer=initializer), + bias_attr=False) + self.bn1 = nn.BatchNorm2D(out_channel) + self.pwconv = nn.Conv2D( + in_channels=out_channel, + out_channels=out_channel, + kernel_size=1, + groups=1, + padding=0, + weight_attr=ParamAttr(initializer=initializer), + bias_attr=False) + self.bn2 = nn.BatchNorm2D(out_channel) + + def act_func(self, x): + if self.act == "leaky_relu": + x = F.leaky_relu(x) + elif self.act == "hard_swish": + x = F.hardswish(x) + return x + + def forward(self, x): + x = self.act_func(self.bn1(self.dwconv(x))) + x = self.act_func(self.bn2(self.pwconv(x))) + return x + + +class DarknetBottleneck(nn.Layer): + """The basic bottleneck block used in Darknet. + Each Block consists of two ConvModules and the input is added to the + final output. Each ConvModule is composed of Conv, BN, and act. + The first convLayer has filter size of 1x1 and the second one has the + filter size of 3x3. + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + expansion (int): The kernel size of the convolution. Default: 0.5 + add_identity (bool): Whether to add identity to the out. + Default: True + use_depthwise (bool): Whether to use depthwise separable convolution. + Default: False + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + expansion=0.5, + add_identity=True, + use_depthwise=False, + act="leaky_relu"): + super(DarknetBottleneck, self).__init__() + hidden_channels = int(out_channels * expansion) + conv_func = DPModule if use_depthwise else ConvBNLayer + self.conv1 = ConvBNLayer( + in_channel=in_channels, + out_channel=hidden_channels, + kernel_size=1, + act=act) + self.conv2 = conv_func( + in_channel=hidden_channels, + out_channel=out_channels, + kernel_size=kernel_size, + stride=1, + act=act) + self.add_identity = \ + add_identity and in_channels == out_channels + + def forward(self, x): + identity = x + out = self.conv1(x) + out = self.conv2(out) + + if self.add_identity: + return out + identity + else: + return out + + +class CSPLayer(nn.Layer): + """Cross Stage Partial Layer. + Args: + in_channels (int): The input channels of the CSP layer. + out_channels (int): The output channels of the CSP layer. + expand_ratio (float): Ratio to adjust the number of channels of the + hidden layer. Default: 0.5 + num_blocks (int): Number of blocks. Default: 1 + add_identity (bool): Whether to add identity in blocks. + Default: True + use_depthwise (bool): Whether to depthwise separable convolution in + blocks. Default: False + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + expand_ratio=0.5, + num_blocks=1, + add_identity=True, + use_depthwise=False, + act="leaky_relu"): + super().__init__() + mid_channels = int(out_channels * expand_ratio) + self.main_conv = ConvBNLayer(in_channels, mid_channels, 1, act=act) + self.short_conv = ConvBNLayer(in_channels, mid_channels, 1, act=act) + self.final_conv = ConvBNLayer( + 2 * mid_channels, out_channels, 1, act=act) + + self.blocks = nn.Sequential(* [ + DarknetBottleneck( + mid_channels, + mid_channels, + kernel_size, + 1.0, + add_identity, + use_depthwise, + act=act) for _ in range(num_blocks) + ]) + + def forward(self, x): + x_short = self.short_conv(x) + + x_main = self.main_conv(x) + x_main = self.blocks(x_main) + + x_final = paddle.concat((x_main, x_short), axis=1) + return self.final_conv(x_final) + + +class Channel_T(nn.Layer): + def __init__(self, + in_channels=[116, 232, 464], + out_channels=96, + act="leaky_relu"): + super(Channel_T, self).__init__() + self.convs = nn.LayerList() + for i in range(len(in_channels)): + self.convs.append( + ConvBNLayer( + in_channels[i], out_channels, 1, act=act)) + + def forward(self, x): + outs = [self.convs[i](x[i]) for i in range(len(x))] + return outs + + +class CSPPAN(nn.Layer): + """Path Aggregation Network with CSP module. + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale) + kernel_size (int): The conv2d kernel size of this Module. + num_csp_blocks (int): Number of bottlenecks in CSPLayer. Default: 1 + use_depthwise (bool): Whether to depthwise separable convolution in + blocks. Default: True + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=5, + num_csp_blocks=1, + use_depthwise=True, + act='hard_swish'): + super(CSPPAN, self).__init__() + self.in_channels = in_channels + self.out_channels = [out_channels] * len(in_channels) + conv_func = DPModule if use_depthwise else ConvBNLayer + + self.conv_t = Channel_T(in_channels, out_channels, act=act) + + # build top-down blocks + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + self.top_down_blocks = nn.LayerList() + for idx in range(len(in_channels) - 1, 0, -1): + self.top_down_blocks.append( + CSPLayer( + out_channels * 2, + out_channels, + kernel_size=kernel_size, + num_blocks=num_csp_blocks, + add_identity=False, + use_depthwise=use_depthwise, + act=act)) + + # build bottom-up blocks + self.downsamples = nn.LayerList() + self.bottom_up_blocks = nn.LayerList() + for idx in range(len(in_channels) - 1): + self.downsamples.append( + conv_func( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=2, + act=act)) + self.bottom_up_blocks.append( + CSPLayer( + out_channels * 2, + out_channels, + kernel_size=kernel_size, + num_blocks=num_csp_blocks, + add_identity=False, + use_depthwise=use_depthwise, + act=act)) + + def forward(self, inputs): + """ + Args: + inputs (tuple[Tensor]): input features. + Returns: + tuple[Tensor]: CSPPAN features. + """ + assert len(inputs) == len(self.in_channels) + inputs = self.conv_t(inputs) + + # top-down path + inner_outs = [inputs[-1]] + for idx in range(len(self.in_channels) - 1, 0, -1): + feat_heigh = inner_outs[0] + feat_low = inputs[idx - 1] + upsample_feat = F.upsample( + feat_heigh, size=paddle.shape(feat_low)[2:4], mode="nearest") + + inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx]( + paddle.concat([upsample_feat, feat_low], 1)) + inner_outs.insert(0, inner_out) + + # bottom-up path + outs = [inner_outs[0]] + for idx in range(len(self.in_channels) - 1): + feat_low = outs[-1] + feat_height = inner_outs[idx + 1] + downsample_feat = self.downsamples[idx](feat_low) + out = self.bottom_up_blocks[idx](paddle.concat( + [downsample_feat, feat_height], 1)) + outs.append(out) + + return tuple(outs) diff --git a/ppocr/postprocess/table_postprocess.py b/ppocr/postprocess/table_postprocess.py index 4396ec4f701478e7bdcdd8c7752738c5c8ef148d..a47061f935e31b24fdb624df170f8abb38e01f40 100644 --- a/ppocr/postprocess/table_postprocess.py +++ b/ppocr/postprocess/table_postprocess.py @@ -21,9 +21,29 @@ from .rec_postprocess import AttnLabelDecode class TableLabelDecode(AttnLabelDecode): """ """ - def __init__(self, character_dict_path, **kwargs): - super(TableLabelDecode, self).__init__(character_dict_path) - self.td_token = ['', '', ''] + def __init__(self, + character_dict_path, + merge_no_span_structure=False, + **kwargs): + dict_character = [] + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip("\n").strip("\r\n") + dict_character.append(line) + + if merge_no_span_structure: + if "" not in dict_character: + dict_character.append("") + if "" in dict_character: + dict_character.remove("") + + dict_character = self.add_special_char(dict_character) + self.dict = {} + for i, char in enumerate(dict_character): + self.dict[char] = i + self.character = dict_character + self.td_token = ['', ''] def __call__(self, preds, batch=None): structure_probs = preds['structure_probs'] @@ -114,18 +134,21 @@ class TableLabelDecode(AttnLabelDecode): def _bbox_decode(self, bbox, shape): h, w, ratio_h, ratio_w, pad_h, pad_w = shape - src_h = h / ratio_h - src_w = w / ratio_w - bbox[0::2] *= src_w - bbox[1::2] *= src_h + bbox[0::2] *= w + bbox[1::2] *= h return bbox class TableMasterLabelDecode(TableLabelDecode): """ """ - def __init__(self, character_dict_path, box_shape='ori', **kwargs): - super(TableMasterLabelDecode, self).__init__(character_dict_path) + def __init__(self, + character_dict_path, + box_shape='ori', + merge_no_span_structure=True, + **kwargs): + super(TableMasterLabelDecode, self).__init__(character_dict_path, + merge_no_span_structure) self.box_shape = box_shape assert box_shape in [ 'ori', 'pad' @@ -157,4 +180,7 @@ class TableMasterLabelDecode(TableLabelDecode): bbox[1::2] *= h bbox[0::2] /= ratio_w bbox[1::2] /= ratio_h + x, y, w, h = bbox + x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2 + bbox = np.array([x1, y1, x2, y2]) return bbox diff --git a/ppocr/utils/dict/table_structure_dict_ch.txt b/ppocr/utils/dict/table_structure_dict_ch.txt new file mode 100644 index 0000000000000000000000000000000000000000..0c59c0e9998a31f9d32f703625aa1c5ca7718c8d --- /dev/null +++ b/ppocr/utils/dict/table_structure_dict_ch.txt @@ -0,0 +1,48 @@ + + + + + + + + + + colspan="2" + colspan="3" + colspan="4" + colspan="5" + colspan="6" + colspan="7" + colspan="8" + colspan="9" + colspan="10" + colspan="11" + colspan="12" + colspan="13" + colspan="14" + colspan="15" + colspan="16" + colspan="17" + colspan="18" + colspan="19" + colspan="20" + rowspan="2" + rowspan="3" + rowspan="4" + rowspan="5" + rowspan="6" + rowspan="7" + rowspan="8" + rowspan="9" + rowspan="10" + rowspan="11" + rowspan="12" + rowspan="13" + rowspan="14" + rowspan="15" + rowspan="16" + rowspan="17" + rowspan="18" + rowspan="19" + rowspan="20" diff --git a/ppocr/utils/network.py b/ppocr/utils/network.py index 118d1be364925d9416134cffe21d636fcac753e9..080a5d160116cfdd3b255a883525281d97ee9cc9 100644 --- a/ppocr/utils/network.py +++ b/ppocr/utils/network.py @@ -41,9 +41,7 @@ def download_with_progressbar(url, save_path): def maybe_download(model_storage_directory, url): # using custom model - tar_file_name_list = [ - 'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel' - ] + tar_file_name_list = ['.pdiparams', '.pdiparams.info', '.pdmodel'] if not os.path.exists( os.path.join(model_storage_directory, 'inference.pdiparams') ) or not os.path.exists( @@ -57,8 +55,8 @@ def maybe_download(model_storage_directory, url): for member in tarObj.getmembers(): filename = None for tar_file_name in tar_file_name_list: - if tar_file_name in member.name: - filename = tar_file_name + if member.name.endswith(tar_file_name): + filename = 'inference' + tar_file_name if filename is None: continue file = tarObj.extractfile(member) diff --git a/ppocr/utils/visual.py b/ppocr/utils/visual.py index e0fbf06abb471c294cb268520fb99bca1a6b1d61..5bd805ea6e76be37612a142102beab492bece941 100644 --- a/ppocr/utils/visual.py +++ b/ppocr/utils/visual.py @@ -113,14 +113,11 @@ def draw_re_results(image, return np.array(img_new) -def draw_rectangle(img_path, boxes, use_xywh=False): +def draw_rectangle(img_path, boxes): + boxes = np.array(boxes) img = cv2.imread(img_path) img_show = img.copy() for box in boxes.astype(int): - if use_xywh: - x, y, w, h = box - x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2 - else: - x1, y1, x2, y2 = box + x1, y1, x2, y2 = box cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2) return img_show \ No newline at end of file diff --git a/ppstructure/README.md b/ppstructure/README.md index 72670e33575ebe444c78b15fbab4e330389a7498..856de5a4306de987378dafc65e582f280be4bef3 100644 --- a/ppstructure/README.md +++ b/ppstructure/README.md @@ -106,9 +106,9 @@ PP-Structure Series Model List (Updating) |model name|description|model size|download| | --- | --- | --- | --- | -|ch_PP-OCRv2_det_slim|[New] Slim quantization with distillation lightweight model, supporting Chinese, English, multilingual text detection| 3M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar)| -|ch_PP-OCRv2_rec_slim|[New] Slim qunatization with distillation lightweight model, supporting Chinese, English, multilingual text recognition| 9M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_train.tar) | -|en_ppocr_mobile_v2.0_table_structure|Table structure prediction of English table scene trained on PubLayNet dataset| 18.6M |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) | +|ch_PP-OCRv3_det_slim|[New] slim quantization with distillation lightweight model, supporting Chinese, English, multilingual text detection| 1.1M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_distill_train.tar)| +|ch_PP-OCRv3_rec_slim |[New] Slim qunatization with distillation lightweight model, supporting Chinese, English text recognition| 4.9M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_train.tar) | +|ch_ppstructure_mobile_v2.0_SLANet|Chinese table recognition model trained on PubTabNet dataset based on SLANet|9.3M|[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_train.tar) | ### 7.3 DOC-VQA model diff --git a/ppstructure/README_ch.md b/ppstructure/README_ch.md index ddacbb077937f325db0430846b8f05bfda9619cd..64af0cbe53265c85fd9027fe48e82102f4b5ea57 100644 --- a/ppstructure/README_ch.md +++ b/ppstructure/README_ch.md @@ -120,9 +120,10 @@ PP-Structure系列模型列表(更新中) |模型名称|模型简介|模型大小|下载地址| | --- | --- | --- | --- | -|ch_PP-OCRv2_det_slim|【最新】slim量化+蒸馏版超轻量模型,支持中英文、多语种文本检测| 3M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar)| -|ch_PP-OCRv2_rec_slim|【最新】slim量化版超轻量模型,支持中英文、数字识别| 9M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_train.tar) | -|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) | +|ch_PP-OCRv3_det_slim|【最新】slim量化+蒸馏版超轻量模型,支持中英文、多语种文本检测| 1.1M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_distill_train.tar)| +|ch_PP-OCRv3_rec_slim |【最新】slim量化版超轻量模型,支持中英文、数字识别| 4.9M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_train.tar) | +|ch_ppstructure_mobile_v2.0_SLANet|基于SLANet在PubTabNet数据集上训练的中文表格识别模型|9.3M|[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_train.tar) | + ### 7.3 DocVQA 模型 diff --git a/ppstructure/docs/imgs/slanet_result.jpg b/ppstructure/docs/imgs/slanet_result.jpg new file mode 100644 index 0000000000000000000000000000000000000000..011857fbc2295b91a96d938f861d38b8e07421bc Binary files /dev/null and b/ppstructure/docs/imgs/slanet_result.jpg differ diff --git a/ppstructure/docs/imgs/table_ch_result1.jpg b/ppstructure/docs/imgs/table_ch_result1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c75eee40f642d437451fa16bff9cb4a3bdb4f38a Binary files /dev/null and b/ppstructure/docs/imgs/table_ch_result1.jpg differ diff --git a/ppstructure/docs/imgs/table_ch_result2.jpg b/ppstructure/docs/imgs/table_ch_result2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..802871a8e1b6983f304fc73a9fd13404aa02630f Binary files /dev/null and b/ppstructure/docs/imgs/table_ch_result2.jpg differ diff --git a/ppstructure/docs/imgs/table_ch_result3.jpg b/ppstructure/docs/imgs/table_ch_result3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bdd92aa6ee7819c837fd3e2abc38cac915588a71 Binary files /dev/null and b/ppstructure/docs/imgs/table_ch_result3.jpg differ diff --git a/ppstructure/docs/installation.md b/ppstructure/docs/installation.md index 155baf29de5701b58c9342cf82897b23f4ab7e45..3f564cb2ddfe546642e6f92e2c024bbe3a1f7ffc 100644 --- a/ppstructure/docs/installation.md +++ b/ppstructure/docs/installation.md @@ -1,8 +1,7 @@ - [快速安装](#快速安装) - [1. PaddlePaddle 和 PaddleOCR](#1-paddlepaddle-和-paddleocr) - [2. 安装其他依赖](#2-安装其他依赖) - - [2.1 版面分析所需 Layout-Parser](#21-版面分析所需--layout-parser) - - [2.2 VQA所需依赖](#22--vqa所需依赖) + - [2.1 VQA所需依赖](#21--vqa所需依赖) # 快速安装 @@ -12,14 +11,7 @@ ## 2. 安装其他依赖 -### 2.1 版面分析所需 Layout-Parser - -Layout-Parser 可通过如下命令安装 - -```bash -pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl -``` -### 2.2 VQA所需依赖 +### 2.1 VQA所需依赖 * paddleocr ```bash diff --git a/ppstructure/docs/installation_en.md b/ppstructure/docs/installation_en.md new file mode 100644 index 0000000000000000000000000000000000000000..02b02db0c58f60a5296734b93563510732a7286d --- /dev/null +++ b/ppstructure/docs/installation_en.md @@ -0,0 +1,30 @@ +# Quick installation + +- [1. PaddlePaddle 和 PaddleOCR](#1) +- [2. Install other dependencies](#2) + - [2.1 VQA](#21) + + + +## 1. PaddlePaddle and PaddleOCR + +Please refer to [PaddleOCR installation documentation](../../doc/doc_en/installation_en.md) + + +## 2. Install other dependencies + + +### 2.1 VQA + +* paddleocr + +```bash +pip3 install paddleocr +``` + +* PaddleNLP +```bash +git clone https://github.com/PaddlePaddle/PaddleNLP -b develop +cd PaddleNLP +pip3 install -e . +``` diff --git a/ppstructure/docs/models_list.md b/ppstructure/docs/models_list.md index 89fa98d3b77a1a17f53e0f5efa770396360c87b1..ef2994cabea38709464780d25b5f32c3b9801b4c 100644 --- a/ppstructure/docs/models_list.md +++ b/ppstructure/docs/models_list.md @@ -34,7 +34,9 @@ |模型名称|模型简介|推理模型大小|下载地址| | --- | --- | --- | --- | -|en_ppocr_mobile_v2.0_table_structure|PubTabNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) | +|en_ppocr_mobile_v2.0_table_structure|基于TableRec-RARE在PubTabNet数据集上训练的英文表格识别模型|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) | +|en_ppstructure_mobile_v2.0_SLANet|基于SLANet在PubTabNet数据集上训练的英文表格识别模型|9M|[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_train.tar) | +|ch_ppstructure_mobile_v2.0_SLANet|基于SLANet在PubTabNet数据集上训练的中文表格识别模型|9.3M|[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_train.tar) | diff --git a/ppstructure/docs/models_list_en.md b/ppstructure/docs/models_list_en.md index e133a0bb2a9b017207b5e92ea444aba4633a7457..64a7cdebc3e3c7ac18ae2f61013aa4e8a7c3ead8 100644 --- a/ppstructure/docs/models_list_en.md +++ b/ppstructure/docs/models_list_en.md @@ -35,7 +35,9 @@ If you need to use other OCR models, you can download the model in [PP-OCR model |model| description |inference model size|download| | --- |-----------------------------------------------------------------------------| --- | --- | -|en_ppocr_mobile_v2.0_table_structure| Table structure model for English table scenes trained on PubTabNet dataset |18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) | +|en_ppocr_mobile_v2.0_table_structure| English table recognition model trained on PubTabNet dataset based on TableRec-RARE |18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) | +|en_ppstructure_mobile_v2.0_SLANet|English table recognition model trained on PubTabNet dataset based on SLANet|9M|[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_train.tar) | +|ch_ppstructure_mobile_v2.0_SLANet|Chinese table recognition model trained on PubTabNet dataset based on SLANet|9.3M|[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_train.tar) | ## 3. VQA diff --git a/ppstructure/docs/quickstart.md b/ppstructure/docs/quickstart.md index 31e59416247b4f0e6b6d82fb13e0d3841a113a5f..f4645bdfe011a12370bedc7bd7a125b28ded41ff 100644 --- a/ppstructure/docs/quickstart.md +++ b/ppstructure/docs/quickstart.md @@ -1,21 +1,23 @@ # PP-Structure 快速开始 -- [1. 安装依赖包](#1) -- [2. 便捷使用](#2) - - [2.1 命令行使用](#21) - - [2.1.1 版面分析+表格识别](#211) - - [2.1.2 版面分析](#212) - - [2.1.3 表格识别](#213) - - [2.1.4 DocVQA](#214) - - [2.2 代码使用](#22) - - [2.2.1 版面分析+表格识别](#221) - - [2.2.2 版面分析](#222) - - [2.2.3 表格识别](#223) - - [2.2.4 DocVQA](#224) - - [2.3 返回结果说明](#23) - - [2.3.1 版面分析+表格识别](#231) - - [2.3.2 DocVQA](#232) - - [2.4 参数说明](#24) +- [1. 安装依赖包](#1-安装依赖包) +- [2. 便捷使用](#2-便捷使用) + - [2.1 命令行使用](#21-命令行使用) + - [2.1.1 图像方向分类+版面分析+表格识别](#211-图像方向分类版面分析表格识别) + - [2.1.2 版面分析+表格识别](#212-版面分析表格识别) + - [2.1.3 版面分析](#213-版面分析) + - [2.1.4 表格识别](#214-表格识别) + - [2.1.5 DocVQA](#215-docvqa) + - [2.2 代码使用](#22-代码使用) + - [2.2.1 图像方向分类版面分析表格识别](#221-图像方向分类版面分析表格识别) + - [2.2.2 版面分析+表格识别](#222-版面分析表格识别) + - [2.2.3 版面分析](#223-版面分析) + - [2.2.4 表格识别](#224-表格识别) + - [2.2.5 DocVQA](#225-docvqa) + - [2.3 返回结果说明](#23-返回结果说明) + - [2.3.1 版面分析+表格识别](#231-版面分析表格识别) + - [2.3.2 DocVQA](#232-docvqa) + - [2.4 参数说明](#24-参数说明) @@ -24,8 +26,6 @@ ```bash # 安装 paddleocr,推荐使用2.5+版本 pip3 install "paddleocr>=2.5" -# 安装 版面分析依赖包layoutparser(如不需要版面分析功能,可跳过) -pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl # 安装 DocVQA依赖包paddlenlp(如不需要DocVQA功能,可跳过) pip install paddlenlp @@ -38,25 +38,31 @@ pip install paddlenlp ### 2.1 命令行使用 -#### 2.1.1 版面分析+表格识别 +#### 2.1.1 图像方向分类+版面分析+表格识别 ```bash -paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure +paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --image_orientation=true ``` -#### 2.1.2 版面分析 +#### 2.1.2 版面分析+表格识别 ```bash -paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --table=false --ocr=false +paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure ``` -#### 2.1.3 表格识别 +#### 2.1.3 版面分析 ```bash -paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structure --layout=false +paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --table=false --ocr=false ``` -#### 2.1.4 DocVQA +#### 2.1.4 表格识别 +```bash +paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structure --layout=false +``` + + +#### 2.1.5 DocVQA 请参考:[文档视觉问答](../vqa/README.md)。 @@ -64,14 +70,14 @@ paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structur ### 2.2 代码使用 -#### 2.2.1 版面分析+表格识别 +#### 2.2.1 图像方向分类版面分析表格识别 ```python import os import cv2 from paddleocr import PPStructure,draw_structure_result,save_structure_res -table_engine = PPStructure(show_log=True) +table_engine = PPStructure(show_log=True, image_orientation=True) save_folder = './output' img_path = 'PaddleOCR/ppstructure/docs/table/1.png' @@ -93,7 +99,36 @@ im_show.save('result.jpg') ``` -#### 2.2.2 版面分析 +#### 2.2.2 版面分析+表格识别 + +```python +import os +import cv2 +from paddleocr import PPStructure,draw_structure_result,save_structure_res + +table_engine = PPStructure(show_log=True) + +save_folder = './output' +img_path = 'PaddleOCR/ppstructure/docs/table/1.png' +img = cv2.imread(img_path) +result = table_engine(img) +save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0]) + +for line in result: + line.pop('img') + print(line) + +from PIL import Image + +font_path = 'PaddleOCR/doc/fonts/simfang.ttf' # PaddleOCR下提供字体包 +image = Image.open(img_path).convert('RGB') +im_show = draw_structure_result(image, result,font_path=font_path) +im_show = Image.fromarray(im_show) +im_show.save('result.jpg') +``` + + +#### 2.2.3 版面分析 ```python import os @@ -113,8 +148,8 @@ for line in result: print(line) ``` - -#### 2.2.3 表格识别 + +#### 2.2.4 表格识别 ```python import os @@ -134,8 +169,8 @@ for line in result: print(line) ``` - -#### 2.2.4 DocVQA + +#### 2.2.5 DocVQA 请参考:[文档视觉问答](../vqa/README.md)。 @@ -156,10 +191,10 @@ PP-Structure的返回结果为一个dict组成的list,示例如下 ``` dict 里各个字段说明如下 -| 字段 | 说明 | -| --------------- |-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -|type| 图片区域的类型 | -|bbox| 图片区域的在原图的坐标,分别[左上角x,左上角y,右下角x,右下角y] | +| 字段 | 说明| +| --- |---| +|type| 图片区域的类型 | +|bbox| 图片区域的在原图的坐标,分别[左上角x,左上角y,右下角x,右下角y]| |res| 图片区域的OCR或表格识别结果。
表格: 一个dict,字段说明如下
        `html`: 表格的HTML字符串
        在代码使用模式下,前向传入return_ocr_result_in_table=True可以拿到表格中每个文本的检测识别结果,对应为如下字段:
        `boxes`: 文本检测坐标
        `rec_res`: 文本识别结果。
OCR: 一个包含各个单行文字的检测坐标和识别结果的元组 | 运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名为表格在图片里的坐标。 @@ -180,20 +215,26 @@ dict 里各个字段说明如下 ### 2.4 参数说明 -| 字段 | 说明 | 默认值 | -|----------------------|----------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------| -| output | excel和识别结果保存的地址 | ./output/table | -| table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 | -| table_model_dir | 表格结构模型 inference 模型地址 | None | -| table_char_dict_path | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.txt | -| layout_path_model | 版面分析模型模型地址,可以为在线地址或者本地地址,当为本地地址时,需要指定 layout_label_map, 命令行模式下可通过--layout_label_map='{0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}' 指定 | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config | -| layout_label_map | 版面分析模型模型label映射字典 | None | -| model_name_or_path | VQA SER模型地址 | None | -| max_seq_length | VQA SER模型最大支持token长度 | 512 | -| label_map_path | VQA SER 标签文件地址 | ./vqa/labels/labels_ser.txt | -| layout | 前向中是否执行版面分析 | True | -| table | 前向中是否执行表格识别 | True | -| ocr | 对于版面分析中的非表格区域,是否执行ocr。当layout为False时会被自动设置为False | True | -| structure_version | 表格结构化模型版本,可选 PP-STRUCTURE。PP-STRUCTURE支持表格结构化模型 | PP-STRUCTURE | +| 字段 | 说明 | 默认值 | +|---|---|---| +| output | 结果保存地址 | ./output/table | +| table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 | +| table_model_dir | 表格结构模型 inference 模型地址| None | +| table_char_dict_path | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.txt | +| merge_no_span_structure | 表格识别模型中,是否对'\'和'\' 进行合并 | False | +| layout_model_dir | 版面分析模型 inference 模型地址 | None | +| layout_dict_path | 版面分析模型字典| ../ppocr/utils/dict/layout_publaynet_dict.txt | +| layout_score_threshold | 版面分析模型检测框阈值| 0.5| +| layout_nms_threshold | 版面分析模型nms阈值| 0.5| +| vqa_algorithm | vqa模型算法| LayoutXLM| +| ser_model_dir | ser模型 inference 模型地址| None| +| ser_dict_path | ser模型字典| ../train_data/XFUND/class_list_xfun.txt| +| mode | structure or vqa | structure | +| image_orientation | 前向中是否执行图像方向分类 | False | +| layout | 前向中是否执行版面分析 | True | +| table | 前向中是否执行表格识别 | True | +| ocr | 对于版面分析中的非表格区域,是否执行ocr。当layout为False时会被自动设置为False| True | +| recovery | 前向中是否执行版面恢复| False | +| structure_version | 模型版本,可选 PP-structure和PP-structurev2 | PP-structure | 大部分参数和PaddleOCR whl包保持一致,见 [whl包文档](../../doc/doc_ch/whl.md) diff --git a/ppstructure/docs/quickstart_en.md b/ppstructure/docs/quickstart_en.md index 1f78b43ea3334648a37a37745737a6a26e27ece3..b4dee3f02d3c2762ef71720995f4da697ae43622 100644 --- a/ppstructure/docs/quickstart_en.md +++ b/ppstructure/docs/quickstart_en.md @@ -1,21 +1,23 @@ # PP-Structure Quick Start -- [1. Install package](#1) -- [2. Use](#2) - - [2.1 Use by command line](#21) - - [2.1.1 layout analysis + table recognition](#211) - - [2.1.2 layout analysis](#212) - - [2.1.3 table recognition](#213) - - [2.1.4 DocVQA](#214) - - [2.2 Use by code](#22) - - [2.2.1 layout analysis + table recognition](#221) - - [2.2.2 layout analysis](#222) - - [2.2.3 table recognition](#223) - - [2.2.4 DocVQA](#224) - - [2.3 Result description](#23) - - [2.3.1 layout analysis + table recognition](#231) - - [2.3.2 DocVQA](#232) - - [2.4 Parameter Description](#24) +- [1. Install package](#1-install-package) +- [2. Use](#2-use) + - [2.1 Use by command line](#21-use-by-command-line) + - [2.1.1 image orientation + layout analysis + table recognition](#211-image-orientation--layout-analysis--table-recognition) + - [2.1.2 layout analysis + table recognition](#212-layout-analysis--table-recognition) + - [2.1.3 layout analysis](#213-layout-analysis) + - [2.1.4 table recognition](#214-table-recognition) + - [2.1.5 DocVQA](#215-docvqa) + - [2.2 Use by code](#22-use-by-code) + - [2.2.1 image orientation + layout analysis + table recognition](#221-image-orientation--layout-analysis--table-recognition) + - [2.2.2 layout analysis + table recognition](#222-layout-analysis--table-recognition) + - [2.2.3 layout analysis](#223-layout-analysis) + - [2.2.4 table recognition](#224-table-recognition) + - [2.2.5 DocVQA](#225-docvqa) + - [2.3 Result description](#23-result-description) + - [2.3.1 layout analysis + table recognition](#231-layout-analysis--table-recognition) + - [2.3.2 DocVQA](#232-docvqa) + - [2.4 Parameter Description](#24-parameter-description) @@ -24,8 +26,6 @@ ```bash # Install paddleocr, version 2.5+ is recommended pip3 install "paddleocr>=2.5" -# Install layoutparser (if you do not use the layout analysis, you can skip it) -pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl # Install the DocVQA dependency package paddlenlp (if you do not use the DocVQA, you can skip it) pip install paddlenlp @@ -38,25 +38,31 @@ pip install paddlenlp ### 2.1 Use by command line -#### 2.1.1 layout analysis + table recognition +#### 2.1.1 image orientation + layout analysis + table recognition ```bash -paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure +paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --image_orientation=true ``` -#### 2.1.2 layout analysis +#### 2.1.2 layout analysis + table recognition ```bash -paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --table=false --ocr=false +paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure ``` -#### 2.1.3 table recognition +#### 2.1.3 layout analysis ```bash -paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structure --layout=false +paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --table=false --ocr=false ``` -#### 2.1.4 DocVQA +#### 2.1.4 table recognition +```bash +paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structure --layout=false +``` + + +#### 2.1.5 DocVQA Please refer to: [Documentation Visual Q&A](../vqa/README.md) . @@ -64,14 +70,14 @@ Please refer to: [Documentation Visual Q&A](../vqa/README.md) . ### 2.2 Use by code -#### 2.2.1 layout analysis + table recognition +#### 2.2.1 image orientation + layout analysis + table recognition ```python import os import cv2 from paddleocr import PPStructure,draw_structure_result,save_structure_res -table_engine = PPStructure(show_log=True) +table_engine = PPStructure(show_log=True, image_orientation=True) save_folder = './output' img_path = 'PaddleOCR/ppstructure/docs/table/1.png' @@ -93,7 +99,36 @@ im_show.save('result.jpg') ``` -#### 2.2.2 layout analysis +#### 2.2.2 layout analysis + table recognition + +```python +import os +import cv2 +from paddleocr import PPStructure,draw_structure_result,save_structure_res + +table_engine = PPStructure(show_log=True) + +save_folder = './output' +img_path = 'PaddleOCR/ppstructure/docs/table/1.png' +img = cv2.imread(img_path) +result = table_engine(img) +save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0]) + +for line in result: + line.pop('img') + print(line) + +from PIL import Image + +font_path = 'PaddleOCR/doc/fonts/simfang.ttf' # PaddleOCR下提供字体包 +image = Image.open(img_path).convert('RGB') +im_show = draw_structure_result(image, result,font_path=font_path) +im_show = Image.fromarray(im_show) +im_show.save('result.jpg') +``` + + +#### 2.2.3 layout analysis ```python import os @@ -113,8 +148,8 @@ for line in result: print(line) ``` - -#### 2.2.3 table recognition + +#### 2.2.4 table recognition ```python import os @@ -134,8 +169,8 @@ for line in result: print(line) ``` - -#### 2.2.4 DocVQA + +#### 2.2.5 DocVQA Please refer to: [Documentation Visual Q&A](../vqa/README.md) . @@ -157,8 +192,8 @@ The return of PP-Structure is a list of dicts, the example is as follows: ``` Each field in dict is described as follows: -| field | description | -| --------------- |--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| field | description | +| --- |---| |type| Type of image area. | |bbox| The coordinates of the image area in the original image, respectively [upper left corner x, upper left corner y, lower right corner x, lower right corner y]. | |res| OCR or table recognition result of the image area.
table: a dict with field descriptions as follows:
        `html`: html str of table.
        In the code usage mode, set return_ocr_result_in_table=True whrn call can get the detection and recognition results of each text in the table area, corresponding to the following fields:
        `boxes`: text detection boxes.
        `rec_res`: text recognition results.
OCR: A tuple containing the detection boxes and recognition results of each single text. | @@ -180,19 +215,26 @@ Please refer to: [Documentation Visual Q&A](../vqa/README.md) . ### 2.4 Parameter Description -| field | description | default | -|----------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------| -| output | The save path of result | ./output/table | -| table_max_len | When the table structure model predicts, the long side of the image | 488 | -| table_model_dir | the path of table structure model | None | -| table_char_dict_path | the dict path of table structure model | ../ppocr/utils/dict/table_structure_dict.txt | -| layout_path_model | The model path of the layout analysis model, which can be an online address or a local path. When it is a local path, layout_label_map needs to be set. In command line mode, use --layout_label_map='{0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}' | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config | -| layout_label_map | Layout analysis model model label mapping dictionary path | None | -| model_name_or_path | the model path of VQA SER model | None | -| max_seq_length | the max token length of VQA SER model | 512 | -| label_map_path | the label path of VQA SER model | ./vqa/labels/labels_ser.txt | -| layout | Whether to perform layout analysis in forward | True | -| table | Whether to perform table recognition in forward | True | -| ocr | Whether to perform ocr for non-table areas in layout analysis. When layout is False, it will be automatically set to False | True | -| structure_version | table structure Model version number, the current model support list is as follows: PP-STRUCTURE support english table structure model | PP-STRUCTURE | +| field | description | default | +|---|---|---| +| output | result save path | ./output/table | +| table_max_len | long side of the image resize in table structure model | 488 | +| table_model_dir | Table structure model inference model path| None | +| table_char_dict_path | The dictionary path of table structure model | ../ppocr/utils/dict/table_structure_dict.txt | +| merge_no_span_structure | In the table recognition model, whether to merge '\' and '\' | False | +| layout_model_dir | Layout analysis model inference model path| None | +| layout_dict_path | The dictionary path of layout analysis model| ../ppocr/utils/dict/layout_publaynet_dict.txt | +| layout_score_threshold | The box threshold path of layout analysis model| 0.5| +| layout_nms_threshold | The nms threshold path of layout analysis model| 0.5| +| vqa_algorithm | vqa model algorithm| LayoutXLM| +| ser_model_dir | Ser model inference model path| None| +| ser_dict_path | The dictionary path of Ser model| ../train_data/XFUND/class_list_xfun.txt| +| mode | structure or vqa | structure | +| image_orientation | Whether to perform image orientation classification in forward | False | +| layout | Whether to perform layout analysis in forward | True | +| table | Whether to perform table recognition in forward | True | +| ocr | Whether to perform ocr for non-table areas in layout analysis. When layout is False, it will be automatically set to False| True | +| recovery | Whether to perform layout recovery in forward| False | +| structure_version | Structure version, optional PP-structure and PP-structurev2 | PP-structure | + Most of the parameters are consistent with the PaddleOCR whl package, see [whl package documentation](../../doc/doc_en/whl.md) diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py index d6f2e24240ff783e14dbd61efdd27877f9ec39ff..053a8aac00ffe762dd05d7f8030db9aaa32c0f8a 100644 --- a/ppstructure/predict_system.py +++ b/ppstructure/predict_system.py @@ -18,7 +18,7 @@ import subprocess __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../'))) os.environ["FLAGS_allocator_strategy"] = 'auto_growth' import cv2 @@ -27,11 +27,11 @@ import numpy as np import time import logging from copy import deepcopy -from attrdict import AttrDict from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.logging import get_logger from tools.infer.predict_system import TextSystem +from ppstructure.layout.predict_layout import LayoutPredictor from ppstructure.table.predict_table import TableSystem, to_excel from ppstructure.utility import parse_args, draw_structure_result from ppstructure.recovery.recovery_to_doc import convert_info_docx @@ -42,6 +42,14 @@ logger = get_logger() class StructureSystem(object): def __init__(self, args): self.mode = args.mode + self.recovery = args.recovery + + self.image_orientation_predictor = None + if args.image_orientation: + import paddleclas + self.image_orientation_predictor = paddleclas.PaddleClas( + model_name="text_image_orientation") + if self.mode == 'structure': if not args.show_log: logger.setLevel(logging.INFO) @@ -51,28 +59,14 @@ class StructureSystem(object): "When args.layout is false, args.ocr is automatically set to false" ) args.drop_score = 0 - # init layout and ocr model + # init model + self.layout_predictor = None self.text_system = None + self.table_system = None if args.layout: - import layoutparser as lp - config_path = None - model_path = None - if os.path.isdir(args.layout_path_model): - model_path = args.layout_path_model - else: - config_path = args.layout_path_model - self.table_layout = lp.PaddleDetectionLayoutModel( - config_path=config_path, - model_path=model_path, - label_map=args.layout_label_map, - threshold=0.5, - enable_mkldnn=args.enable_mkldnn, - enforce_cpu=not args.use_gpu, - thread_num=args.cpu_threads) + self.layout_predictor = LayoutPredictor(args) if args.ocr: self.text_system = TextSystem(args) - else: - self.table_layout = None if args.table: if self.text_system is not None: self.table_system = TableSystem( @@ -80,39 +74,78 @@ class StructureSystem(object): self.text_system.text_recognizer) else: self.table_system = TableSystem(args) - else: - self.table_system = None elif self.mode == 'vqa': raise NotImplementedError def __call__(self, img, return_ocr_result_in_table=False): + time_dict = { + 'image_orientation': 0, + 'layout': 0, + 'table': 0, + 'table_match': 0, + 'det': 0, + 'rec': 0, + 'vqa': 0, + 'all': 0 + } + start = time.time() + if self.image_orientation_predictor is not None: + tic = time.time() + cls_result = self.image_orientation_predictor.predict( + input_data=img) + cls_res = next(cls_result) + angle = cls_res[0]['label_names'][0] + cv_rotate_code = { + '90': cv2.ROTATE_90_COUNTERCLOCKWISE, + '180': cv2.ROTATE_180, + '270': cv2.ROTATE_90_CLOCKWISE + } + img = cv2.rotate(img, cv_rotate_code[angle]) + toc = time.time() + time_dict['image_orientation'] = toc - tic if self.mode == 'structure': ori_im = img.copy() - if self.table_layout is not None: - layout_res = self.table_layout.detect(img[..., ::-1]) + if self.layout_predictor is not None: + layout_res, elapse = self.layout_predictor(img) + time_dict['layout'] += elapse else: h, w = ori_im.shape[:2] - layout_res = [AttrDict(coordinates=[0, 0, w, h], type='Table')] + layout_res = [dict(bbox=None, label='table')] res_list = [] for region in layout_res: res = '' - x1, y1, x2, y2 = region.coordinates - x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) - roi_img = ori_im[y1:y2, x1:x2, :] - if region.type == 'Table': + if region['bbox'] is not None: + x1, y1, x2, y2 = region['bbox'] + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + roi_img = ori_im[y1:y2, x1:x2, :] + else: + x1, y1, x2, y2 = 0, 0, w, h + roi_img = ori_im + if region['label'] == 'table': if self.table_system is not None: - res = self.table_system(roi_img, - return_ocr_result_in_table) + res, table_time_dict = self.table_system( + roi_img, return_ocr_result_in_table) + time_dict['table'] += table_time_dict['table'] + time_dict['table_match'] += table_time_dict['match'] + time_dict['det'] += table_time_dict['det'] + time_dict['rec'] += table_time_dict['rec'] else: if self.text_system is not None: - if args.recovery: + if self.recovery: wht_im = np.ones(ori_im.shape, dtype=ori_im.dtype) wht_im[y1:y2, x1:x2, :] = roi_img - filter_boxes, filter_rec_res = self.text_system(wht_im) + filter_boxes, filter_rec_res, ocr_time_dict = self.text_system( + wht_im) else: - filter_boxes, filter_rec_res = self.text_system(roi_img) - # remove style char + filter_boxes, filter_rec_res, ocr_time_dict = self.text_system( + roi_img) + time_dict['det'] += ocr_time_dict['det'] + time_dict['rec'] += ocr_time_dict['rec'] + + # remove style char, + # when using the recognition model trained on the PubtabNet dataset, + # it will recognize the text format in the table, such as style_token = [ '', '', '', '', '', '', '', '', '', @@ -125,7 +158,7 @@ class StructureSystem(object): for token in style_token: if token in rec_str: rec_str = rec_str.replace(token, '') - if not args.recovery: + if not self.recovery: box += [x1, y1] res.append({ 'text': rec_str, @@ -133,15 +166,17 @@ class StructureSystem(object): 'text_region': box.tolist() }) res_list.append({ - 'type': region.type, + 'type': region['label'].lower(), 'bbox': [x1, y1, x2, y2], 'img': roi_img, 'res': res }) - return res_list + end = time.time() + time_dict['all'] = end - start + return res_list, time_dict elif self.mode == 'vqa': raise NotImplementedError - return None + return None, None def save_structure_res(res, save_folder, img_name): @@ -156,12 +191,12 @@ def save_structure_res(res, save_folder, img_name): roi_img = region.pop('img') f.write('{}\n'.format(json.dumps(region))) - if region['type'] == 'Table' and len(region[ + if region['type'] == 'table' and len(region[ 'res']) > 0 and 'html' in region['res']: excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox'])) to_excel(region['res']['html'], excel_path) - elif region['type'] == 'Figure': + elif region['type'] == 'figure': img_path = os.path.join(excel_save_folder, '{}.jpg'.format(region['bbox'])) cv2.imwrite(img_path, roi_img) @@ -187,8 +222,7 @@ def main(args): if img is None: logger.error("error in loading image:{}".format(image_file)) continue - starttime = time.time() - res = structure_sys(img) + res, time_dict = structure_sys(img) if structure_sys.mode == 'structure': save_structure_res(res, save_folder, img_name) @@ -201,9 +235,8 @@ def main(args): cv2.imwrite(img_save_path, draw_img) logger.info('result save to {}'.format(img_save_path)) if args.recovery: - convert_info_docx(img, res, save_folder, img_name) - elapse = time.time() - starttime - logger.info("Predict time : {:.3f}s".format(elapse)) + convert_info_docx(img, res, save_folder, img_name) + logger.info("Predict time : {:.3f}s".format(time_dict['all'])) if __name__ == "__main__": diff --git a/ppstructure/table/README.md b/ppstructure/table/README.md index b6804c6f09b4ee3d17cd2b81e6cc6642c1c1be9a..f4eff868d292cf63af119827c445a5b85aa4bebb 100644 --- a/ppstructure/table/README.md +++ b/ppstructure/table/README.md @@ -1,126 +1,124 @@ -- [Table Recognition](#table-recognition) - - [1. pipeline](#1-pipeline) - - [2. Performance](#2-performance) - - [3. How to use](#3-how-to-use) - - [3.1 quick start](#31-quick-start) - - [3.2 Train](#32-train) - - [3.3 Eval](#33-eval) - - [3.4 Inference](#34-inference) - +English | [简体中文](README_ch.md) # Table Recognition +- [1. pipeline](#1-pipeline) +- [2. Performance](#2-performance) +- [3. Result](#3-result) +- [4. How to use](#4-how-to-use) + - [4.1 Quick start](#41-quick-start) + - [4.2 Train](#42-train) + - [4.3 Calculate TEDS](#43-calculate-teds) +- [5. Reference](#5-reference) + + ## 1. pipeline The table recognition mainly contains three models 1. Single line text detection-DB 2. Single line text recognition-CRNN -3. Table structure and cell coordinate prediction-RARE +3. Table structure and cell coordinate prediction-SLANet The table recognition flow chart is as follows ![tableocr_pipeline](../docs/table/tableocr_pipeline_en.jpg) 1. The coordinates of single-line text is detected by DB model, and then sends it to the recognition model to get the recognition result. -2. The table structure and cell coordinates is predicted by RARE model. +2. The table structure and cell coordinates is predicted by SLANet model. 3. The recognition result of the cell is combined by the coordinates, recognition result of the single line and the coordinates of the cell. 4. The cell recognition result and the table structure together construct the html string of the table. ## 2. Performance We evaluated the algorithm on the PubTabNet[1] eval dataset, and the performance is as follows: +|Method|Acc|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|Speed| +| --- | --- | --- | ---| +| EDD[2] |x| 88.3 |x| +| TableRec-RARE(ours) |73.8%| 93.32 |1550ms| +| SLANet(ours) | 76.2%| 94.98 |766ms| + +The performance indicators are explained as follows: +- Acc: The accuracy of the table structure in each image, a wrong token is considered an error. +- TEDS: The accuracy of the model's restoration of table information. This indicator evaluates not only the table structure, but also the text content in the table. +- Speed: The inference speed of a single image when the model runs on the CPU machine and MKL is enabled. + +## 3. Result + +![](../docs/imgs/table_ch_result1.jpg) +![](../docs/imgs/table_ch_result2.jpg) +![](../docs/imgs/table_ch_result3.jpg) -|Method|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)| -| --- | --- | -| EDD[2] | 88.3 | -| Ours | 93.32 | +## 4. How to use -## 3. How to use +### 4.1 Quick start -### 3.1 quick start +Use the following commands to quickly complete the identification of a table. ```python cd PaddleOCR/ppstructure # download model mkdir inference && cd inference -# Download the detection model of the ultra-lightweight table English OCR model and unzip it -wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar -# Download the recognition model of the ultra-lightweight table English OCR model and unzip it -wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar -# Download the ultra-lightweight English table inch model and unzip it -wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar +# Download the PP-OCRv3 text detection model and unzip it +wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.tar && tar xf ch_PP-OCRv3_det_slim_infer.tar +# Download the PP-OCRv3 text recognition model and unzip it +wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.tar && tar xf ch_PP-OCRv3_rec_slim_infer.tar +# Download the PP-Structurev2 form recognition model and unzip it +wget https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar && tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar cd .. # run -python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=./docs/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ./output/table -``` -Note: The above model is trained on the PubLayNet dataset and only supports English scanning scenarios. If you need to identify other scenarios, you need to train the model yourself and replace the three fields `det_model_dir`, `rec_model_dir`, `table_model_dir`. - -After running, the excel sheet of each picture will be saved in the directory specified by the output field - -### 3.2 Train - -In this chapter, we only introduce the training of the table structure model, For model training of [text detection](../../doc/doc_en/detection_en.md) and [text recognition](../../doc/doc_en/recognition_en.md), please refer to the corresponding documents - -* data preparation -The training data uses public data set [PubTabNet](https://arxiv.org/abs/1911.10683 ), Can be downloaded from the official [website](https://github.com/ibm-aur-nlp/PubTabNet) 。The PubTabNet data set contains about 500,000 images, as well as annotations in html format。 +python3.7 table/predict_table.py \ + --det_model_dir=inference/ch_PP-OCRv3_det_slim_infer \ + --rec_model_dir=inference/ch_PP-OCRv3_rec_slim_infer \ + --table_model_dir=inference/ch_ppstructure_mobile_v2.0_SLANet_infer \ + --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \ + --table_char_dict_path=../ppocr/utils/dict/table_structure_dict_ch.txt \ + --image_dir=docs/table/table.jpg \ + --output=../output/table -* Start training -*If you are installing the cpu version of paddle, please modify the `use_gpu` field in the configuration file to false* -```shell -# single GPU training -python3 tools/train.py -c configs/table/table_mv3.yml -# multi-GPU training -# Set the GPU ID used by the '--gpus' parameter. -python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/table/table_mv3.yml ``` -In the above instruction, use `-c` to select the training to use the `configs/table/table_mv3.yml` configuration file. -For a detailed explanation of the configuration file, please refer to [config](../../doc/doc_en/config_en.md). +After the operation is completed, the excel table of each image will be saved to the directory specified by the output field, and an html file will be produced in the directory to visually view the cell coordinates and the recognized table. -* load trained model and continue training +### 4.2 Train -If you expect to load trained model and continue the training again, you can specify the parameter `Global.checkpoints` as the model path to be loaded. +The training, evaluation and inference process of the text detection model can be referred to [detection](../../doc/doc_en/detection_en.md) -```shell -python3 tools/train.py -c configs/table/table_mv3.yml -o Global.checkpoints=./your/trained/model -``` +The training, evaluation and inference process of the text recognition model can be referred to [recognition](../../doc/doc_en/recognition_en.md) -**Note**: The priority of `Global.checkpoints` is higher than that of `Global.pretrain_weights`, that is, when two parameters are specified at the same time, the model specified by `Global.checkpoints` will be loaded first. If the model path specified by `Global.checkpoints` is wrong, the one specified by `Global.pretrain_weights` will be loaded. +The training, evaluation and inference process of the table recognition model can be referred to [table_recognition](../../doc/doc_en/table_recognition_en.md) -### 3.3 Eval +### 4.3 Calculate TEDS The table uses [TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) as the evaluation metric of the model. Before the model evaluation, the three models in the pipeline need to be exported as inference models (we have provided them), and the gt for evaluation needs to be prepared. Examples of gt are as follows: -```json -{"PMC4289340_004_00.png": [ - ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "
", "", "", "
", "", "", "
", "", ""], - [[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]], - [["", "F", "e", "a", "t", "u", "r", "e", ""], ["", "G", "b", "3", " ", "+", ""], ["", "G", "b", "3", " ", "-", ""], ["", "P", "a", "t", "i", "e", "n", "t", "s", ""], ["6", "2"], ["4", "5"]] -]} +```txt +PMC5755158_010_01.png
WeaningWeek 15Off-test
Weaning
Week 150.17 ± 0.080.16 ± 0.03
Off-test0.80 ± 0.240.19 ± 0.09
+``` +Each line in gt consists of the file name and the html string of the table. The file name and the html string of the table are separated by `\t`. + +You can also use the following command to generate an evaluation gt file from the annotation file: +```python +python3 ppstructure/table/convert_label2html.py --ori_gt_path /path/to/your_label_file --save_path /path/to/save_file ``` -In gt json, the key is the image name, the value is the corresponding gt, and gt is a list composed of four items, and each item is -1. HTML string list of table structure -2. The coordinates of each cell (not including the empty text in the cell) -3. The text information in each cell (not including the empty text in the cell) Use the following command to evaluate. After the evaluation is completed, the teds indicator will be output. ```python -cd PaddleOCR/ppstructure -python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json +python3 table/eval_table.py \ + --det_model_dir=path/to/det_model_dir \ + --rec_model_dir=path/to/rec_model_dir \ + --table_model_dir=path/to/table_model_dir \ + --image_dir=../doc/table/1.png \ + --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt \ + --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ + --det_limit_side_len=736 \ + --det_limit_type=min \ + --gt_path=path/to/gt.txt ``` If the PubLatNet eval dataset is used, it will be output ```bash -teds: 93.32 -``` - -### 3.4 Inference - -```python -cd PaddleOCR/ppstructure -python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table +teds: 94.98 ``` -After running, the excel sheet of each picture will be saved in the directory specified by the output field -Reference +## 5. Reference 1. https://github.com/ibm-aur-nlp/PubTabNet 2. https://arxiv.org/pdf/1911.10683 diff --git a/ppstructure/table/README_ch.md b/ppstructure/table/README_ch.md index a0a64d6b7ebcb272e4b607975170a679abd036ab..badabc7992cc5fc4b474fa59837f6fab5f659415 100644 --- a/ppstructure/table/README_ch.md +++ b/ppstructure/table/README_ch.md @@ -2,22 +2,22 @@ # 表格识别 -- [1. 表格识别 pipeline](#1) -- [2. 性能](#2) -- [3. 使用](#3) - - [3.1 快速开始](#31) - - [3.2 训练](#32) - - [3.3 评估](#33) - - [3.4 预测](#34) +- [1. 表格识别 pipeline](#1-表格识别-pipeline) +- [2. 性能](#2-性能) +- [3. 效果演示](#3-效果演示) +- [4. 使用](#4-使用) + - [4.1 快速开始](#41-快速开始) + - [4.2 训练](#42-训练) + - [4.3 计算TEDS](#43-计算teds) +- [5. Reference](#5-reference) - ## 1. 表格识别 pipeline 表格识别主要包含三个模型 1. 单行文本检测-DB 2. 单行文本识别-CRNN -3. 表格结构和cell坐标预测-RARE +3. 表格结构和cell坐标预测-SLANet 具体流程图如下 @@ -26,111 +26,102 @@ 流程说明: 1. 图片由单行文字检测模型检测到单行文字的坐标,然后送入识别模型拿到识别结果。 -2. 图片由表格结构和cell坐标预测模型拿到表格的结构信息和单元格的坐标信息。 +2. 图片由SLANet模型拿到表格的结构信息和单元格的坐标信息。 3. 由单行文字的坐标、识别结果和单元格的坐标一起组合出单元格的识别结果。 4. 单元格的识别结果和表格结构一起构造表格的html字符串。 - ## 2. 性能 我们在 PubTabNet[1] 评估数据集上对算法进行了评估,性能如下 -|算法|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)| -| --- | --- | -| EDD[2] | 88.3 | -| Ours | 93.32 | +|算法|Acc|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|Speed| +| --- | --- | --- | ---| +| EDD[2] |x| 88.3 |x| +| TableRec-RARE(ours) |73.8%| 93.32 |1550ms| +| SLANet(ours) | 76.2%| 94.98 |766ms| - -## 3. 使用 +性能指标解释如下: +- Acc: 模型对每张图像里表格结构的识别准确率,错一个token就算错误。 +- TEDS: 模型对表格信息还原的准确度,此指标评价内容不仅包含表格结构,还包含表格内的文字内容。 +- Speed: 模型在CPU机器上,开启MKL的情况下,单张图片的推理速度。 - -### 3.1 快速开始 +## 3. 效果演示 +![](../docs/imgs/table_ch_result1.jpg) +![](../docs/imgs/table_ch_result2.jpg) +![](../docs/imgs/table_ch_result3.jpg) + +## 4. 使用 + +### 4.1 快速开始 + +使用如下命令即可快速完成一张表格的识别。 ```python cd PaddleOCR/ppstructure # 下载模型 mkdir inference && cd inference -# 下载超轻量级表格英文OCR模型的检测模型并解压 -wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar -# 下载超轻量级表格英文OCR模型的识别模型并解压 -wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar -# 下载超轻量级英文表格英寸模型并解压 -wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar +# 下载PP-OCRv3文本检测模型并解压 +wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.tar && tar xf ch_PP-OCRv3_det_slim_infer.tar +# 下载PP-OCRv3文本识别模型并解压 +wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.tar && tar xf ch_PP-OCRv3_rec_slim_infer.tar +# 下载PP-Structurev2表格识别模型并解压 +wget https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar && tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar cd .. -# 执行预测 -python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=./docs/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ./output/table +# 执行表格识别 +python table/predict_table.py \ + --det_model_dir=inference/ch_PP-OCRv3_det_slim_infer \ + --rec_model_dir=inference/ch_PP-OCRv3_rec_slim_infer \ + --table_model_dir=inference/ch_ppstructure_mobile_v2.0_SLANet_infer \ + --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \ + --table_char_dict_path=../ppocr/utils/dict/table_structure_dict_ch.txt \ + --image_dir=docs/table/table.jpg \ + --output=../output/table ``` -运行完成后,每张图片的excel表格会保存到output字段指定的目录下 +运行完成后,每张图片的excel表格会保存到output字段指定的目录下,同时在该目录下回生产一个html文件,用于可视化查看单元格坐标和识别的表格。 -note: 上述模型是在 PubLayNet 数据集上训练的表格识别模型,仅支持英文扫描场景,如需识别其他场景需要自己训练模型后替换 `det_model_dir`,`rec_model_dir`,`table_model_dir`三个字段即可。 +### 4.2 训练 - -### 3.2 训练 +文本检测模型的训练、评估和推理流程可参考 [detection](../../doc/doc_ch/detection.md) -在这一章节中,我们仅介绍表格结构模型的训练,[文字检测](../../doc/doc_ch/detection.md)和[文字识别](../../doc/doc_ch/recognition.md)的模型训练请参考对应的文档。 +文本识别模型的训练、评估和推理流程可参考 [recognition](../../doc/doc_ch/recognition.md) -* 数据准备 +表格识别模型的训练、评估和推理流程可参考 [table_recognition](../../doc/doc_ch/table_recognition.md) -训练数据使用公开数据集PubTabNet ([论文](https://arxiv.org/abs/1911.10683),[下载地址](https://github.com/ibm-aur-nlp/PubTabNet))。PubTabNet数据集包含约50万张表格数据的图像,以及图像对应的html格式的注释。 +### 4.3 计算TEDS -* 启动训练 - -*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false* -```shell -# 单机单卡训练 -python3 tools/train.py -c configs/table/table_mv3.yml -# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID -python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/table/table_mv3.yml -``` - -上述指令中,通过-c 选择训练使用configs/table/table_mv3.yml配置文件。有关配置文件的详细解释,请参考[链接](../../doc/doc_ch/config.md)。 - -* 断点训练 - -如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径: -```shell -python3 tools/train.py -c configs/table/table_mv3.yml -o Global.checkpoints=./your/trained/model +表格使用 [TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) 作为模型的评估指标。在进行模型评估之前,需要将pipeline中的三个模型分别导出为inference模型(我们已经提供好),还需要准备评估的gt, gt示例如下: +```txt +PMC5755158_010_01.png
WeaningWeek 15Off-test
Weaning
Week 150.17 ± 0.080.16 ± 0.03
Off-test0.80 ± 0.240.19 ± 0.09
``` +gt每一行都由文件名和表格的html字符串组成,文件名和表格的html字符串之间使用`\t`分隔。 -**注意**:`Global.checkpoints`的优先级高于`Global.pretrain_weights`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrain_weights`指定的模型。 - - -### 3.3 评估 - -表格使用 [TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) 作为模型的评估指标。在进行模型评估之前,需要将pipeline中的三个模型分别导出为inference模型(我们已经提供好),还需要准备评估的gt, gt示例如下: -```json -{"PMC4289340_004_00.png": [ - ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "
", "", "", "
", "", "", "
", "", ""], - [[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]], - [["", "F", "e", "a", "t", "u", "r", "e", ""], ["", "G", "b", "3", " ", "+", ""], ["", "G", "b", "3", " ", "-", ""], ["", "P", "a", "t", "i", "e", "n", "t", "s", ""], ["6", "2"], ["4", "5"]] -]} +也可使用如下命令,由标注文件生成评估的gt文件: +```python +python3 ppstructure/table/convert_label2html.py --ori_gt_path /path/to/your_label_file --save_path /path/to/save_file ``` -json 中,key为图片名,value为对应的gt,gt是一个由三个item组成的list,每个item分别为 -1. 表格结构的html字符串list -2. 每个cell的坐标 (不包括cell里文字为空的) -3. 每个cell里的文字信息 (不包括cell里文字为空的) 准备完成后使用如下命令进行评估,评估完成后会输出teds指标。 ```python cd PaddleOCR/ppstructure -python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json +python3 table/eval_table.py \ + --det_model_dir=path/to/det_model_dir \ + --rec_model_dir=path/to/rec_model_dir \ + --table_model_dir=path/to/table_model_dir \ + --image_dir=../doc/table/1.png \ + --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt \ + --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ + --det_limit_side_len=736 \ + --det_limit_type=min \ + --gt_path=path/to/gt.txt ``` 如使用PubLatNet评估数据集,将会输出 ```bash -teds: 93.32 -``` - - -### 3.4 预测 - -```python -cd PaddleOCR/ppstructure -python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table +teds: 94.98 ``` -# Reference +## 5. Reference 1. https://github.com/ibm-aur-nlp/PubTabNet 2. https://arxiv.org/pdf/1911.10683 diff --git a/ppstructure/table/convert_label2html.py b/ppstructure/table/convert_label2html.py new file mode 100644 index 0000000000000000000000000000000000000000..be16212ac420326a91cf8ab281a77e5990530c0e --- /dev/null +++ b/ppstructure/table/convert_label2html.py @@ -0,0 +1,102 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +conver table label to html +""" + +import json +import argparse +from tqdm import tqdm + + +def save_pred_txt(key, val, tmp_file_path): + with open(tmp_file_path, 'a+', encoding='utf-8') as f: + f.write('{}\t{}\n'.format(key, val)) + + +def skip_char(text, sp_char_list): + """ + skip empty cell + @param text: text in cell + @param sp_char_list: style char and special code + @return: + """ + for sp_char in sp_char_list: + text = text.replace(sp_char, '') + return text + + +def gen_html(img): + ''' + Formats HTML code from tokenized annotation of img + ''' + html_code = img['html']['structure']['tokens'].copy() + to_insert = [i for i, tag in enumerate(html_code) if tag in ('', '>')] + for i, cell in zip(to_insert[::-1], img['html']['cells'][::-1]): + if cell['tokens']: + text = ''.join(cell['tokens']) + # skip empty text + sp_char_list = ['', '', '\u2028', ' ', '', ''] + text_remove_style = skip_char(text, sp_char_list) + if len(text_remove_style) == 0: + continue + html_code.insert(i + 1, text) + html_code = ''.join(html_code) + html_code = '{}
'.format(html_code) + return html_code + + +def load_gt_data(gt_path): + """ + load gt + @param gt_path: + @return: + """ + data_list = {} + with open(gt_path, 'rb') as f: + lines = f.readlines() + for line in tqdm(lines): + data_line = line.decode('utf-8').strip("\n") + info = json.loads(data_line) + data_list[info['filename']] = info + return data_list + + +def convert(origin_gt_path, save_path): + """ + gen html from label file + @param origin_gt_path: + @param save_path: + @return: + """ + data_dict = load_gt_data(origin_gt_path) + for img_name, gt in tqdm(data_dict.items()): + html = gen_html(gt) + save_pred_txt(img_name, html, save_path) + print('conver finish') + + +def parse_args(): + parser = argparse.ArgumentParser(description="args for paddleserving") + parser.add_argument( + "--ori_gt_path", type=str, required=True, help="label gt path") + parser.add_argument( + "--save_path", type=str, required=True, help="path to save file") + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + convert(args.ori_gt_path, args.save_path) diff --git a/ppstructure/table/eval_table.py b/ppstructure/table/eval_table.py index 87b44d3d9792356ec1cdc65693392c288bf67448..4fc16b5d4c6a0143dcea149508bd6b62730092b6 100755 --- a/ppstructure/table/eval_table.py +++ b/ppstructure/table/eval_table.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,14 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import sys + __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..'))) import cv2 -import json +import pickle +import paddle from tqdm import tqdm from ppstructure.table.table_metric import TEDS from ppstructure.table.predict_table import TableSystem @@ -33,40 +36,74 @@ def parse_args(): parser.add_argument("--gt_path", type=str) return parser.parse_args() -def main(gt_path, img_root, args): - teds = TEDS(n_jobs=16) +def load_txt(txt_path): + pred_html_dict = {} + if not os.path.exists(txt_path): + return pred_html_dict + with open(txt_path, encoding='utf-8') as f: + lines = f.readlines() + for line in lines: + line = line.strip().split('\t') + img_name, pred_html = line + pred_html_dict[img_name] = pred_html + return pred_html_dict + + +def load_result(path): + data = {} + if os.path.exists(path): + data = pickle.load(open(path, 'rb')) + return data + + +def save_result(path, data): + old_data = load_result(path) + old_data.update(data) + with open(path, 'wb') as f: + pickle.dump(old_data, f) + + +def main(gt_path, img_root, args): + os.makedirs(args.output, exist_ok=True) + # init TableSystem text_sys = TableSystem(args) - jsons_gt = json.load(open(gt_path)) # gt + # load gt and preds html result + gt_html_dict = load_txt(gt_path) + + ocr_result = load_result(os.path.join(args.output, 'ocr.pickle')) + structure_result = load_result( + os.path.join(args.output, 'structure.pickle')) + pred_htmls = [] gt_htmls = [] - for img_name in tqdm(jsons_gt): - # read image - img = cv2.imread(os.path.join(img_root,img_name)) - pred_html = text_sys(img) - pred_htmls.append(pred_html) + for img_name, gt_html in tqdm(gt_html_dict.items()): + img = cv2.imread(os.path.join(img_root, img_name)) + # run ocr and save result + if img_name not in ocr_result: + dt_boxes, rec_res, _, _ = text_sys._ocr(img) + ocr_result[img_name] = [dt_boxes, rec_res] + save_result(os.path.join(args.output, 'ocr.pickle'), ocr_result) + # run structure and save result + if img_name not in structure_result: + structure_res, _ = text_sys._structure(img) + structure_result[img_name] = structure_res + save_result( + os.path.join(args.output, 'structure.pickle'), structure_result) + dt_boxes, rec_res = ocr_result[img_name] + structure_res = structure_result[img_name] + # match ocr and structure + pred_html = text_sys.match(structure_res, dt_boxes, rec_res) - gt_structures, gt_bboxes, gt_contents = jsons_gt[img_name] - gt_html, gt = get_gt_html(gt_structures, gt_contents) + pred_htmls.append(pred_html) gt_htmls.append(gt_html) - scores = teds.batch_evaluate_html(gt_htmls, pred_htmls) - logger.info('teds:', sum(scores) / len(scores)) - -def get_gt_html(gt_structures, gt_contents): - end_html = [] - td_index = 0 - for tag in gt_structures: - if '' in tag: - if gt_contents[td_index] != []: - end_html.extend(gt_contents[td_index]) - end_html.append(tag) - td_index += 1 - else: - end_html.append(tag) - return ''.join(end_html), end_html + # compute teds + teds = TEDS(n_jobs=16) + scores = teds.batch_evaluate_html(gt_htmls, pred_htmls) + logger.info('teds: {}'.format(sum(scores) / len(scores))) if __name__ == '__main__': args = parse_args() - main(args.gt_path,args.image_dir, args) + main(args.gt_path, args.image_dir, args) diff --git a/ppstructure/table/matcher.py b/ppstructure/table/matcher.py index c3b56384403f5fd92a8db4b4bb378a6d55e5a76c..9c5bd2630f78527ade4fd1309f22d1731fe838a2 100755 --- a/ppstructure/table/matcher.py +++ b/ppstructure/table/matcher.py @@ -1,11 +1,29 @@ -import json +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from ppstructure.table.table_master_match import deal_eb_token, deal_bb + + def distance(box_1, box_2): - x1, y1, x2, y2 = box_1 - x3, y3, x4, y4 = box_2 - dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2) - dis_2 = abs(x3 - x1) + abs(y3 - y1) - dis_3 = abs(x4- x2) + abs(y4 - y2) - return dis + min(dis_2, dis_3) + x1, y1, x2, y2 = box_1 + x3, y3, x4, y4 = box_2 + dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2) + dis_2 = abs(x3 - x1) + abs(y3 - y1) + dis_3 = abs(x4 - x2) + abs(y4 - y2) + return dis + min(dis_2, dis_3) + def compute_iou(rec1, rec2): """ @@ -18,175 +36,157 @@ def compute_iou(rec1, rec2): # computing area of each rectangles S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) - + # computing the sum_area sum_area = S_rec1 + S_rec2 - + # find the each edge of intersect rectangle left_line = max(rec1[1], rec2[1]) right_line = min(rec1[3], rec2[3]) top_line = max(rec1[0], rec2[0]) bottom_line = min(rec1[2], rec2[2]) - + # judge if there is an intersect if left_line >= right_line or top_line >= bottom_line: return 0.0 else: intersect = (right_line - left_line) * (bottom_line - top_line) - return (intersect / (sum_area - intersect))*1.0 - - - -def matcher_merge(ocr_bboxes, pred_bboxes): - all_dis = [] - ious = [] - matched = {} - for i, gt_box in enumerate(ocr_bboxes): - distances = [] - for j, pred_box in enumerate(pred_bboxes): - # compute l1 distence and IOU between two boxes - distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) - sorted_distances = distances.copy() - # select nearest cell - sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0])) - if distances.index(sorted_distances[0]) not in matched.keys(): - matched[distances.index(sorted_distances[0])] = [i] - else: - matched[distances.index(sorted_distances[0])].append(i) - return matched#, sum(ious) / len(ious) - -def complex_num(pred_bboxes): - complex_nums = [] - for bbox in pred_bboxes: - distances = [] - temp_ious = [] - for pred_bbox in pred_bboxes: - if bbox != pred_bbox: - distances.append(distance(bbox, pred_bbox)) - temp_ious.append(compute_iou(bbox, pred_bbox)) - complex_nums.append(temp_ious[distances.index(min(distances))]) - return sum(complex_nums) / len(complex_nums) - -def get_rows(pred_bboxes): - pre_bbox = pred_bboxes[0] - res = [] - step = 0 - for i in range(len(pred_bboxes)): - bbox = pred_bboxes[i] - if bbox[1] - pre_bbox[1] > 2 or bbox[0] - pre_bbox[0] < 0: - break - else: - res.append(bbox) - step += 1 - for i in range(step): - pred_bboxes.pop(0) - return res, pred_bboxes -def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上 - ys_1 = [] - ys_2 = [] - for box in pred_bboxes: - ys_1.append(box[1]) - ys_2.append(box[3]) - min_y_1 = sum(ys_1) / len(ys_1) - min_y_2 = sum(ys_2) / len(ys_2) - re_boxes = [] - for box in pred_bboxes: - box[1] = min_y_1 - box[3] = min_y_2 - re_boxes.append(box) - return re_boxes - -def matcher_refine_row(gt_bboxes, pred_bboxes): - before_refine_pred_bboxes = pred_bboxes.copy() - pred_bboxes = [] - while(len(before_refine_pred_bboxes) != 0): - row_bboxes, before_refine_pred_bboxes = get_rows(before_refine_pred_bboxes) - print(row_bboxes) - pred_bboxes.extend(refine_rows(row_bboxes)) - all_dis = [] - ious = [] - matched = {} - for i, gt_box in enumerate(gt_bboxes): - distances = [] - #temp_ious = [] - for j, pred_box in enumerate(pred_bboxes): - distances.append(distance(gt_box, pred_box)) - #temp_ious.append(compute_iou(gt_box, pred_box)) - #all_dis.append(min(distances)) - #ious.append(temp_ious[distances.index(min(distances))]) - if distances.index(min(distances)) not in matched.keys(): - matched[distances.index(min(distances))] = [i] + return (intersect / (sum_area - intersect)) * 1.0 + + +class TableMatch: + def __init__(self, filter_ocr_result=False, use_master=False): + self.filter_ocr_result = filter_ocr_result + self.use_master = use_master + + def __call__(self, structure_res, dt_boxes, rec_res): + pred_structures, pred_bboxes = structure_res + if self.filter_ocr_result: + dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes, + rec_res) + matched_index = self.match_result(dt_boxes, pred_bboxes) + if self.use_master: + pred_html, pred = self.get_pred_html_master(pred_structures, + matched_index, rec_res) else: - matched[distances.index(min(distances))].append(i) - return matched#, sum(ious) / len(ious) - - - -#先挑选出一行,再进行匹配 -def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes): - gt_box_index = 0 - delete_gt_bboxes = gt_bboxes.copy() - match_bboxes_ready = [] - matched = {} - while(len(delete_gt_bboxes) != 0): - row_bboxes, delete_gt_bboxes = get_rows(delete_gt_bboxes) - row_bboxes = sorted(row_bboxes, key = lambda key: key[0]) - if len(pred_bboxes_rows) > 0: - match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) - print(row_bboxes) - for i, gt_box in enumerate(row_bboxes): - #print(gt_box) - pred_distances = [] - distances = [] - for pred_bbox in pred_bboxes: - pred_distances.append(distance(gt_box, pred_bbox)) - for j, pred_box in enumerate(match_bboxes_ready): - distances.append(distance(gt_box, pred_box)) - index = pred_distances.index(min(distances)) - #print('index', index) - if index not in matched.keys(): - matched[index] = [gt_box_index] + pred_html, pred = self.get_pred_html(pred_structures, matched_index, + rec_res) + return pred_html + + def match_result(self, dt_boxes, pred_bboxes): + matched = {} + for i, gt_box in enumerate(dt_boxes): + distances = [] + for j, pred_box in enumerate(pred_bboxes): + if len(pred_box) == 8: + pred_box = [ + np.min(pred_box[0::2]), np.min(pred_box[1::2]), + np.max(pred_box[0::2]), np.max(pred_box[1::2]) + ] + distances.append((distance(gt_box, pred_box), + 1. - compute_iou(gt_box, pred_box) + )) # compute iou and l1 distance + sorted_distances = distances.copy() + # select det box by iou and l1 distance + sorted_distances = sorted( + sorted_distances, key=lambda item: (item[1], item[0])) + if distances.index(sorted_distances[0]) not in matched.keys(): + matched[distances.index(sorted_distances[0])] = [i] else: - matched[index].append(gt_box_index) - gt_box_index += 1 - return matched - -def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes): - ''' - gt_bboxes: 排序后 - pred_bboxes: - ''' - pre_bbox = gt_bboxes[0] - matched = {} - match_bboxes_ready = [] - match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) - for i, gt_box in enumerate(gt_bboxes): - - pred_distances = [] - for pred_bbox in pred_bboxes: - pred_distances.append(distance(gt_box, pred_bbox)) - distances = [] - gap_pre = gt_box[1] - pre_bbox[1] - gap_pre_1 = gt_box[0] - pre_bbox[2] - #print(gap_pre, len(pred_bboxes_rows)) - if (gap_pre_1 < 0 and len(pred_bboxes_rows) > 0): - match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) - if len(pred_bboxes_rows) == 1: - match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) - if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) > 0: - match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) - if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) == 0: - break - #print(match_bboxes_ready) - for j, pred_box in enumerate(match_bboxes_ready): - distances.append(distance(gt_box, pred_box)) - index = pred_distances.index(min(distances)) - #print(gt_box, index) - #match_bboxes_ready.pop(distances.index(min(distances))) - print(gt_box, match_bboxes_ready[distances.index(min(distances))]) - if index not in matched.keys(): - matched[index] = [i] - else: - matched[index].append(i) - pre_bbox = gt_box - return matched + matched[distances.index(sorted_distances[0])].append(i) + return matched + + def get_pred_html(self, pred_structures, matched_index, ocr_contents): + end_html = [] + td_index = 0 + for tag in pred_structures: + if '' in tag: + if '' == tag: + end_html.extend('') + if td_index in matched_index.keys(): + b_with = False + if '' in ocr_contents[matched_index[td_index][ + 0]] and len(matched_index[td_index]) > 1: + b_with = True + end_html.extend('') + for i, td_index_index in enumerate(matched_index[td_index]): + content = ocr_contents[td_index_index][0] + if len(matched_index[td_index]) > 1: + if len(content) == 0: + continue + if content[0] == ' ': + content = content[1:] + if '' in content: + content = content[3:] + if '' in content: + content = content[:-4] + if len(content) == 0: + continue + if i != len(matched_index[ + td_index]) - 1 and ' ' != content[-1]: + content += ' ' + end_html.extend(content) + if b_with: + end_html.extend('') + if '' == tag: + end_html.append('') + else: + end_html.append(tag) + td_index += 1 + else: + end_html.append(tag) + return ''.join(end_html), end_html + + def get_pred_html_master(self, pred_structures, matched_index, + ocr_contents): + end_html = [] + td_index = 0 + for token in pred_structures: + if '' in token: + txt = '' + b_with = False + if td_index in matched_index.keys(): + if '' in ocr_contents[matched_index[td_index][ + 0]] and len(matched_index[td_index]) > 1: + b_with = True + for i, td_index_index in enumerate(matched_index[td_index]): + content = ocr_contents[td_index_index][0] + if len(matched_index[td_index]) > 1: + if len(content) == 0: + continue + if content[0] == ' ': + content = content[1:] + if '' in content: + content = content[3:] + if '' in content: + content = content[:-4] + if len(content) == 0: + continue + if i != len(matched_index[ + td_index]) - 1 and ' ' != content[-1]: + content += ' ' + txt += content + if b_with: + txt = '{}'.format(txt) + if '' == token: + token = '{}'.format(txt) + else: + token = '{}'.format(txt) + td_index += 1 + token = deal_eb_token(token) + end_html.append(token) + html = ''.join(end_html) + html = deal_bb(html) + return html, end_html + + def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res): + y1 = pred_bboxes[:, 1::2].min() + new_dt_boxes = [] + new_rec_res = [] + + for box, rec in zip(dt_boxes, rec_res): + if np.max(box[1::2]) < y1: + continue + new_dt_boxes.append(box) + new_rec_res.append(rec) + return new_dt_boxes, new_rec_res diff --git a/ppstructure/table/predict_structure.py b/ppstructure/table/predict_structure.py index 7a7d3169d567493b4707b63c75cec07485cf5acb..7198fb2bcdc4d9d10f884f3a1545f23a1e628454 100755 --- a/ppstructure/table/predict_structure.py +++ b/ppstructure/table/predict_structure.py @@ -16,7 +16,7 @@ import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..'))) os.environ["FLAGS_allocator_strategy"] = 'auto_growth' @@ -73,12 +73,14 @@ class TableStructurer(object): postprocess_params = { 'name': 'TableLabelDecode', "character_dict_path": args.table_char_dict_path, + 'merge_no_span_structure': args.merge_no_span_structure } else: postprocess_params = { 'name': 'TableMasterLabelDecode', "character_dict_path": args.table_char_dict_path, - 'box_shape': 'pad' + 'box_shape': 'pad', + 'merge_no_span_structure': args.merge_no_span_structure } self.preprocess_op = create_operators(pre_process_list) @@ -87,6 +89,7 @@ class TableStructurer(object): utility.create_predictor(args, 'table', logger) def __call__(self, img): + starttime = time.time() ori_im = img.copy() data = {'image': img} data = transform(data, self.preprocess_op) @@ -95,7 +98,6 @@ class TableStructurer(object): return None, 0 img = np.expand_dims(img, axis=0) img = img.copy() - starttime = time.time() self.input_tensor.copy_from_cpu(img) self.predictor.run() @@ -126,7 +128,6 @@ def main(args): table_structurer = TableStructurer(args) count = 0 total_time = 0 - use_xywh = args.table_algorithm in ['TableMaster'] os.makedirs(args.output, exist_ok=True) with open( os.path.join(args.output, 'infer.txt'), mode='w', @@ -146,7 +147,10 @@ def main(args): f_w.write("result: {}, {}\n".format(structure_str_list, bbox_list_str)) - img = draw_rectangle(image_file, bbox_list, use_xywh) + if len(bbox_list) > 0 and len(bbox_list[0]) == 4: + img = draw_rectangle(image_file, pred_res['cell_bbox']) + else: + img = utility.draw_boxes(img, bbox_list) img_save_path = os.path.join(args.output, os.path.basename(image_file)) cv2.imwrite(img_save_path, img) diff --git a/ppstructure/table/predict_table.py b/ppstructure/table/predict_table.py index becc6daef02e7e3e98fcccd3b87a93e725577886..e94347d86144cd66474546e99a2c9dffee4978d9 100644 --- a/ppstructure/table/predict_table.py +++ b/ppstructure/table/predict_table.py @@ -18,20 +18,23 @@ import subprocess __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) -sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..'))) os.environ["FLAGS_allocator_strategy"] = 'auto_growth' import cv2 import copy +import logging import numpy as np import time import tools.infer.predict_rec as predict_rec import tools.infer.predict_det as predict_det import tools.infer.utility as utility +from tools.infer.predict_system import sorted_boxes from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.logging import get_logger -from ppstructure.table.matcher import distance, compute_iou +from ppstructure.table.matcher import TableMatch +from ppstructure.table.table_master_match import TableMasterMatcher from ppstructure.utility import parse_args import ppstructure.table.predict_structure as predict_strture @@ -55,11 +58,20 @@ def expand(pix, det_box, shape): class TableSystem(object): def __init__(self, args, text_detector=None, text_recognizer=None): + if not args.show_log: + logger.setLevel(logging.INFO) + self.text_detector = predict_det.TextDetector( args) if text_detector is None else text_detector self.text_recognizer = predict_rec.TextRecognizer( args) if text_recognizer is None else text_recognizer + self.table_structurer = predict_strture.TableStructurer(args) + if args.table_algorithm in ['TableMaster']: + self.match = TableMasterMatcher() + else: + self.match = TableMatch(filter_ocr_result=True) + self.benchmark = args.benchmark self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor( args, 'table', logger) @@ -85,145 +97,72 @@ class TableSystem(object): def __call__(self, img, return_ocr_result_in_table=False): result = dict() - ori_im = img.copy() + time_dict = {'det': 0, 'rec': 0, 'table': 0, 'all': 0, 'match': 0} + start = time.time() + + structure_res, elapse = self._structure(copy.deepcopy(img)) + result['cell_bbox'] = structure_res[1].tolist() + time_dict['table'] = elapse + + dt_boxes, rec_res, det_elapse, rec_elapse = self._ocr( + copy.deepcopy(img)) + time_dict['det'] = det_elapse + time_dict['rec'] = rec_elapse + + if return_ocr_result_in_table: + result['boxes'] = dt_boxes #[x.tolist() for x in dt_boxes] + result['rec_res'] = rec_res + + tic = time.time() + pred_html = self.match(structure_res, dt_boxes, rec_res) + toc = time.time() + time_dict['match'] = toc - tic + result['html'] = pred_html + if self.benchmark: + self.autolog.times.end(stamp=True) + end = time.time() + time_dict['all'] = end - start + if self.benchmark: + self.autolog.times.stamp() + return result, time_dict + + def _structure(self, img): if self.benchmark: self.autolog.times.start() structure_res, elapse = self.table_structurer(copy.deepcopy(img)) + return structure_res, elapse + + def _ocr(self, img): + h, w = img.shape[:2] if self.benchmark: self.autolog.times.stamp() - dt_boxes, elapse = self.text_detector(copy.deepcopy(img)) + dt_boxes, det_elapse = self.text_detector(copy.deepcopy(img)) dt_boxes = sorted_boxes(dt_boxes) - if return_ocr_result_in_table: - result['boxes'] = [x.tolist() for x in dt_boxes] + r_boxes = [] for box in dt_boxes: - x_min = box[:, 0].min() - 1 - x_max = box[:, 0].max() + 1 - y_min = box[:, 1].min() - 1 - y_max = box[:, 1].max() + 1 + x_min = max(0, box[:, 0].min() - 1) + x_max = min(w, box[:, 0].max() + 1) + y_min = max(0, box[:, 1].min() - 1) + y_max = min(h, box[:, 1].max() + 1) box = [x_min, y_min, x_max, y_max] r_boxes.append(box) dt_boxes = np.array(r_boxes) logger.debug("dt_boxes num : {}, elapse : {}".format( - len(dt_boxes), elapse)) + len(dt_boxes), det_elapse)) if dt_boxes is None: return None, None + img_crop_list = [] for i in range(len(dt_boxes)): det_box = dt_boxes[i] - x0, y0, x1, y1 = expand(2, det_box, ori_im.shape) - text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :] + x0, y0, x1, y1 = expand(2, det_box, img.shape) + text_rect = img[int(y0):int(y1), int(x0):int(x1), :] img_crop_list.append(text_rect) - rec_res, elapse = self.text_recognizer(img_crop_list) + rec_res, rec_elapse = self.text_recognizer(img_crop_list) logger.debug("rec_res num : {}, elapse : {}".format( - len(rec_res), elapse)) - if self.benchmark: - self.autolog.times.stamp() - if return_ocr_result_in_table: - result['rec_res'] = rec_res - pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res) - result['html'] = pred_html - if self.benchmark: - self.autolog.times.end(stamp=True) - return result - - def rebuild_table(self, structure_res, dt_boxes, rec_res): - pred_structures, pred_bboxes = structure_res - dt_boxes, rec_res = self.filter_ocr_result(pred_bboxes,dt_boxes, rec_res) - matched_index = self.match_result(dt_boxes, pred_bboxes) - pred_html, pred = self.get_pred_html(pred_structures, matched_index, - rec_res) - return pred_html, pred - - def filter_ocr_result(self, pred_bboxes,dt_boxes, rec_res): - y1 = pred_bboxes[:,1::2].min() - new_dt_boxes = [] - new_rec_res = [] - - for box,rec in zip(dt_boxes, rec_res): - if np.max(box[1::2]) < y1: - continue - new_dt_boxes.append(box) - new_rec_res.append(rec) - return new_dt_boxes, new_rec_res - - - def match_result(self, dt_boxes, pred_bboxes): - matched = {} - for i, gt_box in enumerate(dt_boxes): - # gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])] - distances = [] - for j, pred_box in enumerate(pred_bboxes): - distances.append((distance(gt_box, pred_box), - 1. - compute_iou(gt_box, pred_box) - )) # 获取两两cell之间的L1距离和 1- IOU - sorted_distances = distances.copy() - # 根据距离和IOU挑选最"近"的cell - sorted_distances = sorted( - sorted_distances, key=lambda item: (item[1], item[0])) - if distances.index(sorted_distances[0]) not in matched.keys(): - matched[distances.index(sorted_distances[0])] = [i] - else: - matched[distances.index(sorted_distances[0])].append(i) - return matched - - def get_pred_html(self, pred_structures, matched_index, ocr_contents): - end_html = [] - td_index = 0 - for tag in pred_structures: - if '' in tag: - if td_index in matched_index.keys(): - b_with = False - if '' in ocr_contents[matched_index[td_index][ - 0]] and len(matched_index[td_index]) > 1: - b_with = True - end_html.extend('') - for i, td_index_index in enumerate(matched_index[td_index]): - content = ocr_contents[td_index_index][0] - if len(matched_index[td_index]) > 1: - if len(content) == 0: - continue - if content[0] == ' ': - content = content[1:] - if '' in content: - content = content[3:] - if '' in content: - content = content[:-4] - if len(content) == 0: - continue - if i != len(matched_index[ - td_index]) - 1 and ' ' != content[-1]: - content += ' ' - end_html.extend(content) - if b_with: - end_html.extend('') - - end_html.append(tag) - td_index += 1 - else: - end_html.append(tag) - return ''.join(end_html), end_html - - -def sorted_boxes(dt_boxes): - """ - Sort text boxes in order from top to bottom, left to right - args: - dt_boxes(array):detected text boxes with shape [4, 2] - return: - sorted boxes(array) with shape [4, 2] - """ - num_boxes = dt_boxes.shape[0] - sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) - _boxes = list(sorted_boxes) - - for i in range(num_boxes - 1): - if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \ - (_boxes[i + 1][0][0] < _boxes[i][0][0]): - tmp = _boxes[i] - _boxes[i] = _boxes[i + 1] - _boxes[i + 1] = tmp - return _boxes + len(rec_res), rec_elapse)) + return dt_boxes, rec_res, det_elapse, rec_elapse def to_excel(html_table, excel_path): @@ -236,8 +175,23 @@ def main(args): image_file_list = image_file_list[args.process_id::args.total_process_num] os.makedirs(args.output, exist_ok=True) - text_sys = TableSystem(args) + table_sys = TableSystem(args) img_num = len(image_file_list) + + f_html = open( + os.path.join(args.output, 'show.html'), mode='w', encoding='utf-8') + f_html.write('\n\n') + f_html.write('\n') + f_html.write( + "" + ) + f_html.write("\n") + f_html.write('') + f_html.write('') + f_html.write('') + f_html.write("\n") + for i, image_file in enumerate(image_file_list): logger.info("[{}/{}] {}".format(i, img_num, image_file)) img, flag = check_and_read_gif(image_file) @@ -249,13 +203,35 @@ def main(args): logger.error("error in loading image:{}".format(image_file)) continue starttime = time.time() - pred_res = text_sys(img) + pred_res, _ = table_sys(img) pred_html = pred_res['html'] logger.info(pred_html) to_excel(pred_html, excel_path) logger.info('excel saved to {}'.format(excel_path)) elapse = time.time() - starttime logger.info("Predict time : {:.3f}s".format(elapse)) + + if len(pred_res['cell_bbox']) > 0 and len(pred_res['cell_bbox'][ + 0]) == 4: + img = predict_strture.draw_rectangle(image_file, + pred_res['cell_bbox']) + else: + img = utility.draw_boxes(img, pred_res['cell_bbox']) + img_save_path = os.path.join(args.output, os.path.basename(image_file)) + cv2.imwrite(img_save_path, img) + + f_html.write("\n") + f_html.write(f'\n') + f_html.write('
img name\n') + f_html.write('ori imagetable htmlcell box
{os.path.basename(image_file)}
\n') + f_html.write(f'
' + pred_html.replace( + '
', '').replace('
', '') + + '
\n') + f_html.write( + f'\n') + f_html.write("\n") + f_html.write("\n") + f_html.close() + if args.benchmark: text_sys.autolog.report() diff --git a/ppstructure/table/table_master_match.py b/ppstructure/table/table_master_match.py new file mode 100644 index 0000000000000000000000000000000000000000..7a7208d4a94bb357b1bbce0d664d9d6449a96874 --- /dev/null +++ b/ppstructure/table/table_master_match.py @@ -0,0 +1,953 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/match.py +""" + +import os +import re +import cv2 +import glob +import copy +import math +import pickle +import numpy as np + +from shapely.geometry import Polygon, MultiPoint +""" +Useful function in matching. +""" + + +def remove_empty_bboxes(bboxes): + """ + remove [0., 0., 0., 0.] in structure master bboxes. + len(bboxes.shape) must be 2. + :param bboxes: + :return: + """ + new_bboxes = [] + for bbox in bboxes: + if sum(bbox) == 0.: + continue + new_bboxes.append(bbox) + return np.array(new_bboxes) + + +def xywh2xyxy(bboxes): + if len(bboxes.shape) == 1: + new_bboxes = np.empty_like(bboxes) + new_bboxes[0] = bboxes[0] - bboxes[2] / 2 + new_bboxes[1] = bboxes[1] - bboxes[3] / 2 + new_bboxes[2] = bboxes[0] + bboxes[2] / 2 + new_bboxes[3] = bboxes[1] + bboxes[3] / 2 + return new_bboxes + elif len(bboxes.shape) == 2: + new_bboxes = np.empty_like(bboxes) + new_bboxes[:, 0] = bboxes[:, 0] - bboxes[:, 2] / 2 + new_bboxes[:, 1] = bboxes[:, 1] - bboxes[:, 3] / 2 + new_bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2] / 2 + new_bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3] / 2 + return new_bboxes + else: + raise ValueError + + +def xyxy2xywh(bboxes): + if len(bboxes.shape) == 1: + new_bboxes = np.empty_like(bboxes) + new_bboxes[0] = bboxes[0] + (bboxes[2] - bboxes[0]) / 2 + new_bboxes[1] = bboxes[1] + (bboxes[3] - bboxes[1]) / 2 + new_bboxes[2] = bboxes[2] - bboxes[0] + new_bboxes[3] = bboxes[3] - bboxes[1] + return new_bboxes + elif len(bboxes.shape) == 2: + new_bboxes = np.empty_like(bboxes) + new_bboxes[:, 0] = bboxes[:, 0] + (bboxes[:, 2] - bboxes[:, 0]) / 2 + new_bboxes[:, 1] = bboxes[:, 1] + (bboxes[:, 3] - bboxes[:, 1]) / 2 + new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] + new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] + return new_bboxes + else: + raise ValueError + + +def pickle_load(path, prefix='end2end'): + if os.path.isfile(path): + data = pickle.load(open(path, 'rb')) + elif os.path.isdir(path): + data = dict() + search_path = os.path.join(path, '{}_*.pkl'.format(prefix)) + pkls = glob.glob(search_path) + for pkl in pkls: + this_data = pickle.load(open(pkl, 'rb')) + data.update(this_data) + else: + raise ValueError + return data + + +def convert_coord(xyxy): + """ + Convert two points format to four points format. + :param xyxy: + :return: + """ + new_bbox = np.zeros([4, 2], dtype=np.float32) + new_bbox[0, 0], new_bbox[0, 1] = xyxy[0], xyxy[1] + new_bbox[1, 0], new_bbox[1, 1] = xyxy[2], xyxy[1] + new_bbox[2, 0], new_bbox[2, 1] = xyxy[2], xyxy[3] + new_bbox[3, 0], new_bbox[3, 1] = xyxy[0], xyxy[3] + return new_bbox + + +def cal_iou(bbox1, bbox2): + bbox1_poly = Polygon(bbox1).convex_hull + bbox2_poly = Polygon(bbox2).convex_hull + union_poly = np.concatenate((bbox1, bbox2)) + + if not bbox1_poly.intersects(bbox2_poly): + iou = 0 + else: + inter_area = bbox1_poly.intersection(bbox2_poly).area + union_area = MultiPoint(union_poly).convex_hull.area + if union_area == 0: + iou = 0 + else: + iou = float(inter_area) / union_area + return iou + + +def cal_distance(p1, p2): + delta_x = p1[0] - p2[0] + delta_y = p1[1] - p2[1] + d = math.sqrt((delta_x**2) + (delta_y**2)) + return d + + +def is_inside(center_point, corner_point): + """ + Find if center_point inside the bbox(corner_point) or not. + :param center_point: center point (x, y) + :param corner_point: corner point ((x1,y1),(x2,y2)) + :return: + """ + x_flag = False + y_flag = False + if (center_point[0] >= corner_point[0][0]) and ( + center_point[0] <= corner_point[1][0]): + x_flag = True + if (center_point[1] >= corner_point[0][1]) and ( + center_point[1] <= corner_point[1][1]): + y_flag = True + if x_flag and y_flag: + return True + else: + return False + + +def find_no_match(match_list, all_end2end_nums, type='end2end'): + """ + Find out no match end2end bbox in previous match list. + :param match_list: matching pairs. + :param all_end2end_nums: numbers of end2end_xywh + :param type: 'end2end' corresponding to idx 0, 'master' corresponding to idx 1. + :return: no match pse bbox index list + """ + if type == 'end2end': + idx = 0 + elif type == 'master': + idx = 1 + else: + raise ValueError + + no_match_indexs = [] + # m[0] is end2end index m[1] is master index + matched_bbox_indexs = [m[idx] for m in match_list] + for n in range(all_end2end_nums): + if n not in matched_bbox_indexs: + no_match_indexs.append(n) + return no_match_indexs + + +def is_abs_lower_than_threshold(this_bbox, target_bbox, threshold=3): + # only consider y axis, for grouping in row. + delta = abs(this_bbox[1] - target_bbox[1]) + if delta < threshold: + return True + else: + return False + + +def sort_line_bbox(g, bg): + """ + Sorted the bbox in the same line(group) + compare coord 'x' value, where 'y' value is closed in the same group. + :param g: index in the same group + :param bg: bbox in the same group + :return: + """ + + xs = [bg_item[0] for bg_item in bg] + xs_sorted = sorted(xs) + + g_sorted = [None] * len(xs_sorted) + bg_sorted = [None] * len(xs_sorted) + for g_item, bg_item in zip(g, bg): + idx = xs_sorted.index(bg_item[0]) + bg_sorted[idx] = bg_item + g_sorted[idx] = g_item + + return g_sorted, bg_sorted + + +def flatten(sorted_groups, sorted_bbox_groups): + idxs = [] + bboxes = [] + for group, bbox_group in zip(sorted_groups, sorted_bbox_groups): + for g, bg in zip(group, bbox_group): + idxs.append(g) + bboxes.append(bg) + return idxs, bboxes + + +def sort_bbox(end2end_xywh_bboxes, no_match_end2end_indexes): + """ + This function will group the render end2end bboxes in row. + :param end2end_xywh_bboxes: + :param no_match_end2end_indexes: + :return: + """ + groups = [] + bbox_groups = [] + for index, end2end_xywh_bbox in zip(no_match_end2end_indexes, + end2end_xywh_bboxes): + this_bbox = end2end_xywh_bbox + if len(groups) == 0: + groups.append([index]) + bbox_groups.append([this_bbox]) + else: + flag = False + for g, bg in zip(groups, bbox_groups): + # this_bbox is belong to bg's row or not + if is_abs_lower_than_threshold(this_bbox, bg[0]): + g.append(index) + bg.append(this_bbox) + flag = True + break + if not flag: + # this_bbox is not belong to bg's row, create a row. + groups.append([index]) + bbox_groups.append([this_bbox]) + + # sorted bboxes in a group + tmp_groups, tmp_bbox_groups = [], [] + for g, bg in zip(groups, bbox_groups): + g_sorted, bg_sorted = sort_line_bbox(g, bg) + tmp_groups.append(g_sorted) + tmp_bbox_groups.append(bg_sorted) + + # sorted groups, sort by coord y's value. + sorted_groups = [None] * len(tmp_groups) + sorted_bbox_groups = [None] * len(tmp_bbox_groups) + ys = [bg[0][1] for bg in tmp_bbox_groups] + sorted_ys = sorted(ys) + for g, bg in zip(tmp_groups, tmp_bbox_groups): + idx = sorted_ys.index(bg[0][1]) + sorted_groups[idx] = g + sorted_bbox_groups[idx] = bg + + # flatten, get final result + end2end_sorted_idx_list, end2end_sorted_bbox_list \ + = flatten(sorted_groups, sorted_bbox_groups) + + return end2end_sorted_idx_list, end2end_sorted_bbox_list, sorted_groups, sorted_bbox_groups + + +def get_bboxes_list(end2end_result, structure_master_result): + """ + This function is use to convert end2end results and structure master results to + List of xyxy bbox format and List of xywh bbox format + :param end2end_result: bbox's format is xyxy + :param structure_master_result: bbox's format is xywh + :return: 4 kind list of bbox () + """ + # end2end + end2end_xyxy_list = [] + end2end_xywh_list = [] + for end2end_item in end2end_result: + src_bbox = end2end_item['bbox'] + end2end_xyxy_list.append(src_bbox) + xywh_bbox = xyxy2xywh(src_bbox) + end2end_xywh_list.append(xywh_bbox) + end2end_xyxy_bboxes = np.array(end2end_xyxy_list) + end2end_xywh_bboxes = np.array(end2end_xywh_list) + + # structure master + src_bboxes = structure_master_result['bbox'] + src_bboxes = remove_empty_bboxes(src_bboxes) + structure_master_xyxy_bboxes = src_bboxes + xywh_bbox = xyxy2xywh(src_bboxes) + structure_master_xywh_bboxes = xywh_bbox + + return end2end_xyxy_bboxes, end2end_xywh_bboxes, structure_master_xywh_bboxes, structure_master_xyxy_bboxes + + +def center_rule_match(end2end_xywh_bboxes, structure_master_xyxy_bboxes): + """ + Judge end2end Bbox's center point is inside structure master Bbox or not, + if end2end Bbox's center is in structure master Bbox, get matching pair. + :param end2end_xywh_bboxes: + :param structure_master_xyxy_bboxes: + :return: match pairs list, e.g. [[0,1], [1,2], ...] + """ + match_pairs_list = [] + for i, end2end_xywh in enumerate(end2end_xywh_bboxes): + for j, master_xyxy in enumerate(structure_master_xyxy_bboxes): + x_end2end, y_end2end = end2end_xywh[0], end2end_xywh[1] + x_master1, y_master1, x_master2, y_master2 \ + = master_xyxy[0], master_xyxy[1], master_xyxy[2], master_xyxy[3] + center_point_end2end = (x_end2end, y_end2end) + corner_point_master = ((x_master1, y_master1), + (x_master2, y_master2)) + if is_inside(center_point_end2end, corner_point_master): + match_pairs_list.append([i, j]) + return match_pairs_list + + +def iou_rule_match(end2end_xyxy_bboxes, end2end_xyxy_indexes, + structure_master_xyxy_bboxes): + """ + Use iou to find matching list. + choose max iou value bbox as match pair. + :param end2end_xyxy_bboxes: + :param end2end_xyxy_indexes: original end2end indexes. + :param structure_master_xyxy_bboxes: + :return: match pairs list, e.g. [[0,1], [1,2], ...] + """ + match_pair_list = [] + for end2end_xyxy_index, end2end_xyxy in zip(end2end_xyxy_indexes, + end2end_xyxy_bboxes): + max_iou = 0 + max_match = [None, None] + for j, master_xyxy in enumerate(structure_master_xyxy_bboxes): + end2end_4xy = convert_coord(end2end_xyxy) + master_4xy = convert_coord(master_xyxy) + iou = cal_iou(end2end_4xy, master_4xy) + if iou > max_iou: + max_match[0], max_match[1] = end2end_xyxy_index, j + max_iou = iou + + if max_match[0] is None: + # no match + continue + match_pair_list.append(max_match) + return match_pair_list + + +def distance_rule_match(end2end_indexes, end2end_bboxes, master_indexes, + master_bboxes): + """ + Get matching between no-match end2end bboxes and no-match master bboxes. + Use min distance to match. + This rule will only run (no-match end2end nums > 0) and (no-match master nums > 0) + It will Return master_bboxes_nums match-pairs. + :param end2end_indexes: + :param end2end_bboxes: + :param master_indexes: + :param master_bboxes: + :return: match_pairs list, e.g. [[0,1], [1,2], ...] + """ + min_match_list = [] + for j, master_bbox in zip(master_indexes, master_bboxes): + min_distance = np.inf + min_match = [0, 0] # i, j + for i, end2end_bbox in zip(end2end_indexes, end2end_bboxes): + x_end2end, y_end2end = end2end_bbox[0], end2end_bbox[1] + x_master, y_master = master_bbox[0], master_bbox[1] + end2end_point = (x_end2end, y_end2end) + master_point = (x_master, y_master) + dist = cal_distance(master_point, end2end_point) + if dist < min_distance: + min_match[0], min_match[1] = i, j + min_distance = dist + min_match_list.append(min_match) + return min_match_list + + +def extra_match(no_match_end2end_indexes, master_bbox_nums): + """ + This function will create some virtual master bboxes, + and get match with the no match end2end indexes. + :param no_match_end2end_indexes: + :param master_bbox_nums: + :return: + """ + end_nums = len(no_match_end2end_indexes) + master_bbox_nums + extra_match_list = [] + for i in range(master_bbox_nums, end_nums): + end2end_index = no_match_end2end_indexes[i - master_bbox_nums] + extra_match_list.append([end2end_index, i]) + return extra_match_list + + +def get_match_dict(match_list): + """ + Convert match_list to a dict, where key is master bbox's index, value is end2end bbox index. + :param match_list: + :return: + """ + match_dict = dict() + for match_pair in match_list: + end2end_index, master_index = match_pair[0], match_pair[1] + if master_index not in match_dict.keys(): + match_dict[master_index] = [end2end_index] + else: + match_dict[master_index].append(end2end_index) + return match_dict + + +def deal_successive_space(text): + """ + deal successive space character for text + 1. Replace ' '*3 with '' which is real space is text + 2. Remove ' ', which is split token, not true space + 3. Replace '' with ' ', to get real text + :param text: + :return: + """ + text = text.replace(' ' * 3, '') + text = text.replace(' ', '') + text = text.replace('', ' ') + return text + + +def reduce_repeat_bb(text_list, break_token): + """ + convert ['Local', 'government', 'unit'] to ['Local government unit'] + PS: maybe style Local is also exist, too. it can be processed like this. + :param text_list: + :param break_token: + :return: + """ + count = 0 + for text in text_list: + if text.startswith(''): + count += 1 + if count == len(text_list): + new_text_list = [] + for text in text_list: + text = text.replace('', '').replace('', '') + new_text_list.append(text) + return ['' + break_token.join(new_text_list) + ''] + else: + return text_list + + +def get_match_text_dict(match_dict, end2end_info, break_token=' '): + match_text_dict = dict() + for master_index, end2end_index_list in match_dict.items(): + text_list = [ + end2end_info[end2end_index]['text'] + for end2end_index in end2end_index_list + ] + text_list = reduce_repeat_bb(text_list, break_token) + text = break_token.join(text_list) + match_text_dict[master_index] = text + return match_text_dict + + +def merge_span_token(master_token_list): + """ + Merge the span style token (row span or col span). + :param master_token_list: + :return: + """ + new_master_token_list = [] + pointer = 0 + if master_token_list[-1] != '': + master_token_list.append('') + while master_token_list[pointer] != '': + try: + if master_token_list[pointer] == '' + '' + """ + tmp = ''.join(master_token_list[pointer:pointer + 3 + 1]) + pointer += 4 + new_master_token_list.append(tmp) + + elif master_token_list[pointer + 2].startswith( + ' colspan=') or master_token_list[ + pointer + 2].startswith(' rowspan='): + """ + example: + pattern + '' + '' + """ + tmp = ''.join(master_token_list[pointer:pointer + 4 + 1]) + pointer += 5 + new_master_token_list.append(tmp) + + else: + new_master_token_list.append(master_token_list[pointer]) + pointer += 1 + else: + new_master_token_list.append(master_token_list[pointer]) + pointer += 1 + except: + print("Break in merge...") + break + new_master_token_list.append('') + + return new_master_token_list + + +def deal_eb_token(master_token): + """ + post process with , , ... + emptyBboxTokenDict = { + "[]": '', + "[' ']": '', + "['', ' ', '']": '', + "['\\u2028', '\\u2028']": '', + "['', ' ', '']": '', + "['', '']": '', + "['', ' ', '']": '', + "['', '', '', '']": '', + "['', '', ' ', '', '']": '', + "['', '']": '', + "['', ' ', '\\u2028', ' ', '\\u2028', ' ', '']": '', + } + :param master_token: + :return: + """ + master_token = master_token.replace('', '') + master_token = master_token.replace('', ' ') + master_token = master_token.replace('', ' ') + master_token = master_token.replace('', '\u2028\u2028') + master_token = master_token.replace('', ' ') + master_token = master_token.replace('', '') + master_token = master_token.replace('', ' ') + master_token = master_token.replace('', + '') + master_token = master_token.replace('', + ' ') + master_token = master_token.replace('', '') + master_token = master_token.replace('', + ' \u2028 \u2028 ') + return master_token + + +def insert_text_to_token(master_token_list, match_text_dict): + """ + Insert OCR text result to structure token. + :param master_token_list: + :param match_text_dict: + :return: + """ + master_token_list = merge_span_token(master_token_list) + merged_result_list = [] + text_count = 0 + for master_token in master_token_list: + if master_token.startswith(' len(match_text_dict) - 1: + text_count += 1 + continue + elif text_count not in match_text_dict.keys(): + text_count += 1 + continue + else: + master_token = master_token.replace( + '><', '>{}<'.format(match_text_dict[text_count])) + text_count += 1 + master_token = deal_eb_token(master_token) + merged_result_list.append(master_token) + + return ''.join(merged_result_list) + + +def deal_isolate_span(thead_part): + """ + Deal with isolate span cases in this function. + It causes by wrong prediction in structure recognition model. + eg. predict to rowspan="2">. + :param thead_part: + :return: + """ + # 1. find out isolate span tokens. + isolate_pattern = " rowspan=\"(\d)+\" colspan=\"(\d)+\">
|" \ + " colspan=\"(\d)+\" rowspan=\"(\d)+\">
|" \ + " rowspan=\"(\d)+\">
|" \ + " colspan=\"(\d)+\">" + isolate_iter = re.finditer(isolate_pattern, thead_part) + isolate_list = [i.group() for i in isolate_iter] + + # 2. find out span number, by step 1 results. + span_pattern = " rowspan=\"(\d)+\" colspan=\"(\d)+\"|" \ + " colspan=\"(\d)+\" rowspan=\"(\d)+\"|" \ + " rowspan=\"(\d)+\"|" \ + " colspan=\"(\d)+\"" + corrected_list = [] + for isolate_item in isolate_list: + span_part = re.search(span_pattern, isolate_item) + spanStr_in_isolateItem = span_part.group() + # 3. merge the span number into the span token format string. + if spanStr_in_isolateItem is not None: + corrected_item = ''.format(spanStr_in_isolateItem) + corrected_list.append(corrected_item) + else: + corrected_list.append(None) + + # 4. replace original isolated token. + for corrected_item, isolate_item in zip(corrected_list, isolate_list): + if corrected_item is not None: + thead_part = thead_part.replace(isolate_item, corrected_item) + else: + pass + return thead_part + + +def deal_duplicate_bb(thead_part): + """ + Deal duplicate or after replace. + Keep one in a token. + :param thead_part: + :return: + """ + # 1. find out in . + td_pattern = "(.+?)|" \ + "(.+?)|" \ + "(.+?)|" \ + "(.+?)|" \ + "(.*?)" + td_iter = re.finditer(td_pattern, thead_part) + td_list = [t.group() for t in td_iter] + + # 2. is multiply in or not? + new_td_list = [] + for td_item in td_list: + if td_item.count('') > 1 or td_item.count('') > 1: + # multiply in case. + # 1. remove all + td_item = td_item.replace('', '').replace('', '') + # 2. replace -> , -> . + td_item = td_item.replace('', '').replace('', + '') + new_td_list.append(td_item) + else: + new_td_list.append(td_item) + + # 3. replace original thead part. + for td_item, new_td_item in zip(td_list, new_td_list): + thead_part = thead_part.replace(td_item, new_td_item) + return thead_part + + +def deal_bb(result_token): + """ + In our opinion, always occurs in text's context. + This function will find out all tokens in and insert by manual. + :param result_token: + :return: + """ + # find out parts. + thead_pattern = '(.*?)' + if re.search(thead_pattern, result_token) is None: + return result_token + thead_part = re.search(thead_pattern, result_token).group() + origin_thead_part = copy.deepcopy(thead_part) + + # check "rowspan" or "colspan" occur in parts or not . + span_pattern = "|||" + span_iter = re.finditer(span_pattern, thead_part) + span_list = [s.group() for s in span_iter] + has_span_in_head = True if len(span_list) > 0 else False + + if not has_span_in_head: + # not include "rowspan" or "colspan" branch 1. + # 1. replace to , and to + # 2. it is possible to predict text include or by Text-line recognition, + # so we replace to , and to + thead_part = thead_part.replace('', '')\ + .replace('', '')\ + .replace('', '')\ + .replace('', '') + else: + # include "rowspan" or "colspan" branch 2. + # Firstly, we deal rowspan or colspan cases. + # 1. replace > to > + # 2. replace to + # 3. it is possible to predict text include or by Text-line recognition, + # so we replace to , and to + + # Secondly, deal ordinary cases like branch 1 + + # replace ">" to "" + replaced_span_list = [] + for sp in span_list: + replaced_span_list.append(sp.replace('>', '>')) + for sp, rsp in zip(span_list, replaced_span_list): + thead_part = thead_part.replace(sp, rsp) + + # replace "" to "" + thead_part = thead_part.replace('', '') + + # remove duplicated by re.sub + mb_pattern = "()+" + single_b_string = "" + thead_part = re.sub(mb_pattern, single_b_string, thead_part) + + mgb_pattern = "()+" + single_gb_string = "" + thead_part = re.sub(mgb_pattern, single_gb_string, thead_part) + + # ordinary cases like branch 1 + thead_part = thead_part.replace('', '').replace('', + '') + + # convert back to , empty cell has no . + # but space cell( ) is suitable for + thead_part = thead_part.replace('', '') + # deal with duplicated + thead_part = deal_duplicate_bb(thead_part) + # deal with isolate span tokens, which causes by wrong predict by structure prediction. + # eg.PMC5994107_011_00.png + thead_part = deal_isolate_span(thead_part) + # replace original result with new thead part. + result_token = result_token.replace(origin_thead_part, thead_part) + return result_token + + +class Matcher: + def __init__(self, end2end_file, structure_master_file): + """ + This class process the end2end results and structure recognition results. + :param end2end_file: end2end results predict by end2end inference. + :param structure_master_file: structure recognition results predict by structure master inference. + """ + self.end2end_file = end2end_file + self.structure_master_file = structure_master_file + self.end2end_results = pickle_load(end2end_file, prefix='end2end') + self.structure_master_results = pickle_load( + structure_master_file, prefix='structure') + + def match(self): + """ + Match process: + pre-process : convert end2end and structure master results to xyxy, xywh ndnarray format. + 1. Use pseBbox is inside masterBbox judge rule + 2. Use iou between pseBbox and masterBbox rule + 3. Use min distance of center point rule + :return: + """ + match_results = dict() + for idx, (file_name, + end2end_result) in enumerate(self.end2end_results.items()): + match_list = [] + if file_name not in self.structure_master_results: + continue + structure_master_result = self.structure_master_results[file_name] + end2end_xyxy_bboxes, end2end_xywh_bboxes, structure_master_xywh_bboxes, structure_master_xyxy_bboxes = \ + get_bboxes_list(end2end_result, structure_master_result) + + # rule 1: center rule + center_rule_match_list = \ + center_rule_match(end2end_xywh_bboxes, structure_master_xyxy_bboxes) + match_list.extend(center_rule_match_list) + + # rule 2: iou rule + # firstly, find not match index in previous step. + center_no_match_end2end_indexs = \ + find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end') + if len(center_no_match_end2end_indexs) > 0: + center_no_match_end2end_xyxy = end2end_xyxy_bboxes[ + center_no_match_end2end_indexs] + # secondly, iou rule match + iou_rule_match_list = \ + iou_rule_match(center_no_match_end2end_xyxy, center_no_match_end2end_indexs, structure_master_xyxy_bboxes) + match_list.extend(iou_rule_match_list) + + # rule 3: distance rule + # match between no-match end2end bboxes and no-match master bboxes. + # it will return master_bboxes_nums match-pairs. + # firstly, find not match index in previous step. + centerIou_no_match_end2end_indexs = \ + find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end') + centerIou_no_match_master_indexs = \ + find_no_match(match_list, len(structure_master_xywh_bboxes), type='master') + if len(centerIou_no_match_master_indexs) > 0 and len( + centerIou_no_match_end2end_indexs) > 0: + centerIou_no_match_end2end_xywh = end2end_xywh_bboxes[ + centerIou_no_match_end2end_indexs] + centerIou_no_match_master_xywh = structure_master_xywh_bboxes[ + centerIou_no_match_master_indexs] + distance_match_list = distance_rule_match( + centerIou_no_match_end2end_indexs, + centerIou_no_match_end2end_xywh, + centerIou_no_match_master_indexs, + centerIou_no_match_master_xywh) + match_list.extend(distance_match_list) + + # TODO: + # The render no-match pseBbox, insert the last + # After step3 distance rule, a master bbox at least match one end2end bbox. + # But end2end bbox maybe overmuch, because numbers of master bbox will cut by max length. + # For these render end2end bboxes, we will make some virtual master bboxes, and get matching. + # The above extra insert bboxes will be further processed in "formatOutput" function. + # After this operation, it will increase TEDS score. + no_match_end2end_indexes = \ + find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end') + if len(no_match_end2end_indexes) > 0: + no_match_end2end_xywh = end2end_xywh_bboxes[ + no_match_end2end_indexes] + # sort the render no-match end2end bbox in row + end2end_sorted_indexes_list, end2end_sorted_bboxes_list, sorted_groups, sorted_bboxes_groups = \ + sort_bbox(no_match_end2end_xywh, no_match_end2end_indexes) + # make virtual master bboxes, and get matching with the no-match end2end bboxes. + extra_match_list = extra_match( + end2end_sorted_indexes_list, + len(structure_master_xywh_bboxes)) + match_list_add_extra_match = copy.deepcopy(match_list) + match_list_add_extra_match.extend(extra_match_list) + else: + # no no-match end2end bboxes + match_list_add_extra_match = copy.deepcopy(match_list) + sorted_groups = [] + sorted_bboxes_groups = [] + + match_result_dict = { + 'match_list': match_list, + 'match_list_add_extra_match': match_list_add_extra_match, + 'sorted_groups': sorted_groups, + 'sorted_bboxes_groups': sorted_bboxes_groups + } + + # format output + match_result_dict = self._format(match_result_dict, file_name) + + match_results[file_name] = match_result_dict + + return match_results + + def _format(self, match_result, file_name): + """ + Extend the master token(insert virtual master token), and format matching result. + :param match_result: + :param file_name: + :return: + """ + end2end_info = self.end2end_results[file_name] + master_info = self.structure_master_results[file_name] + master_token = master_info['text'] + sorted_groups = match_result['sorted_groups'] + + # creat virtual master token + virtual_master_token_list = [] + for line_group in sorted_groups: + tmp_list = [''] + item_nums = len(line_group) + for _ in range(item_nums): + tmp_list.append('') + tmp_list.append('') + virtual_master_token_list.extend(tmp_list) + + # insert virtual master token + master_token_list = master_token.split(',') + if master_token_list[-1] == '': + # complete predict(no cut by max length) + # This situation insert virtual master token will drop TEDs score in val set. + # So we will not extend virtual token in this situation. + + # fake extend virtual + master_token_list[:-1].extend(virtual_master_token_list) + + # real extend virtual + # master_token_list = master_token_list[:-1] + # master_token_list.extend(virtual_master_token_list) + # master_token_list.append('') + + elif master_token_list[-1] == '': + master_token_list.append('') + master_token_list.extend(virtual_master_token_list) + master_token_list.append('') + else: + master_token_list.extend(virtual_master_token_list) + master_token_list.append('') + + # format output + match_result.setdefault('matched_master_token_list', master_token_list) + return match_result + + def get_merge_result(self, match_results): + """ + Merge the OCR result into structure token to get final results. + :param match_results: + :return: + """ + merged_results = dict() + + # break_token is linefeed token, when one master bbox has multiply end2end bboxes. + break_token = ' ' + + for idx, (file_name, match_info) in enumerate(match_results.items()): + end2end_info = self.end2end_results[file_name] + master_token_list = match_info['matched_master_token_list'] + match_list = match_info['match_list_add_extra_match'] + + match_dict = get_match_dict(match_list) + match_text_dict = get_match_text_dict(match_dict, end2end_info, + break_token) + merged_result = insert_text_to_token(master_token_list, + match_text_dict) + merged_result = deal_bb(merged_result) + + merged_results[file_name] = merged_result + + return merged_results + + +class TableMasterMatcher(Matcher): + def __init__(self): + pass + + def __call__(self, structure_res, dt_boxes, rec_res, img_name=1): + end2end_results = {img_name: []} + for dt_box, res in zip(dt_boxes, rec_res): + d = dict( + bbox=np.array(dt_box), + text=res[0], ) + end2end_results[img_name].append(d) + + self.end2end_results = end2end_results + + structure_master_result_dict = {img_name: {}} + pred_structures, pred_bboxes = structure_res + pred_structures = ','.join(pred_structures[3:-3]) + structure_master_result_dict[img_name]['text'] = pred_structures + structure_master_result_dict[img_name]['bbox'] = pred_bboxes + self.structure_master_results = structure_master_result_dict + + # match + match_results = self.match() + merged_results = self.get_merge_result(match_results) + pred_html = merged_results[img_name] + pred_html = '' + pred_html + '
' + return pred_html diff --git a/ppstructure/utility.py b/ppstructure/utility.py index 699f6088a5ac0404c05aa167bc9d9efe1eac12d9..cda4c063bccbd2aff34cf25768866feb4d68dc2d 100644 --- a/ppstructure/utility.py +++ b/ppstructure/utility.py @@ -27,6 +27,8 @@ def init_args(): parser.add_argument("--table_max_len", type=int, default=488) parser.add_argument("--table_algorithm", type=str, default='TableAttn') parser.add_argument("--table_model_dir", type=str) + parser.add_argument( + "--merge_no_span_structure", type=str2bool, default=True) parser.add_argument( "--table_char_dict_path", type=str, @@ -36,14 +38,17 @@ def init_args(): parser.add_argument( "--layout_dict_path", type=str, - default="../ppocr/utils/dict/layout_pubalynet_dict.txt") + default="../ppocr/utils/dict/layout_publaynet_dict.txt") parser.add_argument( "--layout_score_threshold", type=float, default=0.5, help="Threshold of score.") parser.add_argument( - "--layout_nms_threshold", type=float, default=0.5, help="Threshold of nms.") + "--layout_nms_threshold", + type=float, + default=0.5, + help="Threshold of nms.") # params for vqa parser.add_argument("--vqa_algorithm", type=str, default='LayoutXLM') parser.add_argument("--ser_model_dir", type=str) @@ -59,6 +64,11 @@ def init_args(): type=str, default='structure', help='structure and vqa is supported') + parser.add_argument( + "--image_orientation", + type=bool, + default=False, + help='Whether to enable image orientation recognition') parser.add_argument( "--layout", type=str2bool, diff --git a/ppstructure/vqa/predict_vqa_token_ser.py b/ppstructure/vqa/predict_vqa_token_ser.py index 855be42de33e5f5f63adf36ea44f0936c3bf5ca8..7647af9d10684bc6621b32e95d55e05948cb59b7 100644 --- a/ppstructure/vqa/predict_vqa_token_ser.py +++ b/ppstructure/vqa/predict_vqa_token_ser.py @@ -41,7 +41,11 @@ logger = get_logger() class SerPredictor(object): def __init__(self, args): self.ocr_engine = PaddleOCR( - use_angle_cls=False, show_log=False, use_gpu=args.use_gpu) + use_angle_cls=args.use_angle_cls, + det_model_dir=args.det_model_dir, + rec_model_dir=args.rec_model_dir, + show_log=False, + use_gpu=args.use_gpu) pre_process_list = [{ 'VQATokenLabelEncode': { diff --git a/test_tipc/common_func.sh b/test_tipc/common_func.sh index f7d8a1e04adee9d32332eda8cb5913bbaf168481..1bbf829165323b76341461b297b71102462d83af 100644 --- a/test_tipc/common_func.sh +++ b/test_tipc/common_func.sh @@ -58,10 +58,11 @@ function status_check(){ run_command=$2 run_log=$3 model_name=$4 + log_path=$5 if [ $last_status -eq 0 ]; then - echo -e "\033[33m Run successfully with command - ${model_name} - ${run_command}! \033[0m" | tee -a ${run_log} + echo -e "\033[33m Run successfully with command - ${model_name} - ${run_command} - ${log_path} \033[0m" | tee -a ${run_log} else - echo -e "\033[33m Run failed with command - ${model_name} - ${run_command}! \033[0m" | tee -a ${run_log} + echo -e "\033[33m Run failed with command - ${model_name} - ${run_command} - ${log_path} \033[0m" | tee -a ${run_log} fi } diff --git a/test_tipc/configs/det_r18_vd_db_v2_0/train_infer_python.txt b/test_tipc/configs/det_r18_vd_db_v2_0/train_infer_python.txt deleted file mode 100644 index df88c0e5434511fb48deac699e8f67fc535765d3..0000000000000000000000000000000000000000 --- a/test_tipc/configs/det_r18_vd_db_v2_0/train_infer_python.txt +++ /dev/null @@ -1,58 +0,0 @@ -===========================train_params=========================== -model_name:det_r18_db_v2_0 -python:python3.7 -gpu_list:0|0,1 -Global.use_gpu:True|True -Global.auto_cast:null -Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300 -Global.save_model_dir:./output/ -Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_lite_infer=4 -Global.pretrained_model:null -train_model_name:latest -train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ -null:null -## -trainer:norm_train -norm_train:tools/train.py -c configs/det/det_res18_db_v2.0.yml -o -quant_export:null -fpgm_export:null -distill_train:null -null:null -null:null -## -===========================eval_params=========================== -eval:null -null:null -## -===========================infer_params=========================== -Global.save_inference_dir:./output/ -Global.checkpoints: -norm_export:null -quant_export:null -fpgm_export:null -distill_export:null -export1:null -export2:null -## -train_model:null -infer_export:null -infer_quant:False -inference:tools/infer/predict_det.py ---use_gpu:True|False ---enable_mkldnn:False ---cpu_threads:6 ---rec_batch_num:1 ---use_tensorrt:False ---precision:fp32 ---det_model_dir: ---image_dir:./inference/ch_det_data_50/all-sum-510/ ---save_log_path:null ---benchmark:True -null:null -===========================infer_benchmark_params========================== -random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] -===========================train_benchmark_params========================== -batch_size:8|16 -fp_items:fp32|fp16 -epoch:15 ---profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile diff --git a/test_tipc/configs/en_table_structure/table_mv3.yml b/test_tipc/configs/en_table_structure/table_mv3.yml index 5d8e84c95c477a639130a342c6c72345e97701da..6ff31fc262b4380b4cc5258a7b2e098ada39dba0 100755 --- a/test_tipc/configs/en_table_structure/table_mv3.yml +++ b/test_tipc/configs/en_table_structure/table_mv3.yml @@ -19,8 +19,6 @@ Global: character_type: en max_text_length: 800 infer_mode: False - process_total_num: 0 - process_cut_num: 0 Optimizer: name: Adam diff --git a/test_tipc/configs/layoutxlm_ser/train_infer_python.txt b/test_tipc/configs/layoutxlm_ser/train_infer_python.txt index 34082bc193a2ebd8f4c7a9e7c9ce55dc8dbf8e40..5284ffabe2de4eb8bb000e7fb745ef2846ed6b64 100644 --- a/test_tipc/configs/layoutxlm_ser/train_infer_python.txt +++ b/test_tipc/configs/layoutxlm_ser/train_infer_python.txt @@ -52,7 +52,7 @@ null:null ===========================infer_benchmark_params========================== random_infer_input:[{float32,[3,224,224]}] ===========================train_benchmark_params========================== -batch_size:4 +batch_size:8 fp_items:fp32|fp16 epoch:3 --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile diff --git a/test_tipc/configs/table_master/table_master.yml b/test_tipc/configs/table_master/table_master.yml index c519b5b8f464d8843888155387b74a8416821f2f..27f81683b9b7e9475bdfa4ad4862166f4cf9c14d 100644 --- a/test_tipc/configs/table_master/table_master.yml +++ b/test_tipc/configs/table_master/table_master.yml @@ -16,8 +16,6 @@ Global: character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt infer_mode: false max_text_length: 500 - process_total_num: 0 - process_cut_num: 0 Optimizer: @@ -86,7 +84,7 @@ Train: - PaddingTableImage: size: [480, 480] - TableBoxEncode: - use_xywh: True + box_format: 'xywh' - NormalizeImage: scale: 1./255. mean: [0.5, 0.5, 0.5] @@ -120,7 +118,7 @@ Eval: - PaddingTableImage: size: [480, 480] - TableBoxEncode: - use_xywh: True + box_format: 'xywh' - NormalizeImage: scale: 1./255. mean: [0.5, 0.5, 0.5] diff --git a/test_tipc/configs/vi_layoutxlm_ser/train_infer_python.txt b/test_tipc/configs/vi_layoutxlm_ser/train_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..59d347461171487c186c052e290f6b13236aa5c9 --- /dev/null +++ b/test_tipc/configs/vi_layoutxlm_ser/train_infer_python.txt @@ -0,0 +1,59 @@ +===========================train_params=========================== +model_name:vi_layoutxlm_ser +python:python3.7 +gpu_list:0|0,1 +Global.use_gpu:True|True +Global.auto_cast:fp32 +Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17 +Global.save_model_dir:./output/ +Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8 +Architecture.Backbone.checkpoints:null +train_model_name:latest +train_infer_img_dir:ppstructure/docs/vqa/input/zh_val_42.jpg +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ./configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================infer_params=========================== +Global.save_inference_dir:./output/ +Architecture.Backbone.checkpoints: +norm_export:tools/export_model.py -c ./configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml -o +quant_export: +fpgm_export: +distill_export:null +export1:null +export2:null +## +infer_model:null +infer_export:null +infer_quant:False +inference:ppstructure/vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_dict_path=train_data/XFUND/class_list_xfun.txt --output=output --ocr_order_method=tb-yx +--use_gpu:True|False +--enable_mkldnn:False +--cpu_threads:6 +--rec_batch_num:1 +--use_tensorrt:False +--precision:fp32 +--ser_model_dir: +--image_dir:./ppstructure/docs/vqa/input/zh_val_42.jpg +null:null +--benchmark:False +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}] +===========================train_benchmark_params========================== +batch_size:4 +fp_items:fp32|fp16 +epoch:3 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98 diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh index 76543f39e4952b40368cdd392acc430dda8fcd9b..259a1159cb326760384645b2aff313b75da6084a 100644 --- a/test_tipc/prepare.sh +++ b/test_tipc/prepare.sh @@ -106,7 +106,7 @@ if [ ${MODE} = "benchmark_train" ];then ln -s ./icdar2015_benckmark ./icdar2015 cd ../ fi - if [ ${model_name} == "layoutxlm_ser" ]; then + if [ ${model_name} == "layoutxlm_ser" ] || [ ${model_name} == "vi_layoutxlm_ser" ]; then pip install -r ppstructure/vqa/requirements.txt pip install paddlenlp\>=2.3.5 --force-reinstall -i https://mirrors.aliyun.com/pypi/simple/ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate @@ -220,7 +220,7 @@ if [ ${MODE} = "lite_train_lite_infer" ];then wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/rec_r32_gaspin_bilstm_att_train.tar --no-check-certificate cd ./pretrain_models/ && tar xf rec_r32_gaspin_bilstm_att_train.tar && cd ../ fi - if [ ${model_name} == "layoutxlm_ser" ]; then + if [ ${model_name} == "layoutxlm_ser" ] || [ ${model_name} == "vi_layoutxlm_ser" ]; then pip install -r ppstructure/vqa/requirements.txt pip install paddlenlp\>=2.3.5 --force-reinstall -i https://mirrors.aliyun.com/pypi/simple/ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate diff --git a/test_tipc/test_inference_cpp.sh b/test_tipc/test_inference_cpp.sh index c0c7c18a38a46b00c839757e303049135a508691..aadaa8b0773632885138806861fc851ede503f3d 100644 --- a/test_tipc/test_inference_cpp.sh +++ b/test_tipc/test_inference_cpp.sh @@ -84,7 +84,7 @@ function func_cpp_inference(){ eval $command last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${command}" "${status_log}" "${model_name}" + status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}" done done done @@ -117,7 +117,7 @@ function func_cpp_inference(){ eval $command last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${command}" "${status_log}" "${model_name}" + status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}" done done diff --git a/test_tipc/test_inference_python.sh b/test_tipc/test_inference_python.sh index 2a31a468f0d54d1979e82c8f0da98cac6f4edcec..e9908df1f6049f9d38524dc6598499ddd2b58af8 100644 --- a/test_tipc/test_inference_python.sh +++ b/test_tipc/test_inference_python.sh @@ -88,7 +88,7 @@ function func_inference(){ eval $command last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${command}" "${status_log}" "${model_name}" + status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}" done done done @@ -119,7 +119,7 @@ function func_inference(){ eval $command last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${command}" "${status_log}" "${model_name}" + status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}" done done @@ -146,14 +146,15 @@ if [ ${MODE} = "whole_infer" ]; then for infer_model in ${infer_model_dir_list[*]}; do # run export if [ ${infer_run_exports[Count]} != "null" ];then + _save_log_path="${_log_path}/python_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}_infermodel_${infer_model}.log" save_infer_dir=$(dirname $infer_model) set_export_weight=$(func_set_params "${export_weight}" "${infer_model}") set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}") - export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key}" + export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key} > ${_save_log_path} 2>&1 " echo ${infer_run_exports[Count]} eval $export_cmd status_export=$? - status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" + status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" "${_save_log_path}" else save_infer_dir=${infer_model} fi diff --git a/test_tipc/test_paddle2onnx.sh b/test_tipc/test_paddle2onnx.sh index 78d79d0b8eaac782f98c1e883d091a001443f41a..bace6b2d4684e0ad40ffbd76b37a78ddf1e70722 100644 --- a/test_tipc/test_paddle2onnx.sh +++ b/test_tipc/test_paddle2onnx.sh @@ -66,7 +66,7 @@ function func_paddle2onnx(){ trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_det_log} 2>&1 " eval $trans_model_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_det_log}" # trans rec set_dirname=$(func_set_params "--model_dir" "${rec_infer_model_dir_value}") set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}") @@ -78,7 +78,7 @@ function func_paddle2onnx(){ trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_rec_log} 2>&1 " eval $trans_model_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}" elif [[ ${model_name} =~ "det" ]]; then # trans det set_dirname=$(func_set_params "--model_dir" "${det_infer_model_dir_value}") @@ -91,7 +91,7 @@ function func_paddle2onnx(){ trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_det_log} 2>&1 " eval $trans_model_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_det_log}" elif [[ ${model_name} =~ "rec" ]]; then # trans rec set_dirname=$(func_set_params "--model_dir" "${rec_infer_model_dir_value}") @@ -104,7 +104,7 @@ function func_paddle2onnx(){ trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_rec_log} 2>&1 " eval $trans_model_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}" fi # python inference @@ -127,7 +127,7 @@ function func_paddle2onnx(){ eval $infer_model_cmd last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}" "${_save_log_path}" elif [ ${use_gpu} = "True" ] || [ ${use_gpu} = "gpu" ]; then _save_log_path="${LOG_PATH}/paddle2onnx_infer_gpu.log" set_gpu=$(func_set_params "${use_gpu_key}" "${use_gpu}") @@ -146,7 +146,7 @@ function func_paddle2onnx(){ eval $infer_model_cmd last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}" "${_save_log_path}" else echo "Does not support hardware other than CPU and GPU Currently!" fi @@ -158,4 +158,4 @@ echo "################### run test ###################" export Count=0 IFS="|" -func_paddle2onnx \ No newline at end of file +func_paddle2onnx diff --git a/test_tipc/test_ptq_inference_python.sh b/test_tipc/test_ptq_inference_python.sh index e2939fd5e638ad0f6b4c44422a6fec6459903d1c..caf3d506029ee066aa5abebc25b739439b6e9d75 100644 --- a/test_tipc/test_ptq_inference_python.sh +++ b/test_tipc/test_ptq_inference_python.sh @@ -84,7 +84,7 @@ function func_inference(){ eval $command last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${command}" "${status_log}" "${model_name}" + status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}" done done done @@ -109,7 +109,7 @@ function func_inference(){ eval $command last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${command}" "${status_log}" "${model_name}" + status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}" done done @@ -145,7 +145,7 @@ if [ ${MODE} = "whole_infer" ]; then echo $export_cmd eval $export_cmd status_export=$? - status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" + status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" "${export_log_path}" else save_infer_dir=${infer_model} fi diff --git a/test_tipc/test_serving_infer_cpp.sh b/test_tipc/test_serving_infer_cpp.sh index 0be6a45adf3105f088a96336dddfbe9ac612f19b..10ddecf3fa26805fef7bc6ae10d78ee5e741cd27 100644 --- a/test_tipc/test_serving_infer_cpp.sh +++ b/test_tipc/test_serving_infer_cpp.sh @@ -83,7 +83,7 @@ function func_serving(){ trans_model_cmd="${python_list[0]} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client} > ${trans_rec_log} 2>&1 " eval $trans_model_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}" set_image_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}") python_list=(${python_list}) cd ${serving_dir_value} @@ -95,14 +95,14 @@ function func_serving(){ web_service_cpp_cmd="nohup ${python_list[0]} ${web_service_py} --model ${det_server_value} ${rec_server_value} ${op_key} ${op_value} ${port_key} ${port_value} > ${server_log_path} 2>&1 &" eval $web_service_cpp_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${web_service_cpp_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${web_service_cpp_cmd}" "${status_log}" "${model_name}" "${server_log_path}" sleep 5s _save_log_path="${LOG_PATH}/cpp_client_cpu.log" cpp_client_cmd="${python_list[0]} ${cpp_client_py} ${det_client_value} ${rec_client_value} > ${_save_log_path} 2>&1" eval $cpp_client_cmd last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}" "${_save_log_path}" ps ux | grep -i ${port_value} | awk '{print $2}' | xargs kill -s 9 else server_log_path="${LOG_PATH}/cpp_server_gpu.log" @@ -114,7 +114,7 @@ function func_serving(){ eval $cpp_client_cmd last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}" "${_save_log_path}" ps ux | grep -i ${port_value} | awk '{print $2}' | xargs kill -s 9 fi done diff --git a/test_tipc/test_serving_infer_python.sh b/test_tipc/test_serving_infer_python.sh index 4b7dfcf785a3c8459cce95d55744dbcd4f97027a..c7d305d5d2dcd2ea1bf5a7c3254eea4231d59879 100644 --- a/test_tipc/test_serving_infer_python.sh +++ b/test_tipc/test_serving_infer_python.sh @@ -126,19 +126,19 @@ function func_serving(){ web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &" eval $web_service_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}" elif [[ ${model_name} =~ "det" ]]; then set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}") web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} > ${server_log_path} 2>&1 &" eval $web_service_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}" elif [[ ${model_name} =~ "rec" ]]; then set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}") web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_rec_model_config} > ${server_log_path} 2>&1 &" eval $web_service_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}" fi sleep 2s for pipeline in ${pipeline_py[*]}; do @@ -147,7 +147,7 @@ function func_serving(){ eval $pipeline_cmd last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}" "${_save_log_path}" sleep 2s done ps ux | grep -E 'web_service' | awk '{print $2}' | xargs kill -s 9 @@ -177,19 +177,19 @@ function func_serving(){ web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &" eval $web_service_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}" elif [[ ${model_name} =~ "det" ]]; then set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}") web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} > ${server_log_path} 2>&1 &" eval $web_service_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}" elif [[ ${model_name} =~ "rec" ]]; then set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}") web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_rec_model_config} > ${server_log_path} 2>&1 &" eval $web_service_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}" fi sleep 2s for pipeline in ${pipeline_py[*]}; do @@ -198,7 +198,7 @@ function func_serving(){ eval $pipeline_cmd last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}" "${_save_log_path}" sleep 2s done ps ux | grep -E 'web_service' | awk '{print $2}' | xargs kill -s 9 diff --git a/test_tipc/test_train_inference_python.sh b/test_tipc/test_train_inference_python.sh index 545cdbba2051c8123ef7f70f2aeb4b4b5a57b7c5..e182fa57f060c81af012a5da89b892bde02b4a2b 100644 --- a/test_tipc/test_train_inference_python.sh +++ b/test_tipc/test_train_inference_python.sh @@ -133,7 +133,7 @@ function func_inference(){ eval $command last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${command}" "${status_log}" "${model_name}" + status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}" done done done @@ -164,7 +164,7 @@ function func_inference(){ eval $command last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${command}" "${status_log}" "${model_name}" + status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}" done done @@ -201,7 +201,7 @@ if [ ${MODE} = "whole_infer" ]; then echo $export_cmd eval $export_cmd status_export=$? - status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" + status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" "${export_log_path}" else save_infer_dir=${infer_model} fi @@ -298,7 +298,7 @@ else # run train eval $cmd eval "cat ${save_log}/train.log >> ${save_log}.log" - status_check $? "${cmd}" "${status_log}" "${model_name}" + status_check $? "${cmd}" "${status_log}" "${model_name}" "${save_log}.log" set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${train_model_name}") @@ -309,7 +309,7 @@ else eval_log_path="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}_nodes_${nodes}_eval.log" eval_cmd="${python} ${eval_py} ${set_eval_pretrain} ${set_use_gpu} ${set_eval_params1} > ${eval_log_path} 2>&1 " eval $eval_cmd - status_check $? "${eval_cmd}" "${status_log}" "${model_name}" + status_check $? "${eval_cmd}" "${status_log}" "${model_name}" "${eval_log_path}" fi # run export model if [ ${run_export} != "null" ]; then @@ -320,7 +320,7 @@ else set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_path}") export_cmd="${python} ${run_export} ${set_export_weight} ${set_save_infer_key} > ${export_log_path} 2>&1 " eval $export_cmd - status_check $? "${export_cmd}" "${status_log}" "${model_name}" + status_check $? "${export_cmd}" "${status_log}" "${model_name}" "${export_log_path}" #run inference eval $env diff --git a/tools/export_model.py b/tools/export_model.py index 2a89b32836a8d5f02552d6ba18f6d05dbd1bf0dc..193988cc1b62a6c4536a8d2ec640e3e5fc81a79c 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -58,6 +58,8 @@ def export_single_model(model, other_shape = [ paddle.static.InputSpec( shape=[None, 3, 48, 160], dtype="float32"), + [paddle.static.InputSpec( + shape=[None], dtype="float32")] ] model = to_static(model, input_spec=other_shape) elif arch_config["algorithm"] == "SVTR": @@ -144,7 +146,7 @@ def export_single_model(model, else: infer_shape = [3, -1, -1] if arch_config["model_type"] == "rec": - infer_shape = [3, 48, -1] # for rec model, H must be 32 + infer_shape = [3, 32, -1] # for rec model, H must be 32 if "Transform" in arch_config and arch_config[ "Transform"] is not None and arch_config["Transform"][ "name"] == "TPS": @@ -156,6 +158,8 @@ def export_single_model(model, infer_shape = [3, 488, 488] if arch_config["algorithm"] == "TableMaster": infer_shape = [3, 480, 480] + if arch_config["algorithm"] == "SLANet": + infer_shape = [3, -1, -1] model = to_static( model, input_spec=[ @@ -248,4 +252,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 6c06ff1f367e2739adcdd910d1a2a97e60ae4c0b..53dab6f26d8b84a224360f2fa6fe5f411eea751f 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -458,7 +458,8 @@ class TextRecognizer(object): valid_ratios = np.concatenate(valid_ratios) inputs = [ norm_img_batch, - valid_ratios, + np.array( + [valid_ratios], dtype=np.float32), ] if self.use_onnx: input_dict = {} diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 625d365f45c578d051974d7174e26246e9bc2442..73b7155baa9f869da928b5be03692c08115489ee 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -65,9 +65,11 @@ class TextSystem(object): self.crop_image_res_index += bbox_num def __call__(self, img, cls=True): + time_dict = {'det': 0, 'rec': 0, 'csl': 0, 'all': 0} + start = time.time() ori_im = img.copy() dt_boxes, elapse = self.text_detector(img) - + time_dict['det'] = elapse logger.debug("dt_boxes num : {}, elapse : {}".format( len(dt_boxes), elapse)) if dt_boxes is None: @@ -83,10 +85,12 @@ class TextSystem(object): if self.use_angle_cls and cls: img_crop_list, angle_list, elapse = self.text_classifier( img_crop_list) + time_dict['cls'] = elapse logger.debug("cls num : {}, elapse : {}".format( len(img_crop_list), elapse)) rec_res, elapse = self.text_recognizer(img_crop_list) + time_dict['rec'] = elapse logger.debug("rec_res num : {}, elapse : {}".format( len(rec_res), elapse)) if self.args.save_crop_res: @@ -98,7 +102,9 @@ class TextSystem(object): if score >= self.drop_score: filter_boxes.append(box) filter_rec_res.append(rec_result) - return filter_boxes, filter_rec_res + end = time.time() + time_dict['all'] = end - start + return filter_boxes, filter_rec_res, time_dict def sorted_boxes(dt_boxes): @@ -133,9 +139,11 @@ def main(args): os.makedirs(draw_img_save_dir, exist_ok=True) save_results = [] - logger.info("In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', " - "if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320") - + logger.info( + "In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', " + "if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320" + ) + # warm up 10 times if args.warmup: img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) @@ -155,7 +163,7 @@ def main(args): logger.debug("error in loading image:{}".format(image_file)) continue starttime = time.time() - dt_boxes, rec_res = text_sys(img) + dt_boxes, rec_res, time_dict = text_sys(img) elapse = time.time() - starttime total_time += elapse @@ -198,7 +206,10 @@ def main(args): text_sys.text_detector.autolog.report() text_sys.text_recognizer.autolog.report() - with open(os.path.join(draw_img_save_dir, "system_results.txt"), 'w', encoding='utf-8') as f: + with open( + os.path.join(draw_img_save_dir, "system_results.txt"), + 'w', + encoding='utf-8') as f: f.writelines(save_results) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 55ec0a238a6d82fbdd91ec0563ee2ecf23d18231..81d0196ccd6b86741e73524d9321618f3f5cc34b 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -163,6 +163,8 @@ def create_predictor(args, mode, logger): model_dir = args.ser_model_dir elif mode == "sr": model_dir = args.sr_model_dir + elif mode == 'layout': + model_dir = args.layout_model_dir else: model_dir = args.e2e_model_dir diff --git a/tools/infer_table.py b/tools/infer_table.py index 6c02dd8640c9345c267e56d6e5a0c14bde121b7e..6dde5d67d061f4d0593928759db34bb9b22cde0d 100644 --- a/tools/infer_table.py +++ b/tools/infer_table.py @@ -37,6 +37,7 @@ from ppocr.postprocess import build_post_process from ppocr.utils.save_load import load_model from ppocr.utils.utility import get_image_file_list from ppocr.utils.visual import draw_rectangle +from tools.infer.utility import draw_boxes import tools.program as program import cv2 @@ -56,7 +57,6 @@ def main(config, device, logger, vdl_writer): model = build_model(config['Architecture']) algorithm = config['Architecture']['algorithm'] - use_xywh = algorithm in ['TableMaster'] load_model(config, model) @@ -106,9 +106,13 @@ def main(config, device, logger, vdl_writer): f_w.write("result: {}, {}\n".format(structure_str_list, bbox_list_str)) - img = draw_rectangle(file, bbox_list, use_xywh) + if len(bbox_list) > 0 and len(bbox_list[0]) == 4: + img = draw_rectangle(file, bbox_list) + else: + img = draw_boxes(cv2.imread(file), bbox_list) cv2.imwrite( os.path.join(save_res_path, os.path.basename(file)), img) + logger.info('save result to {}'.format(save_res_path)) logger.info("success!") diff --git a/tools/program.py b/tools/program.py index 462589033eae923c58f89da513a8e23b3e717e50..195b09b43da93d8c9285c064ef267e01623a733c 100755 --- a/tools/program.py +++ b/tools/program.py @@ -653,7 +653,7 @@ def preprocess(is_train=False): 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN', - 'Gestalt', 'RobustScanner' + 'Gestalt', 'SLANet', 'RobustScanner' ] if use_xpu: diff --git a/tools/train.py b/tools/train.py index b44d76b3832aba24dd8bcad821fe21d22c8b320b..0c881ecae8daf78860829b1419178358c2209f25 100755 --- a/tools/train.py +++ b/tools/train.py @@ -119,6 +119,10 @@ def main(config, device, logger, vdl_writer): config['Loss']['ignore_index'] = char_num - 1 model = build_model(config['Architecture']) + use_sync_bn = config["Global"].get("use_sync_bn", False) + if use_sync_bn: + model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) + logger.info('convert_sync_batchnorm') model = apply_to_static(model, config, logger)