diff --git a/MANIFEST.in b/MANIFEST.in index e16f157d6e9dd249d6c6a14ae54313759a6752c4..cd1c9636d4d23cc4d0f745403ec8ca407d1cc1a8 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,7 +1,7 @@ -include LICENSE.txt +include LICENSE include README.md -recursive-include ppocr/utils *.txt utility.py logging.py +recursive-include ppocr/utils *.txt utility.py logging.py network.py recursive-include ppocr/data/ *.py recursive-include ppocr/postprocess *.py recursive-include tools/infer *.py diff --git a/configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml b/configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml new file mode 100644 index 0000000000000000000000000000000000000000..791b34cf5785d81a0f1346c0ef1ad4485ed3fee8 --- /dev/null +++ b/configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml @@ -0,0 +1,160 @@ +Global: + debug: false + use_gpu: true + epoch_num: 800 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec_chinese_lite_distillation_v2.1 + save_epoch_step: 3 + eval_batch_step: [0, 2000] + cal_metric_during_train: true + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: false + infer_img: doc/imgs_words/ch/word_1.jpg + character_dict_path: ppocr/utils/ppocr_keys_v1.txt + character_type: ch + max_text_length: 25 + infer_mode: false + use_space_char: false + distributed: true + save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt + + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Cosine + learning_rate: 0.0005 + warmup_epoch: 5 + regularizer: + name: L2 + factor: 1.0e-05 +Architecture: + name: DistillationModel + algorithm: Distillation + Models: + Student: + pretrained: + freeze_params: false + return_all_feats: true + model_type: rec + algorithm: CRNN + Transform: + Backbone: + name: MobileNetV3 + scale: 0.5 + model_name: small + small_stride: [1, 2, 2, 2] + Neck: + name: SequenceEncoder + encoder_type: rnn + hidden_size: 64 + Head: + name: CTCHead + mid_channels: 96 + fc_decay: 0.00001 + Teacher: + pretrained: + freeze_params: false + return_all_feats: true + model_type: rec + algorithm: CRNN + Transform: + Backbone: + name: MobileNetV3 + scale: 0.5 + model_name: small + small_stride: [1, 2, 2, 2] + Neck: + name: SequenceEncoder + encoder_type: rnn + hidden_size: 64 + Head: + name: CTCHead + mid_channels: 96 + fc_decay: 0.00001 + + +Loss: + name: CombinedLoss + loss_config_list: + - DistillationCTCLoss: + weight: 1.0 + model_name_list: ["Student", "Teacher"] + key: head_out + - DistillationDMLLoss: + weight: 1.0 + act: "softmax" + model_name_pairs: + - ["Student", "Teacher"] + key: head_out + - DistillationDistanceLoss: + weight: 1.0 + mode: "l2" + model_name_pairs: + - ["Student", "Teacher"] + key: backbone_out + +PostProcess: + name: DistillationCTCLabelDecode + model_name: ["Student", "Teacher"] + key: head_out + +Metric: + name: DistillationMetric + base_metric_name: RecMetric + main_indicator: acc + key: "Student" + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ + label_file_list: + - ./train_data/train_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - RecAug: + - CTCLabelEncode: + - RecResizeImg: + image_shape: [3, 32, 320] + - KeepKeys: + keep_keys: + - image + - label + - length + loader: + shuffle: true + batch_size_per_card: 128 + drop_last: true + num_sections: 1 + num_workers: 8 +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data + label_file_list: + - ./train_data/val_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - CTCLabelEncode: + - RecResizeImg: + image_shape: [3, 32, 320] + - KeepKeys: + keep_keys: + - image + - label + - length + loader: + shuffle: false + drop_last: false + batch_size_per_card: 128 + num_workers: 8 diff --git a/configs/table/table_mv3.yml b/configs/table/table_mv3.yml new file mode 100755 index 0000000000000000000000000000000000000000..a74e18d318699685400cc48430c04db3fef70b60 --- /dev/null +++ b/configs/table/table_mv3.yml @@ -0,0 +1,116 @@ +Global: + use_gpu: true + epoch_num: 50 + log_smooth_window: 20 + print_batch_step: 5 + save_model_dir: ./output/table_mv3/ + save_epoch_step: 5 + # evaluation is run every 400 iterations after the 0th iteration + eval_batch_step: [0, 400] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/ch/word_1.jpg + # for data or label process + character_dict_path: ppocr/utils/dict/table_structure_dict.txt + character_type: en + max_text_length: 100 + max_elem_length: 500 + max_cell_num: 500 + infer_mode: False + process_total_num: 0 + process_cut_num: 0 + + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + clip_norm: 5.0 + lr: + learning_rate: 0.001 + regularizer: + name: 'L2' + factor: 0.00000 + +Architecture: + model_type: table + algorithm: TableAttn + Backbone: + name: MobileNetV3 + scale: 1.0 + model_name: small + disable_se: True + Head: + name: TableAttentionHead + hidden_size: 256 + l2_decay: 0.00001 + loc_type: 2 + +Loss: + name: TableAttentionLoss + structure_weight: 100.0 + loc_weight: 10000.0 + +PostProcess: + name: TableLabelDecode + +Metric: + name: TableMetric + main_indicator: acc + +Train: + dataset: + name: PubTabDataSet + data_dir: train_data/table/pubtabnet/train/ + label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - ResizeTableImage: + max_len: 488 + - TableLabelEncode: + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - PaddingTableImage: + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask'] + loader: + shuffle: True + batch_size_per_card: 32 + drop_last: True + num_workers: 1 + +Eval: + dataset: + name: PubTabDataSet + data_dir: train_data/table/pubtabnet/val/ + label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - ResizeTableImage: + max_len: 488 + - TableLabelEncode: + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - PaddingTableImage: + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 16 + num_workers: 1 diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java index 1c83e2184fe55aedd5022da839ab294b6bbe475c..b4ea34e2a38f91f3ecb1001c6bff3b71496b8f91 100644 --- a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java @@ -465,8 +465,12 @@ public class MainActivity extends AppCompatActivity { } public void btn_load_model_click(View view) { - tvStatus.setText("STATUS: load model ......"); - loadModel(); + if (predictor.isLoaded()){ + tvStatus.setText("STATUS: model has been loaded"); + }else{ + tvStatus.setText("STATUS: load model ......"); + loadModel(); + } } public void btn_run_model_click(View view) { diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java index 1c294995c25b7eb3fa6ded17f41f193bddfc3886..b474d8886a10746b8ac181085c62481dfe7a4229 100644 --- a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java @@ -194,26 +194,25 @@ public class Predictor { "supported!"); return false; } - int[] channelStride = new int[]{width * height, width * height * 2}; - int p = scaleImage.getPixel(scaleImage.getWidth() - 1, scaleImage.getHeight() - 1); - for (int y = 0; y < height; y++) { - for (int x = 0; x < width; x++) { - int color = scaleImage.getPixel(x, y); - float[] rgb = new float[]{(float) red(color) / 255.0f, (float) green(color) / 255.0f, - (float) blue(color) / 255.0f}; - inputData[y * width + x] = (rgb[channelIdx[0]] - inputMean[0]) / inputStd[0]; - inputData[y * width + x + channelStride[0]] = (rgb[channelIdx[1]] - inputMean[1]) / inputStd[1]; - inputData[y * width + x + channelStride[1]] = (rgb[channelIdx[2]] - inputMean[2]) / inputStd[2]; - } + int[] channelStride = new int[]{width * height, width * height * 2}; + int[] pixels=new int[width*height]; + scaleImage.getPixels(pixels,0,scaleImage.getWidth(),0,0,scaleImage.getWidth(),scaleImage.getHeight()); + for (int i = 0; i < pixels.length; i++) { + int color = pixels[i]; + float[] rgb = new float[]{(float) red(color) / 255.0f, (float) green(color) / 255.0f, + (float) blue(color) / 255.0f}; + inputData[i] = (rgb[channelIdx[0]] - inputMean[0]) / inputStd[0]; + inputData[i + channelStride[0]] = (rgb[channelIdx[1]] - inputMean[1]) / inputStd[1]; + inputData[i+ channelStride[1]] = (rgb[channelIdx[2]] - inputMean[2]) / inputStd[2]; } } else if (channels == 1) { - for (int y = 0; y < height; y++) { - for (int x = 0; x < width; x++) { - int color = inputImage.getPixel(x, y); - float gray = (float) (red(color) + green(color) + blue(color)) / 3.0f / 255.0f; - inputData[y * width + x] = (gray - inputMean[0]) / inputStd[0]; - } + int[] pixels=new int[width*height]; + scaleImage.getPixels(pixels,0,scaleImage.getWidth(),0,0,scaleImage.getWidth(),scaleImage.getHeight()); + for (int i = 0; i < pixels.length; i++) { + int color = pixels[i]; + float gray = (float) (red(color) + green(color) + blue(color)) / 3.0f / 255.0f; + inputData[i] = (gray - inputMean[0]) / inputStd[0]; } } else { Log.i(TAG, "Unsupported channel size " + Integer.toString(channels) + ", only channel 1 and 3 is " + diff --git a/deploy/cpp_infer/include/utility.h b/deploy/cpp_infer/include/utility.h index 367e37e434b396ac1eae28961f366dc397ed446f..6e8173e007279319657250b376de022240bc6f62 100644 --- a/deploy/cpp_infer/include/utility.h +++ b/deploy/cpp_infer/include/utility.h @@ -44,6 +44,9 @@ public: inline static size_t argmax(ForwardIterator first, ForwardIterator last) { return std::distance(first, std::max_element(first, last)); } + + static void GetAllFiles(const char *dir_name, + std::vector &all_inputs); }; } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/readme.md b/deploy/cpp_infer/readme.md index ee5a9ed4b9aa16b76836dc01096ae132fead56dd..6a57044b0ef81c4600c13180bb33c45b2bf0bc01 100644 --- a/deploy/cpp_infer/readme.md +++ b/deploy/cpp_infer/readme.md @@ -77,7 +77,7 @@ opencv3/ #### 1.2.1 直接下载安装 -* [Paddle预测库官网](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/build_and_install_lib_cn.html)上提供了不同cuda版本的Linux预测库,可以在官网查看并选择合适的预测库版本(*建议选择paddle版本>=2.0.1版本的预测库* )。 +* [Paddle预测库官网](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0/guides/05_inference_deployment/inference/build_and_install_lib_cn.html) 上提供了不同cuda版本的Linux预测库,可以在官网查看并选择合适的预测库版本(*建议选择paddle版本>=2.0.1版本的预测库* )。 * 下载之后使用下面的方法解压。 @@ -89,10 +89,11 @@ tar -xf paddle_inference.tgz #### 1.2.2 预测库源码编译 * 如果希望获取最新预测库特性,可以从Paddle github上克隆最新代码,源码编译预测库。 -* 可以参考[Paddle预测库官网](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/05_inference_deployment/inference/build_and_install_lib_cn.html)的说明,从github上获取Paddle代码,然后进行编译,生成最新的预测库。使用git获取代码方法如下。 +* 可以参考[Paddle预测库安装编译说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0/guides/05_inference_deployment/inference/build_and_install_lib_cn.html#congyuanmabianyi) 的说明,从github上获取Paddle代码,然后进行编译,生成最新的预测库。使用git获取代码方法如下。 ```shell git clone https://github.com/PaddlePaddle/Paddle.git +git checkout release/2.1 ``` * 进入Paddle目录后,编译方法如下。 @@ -115,7 +116,7 @@ make -j make inference_lib_dist ``` -更多编译参数选项可以参考Paddle C++预测库官网:[https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/05_inference_deployment/inference/build_and_install_lib_cn.html](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/05_inference_deployment/inference/build_and_install_lib_cn.html)。 +更多编译参数选项介绍可以参考[文档说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0/guides/05_inference_deployment/inference/build_and_install_lib_cn.html#congyuanmabianyi)。 * 编译完成之后,可以在`build/paddle_inference_install_dir/`文件下看到生成了以下文件及文件夹。 @@ -140,11 +141,11 @@ build/paddle_inference_install_dir/ ``` inference/ |-- det_db -| |--inference.pdparams -| |--inference.pdimodel +| |--inference.pdiparams +| |--inference.pdmodel |-- rec_rcnn -| |--inference.pdparams -| |--inference.pdparams +| |--inference.pdiparams +| |--inference.pdmodel ``` diff --git a/deploy/cpp_infer/readme_en.md b/deploy/cpp_infer/readme_en.md index 913ba1f91668d682c7c3fa614f8997293d52db89..6c0a18db4f76d4e2971cea16130216434ff01d7b 100644 --- a/deploy/cpp_infer/readme_en.md +++ b/deploy/cpp_infer/readme_en.md @@ -78,8 +78,7 @@ opencv3/ #### 1.2.1 Direct download and installation -* Different cuda versions of the Linux inference library (based on GCC 4.8.2) are provided on the -[Paddle inference library official website](https://www.paddlepaddle.org.cn/documentation/docs/en/develop/guides/05_inference_deployment/inference/build_and_install_lib_en.html). You can view and select the appropriate version of the inference library on the official website. +[Paddle inference library official website](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0/guides/05_inference_deployment/inference/build_and_install_lib_cn.html). You can view and select the appropriate version of the inference library on the official website. * After downloading, use the following method to uncompress. @@ -97,9 +96,10 @@ Finally you can see the following files in the folder of `paddle_inference/`. ```shell git clone https://github.com/PaddlePaddle/Paddle.git +git checkout release/2.1 ``` -* After entering the Paddle directory, the compilation method is as follows. +* After entering the Paddle directory, the commands to compile the paddle inference library are as follows. ```shell rm -rf build @@ -119,7 +119,7 @@ make -j make inference_lib_dist ``` -For more compilation parameter options, please refer to the official website of the Paddle C++ inference library:[https://www.paddlepaddle.org.cn/documentation/docs/en/develop/guides/05_inference_deployment/inference/build_and_install_lib_en.html](https://www.paddlepaddle.org.cn/documentation/docs/en/develop/guides/05_inference_deployment/inference/build_and_install_lib_en.html). +For more compilation parameter options, please refer to the [document](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0/guides/05_inference_deployment/inference/build_and_install_lib_cn.html#congyuanmabianyi). * After the compilation process, you can see the following files in the folder of `build/paddle_inference_install_dir/`. @@ -144,11 +144,11 @@ Among them, `paddle` is the Paddle library required for C++ prediction later, an ``` inference/ |-- det_db -| |--inference.pdparams -| |--inference.pdimodel +| |--inference.pdiparams +| |--inference.pdmodel |-- rec_rcnn -| |--inference.pdparams -| |--inference.pdparams +| |--inference.pdiparams +| |--inference.pdmodel ``` diff --git a/deploy/cpp_infer/src/main.cpp b/deploy/cpp_infer/src/main.cpp index 588c8374ab341163835aea2ba6c7132640c74c64..f25e674b489ea92118fe45c63939fca203ce3823 100644 --- a/deploy/cpp_infer/src/main.cpp +++ b/deploy/cpp_infer/src/main.cpp @@ -27,9 +27,12 @@ #include #include +#include #include #include #include +#include +#include using namespace std; using namespace cv; @@ -47,13 +50,8 @@ int main(int argc, char **argv) { config.PrintConfigInfo(); std::string img_path(argv[2]); - - cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR); - - if (!srcimg.data) { - std::cerr << "[ERROR] image read failed! image path: " << img_path << "\n"; - exit(1); - } + std::vector all_img_names; + Utility::GetAllFiles((char *)img_path.c_str(), all_img_names); DBDetector det(config.det_model_dir, config.use_gpu, config.gpu_id, config.gpu_mem, config.cpu_math_library_num_threads, @@ -76,18 +74,30 @@ int main(int argc, char **argv) { config.use_tensorrt, config.use_fp16); auto start = std::chrono::system_clock::now(); - std::vector>> boxes; - det.Run(srcimg, boxes); - - rec.Run(boxes, srcimg, cls); - auto end = std::chrono::system_clock::now(); - auto duration = - std::chrono::duration_cast(end - start); - std::cout << "Cost " - << double(duration.count()) * - std::chrono::microseconds::period::num / - std::chrono::microseconds::period::den - << "s" << std::endl; + + for (auto img_dir : all_img_names) { + LOG(INFO) << "The predict img: " << img_dir; + + cv::Mat srcimg = cv::imread(img_dir, cv::IMREAD_COLOR); + if (!srcimg.data) { + std::cerr << "[ERROR] image read failed! image path: " << img_path + << "\n"; + exit(1); + } + std::vector>> boxes; + + det.Run(srcimg, boxes); + + rec.Run(boxes, srcimg, cls); + auto end = std::chrono::system_clock::now(); + auto duration = + std::chrono::duration_cast(end - start); + std::cout << "Cost " + << double(duration.count()) * + std::chrono::microseconds::period::num / + std::chrono::microseconds::period::den + << "s" << std::endl; + } return 0; } diff --git a/deploy/cpp_infer/src/ocr_det.cpp b/deploy/cpp_infer/src/ocr_det.cpp index 9bfee6138577288156496d9b533b4da906ae7268..33ad468a33b42c3d9f25beb19452f2fa6a81db9e 100644 --- a/deploy/cpp_infer/src/ocr_det.cpp +++ b/deploy/cpp_infer/src/ocr_det.cpp @@ -30,6 +30,42 @@ void DBDetector::LoadModel(const std::string &model_dir) { this->use_fp16_ ? paddle_infer::Config::Precision::kHalf : paddle_infer::Config::Precision::kFloat32, false, false); + std::map> min_input_shape = { + {"x", {1, 3, 50, 50}}, + {"conv2d_92.tmp_0", {1, 96, 20, 20}}, + {"conv2d_91.tmp_0", {1, 96, 10, 10}}, + {"nearest_interp_v2_1.tmp_0", {1, 96, 10, 10}}, + {"nearest_interp_v2_2.tmp_0", {1, 96, 20, 20}}, + {"nearest_interp_v2_3.tmp_0", {1, 24, 20, 20}}, + {"nearest_interp_v2_4.tmp_0", {1, 24, 20, 20}}, + {"nearest_interp_v2_5.tmp_0", {1, 24, 20, 20}}, + {"elementwise_add_7", {1, 56, 2, 2}}, + {"nearest_interp_v2_0.tmp_0", {1, 96, 2, 2}}}; + std::map> max_input_shape = { + {"x", {1, 3, this->max_side_len_, this->max_side_len_}}, + {"conv2d_92.tmp_0", {1, 96, 400, 400}}, + {"conv2d_91.tmp_0", {1, 96, 200, 200}}, + {"nearest_interp_v2_1.tmp_0", {1, 96, 200, 200}}, + {"nearest_interp_v2_2.tmp_0", {1, 96, 400, 400}}, + {"nearest_interp_v2_3.tmp_0", {1, 24, 400, 400}}, + {"nearest_interp_v2_4.tmp_0", {1, 24, 400, 400}}, + {"nearest_interp_v2_5.tmp_0", {1, 24, 400, 400}}, + {"elementwise_add_7", {1, 56, 400, 400}}, + {"nearest_interp_v2_0.tmp_0", {1, 96, 400, 400}}}; + std::map> opt_input_shape = { + {"x", {1, 3, 640, 640}}, + {"conv2d_92.tmp_0", {1, 96, 160, 160}}, + {"conv2d_91.tmp_0", {1, 96, 80, 80}}, + {"nearest_interp_v2_1.tmp_0", {1, 96, 80, 80}}, + {"nearest_interp_v2_2.tmp_0", {1, 96, 160, 160}}, + {"nearest_interp_v2_3.tmp_0", {1, 24, 160, 160}}, + {"nearest_interp_v2_4.tmp_0", {1, 24, 160, 160}}, + {"nearest_interp_v2_5.tmp_0", {1, 24, 160, 160}}, + {"elementwise_add_7", {1, 56, 40, 40}}, + {"nearest_interp_v2_0.tmp_0", {1, 96, 40, 40}}}; + + config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape, + opt_input_shape); } } else { config.DisableGpu(); @@ -48,7 +84,7 @@ void DBDetector::LoadModel(const std::string &model_dir) { config.SwitchIrOptim(true); config.EnableMemoryOptim(); - config.DisableGlogInfo(); + // config.DisableGlogInfo(); this->predictor_ = CreatePredictor(config); } diff --git a/deploy/cpp_infer/src/ocr_rec.cpp b/deploy/cpp_infer/src/ocr_rec.cpp index 28cd1cb88216a8aa5e4e6fdc939f7be0169db556..b09282b0283743b530cd5477dbe9c5ff751de93c 100644 --- a/deploy/cpp_infer/src/ocr_rec.cpp +++ b/deploy/cpp_infer/src/ocr_rec.cpp @@ -106,6 +106,15 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { this->use_fp16_ ? paddle_infer::Config::Precision::kHalf : paddle_infer::Config::Precision::kFloat32, false, false); + std::map> min_input_shape = { + {"x", {1, 3, 32, 10}}}; + std::map> max_input_shape = { + {"x", {1, 3, 32, 2000}}}; + std::map> opt_input_shape = { + {"x", {1, 3, 32, 320}}}; + + config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape, + opt_input_shape); } } else { config.DisableGpu(); diff --git a/deploy/cpp_infer/src/preprocess_op.cpp b/deploy/cpp_infer/src/preprocess_op.cpp old mode 100755 new mode 100644 index fb7590e359da81e27c52c5a0037b93e19edb77df..23c51c2008dc7280ce4d6c232ed766dbf2a53226 --- a/deploy/cpp_infer/src/preprocess_op.cpp +++ b/deploy/cpp_infer/src/preprocess_op.cpp @@ -47,16 +47,13 @@ void Normalize::Run(cv::Mat *im, const std::vector &mean, e /= 255.0; } (*im).convertTo(*im, CV_32FC3, e); - for (int h = 0; h < im->rows; h++) { - for (int w = 0; w < im->cols; w++) { - im->at(h, w)[0] = - (im->at(h, w)[0] - mean[0]) * scale[0]; - im->at(h, w)[1] = - (im->at(h, w)[1] - mean[1]) * scale[1]; - im->at(h, w)[2] = - (im->at(h, w)[2] - mean[2]) * scale[2]; - } + std::vector bgr_channels(3); + cv::split(*im, bgr_channels); + for (auto i = 0; i < bgr_channels.size(); i++) { + bgr_channels[i].convertTo(bgr_channels[i], CV_32FC1, 1.0 * scale[i], + (0.0 - mean[i]) * scale[i]); } + cv::merge(bgr_channels, *im); } void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, @@ -77,19 +74,13 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, int resize_h = int(float(h) * ratio); int resize_w = int(float(w) * ratio); - + resize_h = max(int(round(float(resize_h) / 32) * 32), 32); resize_w = max(int(round(float(resize_w) / 32) * 32), 32); - if (!use_tensorrt) { - cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); - ratio_h = float(resize_h) / float(h); - ratio_w = float(resize_w) / float(w); - } else { - cv::resize(img, resize_img, cv::Size(640, 640)); - ratio_h = float(640) / float(h); - ratio_w = float(640) / float(w); - } + cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); + ratio_h = float(resize_h) / float(h); + ratio_w = float(resize_w) / float(w); } void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio, @@ -108,23 +99,12 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio, resize_w = imgW; else resize_w = int(ceilf(imgH * ratio)); - if (!use_tensorrt) { - cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f, - cv::INTER_LINEAR); - cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, - int(imgW - resize_img.cols), cv::BORDER_CONSTANT, - {127, 127, 127}); - } else { - int k = int(img.cols * 32 / img.rows); - if (k >= 100) { - cv::resize(img, resize_img, cv::Size(100, 32), 0.f, 0.f, - cv::INTER_LINEAR); - } else { - cv::resize(img, resize_img, cv::Size(k, 32), 0.f, 0.f, cv::INTER_LINEAR); - cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, int(100 - k), - cv::BORDER_CONSTANT, {127, 127, 127}); - } - } + + cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f, + cv::INTER_LINEAR); + cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, + int(imgW - resize_img.cols), cv::BORDER_CONSTANT, + {127, 127, 127}); } void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, @@ -142,15 +122,11 @@ void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, else resize_w = int(ceilf(imgH * ratio)); - if (!use_tensorrt) { - cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f, - cv::INTER_LINEAR); - if (resize_w < imgW) { - cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w, - cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0)); - } - } else { - cv::resize(img, resize_img, cv::Size(100, 32), 0.f, 0.f, cv::INTER_LINEAR); + cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f, + cv::INTER_LINEAR); + if (resize_w < imgW) { + cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w, + cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0)); } } diff --git a/deploy/cpp_infer/src/utility.cpp b/deploy/cpp_infer/src/utility.cpp index c1c9d9382a06432daca71eb7b08acb8b19b8ee98..2cd84f7e8dbdd8144b5337f55b3f3a62ed43d5b3 100644 --- a/deploy/cpp_infer/src/utility.cpp +++ b/deploy/cpp_infer/src/utility.cpp @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include #include +#include +#include #include -#include - namespace PaddleOCR { std::vector Utility::ReadDict(const std::string &path) { @@ -57,4 +59,37 @@ void Utility::VisualizeBboxes( << std::endl; } +// list all files under a directory +void Utility::GetAllFiles(const char *dir_name, + std::vector &all_inputs) { + if (NULL == dir_name) { + std::cout << " dir_name is null ! " << std::endl; + return; + } + struct stat s; + lstat(dir_name, &s); + if (!S_ISDIR(s.st_mode)) { + std::cout << "dir_name is not a valid directory !" << std::endl; + all_inputs.push_back(dir_name); + return; + } else { + struct dirent *filename; // return value for readdir() + DIR *dir; // return value for opendir() + dir = opendir(dir_name); + if (NULL == dir) { + std::cout << "Can not open dir " << dir_name << std::endl; + return; + } + std::cout << "Successfully opened the dir !" << std::endl; + while ((filename = readdir(dir)) != NULL) { + if (strcmp(filename->d_name, ".") == 0 || + strcmp(filename->d_name, "..") == 0) + continue; + // img_dir + std::string("/") + all_inputs[0]; + all_inputs.push_back(dir_name + std::string("/") + + std::string(filename->d_name)); + } + } +} + } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/tools/build.sh b/deploy/cpp_infer/tools/build.sh index 606539487fce82adf817e7a3ee300e3bf890643b..79611300584755e531e6a2f645ab1a9420d3c5ad 100755 --- a/deploy/cpp_infer/tools/build.sh +++ b/deploy/cpp_infer/tools/build.sh @@ -12,9 +12,10 @@ cmake .. \ -DWITH_MKL=ON \ -DWITH_GPU=OFF \ -DWITH_STATIC_LIB=OFF \ - -DUSE_TENSORRT=OFF \ + -DWITH_TENSORRT=OFF \ -DOPENCV_DIR=${OPENCV_DIR} \ -DCUDNN_LIB=${CUDNN_LIB_DIR} \ -DCUDA_LIB=${CUDA_LIB_DIR} \ + -DTENSORRT_DIR=${TENSORRT_DIR} \ make -j diff --git a/deploy/cpp_infer/tools/config.txt b/deploy/cpp_infer/tools/config.txt index 0e5f8472ab62f5fc646738bc2974736a0564b343..d4d66d65225bc9d1d4d62f45550db71fb5d8414e 100644 --- a/deploy/cpp_infer/tools/config.txt +++ b/deploy/cpp_infer/tools/config.txt @@ -20,10 +20,10 @@ cls_thresh 0.9 # rec config rec_model_dir ./inference/ch_ppocr_mobile_v2.0_rec_infer/ -char_list_file ../../ppocr/utils/ppocr_keys_v1.txt +char_list_file ../../ppocr/utils/ppocr_keys_v1.txt # show the detection results -visualize 1 +visualize 0 # use_tensorrt use_tensorrt 0 diff --git a/deploy/hubserving/readme.md b/deploy/hubserving/readme.md index 88f335812de191d860e46f7317bb303a20df8b41..a39ac5a42b905b1efa73c02d7594511c8a7ea103 100755 --- a/deploy/hubserving/readme.md +++ b/deploy/hubserving/readme.md @@ -29,7 +29,7 @@ deploy/hubserving/ocr_system/ ### 1. 准备环境 ```shell # 安装paddlehub -pip3 install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple +pip3 install paddlehub==1.8.3 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple ``` ### 2. 下载推理模型 diff --git a/deploy/hubserving/readme_en.md b/deploy/hubserving/readme_en.md index c948fed1eefe9f5f83f63a82699cdac3548fad52..7d9a8629ef7d27e84e636f029202602a94d1d3f7 100755 --- a/deploy/hubserving/readme_en.md +++ b/deploy/hubserving/readme_en.md @@ -30,7 +30,7 @@ The following steps take the 2-stage series service as an example. If only the d ### 1. Prepare the environment ```shell # Install paddlehub -pip3 install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple +pip3 install paddlehub==1.8.3 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple ``` ### 2. Download inference model diff --git a/doc/doc_ch/config.md b/doc/doc_ch/config.md index 3f8f7ff1674d63e721d7ad2ced31bf771b0183eb..74cd238134d1999a6fbd96d0ad053d0304231a0b 100644 --- a/doc/doc_ch/config.md +++ b/doc/doc_ch/config.md @@ -111,9 +111,9 @@ | 字段 | 用途 | 默认值 | 备注 | | :---------------------: | :---------------------: | :--------------: | :--------------------: | | **dataset** | 每次迭代返回一个样本 | - | - | -| name | dataset类名 | SimpleDataSet | 目前支持`SimpleDataSet`和`LMDBDateSet` | +| name | dataset类名 | SimpleDataSet | 目前支持`SimpleDataSet`和`LMDBDataSet` | | data_dir | 数据集图片存放路径 | ./train_data | \ | -| label_file_list | 数据标签路径 | ["./train_data/train_list.txt"] | dataset为LMDBDateSet时不需要此参数 | +| label_file_list | 数据标签路径 | ["./train_data/train_list.txt"] | dataset为LMDBDataSet时不需要此参数 | | ratio_list | 数据集的比例 | [1.0] | 若label_file_list中有两个train_list,且ratio_list为[0.4,0.6],则从train_list1中采样40%,从train_list2中采样60%组合整个dataset | | transforms | 对图片和标签进行变换的方法列表 | [DecodeImage,CTCLabelEncode,RecResizeImg,KeepKeys] | 见[ppocr/data/imaug](../../ppocr/data/imaug) | | **loader** | dataloader相关 | - | | diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md index 8a7c341cf24738b8af8c974a6da41bcb1b51ce48..0f860065bef9eff8f90c18f120e43dcf0c2a47aa 100644 --- a/doc/doc_ch/recognition.md +++ b/doc/doc_ch/recognition.md @@ -243,7 +243,7 @@ Optimizer: Train: dataset: - # 数据集格式,支持LMDBDateSet以及SimpleDataSet + # 数据集格式,支持LMDBDataSet以及SimpleDataSet name: SimpleDataSet # 数据集路径 data_dir: ./train_data/ @@ -263,7 +263,7 @@ Train: Eval: dataset: - # 数据集格式,支持LMDBDateSet以及SimpleDataSet + # 数据集格式,支持LMDBDataSet以及SimpleDataSet name: SimpleDataSet # 数据集路径 data_dir: ./train_data @@ -393,7 +393,7 @@ Global: Train: dataset: - # 数据集格式,支持LMDBDateSet以及SimpleDataSet + # 数据集格式,支持LMDBDataSet以及SimpleDataSet name: SimpleDataSet # 数据集路径 data_dir: ./train_data/ @@ -403,7 +403,7 @@ Train: Eval: dataset: - # 数据集格式,支持LMDBDateSet以及SimpleDataSet + # 数据集格式,支持LMDBDataSet以及SimpleDataSet name: SimpleDataSet # 数据集路径 data_dir: ./train_data diff --git a/doc/doc_ch/whl.md b/doc/doc_ch/whl.md index 2e93c487c2f2071c7c89c753cf86eef61ce20805..957c6926b15fad3091265da9295f5ad820fe6a26 100644 --- a/doc/doc_ch/whl.md +++ b/doc/doc_ch/whl.md @@ -59,7 +59,7 @@ im_show.save('result.jpg') from paddleocr import PaddleOCR, draw_ocr ocr = PaddleOCR() # need to run only once to download and load model into memory img_path = 'PaddleOCR/doc/imgs/11.jpg' -result = ocr.ocr(img_path) +result = ocr.ocr(img_path,cls=False) for line in result: print(line) @@ -355,3 +355,4 @@ im_show.save('result.jpg') | det | 前向时使用启动检测 | TRUE | | rec | 前向时是否启动识别 | TRUE | | cls | 前向时是否启动分类 (命令行模式下使用use_angle_cls控制前向是否启动分类) | FALSE | +| show_log | 是否打印det和rec等信息 | FALSE | diff --git a/doc/doc_en/config_en.md b/doc/doc_en/config_en.md index 28ebb6e830369447395c661cbcc76aaf067a91d9..5e5847c4b298553b2d376b90196b61b7e0286efe 100644 --- a/doc/doc_en/config_en.md +++ b/doc/doc_en/config_en.md @@ -110,9 +110,9 @@ In ppocr, the network is divided into four stages: Transform, Backbone, Neck and | Parameter | Use | Defaults | Note | | :---------------------: | :---------------------: | :--------------: | :--------------------: | | **dataset** | Return one sample per iteration | - | - | -| name | dataset class name | SimpleDataSet | Currently support`SimpleDataSet`,`LMDBDateSet` | +| name | dataset class name | SimpleDataSet | Currently support`SimpleDataSet`,`LMDBDataSet` | | data_dir | Image folder path | ./train_data | \ | -| label_file_list | Groundtruth file path | ["./train_data/train_list.txt"] | This parameter is not required when dataset is LMDBDateSet | +| label_file_list | Groundtruth file path | ["./train_data/train_list.txt"] | This parameter is not required when dataset is LMDBDataSet | | ratio_list | Ratio of data set | [1.0] | If there are two train_lists in label_file_list and ratio_list is [0.4,0.6], 40% will be sampled from train_list1, and 60% will be sampled from train_list2 to combine the entire dataset | | transforms | List of methods to transform images and labels | [DecodeImage,CTCLabelEncode,RecResizeImg,KeepKeys] | see[ppocr/data/imaug](../../ppocr/data/imaug) | | **loader** | dataloader related | - | | diff --git a/doc/doc_en/distributed_training.md b/doc/doc_en/distributed_training.md new file mode 100644 index 0000000000000000000000000000000000000000..7a8b71ce308837568c84bf56292f78e9979d3907 --- /dev/null +++ b/doc/doc_en/distributed_training.md @@ -0,0 +1,50 @@ +# Distributed training + +## Introduction + +The high performance of distributed training is one of the core advantages of PaddlePaddle. In the classification task, distributed training can achieve almost linear speedup ratio. Generally, OCR training task need massive training data. Such as recognition, ppocrv2.0 model is trained based on 1800W dataset, which is very time-consuming if using single machine. Therefore, the distributed training is used in paddleocr to speedup the training task. For more information about distributed training, please refer to [distributed training quick start tutorial](https://fleet-x.readthedocs.io/en/latest/paddle_fleet_rst/parameter_server/ps_quick_start.html). + +## Quick Start + +### Training with single machine + +Take recognition as an example. After the data is prepared locally, start the training task with the interface of `paddle.distributed.launch`. The start command as follows: + +```shell +python3 -m paddle.distributed.launch \ + --log_dir=./log/ \ + --gpus '0,1,2,3,4,5,6,7' \ + tools/train.py \ + -c configs/rec/rec_mv3_none_bilstm_ctc.yml +``` + +### Training with multi machine + +Compared with single machine, training with multi machine only needs to add the parameter `--ips` to start command, which represents the IP list of machines used for distributed training, and the IP of different machines are separated by commas. The start command as follows: + +```shell +ip_list="192.168.0.1,192.168.0.2" +python3 -m paddle.distributed.launch \ + --log_dir=./log/ \ + --ips="${ip_list}" \ + --gpus="0,1,2,3,4,5,6,7" \ + tools/train.py \ + -c configs/rec/rec_mv3_none_bilstm_ctc.yml +``` + +**Notice:** +* The IP addresses of different machines need to be separated by commas, which can be queried through `ifconfig` or `ipconfig`. +* Different machines need to be set to be secret free and can `ping` success with others directly, otherwise communication cannot establish between them. +* The code, data and start command betweent different machines must be completely consistent and then all machines need to run start command. The first machine in the `ip_list` is set to `trainer0`, and so on. + + +## Performance comparison + +* Based on 26W public recognition dataset (LSVT, rctw, mtwi), training on single 8-card P40 and dual 8-card P40, the final time consumption is as follows. + +| Model | Config file | Number of machines | Number of GPUs per machine | Training time | Recognition acc | Speedup ratio | +| :-------: | :------------: | :----------------: | :----------------------------: | :------------------: | :--------------: | :-----------: | +| CRNN | configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml | 1 | 8 | 60h | 66.7% | - | +| CRNN | configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml | 2 | 8 | 40h | 67.0% | 150% | + +It can be seen that the training time is shortened from 60h to 40h, the speedup ratio can reach 150% (60h / 40h), and the efficiency is 75% (60h / (40h * 2)). diff --git a/doc/doc_en/recognition_en.md b/doc/doc_en/recognition_en.md index 0b3db6a235bdbfeb930d6cf3f7d086829fd32c43..e23166e0caef4f6a246502fa12101f86d61e4eac 100644 --- a/doc/doc_en/recognition_en.md +++ b/doc/doc_en/recognition_en.md @@ -237,7 +237,7 @@ Optimizer: Train: dataset: - # Type of dataset,we support LMDBDateSet and SimpleDataSet + # Type of dataset,we support LMDBDataSet and SimpleDataSet name: SimpleDataSet # Path of dataset data_dir: ./train_data/ @@ -257,7 +257,7 @@ Train: Eval: dataset: - # Type of dataset,we support LMDBDateSet and SimpleDataSet + # Type of dataset,we support LMDBDataSet and SimpleDataSet name: SimpleDataSet # Path of dataset data_dir: ./train_data @@ -394,7 +394,7 @@ Global: Train: dataset: - # Type of dataset,we support LMDBDateSet and SimpleDataSet + # Type of dataset,we support LMDBDataSet and SimpleDataSet name: SimpleDataSet # Path of dataset data_dir: ./train_data/ @@ -404,7 +404,7 @@ Train: Eval: dataset: - # Type of dataset,we support LMDBDateSet and SimpleDataSet + # Type of dataset,we support LMDBDataSet and SimpleDataSet name: SimpleDataSet # Path of dataset data_dir: ./train_data diff --git a/doc/doc_en/whl_en.md b/doc/doc_en/whl_en.md index 69abf085556f466853798077bb116b3986582bcc..b9909f498e830309eaad952df9171cd63b6f5e7b 100644 --- a/doc/doc_en/whl_en.md +++ b/doc/doc_en/whl_en.md @@ -59,7 +59,7 @@ Visualization of results from paddleocr import PaddleOCR,draw_ocr ocr = PaddleOCR(lang='en') # need to run only once to download and load model into memory img_path = 'PaddleOCR/doc/imgs_en/img_12.jpg' -result = ocr.ocr(img_path) +result = ocr.ocr(img_path, cls=False) for line in result: print(line) @@ -362,3 +362,5 @@ im_show.save('result.jpg') | det | Enable detction when `ppocr.ocr` func exec | TRUE | | rec | Enable recognition when `ppocr.ocr` func exec | TRUE | | cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE | +| show_log | Whether to print log in det and rec + | FALSE | \ No newline at end of file diff --git a/doc/table/1.png b/doc/table/1.png new file mode 100644 index 0000000000000000000000000000000000000000..47df618ab1bef431a5dd94418c01be16b09d31aa Binary files /dev/null and b/doc/table/1.png differ diff --git a/doc/table/PaddleDetection_config.png b/doc/table/PaddleDetection_config.png new file mode 100644 index 0000000000000000000000000000000000000000..d18932b66cc148b7796fe4b319ad9eb82c2a2868 Binary files /dev/null and b/doc/table/PaddleDetection_config.png differ diff --git a/doc/table/paper-image.jpg b/doc/table/paper-image.jpg new file mode 100644 index 0000000000000000000000000000000000000000..db7246b314556d73cd49d049b9b480887b6ef994 Binary files /dev/null and b/doc/table/paper-image.jpg differ diff --git a/doc/table/pipeline.png b/doc/table/pipeline.png new file mode 100644 index 0000000000000000000000000000000000000000..4acfb3e2ef423402d9fd1fc1b8ad02f0a072049b Binary files /dev/null and b/doc/table/pipeline.png differ diff --git a/doc/table/result_all.jpg b/doc/table/result_all.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3dd9840643989f1049c228c201b43f9ed89a5fcb Binary files /dev/null and b/doc/table/result_all.jpg differ diff --git a/doc/table/result_text.jpg b/doc/table/result_text.jpg new file mode 100644 index 0000000000000000000000000000000000000000..94c9bce4a73b2764bb9791972f62a3a5b37fed45 Binary files /dev/null and b/doc/table/result_text.jpg differ diff --git a/doc/table/tableocr_pipeline.png b/doc/table/tableocr_pipeline.png new file mode 100644 index 0000000000000000000000000000000000000000..731b84da9b832b67db42225379fbe09120cbee6b Binary files /dev/null and b/doc/table/tableocr_pipeline.png differ diff --git a/paddleocr.py b/paddleocr.py index c5da7248d2cc7d778758a87309cfeaedcbd8ceb5..f2a3496897c07f8d969b198441076ef992774c19 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -19,102 +19,101 @@ __dir__ = os.path.dirname(__file__) sys.path.append(os.path.join(__dir__, '')) import cv2 +import logging import numpy as np from pathlib import Path -import tarfile -import requests -from tqdm import tqdm from tools.infer import predict_system from ppocr.utils.logging import get_logger logger = get_logger() from ppocr.utils.utility import check_and_read_gif, get_image_file_list -from tools.infer.utility import draw_ocr +from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url +from tools.infer.utility import draw_ocr, init_args, str2bool __all__ = ['PaddleOCR'] model_urls = { 'det': { 'ch': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar', 'en': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar' + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar' }, 'rec': { 'ch': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/ppocr_keys_v1.txt' }, 'en': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/en_dict.txt' }, 'french': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/french_dict.txt' }, 'german': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/german_dict.txt' }, 'korean': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/korean_dict.txt' }, 'japan': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/japan_dict.txt' }, 'chinese_cht': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt' }, 'ta': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/ta_dict.txt' }, 'te': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/te_dict.txt' }, 'ka': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/ka_dict.txt' }, 'latin': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/latin_dict.txt' }, 'arabic': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/arabic_dict.txt' }, 'cyrillic': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/cyrillic_dict.txt' }, 'devanagari': { 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar', + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/devanagari_dict.txt' } }, 'cls': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar' + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar' } SUPPORT_DET_MODEL = ['DB'] @@ -123,150 +122,24 @@ SUPPORT_REC_MODEL = ['CRNN'] BASE_DIR = os.path.expanduser("~/.paddleocr/") -def download_with_progressbar(url, save_path): - response = requests.get(url, stream=True) - total_size_in_bytes = int(response.headers.get('content-length', 0)) - block_size = 1024 # 1 Kibibyte - progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) - with open(save_path, 'wb') as file: - for data in response.iter_content(block_size): - progress_bar.update(len(data)) - file.write(data) - progress_bar.close() - if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes: - logger.error("Something went wrong while downloading models") - sys.exit(0) - - -def maybe_download(model_storage_directory, url): - # using custom model - tar_file_name_list = [ - 'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel' - ] - if not os.path.exists( - os.path.join(model_storage_directory, 'inference.pdiparams') - ) or not os.path.exists( - os.path.join(model_storage_directory, 'inference.pdmodel')): - tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) - print('download {} to {}'.format(url, tmp_path)) - os.makedirs(model_storage_directory, exist_ok=True) - download_with_progressbar(url, tmp_path) - with tarfile.open(tmp_path, 'r') as tarObj: - for member in tarObj.getmembers(): - filename = None - for tar_file_name in tar_file_name_list: - if tar_file_name in member.name: - filename = tar_file_name - if filename is None: - continue - file = tarObj.extractfile(member) - with open( - os.path.join(model_storage_directory, filename), - 'wb') as f: - f.write(file.read()) - os.remove(tmp_path) - - -def parse_args(mMain=True, add_help=True): +def parse_args(mMain=True): import argparse - - def str2bool(v): - return v.lower() in ("true", "t", "1") - + parser = init_args() + parser.add_help = mMain + parser.add_argument("--lang", type=str, default='ch') + parser.add_argument("--det", type=str2bool, default=True) + parser.add_argument("--rec", type=str2bool, default=True) + + for action in parser._actions: + if action.dest == 'rec_char_dict_path': + action.default = None if mMain: - parser = argparse.ArgumentParser(add_help=add_help) - # params for prediction engine - parser.add_argument("--use_gpu", type=str2bool, default=True) - parser.add_argument("--ir_optim", type=str2bool, default=True) - parser.add_argument("--use_tensorrt", type=str2bool, default=False) - parser.add_argument("--gpu_mem", type=int, default=8000) - - # params for text detector - parser.add_argument("--image_dir", type=str) - parser.add_argument("--det_algorithm", type=str, default='DB') - parser.add_argument("--det_model_dir", type=str, default=None) - parser.add_argument("--det_limit_side_len", type=float, default=960) - parser.add_argument("--det_limit_type", type=str, default='max') - - # DB parmas - parser.add_argument("--det_db_thresh", type=float, default=0.3) - parser.add_argument("--det_db_box_thresh", type=float, default=0.5) - parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) - parser.add_argument("--use_dilation", type=bool, default=False) - parser.add_argument("--det_db_score_mode", type=str, default="fast") - - # EAST parmas - parser.add_argument("--det_east_score_thresh", type=float, default=0.8) - parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) - parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) - - # params for text recognizer - parser.add_argument("--rec_algorithm", type=str, default='CRNN') - parser.add_argument("--rec_model_dir", type=str, default=None) - parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") - parser.add_argument("--rec_char_type", type=str, default='ch') - parser.add_argument("--rec_batch_num", type=int, default=6) - parser.add_argument("--max_text_length", type=int, default=25) - parser.add_argument("--rec_char_dict_path", type=str, default=None) - parser.add_argument("--use_space_char", type=bool, default=True) - parser.add_argument("--drop_score", type=float, default=0.5) - - # params for text classifier - parser.add_argument("--cls_model_dir", type=str, default=None) - parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") - parser.add_argument("--label_list", type=list, default=['0', '180']) - parser.add_argument("--cls_batch_num", type=int, default=6) - parser.add_argument("--cls_thresh", type=float, default=0.9) - - parser.add_argument("--enable_mkldnn", type=bool, default=False) - parser.add_argument("--use_zero_copy_run", type=bool, default=False) - parser.add_argument("--use_pdserving", type=str2bool, default=False) - - parser.add_argument("--lang", type=str, default='ch') - parser.add_argument("--det", type=str2bool, default=True) - parser.add_argument("--rec", type=str2bool, default=True) - parser.add_argument("--use_angle_cls", type=str2bool, default=False) return parser.parse_args() else: - return argparse.Namespace( - use_gpu=True, - ir_optim=True, - use_tensorrt=False, - gpu_mem=8000, - image_dir='', - det_algorithm='DB', - det_model_dir=None, - det_limit_side_len=960, - det_limit_type='max', - det_db_thresh=0.3, - det_db_box_thresh=0.5, - det_db_unclip_ratio=1.6, - use_dilation=False, - det_db_score_mode="fast", - det_east_score_thresh=0.8, - det_east_cover_thresh=0.1, - det_east_nms_thresh=0.2, - rec_algorithm='CRNN', - rec_model_dir=None, - rec_image_shape="3, 32, 320", - rec_char_type='ch', - rec_batch_num=6, - max_text_length=25, - rec_char_dict_path=None, - use_space_char=True, - drop_score=0.5, - cls_model_dir=None, - cls_image_shape="3, 48, 192", - label_list=['0', '180'], - cls_batch_num=6, - cls_thresh=0.9, - enable_mkldnn=False, - use_zero_copy_run=False, - use_pdserving=False, - lang='ch', - det=True, - rec=True, - use_angle_cls=False) + inference_args_dict = {} + for action in parser._actions: + inference_args_dict[action.dest] = action.default + return argparse.Namespace(**inference_args_dict) class PaddleOCR(predict_system.TextSystem): @@ -276,10 +149,12 @@ class PaddleOCR(predict_system.TextSystem): args: **kwargs: other params show in paddleocr --help """ - postprocess_params = parse_args(mMain=False, add_help=False) - postprocess_params.__dict__.update(**kwargs) - self.use_angle_cls = postprocess_params.use_angle_cls - lang = postprocess_params.lang + params = parse_args(mMain=False) + params.__dict__.update(**kwargs) + if not params.show_log: + logger.setLevel(logging.INFO) + self.use_angle_cls = params.use_angle_cls + lang = params.lang latin_lang = [ 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', @@ -305,48 +180,47 @@ class PaddleOCR(predict_system.TextSystem): lang = "devanagari" assert lang in model_urls[ 'rec'], 'param lang must in {}, but got {}'.format( - model_urls['rec'].keys(), lang) + model_urls['rec'].keys(), lang) if lang == "ch": det_lang = "ch" else: det_lang = "en" use_inner_dict = False - if postprocess_params.rec_char_dict_path is None: + if params.rec_char_dict_path is None: use_inner_dict = True - postprocess_params.rec_char_dict_path = model_urls['rec'][lang][ + params.rec_char_dict_path = model_urls['rec'][lang][ 'dict_path'] # init model dir - if postprocess_params.det_model_dir is None: - postprocess_params.det_model_dir = os.path.join(BASE_DIR, VERSION, - 'det', det_lang) - if postprocess_params.rec_model_dir is None: - postprocess_params.rec_model_dir = os.path.join(BASE_DIR, VERSION, - 'rec', lang) - if postprocess_params.cls_model_dir is None: - postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls') - print(postprocess_params) + params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir, + os.path.join(BASE_DIR, VERSION, 'det', det_lang), + model_urls['det'][det_lang]) + params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir, + os.path.join(BASE_DIR, VERSION, 'rec', lang), + model_urls['rec'][lang]['url']) + params.cls_model_dir, cls_url = confirm_model_dir_url(params.cls_model_dir, + os.path.join(BASE_DIR, VERSION, 'cls'), + model_urls['cls']) # download model - maybe_download(postprocess_params.det_model_dir, - model_urls['det'][det_lang]) - maybe_download(postprocess_params.rec_model_dir, - model_urls['rec'][lang]['url']) - maybe_download(postprocess_params.cls_model_dir, model_urls['cls']) + maybe_download(params.det_model_dir, det_url) + maybe_download(params.rec_model_dir, rec_url) + maybe_download(params.cls_model_dir, cls_url) - if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL: + if params.det_algorithm not in SUPPORT_DET_MODEL: logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL)) sys.exit(0) - if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL: + if params.rec_algorithm not in SUPPORT_REC_MODEL: logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL)) sys.exit(0) if use_inner_dict: - postprocess_params.rec_char_dict_path = str( - Path(__file__).parent / postprocess_params.rec_char_dict_path) + params.rec_char_dict_path = str( + Path(__file__).parent / params.rec_char_dict_path) + print(params) # init det_model and rec_model - super().__init__(postprocess_params) + super().__init__(params) - def ocr(self, img, det=True, rec=True, cls=False): + def ocr(self, img, det=True, rec=True, cls=True): """ ocr with paddleocr args: @@ -358,9 +232,7 @@ class PaddleOCR(predict_system.TextSystem): if isinstance(img, list) and det == True: logger.error('When input a list of images, det must be false') exit(0) - if cls == False: - self.use_angle_cls = False - elif cls == True and self.use_angle_cls == False: + if cls == True and self.use_angle_cls == False: logger.warning( 'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process' ) @@ -382,7 +254,7 @@ class PaddleOCR(predict_system.TextSystem): if isinstance(img, np.ndarray) and len(img.shape) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) if det and rec: - dt_boxes, rec_res = self.__call__(img) + dt_boxes, rec_res = self.__call__(img, cls) return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] elif det and not rec: dt_boxes, elapse = self.text_detector(img) @@ -392,7 +264,7 @@ class PaddleOCR(predict_system.TextSystem): else: if not isinstance(img, list): img = [img] - if self.use_angle_cls: + if self.use_angle_cls and cls: img, cls_res, elapse = self.text_classifier(img) if not rec: return cls_res @@ -404,7 +276,7 @@ def main(): # for cmd args = parse_args(mMain=True) image_dir = args.image_dir - if image_dir.startswith('http'): + if is_link(image_dir): download_with_progressbar(image_dir, 'tmp.jpg') image_file_list = ['tmp.jpg'] else: diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py index 728b8317f54687ee76b519cba18f4d7807493821..e860c5a6986f495e6384d9df93c24795c04a0d5f 100644 --- a/ppocr/data/__init__.py +++ b/ppocr/data/__init__.py @@ -35,6 +35,7 @@ from ppocr.data.imaug import transform, create_operators from ppocr.data.simple_dataset import SimpleDataSet from ppocr.data.lmdb_dataset import LMDBDataSet from ppocr.data.pgnet_dataset import PGDataSet +from ppocr.data.pubtab_dataset import PubTabDataSet __all__ = ['build_dataloader', 'transform', 'create_operators'] @@ -55,7 +56,7 @@ signal.signal(signal.SIGTERM, term_mp) def build_dataloader(config, mode, device, logger, seed=None): config = copy.deepcopy(config) - support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet'] + support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet'] module_name = config[mode]['dataset']['name'] assert module_name in support_dict, Exception( 'DataSet only support {}'.format(support_dict)) diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index a808fd586b0676751da1ee31d379179b026fd51d..ff084a725a27a909fcc1b29d7dc3b309fa0623a2 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -29,6 +29,7 @@ from .label_ops import * from .east_process import * from .sast_process import * from .pg_process import * +from .gen_table_mask import * def transform(data, ops=None): diff --git a/ppocr/data/imaug/gen_table_mask.py b/ppocr/data/imaug/gen_table_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..08e35d5d1df7f9663b4e008451308d0ee409cf5a --- /dev/null +++ b/ppocr/data/imaug/gen_table_mask.py @@ -0,0 +1,244 @@ +""" +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import sys +import six +import cv2 +import numpy as np + + +class GenTableMask(object): + """ gen table mask """ + + def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs): + self.shrink_h_max = 5 + self.shrink_w_max = 5 + self.mask_type = mask_type + + def projection(self, erosion, h, w, spilt_threshold=0): + # 水平投影 + projection_map = np.ones_like(erosion) + project_val_array = [0 for _ in range(0, h)] + + for j in range(0, h): + for i in range(0, w): + if erosion[j, i] == 255: + project_val_array[j] += 1 + # 根据数组,获取切割点 + start_idx = 0 # 记录进入字符区的索引 + end_idx = 0 # 记录进入空白区域的索引 + in_text = False # 是否遍历到了字符区内 + box_list = [] + for i in range(len(project_val_array)): + if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了 + in_text = True + start_idx = i + elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了 + end_idx = i + in_text = False + if end_idx - start_idx <= 2: + continue + box_list.append((start_idx, end_idx + 1)) + + if in_text: + box_list.append((start_idx, h - 1)) + # 绘制投影直方图 + for j in range(0, h): + for i in range(0, project_val_array[j]): + projection_map[j, i] = 0 + return box_list, projection_map + + def projection_cx(self, box_img): + box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY) + h, w = box_gray_img.shape + # 灰度图片进行二值化处理 + ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV) + # 纵向腐蚀 + if h < w: + kernel = np.ones((2, 1), np.uint8) + erode = cv2.erode(thresh1, kernel, iterations=1) + else: + erode = thresh1 + # 水平膨胀 + kernel = np.ones((1, 5), np.uint8) + erosion = cv2.dilate(erode, kernel, iterations=1) + # 水平投影 + projection_map = np.ones_like(erosion) + project_val_array = [0 for _ in range(0, h)] + + for j in range(0, h): + for i in range(0, w): + if erosion[j, i] == 255: + project_val_array[j] += 1 + # 根据数组,获取切割点 + start_idx = 0 # 记录进入字符区的索引 + end_idx = 0 # 记录进入空白区域的索引 + in_text = False # 是否遍历到了字符区内 + box_list = [] + spilt_threshold = 0 + for i in range(len(project_val_array)): + if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了 + in_text = True + start_idx = i + elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了 + end_idx = i + in_text = False + if end_idx - start_idx <= 2: + continue + box_list.append((start_idx, end_idx + 1)) + + if in_text: + box_list.append((start_idx, h - 1)) + # 绘制投影直方图 + for j in range(0, h): + for i in range(0, project_val_array[j]): + projection_map[j, i] = 0 + split_bbox_list = [] + if len(box_list) > 1: + for i, (h_start, h_end) in enumerate(box_list): + if i == 0: + h_start = 0 + if i == len(box_list): + h_end = h + word_img = erosion[h_start:h_end + 1, :] + word_h, word_w = word_img.shape + w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h) + w_start, w_end = w_split_list[0][0], w_split_list[-1][1] + if h_start > 0: + h_start -= 1 + h_end += 1 + word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :] + split_bbox_list.append([w_start, h_start, w_end, h_end]) + else: + split_bbox_list.append([0, 0, w, h]) + return split_bbox_list + + def shrink_bbox(self, bbox): + left, top, right, bottom = bbox + sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max) + sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max) + left_new = left + sh_w + right_new = right - sh_w + top_new = top + sh_h + bottom_new = bottom - sh_h + if left_new >= right_new: + left_new = left + right_new = right + if top_new >= bottom_new: + top_new = top + bottom_new = bottom + return [left_new, top_new, right_new, bottom_new] + + def __call__(self, data): + img = data['image'] + cells = data['cells'] + height, width = img.shape[0:2] + if self.mask_type == 1: + mask_img = np.zeros((height, width), dtype=np.float32) + else: + mask_img = np.zeros((height, width, 3), dtype=np.float32) + cell_num = len(cells) + for cno in range(cell_num): + if "bbox" in cells[cno]: + bbox = cells[cno]['bbox'] + left, top, right, bottom = bbox + box_img = img[top:bottom, left:right, :].copy() + split_bbox_list = self.projection_cx(box_img) + for sno in range(len(split_bbox_list)): + split_bbox_list[sno][0] += left + split_bbox_list[sno][1] += top + split_bbox_list[sno][2] += left + split_bbox_list[sno][3] += top + + for sno in range(len(split_bbox_list)): + left, top, right, bottom = split_bbox_list[sno] + left, top, right, bottom = self.shrink_bbox([left, top, right, bottom]) + if self.mask_type == 1: + mask_img[top:bottom, left:right] = 1.0 + data['mask_img'] = mask_img + else: + mask_img[top:bottom, left:right, :] = (255, 255, 255) + data['image'] = mask_img + return data + +class ResizeTableImage(object): + def __init__(self, max_len, **kwargs): + super(ResizeTableImage, self).__init__() + self.max_len = max_len + + def get_img_bbox(self, cells): + bbox_list = [] + if len(cells) == 0: + return bbox_list + cell_num = len(cells) + for cno in range(cell_num): + if "bbox" in cells[cno]: + bbox = cells[cno]['bbox'] + bbox_list.append(bbox) + return bbox_list + + def resize_img_table(self, img, bbox_list, max_len): + height, width = img.shape[0:2] + ratio = max_len / (max(height, width) * 1.0) + resize_h = int(height * ratio) + resize_w = int(width * ratio) + img_new = cv2.resize(img, (resize_w, resize_h)) + bbox_list_new = [] + for bno in range(len(bbox_list)): + left, top, right, bottom = bbox_list[bno].copy() + left = int(left * ratio) + top = int(top * ratio) + right = int(right * ratio) + bottom = int(bottom * ratio) + bbox_list_new.append([left, top, right, bottom]) + return img_new, bbox_list_new + + def __call__(self, data): + img = data['image'] + if 'cells' not in data: + cells = [] + else: + cells = data['cells'] + bbox_list = self.get_img_bbox(cells) + img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len) + data['image'] = img_new + cell_num = len(cells) + bno = 0 + for cno in range(cell_num): + if "bbox" in data['cells'][cno]: + data['cells'][cno]['bbox'] = bbox_list_new[bno] + bno += 1 + data['max_len'] = self.max_len + return data + +class PaddingTableImage(object): + def __init__(self, **kwargs): + super(PaddingTableImage, self).__init__() + + def __call__(self, data): + img = data['image'] + max_len = data['max_len'] + padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32) + height, width = img.shape[0:2] + padding_img[0:height, 0:width, :] = img.copy() + data['image'] = padding_img + return data + \ No newline at end of file diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index bba3209f7560f19b74a54c102caf697319814803..e25cce79b553f127afc0167f18b6f663ceb617d7 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -351,3 +351,162 @@ class SRNLabelEncode(BaseRecLabelEncode): assert False, "Unsupport type %s in get_beg_end_flag_idx" \ % beg_or_end return idx + +class TableLabelEncode(object): + """ Convert between text-label and text-index """ + def __init__(self, + max_text_length, + max_elem_length, + max_cell_num, + character_dict_path, + span_weight = 1.0, + **kwargs): + self.max_text_length = max_text_length + self.max_elem_length = max_elem_length + self.max_cell_num = max_cell_num + list_character, list_elem = self.load_char_elem_dict(character_dict_path) + list_character = self.add_special_char(list_character) + list_elem = self.add_special_char(list_elem) + self.dict_character = {} + for i, char in enumerate(list_character): + self.dict_character[char] = i + self.dict_elem = {} + for i, elem in enumerate(list_elem): + self.dict_elem[elem] = i + self.span_weight = span_weight + + def load_char_elem_dict(self, character_dict_path): + list_character = [] + list_elem = [] + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + substr = lines[0].decode('utf-8').strip("\n").split("\t") + character_num = int(substr[0]) + elem_num = int(substr[1]) + for cno in range(1, 1+character_num): + character = lines[cno].decode('utf-8').strip("\n") + list_character.append(character) + for eno in range(1+character_num, 1+character_num+elem_num): + elem = lines[eno].decode('utf-8').strip("\n") + list_elem.append(elem) + return list_character, list_elem + + def add_special_char(self, list_character): + self.beg_str = "sos" + self.end_str = "eos" + list_character = [self.beg_str] + list_character + [self.end_str] + return list_character + + def get_span_idx_list(self): + span_idx_list = [] + for elem in self.dict_elem: + if 'span' in elem: + span_idx_list.append(self.dict_elem[elem]) + return span_idx_list + + def __call__(self, data): + cells = data['cells'] + structure = data['structure']['tokens'] + structure = self.encode(structure, 'elem') + if structure is None: + return None + elem_num = len(structure) + structure = [0] + structure + [len(self.dict_elem) - 1] + structure = structure + [0] * (self.max_elem_length + 2 - len(structure)) + structure = np.array(structure) + data['structure'] = structure + elem_char_idx1 = self.dict_elem[''] + elem_char_idx2 = self.dict_elem[' 0: + span_weight = len(td_idx_list) * 1.0 / len(span_idx_list) + span_weight = min(max(span_weight, 1.0), self.span_weight) + for cno in range(len(cells)): + if 'bbox' in cells[cno]: + bbox = cells[cno]['bbox'].copy() + bbox[0] = bbox[0] * 1.0 / img_width + bbox[1] = bbox[1] * 1.0 / img_height + bbox[2] = bbox[2] * 1.0 / img_width + bbox[3] = bbox[3] * 1.0 / img_height + td_idx = td_idx_list[cno] + bbox_list[td_idx] = bbox + bbox_list_mask[td_idx] = 1.0 + cand_span_idx = td_idx + 1 + if cand_span_idx < (self.max_elem_length + 2): + if structure[cand_span_idx] in span_idx_list: + structure_mask[cand_span_idx] = span_weight + + data['bbox_list'] = bbox_list + data['bbox_list_mask'] = bbox_list_mask + data['structure_mask'] = structure_mask + char_beg_idx = self.get_beg_end_flag_idx('beg', 'char') + char_end_idx = self.get_beg_end_flag_idx('end', 'char') + elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem') + elem_end_idx = self.get_beg_end_flag_idx('end', 'elem') + data['sp_tokens'] = np.array([char_beg_idx, char_end_idx, elem_beg_idx, + elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length, + self.max_elem_length, self.max_cell_num, elem_num]) + return data + + def encode(self, text, char_or_elem): + """convert text-label into text-index. + """ + if char_or_elem == "char": + max_len = self.max_text_length + current_dict = self.dict_character + else: + max_len = self.max_elem_length + current_dict = self.dict_elem + if len(text) > max_len: + return None + if len(text) == 0: + if char_or_elem == "char": + return [self.dict_character['space']] + else: + return None + text_list = [] + for char in text: + if char not in current_dict: + return None + text_list.append(current_dict[char]) + if len(text_list) == 0: + if char_or_elem == "char": + return [self.dict_character['space']] + else: + return None + return text_list + + def get_ignored_tokens(self, char_or_elem): + beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem) + end_idx = self.get_beg_end_flag_idx("end", char_or_elem) + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end, char_or_elem): + if char_or_elem == "char": + if beg_or_end == "beg": + idx = np.array(self.dict_character[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict_character[self.end_str]) + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \ + % beg_or_end + elif char_or_elem == "elem": + if beg_or_end == "beg": + idx = np.array(self.dict_elem[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict_elem[self.end_str]) + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \ + % beg_or_end + else: + assert False, "Unsupport type %s in char_or_elem" \ + % char_or_elem + return idx + \ No newline at end of file diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index 9c48b09647527cf718113ea1b5df152ff7befa04..2535b4420c503f2e9e9cc5a677ef70c4dd9c36be 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -81,7 +81,7 @@ class NormalizeImage(object): assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage" data['image'] = ( - img.astype('float32') * self.scale - self.mean) / self.std + img.astype('float32') * self.scale - self.mean) / self.std return data @@ -163,7 +163,7 @@ class DetResizeForTest(object): img, (ratio_h, ratio_w) """ limit_side_len = self.limit_side_len - h, w, _ = img.shape + h, w, c = img.shape # limit the max side if self.limit_type == 'max': @@ -174,7 +174,7 @@ class DetResizeForTest(object): ratio = float(limit_side_len) / w else: ratio = 1. - else: + elif self.limit_type == 'min': if min(h, w) < limit_side_len: if h < w: ratio = float(limit_side_len) / h @@ -182,6 +182,10 @@ class DetResizeForTest(object): ratio = float(limit_side_len) / w else: ratio = 1. + elif self.limit_type == 'resize_long': + ratio = float(limit_side_len) / max(h,w) + else: + raise Exception('not support limit type, image ') resize_h = int(h * ratio) resize_w = int(w * ratio) diff --git a/ppocr/data/pubtab_dataset.py b/ppocr/data/pubtab_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..78b76c5afb8c96bc96730c7b8ad76b4bafa31c67 --- /dev/null +++ b/ppocr/data/pubtab_dataset.py @@ -0,0 +1,107 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import os +import random +from paddle.io import Dataset +import json + +from .imaug import transform, create_operators + + +class PubTabDataSet(Dataset): + def __init__(self, config, mode, logger, seed=None): + super(PubTabDataSet, self).__init__() + self.logger = logger + + global_config = config['Global'] + dataset_config = config[mode]['dataset'] + loader_config = config[mode]['loader'] + + label_file_path = dataset_config.pop('label_file_path') + + self.data_dir = dataset_config['data_dir'] + self.do_shuffle = loader_config['shuffle'] + self.do_hard_select = False + if 'hard_select' in loader_config: + self.do_hard_select = loader_config['hard_select'] + self.hard_prob = loader_config['hard_prob'] + if self.do_hard_select: + self.img_select_prob = self.load_hard_select_prob() + self.table_select_type = None + if 'table_select_type' in loader_config: + self.table_select_type = loader_config['table_select_type'] + self.table_select_prob = loader_config['table_select_prob'] + + self.seed = seed + logger.info("Initialize indexs of datasets:%s" % label_file_path) + with open(label_file_path, "rb") as f: + self.data_lines = f.readlines() + self.data_idx_order_list = list(range(len(self.data_lines))) + if mode.lower() == "train": + self.shuffle_data_random() + self.ops = create_operators(dataset_config['transforms'], global_config) + + def shuffle_data_random(self): + if self.do_shuffle: + random.seed(self.seed) + random.shuffle(self.data_lines) + return + + def __getitem__(self, idx): + try: + data_line = self.data_lines[idx] + data_line = data_line.decode('utf-8').strip("\n") + info = json.loads(data_line) + file_name = info['filename'] + select_flag = True + if self.do_hard_select: + prob = self.img_select_prob[file_name] + if prob < random.uniform(0, 1): + select_flag = False + + if self.table_select_type: + structure = info['html']['structure']['tokens'].copy() + structure_str = ''.join(structure) + table_type = "simple" + if 'colspan' in structure_str or 'rowspan' in structure_str: + table_type = "complex" + if table_type == "complex": + if self.table_select_prob < random.uniform(0, 1): + select_flag = False + + if select_flag: + cells = info['html']['cells'].copy() + structure = info['html']['structure'].copy() + img_path = os.path.join(self.data_dir, file_name) + data = {'img_path': img_path, 'cells': cells, 'structure':structure} + if not os.path.exists(img_path): + raise Exception("{} does not exist!".format(img_path)) + with open(data['img_path'], 'rb') as f: + img = f.read() + data['image'] = img + outs = transform(data, self.ops) + else: + outs = None + except Exception as e: + self.logger.error( + "When parsing line {}, error happened with msg: {}".format( + data_line, e)) + outs = None + if outs is None: + return self.__getitem__(np.random.randint(self.__len__())) + return outs + + def __len__(self): + return len(self.data_idx_order_list) diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 223ae6b1da996478ac607e29dd37173ca51d9903..025ae7ca5cc604eea59423ca7f523c37c1492e35 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -13,28 +13,39 @@ # limitations under the License. import copy +import paddle +import paddle.nn as nn +# det loss +from .det_db_loss import DBLoss +from .det_east_loss import EASTLoss +from .det_sast_loss import SASTLoss -def build_loss(config): - # det loss - from .det_db_loss import DBLoss - from .det_east_loss import EASTLoss - from .det_sast_loss import SASTLoss +# rec loss +from .rec_ctc_loss import CTCLoss +from .rec_att_loss import AttentionLoss +from .rec_srn_loss import SRNLoss + +# cls loss +from .cls_loss import ClsLoss + +# e2e loss +from .e2e_pg_loss import PGLoss - # rec loss - from .rec_ctc_loss import CTCLoss - from .rec_att_loss import AttentionLoss - from .rec_srn_loss import SRNLoss +# basic loss function +from .basic_loss import DistanceLoss - # cls loss - from .cls_loss import ClsLoss +# combined loss function +from .combined_loss import CombinedLoss - # e2e loss - from .e2e_pg_loss import PGLoss +# table loss +from .table_att_loss import TableAttentionLoss + +def build_loss(config): support_dict = [ 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', - 'SRNLoss', 'PGLoss'] - + 'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss' + ] config = copy.deepcopy(config) module_name = config.pop('name') assert module_name in support_dict, Exception('loss only support {}'.format( diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..fa3ceda1b747aad3c4b275611b1257bf6950f013 --- /dev/null +++ b/ppocr/losses/basic_loss.py @@ -0,0 +1,103 @@ +#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddle.nn import L1Loss +from paddle.nn import MSELoss as L2Loss +from paddle.nn import SmoothL1Loss + + +class CELoss(nn.Layer): + def __init__(self, epsilon=None): + super().__init__() + if epsilon is not None and (epsilon <= 0 or epsilon >= 1): + epsilon = None + self.epsilon = epsilon + + def _labelsmoothing(self, target, class_num): + if target.shape[-1] != class_num: + one_hot_target = F.one_hot(target, class_num) + else: + one_hot_target = target + soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon) + soft_target = paddle.reshape(soft_target, shape=[-1, class_num]) + return soft_target + + def forward(self, x, label): + loss_dict = {} + if self.epsilon is not None: + class_num = x.shape[-1] + label = self._labelsmoothing(label, class_num) + x = -F.log_softmax(x, axis=-1) + loss = paddle.sum(x * label, axis=-1) + else: + if label.shape[-1] == x.shape[-1]: + label = F.softmax(label, axis=-1) + soft_label = True + else: + soft_label = False + loss = F.cross_entropy(x, label=label, soft_label=soft_label) + return loss + + +class DMLLoss(nn.Layer): + """ + DMLLoss + """ + + def __init__(self, act=None): + super().__init__() + if act is not None: + assert act in ["softmax", "sigmoid"] + if act == "softmax": + self.act = nn.Softmax(axis=-1) + elif act == "sigmoid": + self.act = nn.Sigmoid() + else: + self.act = None + + def forward(self, out1, out2): + if self.act is not None: + out1 = self.act(out1) + out2 = self.act(out2) + + log_out1 = paddle.log(out1) + log_out2 = paddle.log(out2) + loss = (F.kl_div( + log_out1, out2, reduction='batchmean') + F.kl_div( + log_out2, out1, reduction='batchmean')) / 2.0 + return loss + + +class DistanceLoss(nn.Layer): + """ + DistanceLoss: + mode: loss mode + """ + + def __init__(self, mode="l2", **kargs): + super().__init__() + assert mode in ["l1", "l2", "smooth_l1"] + if mode == "l1": + self.loss_func = nn.L1Loss(**kargs) + elif mode == "l2": + self.loss_func = nn.MSELoss(**kargs) + elif mode == "smooth_l1": + self.loss_func = nn.SmoothL1Loss(**kargs) + + def forward(self, x, y): + return self.loss_func(x, y) diff --git a/ppocr/losses/cls_loss.py b/ppocr/losses/cls_loss.py index 41c7db02446549064ffa8896c2c6861d0d9803c5..ecca5d2e1739631716123d4a793f5ece09d7f9ab 100755 --- a/ppocr/losses/cls_loss.py +++ b/ppocr/losses/cls_loss.py @@ -24,7 +24,7 @@ class ClsLoss(nn.Layer): super(ClsLoss, self).__init__() self.loss_func = nn.CrossEntropyLoss(reduction='mean') - def __call__(self, predicts, batch): + def forward(self, predicts, batch): label = batch[1] loss = self.loss_func(input=predicts, label=label) return {'loss': loss} diff --git a/ppocr/losses/combined_loss.py b/ppocr/losses/combined_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..54da70174cba7bf5ca35e8fbf5aa137a437ae29c --- /dev/null +++ b/ppocr/losses/combined_loss.py @@ -0,0 +1,58 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + +from .distillation_loss import DistillationCTCLoss +from .distillation_loss import DistillationDMLLoss +from .distillation_loss import DistillationDistanceLoss + + +class CombinedLoss(nn.Layer): + """ + CombinedLoss: + a combionation of loss function + """ + + def __init__(self, loss_config_list=None): + super().__init__() + self.loss_func = [] + self.loss_weight = [] + assert isinstance(loss_config_list, list), ( + 'operator config should be a list') + for config in loss_config_list: + assert isinstance(config, + dict) and len(config) == 1, "yaml format error" + name = list(config)[0] + param = config[name] + assert "weight" in param, "weight must be in param, but param just contains {}".format( + param.keys()) + self.loss_weight.append(param.pop("weight")) + self.loss_func.append(eval(name)(**param)) + + def forward(self, input, batch, **kargs): + loss_dict = {} + for idx, loss_func in enumerate(self.loss_func): + loss = loss_func(input, batch, **kargs) + if isinstance(loss, paddle.Tensor): + loss = {"loss_{}_{}".format(str(loss), idx): loss} + weight = self.loss_weight[idx] + loss = { + "{}_{}".format(key, idx): loss[key] * weight + for key in loss + } + loss_dict.update(loss) + loss_dict["loss"] = paddle.add_n(list(loss_dict.values())) + return loss_dict diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..1e8aa0d8602e3ddd49913e6a572914859377ca42 --- /dev/null +++ b/ppocr/losses/distillation_loss.py @@ -0,0 +1,108 @@ +#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.nn as nn + +from .rec_ctc_loss import CTCLoss +from .basic_loss import DMLLoss +from .basic_loss import DistanceLoss + + +class DistillationDMLLoss(DMLLoss): + """ + """ + + def __init__(self, model_name_pairs=[], act=None, key=None, + name="loss_dml"): + super().__init__(act=act) + assert isinstance(model_name_pairs, list) + self.key = key + self.model_name_pairs = model_name_pairs + self.name = name + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + loss = super().forward(out1, out2) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], + idx)] = loss[key] + else: + loss_dict["{}_{}".format(self.name, idx)] = loss + return loss_dict + + +class DistillationCTCLoss(CTCLoss): + def __init__(self, model_name_list=[], key=None, name="loss_ctc"): + super().__init__() + self.model_name_list = model_name_list + self.key = key + self.name = name + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, model_name in enumerate(self.model_name_list): + out = predicts[model_name] + if self.key is not None: + out = out[self.key] + loss = super().forward(out, batch) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}".format(self.name, model_name, + idx)] = loss[key] + else: + loss_dict["{}_{}".format(self.name, model_name)] = loss + return loss_dict + + +class DistillationDistanceLoss(DistanceLoss): + """ + """ + + def __init__(self, + mode="l2", + model_name_pairs=[], + key=None, + name="loss_distance", + **kargs): + super().__init__(mode=mode, **kargs) + assert isinstance(model_name_pairs, list) + self.key = key + self.model_name_pairs = model_name_pairs + self.name = name + "_l2" + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + loss = super().forward(out1, out2) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[ + key] + else: + loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1], + idx)] = loss + return loss_dict diff --git a/ppocr/losses/rec_ctc_loss.py b/ppocr/losses/rec_ctc_loss.py index 425de58710a61fde2034a88707a3032e02007d13..6c0b56ff84db4ff23786fb781d461bf9fbc86ef2 100755 --- a/ppocr/losses/rec_ctc_loss.py +++ b/ppocr/losses/rec_ctc_loss.py @@ -25,7 +25,7 @@ class CTCLoss(nn.Layer): super(CTCLoss, self).__init__() self.loss_func = nn.CTCLoss(blank=0, reduction='none') - def __call__(self, predicts, batch): + def forward(self, predicts, batch): predicts = predicts.transpose((1, 0, 2)) N, B, _ = predicts.shape preds_lengths = paddle.to_tensor([N] * B, dtype='int64') diff --git a/ppocr/losses/table_att_loss.py b/ppocr/losses/table_att_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..d7fd99e6952aacc0182a482ca5ae5ddaf959a026 --- /dev/null +++ b/ppocr/losses/table_att_loss.py @@ -0,0 +1,109 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle import fluid + +class TableAttentionLoss(nn.Layer): + def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs): + super(TableAttentionLoss, self).__init__() + self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none') + self.structure_weight = structure_weight + self.loc_weight = loc_weight + self.use_giou = use_giou + self.giou_weight = giou_weight + + def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'): + ''' + :param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,] + :param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,] + :return: loss + ''' + ix1 = fluid.layers.elementwise_max(preds[:, 0], bbox[:, 0]) + iy1 = fluid.layers.elementwise_max(preds[:, 1], bbox[:, 1]) + ix2 = fluid.layers.elementwise_min(preds[:, 2], bbox[:, 2]) + iy2 = fluid.layers.elementwise_min(preds[:, 3], bbox[:, 3]) + + iw = fluid.layers.clip(ix2 - ix1 + 1e-3, 0., 1e10) + ih = fluid.layers.clip(iy2 - iy1 + 1e-3, 0., 1e10) + + # overlap + inters = iw * ih + + # union + uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3 + ) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * ( + bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps + + # ious + ious = inters / uni + + ex1 = fluid.layers.elementwise_min(preds[:, 0], bbox[:, 0]) + ey1 = fluid.layers.elementwise_min(preds[:, 1], bbox[:, 1]) + ex2 = fluid.layers.elementwise_max(preds[:, 2], bbox[:, 2]) + ey2 = fluid.layers.elementwise_max(preds[:, 3], bbox[:, 3]) + ew = fluid.layers.clip(ex2 - ex1 + 1e-3, 0., 1e10) + eh = fluid.layers.clip(ey2 - ey1 + 1e-3, 0., 1e10) + + # enclose erea + enclose = ew * eh + eps + giou = ious - (enclose - uni) / enclose + + loss = 1 - giou + + if reduction == 'mean': + loss = paddle.mean(loss) + elif reduction == 'sum': + loss = paddle.sum(loss) + else: + raise NotImplementedError + return loss + + def forward(self, predicts, batch): + structure_probs = predicts['structure_probs'] + structure_targets = batch[1].astype("int64") + structure_targets = structure_targets[:, 1:] + if len(batch) == 6: + structure_mask = batch[5].astype("int64") + structure_mask = structure_mask[:, 1:] + structure_mask = paddle.reshape(structure_mask, [-1]) + structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]]) + structure_targets = paddle.reshape(structure_targets, [-1]) + structure_loss = self.loss_func(structure_probs, structure_targets) + + if len(batch) == 6: + structure_loss = structure_loss * structure_mask + +# structure_loss = paddle.sum(structure_loss) * self.structure_weight + structure_loss = paddle.mean(structure_loss) * self.structure_weight + + loc_preds = predicts['loc_preds'] + loc_targets = batch[2].astype("float32") + loc_targets_mask = batch[4].astype("float32") + loc_targets = loc_targets[:, 1:, :] + loc_targets_mask = loc_targets_mask[:, 1:, :] + loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight + if self.use_giou: + loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight + total_loss = structure_loss + loc_loss + loc_loss_giou + return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou} + else: + total_loss = structure_loss + loc_loss + return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss} \ No newline at end of file diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py index f913010dbd994633d3df1cf996abb994d246a11a..64f62e51cdf922773c03bb784a4edffdc17f506f 100644 --- a/ppocr/metrics/__init__.py +++ b/ppocr/metrics/__init__.py @@ -19,20 +19,23 @@ from __future__ import unicode_literals import copy -__all__ = ['build_metric'] +__all__ = ["build_metric"] +from .det_metric import DetMetric +from .rec_metric import RecMetric +from .cls_metric import ClsMetric +from .e2e_metric import E2EMetric +from .distillation_metric import DistillationMetric +from .table_metric import TableMetric def build_metric(config): - from .det_metric import DetMetric - from .rec_metric import RecMetric - from .cls_metric import ClsMetric - from .e2e_metric import E2EMetric - - support_dict = ['DetMetric', 'RecMetric', 'ClsMetric', 'E2EMetric'] + support_dict = [ + "DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric" + ] config = copy.deepcopy(config) - module_name = config.pop('name') + module_name = config.pop("name") assert module_name in support_dict, Exception( - 'metric only support {}'.format(support_dict)) + "metric only support {}".format(support_dict)) module_class = eval(module_name)(**config) return module_class diff --git a/ppocr/metrics/distillation_metric.py b/ppocr/metrics/distillation_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..a7d3d095a7d384bf8cdc69b97f8109c359ac2b5b --- /dev/null +++ b/ppocr/metrics/distillation_metric.py @@ -0,0 +1,76 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import copy + +from .rec_metric import RecMetric +from .det_metric import DetMetric +from .e2e_metric import E2EMetric +from .cls_metric import ClsMetric + + +class DistillationMetric(object): + def __init__(self, + key=None, + base_metric_name="RecMetric", + main_indicator='acc', + **kwargs): + self.main_indicator = main_indicator + self.key = key + self.main_indicator = main_indicator + self.base_metric_name = base_metric_name + self.kwargs = kwargs + self.metrics = None + + def _init_metrcis(self, preds): + self.metrics = dict() + mod = importlib.import_module(__name__) + for key in preds: + self.metrics[key] = getattr(mod, self.base_metric_name)( + main_indicator=self.main_indicator, **self.kwargs) + self.metrics[key].reset() + + def __call__(self, preds, *args, **kwargs): + assert isinstance(preds, dict) + if self.metrics is None: + self._init_metrcis(preds) + output = dict() + for key in preds: + metric = self.metrics[key].__call__(preds[key], *args, **kwargs) + for sub_key in metric: + output["{}_{}".format(key, sub_key)] = metric[sub_key] + return output + + def get_metric(self): + """ + return metrics { + 'acc': 0, + 'norm_edit_dis': 0, + } + """ + output = dict() + for key in self.metrics: + metric = self.metrics[key].get_metric() + # main indicator + if key == self.key: + output.update(metric) + else: + for sub_key in metric: + output["{}_{}".format(key, sub_key)] = metric[sub_key] + return output + + def reset(self): + for key in self.metrics: + self.metrics[key].reset() diff --git a/ppocr/metrics/table_metric.py b/ppocr/metrics/table_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..80d1c789ecc3979bd4c33620af91ccd28012f7a8 --- /dev/null +++ b/ppocr/metrics/table_metric.py @@ -0,0 +1,50 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +class TableMetric(object): + def __init__(self, main_indicator='acc', **kwargs): + self.main_indicator = main_indicator + self.reset() + + def __call__(self, pred, batch, *args, **kwargs): + structure_probs = pred['structure_probs'].numpy() + structure_labels = batch[1] + correct_num = 0 + all_num = 0 + structure_probs = np.argmax(structure_probs, axis=2) + structure_labels = structure_labels[:, 1:] + batch_size = structure_probs.shape[0] + for bno in range(batch_size): + all_num += 1 + if (structure_probs[bno] == structure_labels[bno]).all(): + correct_num += 1 + self.correct_num += correct_num + self.all_num += all_num + return { + 'acc': correct_num * 1.0 / all_num, + } + + def get_metric(self): + """ + return metrics { + 'acc': 0, + } + """ + acc = 1.0 * self.correct_num / self.all_num + self.reset() + return {'acc': acc} + + def reset(self): + self.correct_num = 0 + self.all_num = 0 diff --git a/ppocr/modeling/architectures/__init__.py b/ppocr/modeling/architectures/__init__.py index 86eaf7c9fb3c1147f60c7652243184121c62bcea..e9a01cf0281b91d29f2cce88375be3aaf43feb2e 100755 --- a/ppocr/modeling/architectures/__init__.py +++ b/ppocr/modeling/architectures/__init__.py @@ -13,12 +13,20 @@ # limitations under the License. import copy +import importlib + +from .base_model import BaseModel +from .distillation_model import DistillationModel __all__ = ['build_model'] + def build_model(config): - from .base_model import BaseModel - config = copy.deepcopy(config) - module_class = BaseModel(config) - return module_class \ No newline at end of file + if not "name" in config: + arch = BaseModel(config) + else: + name = config.pop("name") + mod = importlib.import_module(__name__) + arch = getattr(mod, name)(config) + return arch diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index 09b6e0346d998e3b90762e6163e8a34b48daff36..03fbcee8465df9c8bb7845ea62fc0ac04917caa0 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,7 +32,6 @@ class BaseModel(nn.Layer): config (dict): the super parameters for module. """ super(BaseModel, self).__init__() - in_channels = config.get('in_channels', 3) model_type = config['model_type'] # build transfrom, @@ -68,14 +67,20 @@ class BaseModel(nn.Layer): config["Head"]['in_channels'] = in_channels self.head = build_head(config["Head"]) + self.return_all_feats = config.get("return_all_feats", False) + def forward(self, x, data=None): + y = dict() if self.use_transform: x = self.transform(x) x = self.backbone(x) + y["backbone_out"] = x if self.use_neck: x = self.neck(x) - if data is None: - x = self.head(x) + y["neck_out"] = x + x = self.head(x, targets=data) + y["head_out"] = x + if self.return_all_feats: + return y else: - x = self.head(x, data) - return x + return x diff --git a/ppocr/modeling/architectures/distillation_model.py b/ppocr/modeling/architectures/distillation_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2e512331afcfc20e422dbef4ba1a4acd581df9e7 --- /dev/null +++ b/ppocr/modeling/architectures/distillation_model.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import nn +from ppocr.modeling.transforms import build_transform +from ppocr.modeling.backbones import build_backbone +from ppocr.modeling.necks import build_neck +from ppocr.modeling.heads import build_head +from .base_model import BaseModel +from ppocr.utils.save_load import init_model + +__all__ = ['DistillationModel'] + + +class DistillationModel(nn.Layer): + def __init__(self, config): + """ + the module for OCR distillation. + args: + config (dict): the super parameters for module. + """ + super().__init__() + self.model_list = [] + self.model_name_list = [] + for key in config["Models"]: + model_config = config["Models"][key] + freeze_params = False + pretrained = None + if "freeze_params" in model_config: + freeze_params = model_config.pop("freeze_params") + if "pretrained" in model_config: + pretrained = model_config.pop("pretrained") + model = BaseModel(model_config) + if pretrained is not None: + init_model(model, path=pretrained) + if freeze_params: + for param in model.parameters(): + param.trainable = False + self.model_list.append(self.add_sublayer(key, model)) + self.model_name_list.append(key) + + def forward(self, x): + result_dict = dict() + for idx, model_name in enumerate(self.model_name_list): + result_dict[model_name] = self.model_list[idx](x) + return result_dict diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index fe2c9bc30a4f2abd1ba7d3d6989b9ef9b20c1f4f..13b70b203371b3be58ee82c6808d744bf6098333 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -29,6 +29,10 @@ def build_backbone(config, model_type): elif model_type == 'e2e': from .e2e_resnet_vd_pg import ResNet support_dict = ['ResNet'] + elif model_type == "table": + from .table_resnet_vd import ResNet + from .table_mobilenet_v3 import MobileNetV3 + support_dict = ['ResNet', 'MobileNetV3'] else: raise NotImplementedError diff --git a/ppocr/modeling/backbones/det_mobilenet_v3.py b/ppocr/modeling/backbones/det_mobilenet_v3.py index bb451bbec9327e2624ab0d501a7adf4355dc3407..05113ea8419aa302c952adfd74e9083055c35dca 100755 --- a/ppocr/modeling/backbones/det_mobilenet_v3.py +++ b/ppocr/modeling/backbones/det_mobilenet_v3.py @@ -102,8 +102,7 @@ class MobileNetV3(nn.Layer): padding=1, groups=1, if_act=True, - act='hardswish', - name='conv1') + act='hardswish') self.stages = [] self.out_channels = [] @@ -125,8 +124,7 @@ class MobileNetV3(nn.Layer): kernel_size=k, stride=s, use_se=se, - act=nl, - name="conv" + str(i + 2))) + act=nl)) inplanes = make_divisible(scale * c) i += 1 block_list.append( @@ -138,8 +136,7 @@ class MobileNetV3(nn.Layer): padding=0, groups=1, if_act=True, - act='hardswish', - name='conv_last')) + act='hardswish')) self.stages.append(nn.Sequential(*block_list)) self.out_channels.append(make_divisible(scale * cls_ch_squeeze)) for i, stage in enumerate(self.stages): @@ -163,8 +160,7 @@ class ConvBNLayer(nn.Layer): padding, groups=1, if_act=True, - act=None, - name=None): + act=None): super(ConvBNLayer, self).__init__() self.if_act = if_act self.act = act @@ -175,16 +171,9 @@ class ConvBNLayer(nn.Layer): stride=stride, padding=padding, groups=groups, - weight_attr=ParamAttr(name=name + '_weights'), bias_attr=False) - self.bn = nn.BatchNorm( - num_channels=out_channels, - act=None, - param_attr=ParamAttr(name=name + "_bn_scale"), - bias_attr=ParamAttr(name=name + "_bn_offset"), - moving_mean_name=name + "_bn_mean", - moving_variance_name=name + "_bn_variance") + self.bn = nn.BatchNorm(num_channels=out_channels, act=None) def forward(self, x): x = self.conv(x) @@ -209,8 +198,7 @@ class ResidualUnit(nn.Layer): kernel_size, stride, use_se, - act=None, - name=''): + act=None): super(ResidualUnit, self).__init__() self.if_shortcut = stride == 1 and in_channels == out_channels self.if_se = use_se @@ -222,8 +210,7 @@ class ResidualUnit(nn.Layer): stride=1, padding=0, if_act=True, - act=act, - name=name + "_expand") + act=act) self.bottleneck_conv = ConvBNLayer( in_channels=mid_channels, out_channels=mid_channels, @@ -232,10 +219,9 @@ class ResidualUnit(nn.Layer): padding=int((kernel_size - 1) // 2), groups=mid_channels, if_act=True, - act=act, - name=name + "_depthwise") + act=act) if self.if_se: - self.mid_se = SEModule(mid_channels, name=name + "_se") + self.mid_se = SEModule(mid_channels) self.linear_conv = ConvBNLayer( in_channels=mid_channels, out_channels=out_channels, @@ -243,8 +229,7 @@ class ResidualUnit(nn.Layer): stride=1, padding=0, if_act=False, - act=None, - name=name + "_linear") + act=None) def forward(self, inputs): x = self.expand_conv(inputs) @@ -258,7 +243,7 @@ class ResidualUnit(nn.Layer): class SEModule(nn.Layer): - def __init__(self, in_channels, reduction=4, name=""): + def __init__(self, in_channels, reduction=4): super(SEModule, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2D(1) self.conv1 = nn.Conv2D( @@ -266,17 +251,13 @@ class SEModule(nn.Layer): out_channels=in_channels // reduction, kernel_size=1, stride=1, - padding=0, - weight_attr=ParamAttr(name=name + "_1_weights"), - bias_attr=ParamAttr(name=name + "_1_offset")) + padding=0) self.conv2 = nn.Conv2D( in_channels=in_channels // reduction, out_channels=in_channels, kernel_size=1, stride=1, - padding=0, - weight_attr=ParamAttr(name + "_2_weights"), - bias_attr=ParamAttr(name=name + "_2_offset")) + padding=0) def forward(self, inputs): outputs = self.avg_pool(inputs) diff --git a/ppocr/modeling/backbones/rec_mobilenet_v3.py b/ppocr/modeling/backbones/rec_mobilenet_v3.py index 1ff17159680372b00e6943e180e5fb638b39ec58..c5dcfdd5a3ad1f2c356f488a89e0f1e660a4a832 100644 --- a/ppocr/modeling/backbones/rec_mobilenet_v3.py +++ b/ppocr/modeling/backbones/rec_mobilenet_v3.py @@ -96,8 +96,7 @@ class MobileNetV3(nn.Layer): padding=1, groups=1, if_act=True, - act='hardswish', - name='conv1') + act='hardswish') i = 0 block_list = [] inplanes = make_divisible(inplanes * scale) @@ -110,8 +109,7 @@ class MobileNetV3(nn.Layer): kernel_size=k, stride=s, use_se=se, - act=nl, - name='conv' + str(i + 2))) + act=nl)) inplanes = make_divisible(scale * c) i += 1 self.blocks = nn.Sequential(*block_list) @@ -124,8 +122,7 @@ class MobileNetV3(nn.Layer): padding=0, groups=1, if_act=True, - act='hardswish', - name='conv_last') + act='hardswish') self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) self.out_channels = make_divisible(scale * cls_ch_squeeze) diff --git a/ppocr/modeling/backbones/table_mobilenet_v3.py b/ppocr/modeling/backbones/table_mobilenet_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..daa87f976038d8d5eeafadceb869b9232ba22cd9 --- /dev/null +++ b/ppocr/modeling/backbones/table_mobilenet_v3.py @@ -0,0 +1,287 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr + +__all__ = ['MobileNetV3'] + + +def make_divisible(v, divisor=8, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class MobileNetV3(nn.Layer): + def __init__(self, + in_channels=3, + model_name='large', + scale=0.5, + disable_se=False, + **kwargs): + """ + the MobilenetV3 backbone network for detection module. + Args: + params(dict): the super parameters for build network + """ + super(MobileNetV3, self).__init__() + + self.disable_se = disable_se + + if model_name == "large": + cfg = [ + # k, exp, c, se, nl, s, + [3, 16, 16, False, 'relu', 1], + [3, 64, 24, False, 'relu', 2], + [3, 72, 24, False, 'relu', 1], + [5, 72, 40, True, 'relu', 2], + [5, 120, 40, True, 'relu', 1], + [5, 120, 40, True, 'relu', 1], + [3, 240, 80, False, 'hardswish', 2], + [3, 200, 80, False, 'hardswish', 1], + [3, 184, 80, False, 'hardswish', 1], + [3, 184, 80, False, 'hardswish', 1], + [3, 480, 112, True, 'hardswish', 1], + [3, 672, 112, True, 'hardswish', 1], + [5, 672, 160, True, 'hardswish', 2], + [5, 960, 160, True, 'hardswish', 1], + [5, 960, 160, True, 'hardswish', 1], + ] + cls_ch_squeeze = 960 + elif model_name == "small": + cfg = [ + # k, exp, c, se, nl, s, + [3, 16, 16, True, 'relu', 2], + [3, 72, 24, False, 'relu', 2], + [3, 88, 24, False, 'relu', 1], + [5, 96, 40, True, 'hardswish', 2], + [5, 240, 40, True, 'hardswish', 1], + [5, 240, 40, True, 'hardswish', 1], + [5, 120, 48, True, 'hardswish', 1], + [5, 144, 48, True, 'hardswish', 1], + [5, 288, 96, True, 'hardswish', 2], + [5, 576, 96, True, 'hardswish', 1], + [5, 576, 96, True, 'hardswish', 1], + ] + cls_ch_squeeze = 576 + else: + raise NotImplementedError("mode[" + model_name + + "_model] is not implemented!") + + supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25] + assert scale in supported_scale, \ + "supported scale are {} but input scale is {}".format(supported_scale, scale) + inplanes = 16 + # conv1 + self.conv = ConvBNLayer( + in_channels=in_channels, + out_channels=make_divisible(inplanes * scale), + kernel_size=3, + stride=2, + padding=1, + groups=1, + if_act=True, + act='hardswish', + name='conv1') + + self.stages = [] + self.out_channels = [] + block_list = [] + i = 0 + inplanes = make_divisible(inplanes * scale) + for (k, exp, c, se, nl, s) in cfg: + se = se and not self.disable_se + start_idx = 2 if model_name == 'large' else 0 + if s == 2 and i > start_idx: + self.out_channels.append(inplanes) + self.stages.append(nn.Sequential(*block_list)) + block_list = [] + block_list.append( + ResidualUnit( + in_channels=inplanes, + mid_channels=make_divisible(scale * exp), + out_channels=make_divisible(scale * c), + kernel_size=k, + stride=s, + use_se=se, + act=nl, + name="conv" + str(i + 2))) + inplanes = make_divisible(scale * c) + i += 1 + block_list.append( + ConvBNLayer( + in_channels=inplanes, + out_channels=make_divisible(scale * cls_ch_squeeze), + kernel_size=1, + stride=1, + padding=0, + groups=1, + if_act=True, + act='hardswish', + name='conv_last')) + self.stages.append(nn.Sequential(*block_list)) + self.out_channels.append(make_divisible(scale * cls_ch_squeeze)) + for i, stage in enumerate(self.stages): + self.add_sublayer(sublayer=stage, name="stage{}".format(i)) + + def forward(self, x): + x = self.conv(x) + out_list = [] + for stage in self.stages: + x = stage(x) + out_list.append(x) + return out_list + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + groups=1, + if_act=True, + act=None, + name=None): + super(ConvBNLayer, self).__init__() + self.if_act = if_act + self.act = act + self.conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + weight_attr=ParamAttr(name=name + '_weights'), + bias_attr=False) + + self.bn = nn.BatchNorm( + num_channels=out_channels, + act=None, + param_attr=ParamAttr(name=name + "_bn_scale"), + bias_attr=ParamAttr(name=name + "_bn_offset"), + moving_mean_name=name + "_bn_mean", + moving_variance_name=name + "_bn_variance") + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.if_act: + if self.act == "relu": + x = F.relu(x) + elif self.act == "hardswish": + x = F.hardswish(x) + else: + print("The activation function({}) is selected incorrectly.". + format(self.act)) + exit() + return x + + +class ResidualUnit(nn.Layer): + def __init__(self, + in_channels, + mid_channels, + out_channels, + kernel_size, + stride, + use_se, + act=None, + name=''): + super(ResidualUnit, self).__init__() + self.if_shortcut = stride == 1 and in_channels == out_channels + self.if_se = use_se + + self.expand_conv = ConvBNLayer( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0, + if_act=True, + act=act, + name=name + "_expand") + self.bottleneck_conv = ConvBNLayer( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + padding=int((kernel_size - 1) // 2), + groups=mid_channels, + if_act=True, + act=act, + name=name + "_depthwise") + if self.if_se: + self.mid_se = SEModule(mid_channels, name=name + "_se") + self.linear_conv = ConvBNLayer( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + if_act=False, + act=None, + name=name + "_linear") + + def forward(self, inputs): + x = self.expand_conv(inputs) + x = self.bottleneck_conv(x) + if self.if_se: + x = self.mid_se(x) + x = self.linear_conv(x) + if self.if_shortcut: + x = paddle.add(inputs, x) + return x + + +class SEModule(nn.Layer): + def __init__(self, in_channels, reduction=4, name=""): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2D(1) + self.conv1 = nn.Conv2D( + in_channels=in_channels, + out_channels=in_channels // reduction, + kernel_size=1, + stride=1, + padding=0, + weight_attr=ParamAttr(name=name + "_1_weights"), + bias_attr=ParamAttr(name=name + "_1_offset")) + self.conv2 = nn.Conv2D( + in_channels=in_channels // reduction, + out_channels=in_channels, + kernel_size=1, + stride=1, + padding=0, + weight_attr=ParamAttr(name + "_2_weights"), + bias_attr=ParamAttr(name=name + "_2_offset")) + + def forward(self, inputs): + outputs = self.avg_pool(inputs) + outputs = self.conv1(outputs) + outputs = F.relu(outputs) + outputs = self.conv2(outputs) + outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5) + return inputs * outputs \ No newline at end of file diff --git a/ppocr/modeling/backbones/table_resnet_vd.py b/ppocr/modeling/backbones/table_resnet_vd.py new file mode 100644 index 0000000000000000000000000000000000000000..1c07c2684eec8d0c4a445cc88c543bfe1da9c864 --- /dev/null +++ b/ppocr/modeling/backbones/table_resnet_vd.py @@ -0,0 +1,280 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F + +__all__ = ["ResNet"] + + +class ConvBNLayer(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + groups=1, + is_vd_mode=False, + act=None, + name=None, ): + super(ConvBNLayer, self).__init__() + + self.is_vd_mode = is_vd_mode + self._pool2d_avg = nn.AvgPool2D( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self._conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(name=name + "_weights"), + bias_attr=False) + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + self._batch_norm = nn.BatchNorm( + out_channels, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def forward(self, inputs): + if self.is_vd_mode: + inputs = self._pool2d_avg(inputs) + y = self._conv(inputs) + y = self._batch_norm(y) + return y + + +class BottleneckBlock(nn.Layer): + def __init__(self, + in_channels, + out_channels, + stride, + shortcut=True, + if_first=False, + name=None): + super(BottleneckBlock, self).__init__() + + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + act='relu', + name=name + "_branch2a") + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + act='relu', + name=name + "_branch2b") + self.conv2 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels * 4, + kernel_size=1, + act=None, + name=name + "_branch2c") + + if not shortcut: + self.short = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels * 4, + kernel_size=1, + stride=1, + is_vd_mode=False if if_first else True, + name=name + "_branch1") + + self.shortcut = shortcut + + def forward(self, inputs): + y = self.conv0(inputs) + conv1 = self.conv1(y) + conv2 = self.conv2(conv1) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + y = paddle.add(x=short, y=conv2) + y = F.relu(y) + return y + + +class BasicBlock(nn.Layer): + def __init__(self, + in_channels, + out_channels, + stride, + shortcut=True, + if_first=False, + name=None): + super(BasicBlock, self).__init__() + self.stride = stride + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + act='relu', + name=name + "_branch2a") + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + act=None, + name=name + "_branch2b") + + if not shortcut: + self.short = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + is_vd_mode=False if if_first else True, + name=name + "_branch1") + + self.shortcut = shortcut + + def forward(self, inputs): + y = self.conv0(inputs) + conv1 = self.conv1(y) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + y = paddle.add(x=short, y=conv1) + y = F.relu(y) + return y + + +class ResNet(nn.Layer): + def __init__(self, in_channels=3, layers=50, **kwargs): + super(ResNet, self).__init__() + + self.layers = layers + supported_layers = [18, 34, 50, 101, 152, 200] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format( + supported_layers, layers) + + if layers == 18: + depth = [2, 2, 2, 2] + elif layers == 34 or layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + elif layers == 200: + depth = [3, 12, 48, 3] + num_channels = [64, 256, 512, + 1024] if layers >= 50 else [64, 64, 128, 256] + num_filters = [64, 128, 256, 512] + + self.conv1_1 = ConvBNLayer( + in_channels=in_channels, + out_channels=32, + kernel_size=3, + stride=2, + act='relu', + name="conv1_1") + self.conv1_2 = ConvBNLayer( + in_channels=32, + out_channels=32, + kernel_size=3, + stride=1, + act='relu', + name="conv1_2") + self.conv1_3 = ConvBNLayer( + in_channels=32, + out_channels=64, + kernel_size=3, + stride=1, + act='relu', + name="conv1_3") + self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) + + self.stages = [] + self.out_channels = [] + if layers >= 50: + for block in range(len(depth)): + block_list = [] + shortcut = False + for i in range(depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) + bottleneck_block = self.add_sublayer( + 'bb_%d_%d' % (block, i), + BottleneckBlock( + in_channels=num_channels[block] + if i == 0 else num_filters[block] * 4, + out_channels=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + shortcut=shortcut, + if_first=block == i == 0, + name=conv_name)) + shortcut = True + block_list.append(bottleneck_block) + self.out_channels.append(num_filters[block] * 4) + self.stages.append(nn.Sequential(*block_list)) + else: + for block in range(len(depth)): + block_list = [] + shortcut = False + for i in range(depth[block]): + conv_name = "res" + str(block + 2) + chr(97 + i) + basic_block = self.add_sublayer( + 'bb_%d_%d' % (block, i), + BasicBlock( + in_channels=num_channels[block] + if i == 0 else num_filters[block], + out_channels=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + shortcut=shortcut, + if_first=block == i == 0, + name=conv_name)) + shortcut = True + block_list.append(basic_block) + self.out_channels.append(num_filters[block]) + self.stages.append(nn.Sequential(*block_list)) + + def forward(self, inputs): + y = self.conv1_1(inputs) + y = self.conv1_2(y) + y = self.conv1_3(y) + y = self.pool2d_max(y) + out = [] + for block in self.stages: + y = block(y) + out.append(y) + return out diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 4852c7f2d14d72b9e4d59f40532469f7226c966d..5096479415f504aa9f074d55bd9b2e4a31c730b4 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -31,8 +31,10 @@ def build_head(config): from .cls_head import ClsHead support_dict = [ 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', - 'SRNHead', 'PGHead'] + 'SRNHead', 'PGHead', 'TableAttentionHead'] + #table head + from .table_att_head import TableAttentionHead module_name = config.pop('name') assert module_name in support_dict, Exception('head only support {}'.format( diff --git a/ppocr/modeling/heads/cls_head.py b/ppocr/modeling/heads/cls_head.py index d9b78b841b3c31ea349cfbf4e767328b12f39aa7..91bfa615a8206b5ec0f993429ccae990a05d0b9b 100644 --- a/ppocr/modeling/heads/cls_head.py +++ b/ppocr/modeling/heads/cls_head.py @@ -43,7 +43,7 @@ class ClsHead(nn.Layer): initializer=nn.initializer.Uniform(-stdv, stdv)), bias_attr=ParamAttr(name="fc_0.b_0"), ) - def forward(self, x): + def forward(self, x, targets=None): x = self.pool(x) x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]]) x = self.fc(x) diff --git a/ppocr/modeling/heads/det_db_head.py b/ppocr/modeling/heads/det_db_head.py index ca18d74a68f7b17ee6383d4a0c995a4c46a16187..f76cb34d37af7d81b5e628d06c1a4cfe126f8bb4 100644 --- a/ppocr/modeling/heads/det_db_head.py +++ b/ppocr/modeling/heads/det_db_head.py @@ -23,10 +23,10 @@ import paddle.nn.functional as F from paddle import ParamAttr -def get_bias_attr(k, name): +def get_bias_attr(k): stdv = 1.0 / math.sqrt(k * 1.0) initializer = paddle.nn.initializer.Uniform(-stdv, stdv) - bias_attr = ParamAttr(initializer=initializer, name=name + "_b_attr") + bias_attr = ParamAttr(initializer=initializer) return bias_attr @@ -38,18 +38,14 @@ class Head(nn.Layer): out_channels=in_channels // 4, kernel_size=3, padding=1, - weight_attr=ParamAttr(name=name_list[0] + '.w_0'), + weight_attr=ParamAttr(), bias_attr=False) self.conv_bn1 = nn.BatchNorm( num_channels=in_channels // 4, param_attr=ParamAttr( - name=name_list[1] + '.w_0', initializer=paddle.nn.initializer.Constant(value=1.0)), bias_attr=ParamAttr( - name=name_list[1] + '.b_0', initializer=paddle.nn.initializer.Constant(value=1e-4)), - moving_mean_name=name_list[1] + '.w_1', - moving_variance_name=name_list[1] + '.w_2', act='relu') self.conv2 = nn.Conv2DTranspose( in_channels=in_channels // 4, @@ -57,19 +53,14 @@ class Head(nn.Layer): kernel_size=2, stride=2, weight_attr=ParamAttr( - name=name_list[2] + '.w_0', initializer=paddle.nn.initializer.KaimingUniform()), - bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv2")) + bias_attr=get_bias_attr(in_channels // 4)) self.conv_bn2 = nn.BatchNorm( num_channels=in_channels // 4, param_attr=ParamAttr( - name=name_list[3] + '.w_0', initializer=paddle.nn.initializer.Constant(value=1.0)), bias_attr=ParamAttr( - name=name_list[3] + '.b_0', initializer=paddle.nn.initializer.Constant(value=1e-4)), - moving_mean_name=name_list[3] + '.w_1', - moving_variance_name=name_list[3] + '.w_2', act="relu") self.conv3 = nn.Conv2DTranspose( in_channels=in_channels // 4, @@ -77,10 +68,8 @@ class Head(nn.Layer): kernel_size=2, stride=2, weight_attr=ParamAttr( - name=name_list[4] + '.w_0', initializer=paddle.nn.initializer.KaimingUniform()), - bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv3"), - ) + bias_attr=get_bias_attr(in_channels // 4), ) def forward(self, x): x = self.conv1(x) @@ -117,7 +106,7 @@ class DBHead(nn.Layer): def step_function(self, x, y): return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y))) - def forward(self, x): + def forward(self, x, targets=None): shrink_maps = self.binarize(x) if not self.training: return {'maps': shrink_maps} diff --git a/ppocr/modeling/heads/det_east_head.py b/ppocr/modeling/heads/det_east_head.py index 9d0c3c4cf83adb018fcc368374cbe305658e07a9..004eb5d7bb9a134d1a84f980e37e5336dc43a29a 100644 --- a/ppocr/modeling/heads/det_east_head.py +++ b/ppocr/modeling/heads/det_east_head.py @@ -109,7 +109,7 @@ class EASTHead(nn.Layer): act=None, name="f_geo") - def forward(self, x): + def forward(self, x, targets=None): f_det = self.det_conv1(x) f_det = self.det_conv2(f_det) f_score = self.score_conv(f_det) diff --git a/ppocr/modeling/heads/det_sast_head.py b/ppocr/modeling/heads/det_sast_head.py index 263b28672299e733369938fa03952dca7685fabe..7a88a2db6c29c8c4fa1ee94d27bd0701cdbc90f8 100644 --- a/ppocr/modeling/heads/det_sast_head.py +++ b/ppocr/modeling/heads/det_sast_head.py @@ -116,7 +116,7 @@ class SASTHead(nn.Layer): self.head1 = SAST_Header1(in_channels) self.head2 = SAST_Header2(in_channels) - def forward(self, x): + def forward(self, x, targets=None): f_score, f_border = self.head1(x) f_tvo, f_tco = self.head2(x) diff --git a/ppocr/modeling/heads/e2e_pg_head.py b/ppocr/modeling/heads/e2e_pg_head.py index 0da9de7580a0ceb473f971b2246c966497026a5d..274e1cdac5172f45590c9f7d7b50522c74db6750 100644 --- a/ppocr/modeling/heads/e2e_pg_head.py +++ b/ppocr/modeling/heads/e2e_pg_head.py @@ -220,7 +220,7 @@ class PGHead(nn.Layer): weight_attr=ParamAttr(name="conv_f_direc{}".format(4)), bias_attr=False) - def forward(self, x): + def forward(self, x, targets=None): f_score = self.conv_f_score1(x) f_score = self.conv_f_score2(f_score) f_score = self.conv_f_score3(f_score) diff --git a/ppocr/modeling/heads/rec_ctc_head.py b/ppocr/modeling/heads/rec_ctc_head.py index 69d4ef50b648c0251b9b8d0b4c1e731a6f236105..9c38d31fa0abcf39a583e5edcebfc8f336f41c46 100755 --- a/ppocr/modeling/heads/rec_ctc_head.py +++ b/ppocr/modeling/heads/rec_ctc_head.py @@ -23,32 +23,57 @@ from paddle import ParamAttr, nn from paddle.nn import functional as F -def get_para_bias_attr(l2_decay, k, name): +def get_para_bias_attr(l2_decay, k): regularizer = paddle.regularizer.L2Decay(l2_decay) stdv = 1.0 / math.sqrt(k * 1.0) initializer = nn.initializer.Uniform(-stdv, stdv) - weight_attr = ParamAttr( - regularizer=regularizer, initializer=initializer, name=name + "_w_attr") - bias_attr = ParamAttr( - regularizer=regularizer, initializer=initializer, name=name + "_b_attr") + weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer) + bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer) return [weight_attr, bias_attr] class CTCHead(nn.Layer): - def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs): + def __init__(self, + in_channels, + out_channels, + fc_decay=0.0004, + mid_channels=None, + **kwargs): super(CTCHead, self).__init__() - weight_attr, bias_attr = get_para_bias_attr( - l2_decay=fc_decay, k=in_channels, name='ctc_fc') - self.fc = nn.Linear( - in_channels, - out_channels, - weight_attr=weight_attr, - bias_attr=bias_attr, - name='ctc_fc') + if mid_channels is None: + weight_attr, bias_attr = get_para_bias_attr( + l2_decay=fc_decay, k=in_channels) + self.fc = nn.Linear( + in_channels, + out_channels, + weight_attr=weight_attr, + bias_attr=bias_attr) + else: + weight_attr1, bias_attr1 = get_para_bias_attr( + l2_decay=fc_decay, k=in_channels) + self.fc1 = nn.Linear( + in_channels, + mid_channels, + weight_attr=weight_attr1, + bias_attr=bias_attr1) + + weight_attr2, bias_attr2 = get_para_bias_attr( + l2_decay=fc_decay, k=mid_channels) + self.fc2 = nn.Linear( + mid_channels, + out_channels, + weight_attr=weight_attr2, + bias_attr=bias_attr2) self.out_channels = out_channels + self.mid_channels = mid_channels - def forward(self, x, labels=None): - predicts = self.fc(x) + def forward(self, x, targets=None): + if self.mid_channels is None: + predicts = self.fc(x) + else: + predicts = self.fc1(x) + predicts = self.fc2(predicts) + if not self.training: predicts = F.softmax(predicts, axis=2) return predicts diff --git a/ppocr/modeling/heads/rec_srn_head.py b/ppocr/modeling/heads/rec_srn_head.py index d2c7fc028d28c79057708d4e6f306c417ba6306a..8d59e4711a043afd9234f430a62c9876c0a8f6f4 100644 --- a/ppocr/modeling/heads/rec_srn_head.py +++ b/ppocr/modeling/heads/rec_srn_head.py @@ -250,7 +250,8 @@ class SRNHead(nn.Layer): self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0 - def forward(self, inputs, others): + def forward(self, inputs, targets=None): + others = targets[-4:] encoder_word_pos = others[0] gsrm_word_pos = others[1] gsrm_slf_attn_bias1 = others[2] diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py new file mode 100644 index 0000000000000000000000000000000000000000..155f036d15673135eae9e5ee493648603609535d --- /dev/null +++ b/ppocr/modeling/heads/table_att_head.py @@ -0,0 +1,238 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import numpy as np + + +class TableAttentionHead(nn.Layer): + def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs): + super(TableAttentionHead, self).__init__() + self.input_size = in_channels[-1] + self.hidden_size = hidden_size + self.elem_num = 30 + self.max_text_length = 100 + self.max_elem_length = 500 + self.max_cell_num = 500 + + self.structure_attention_cell = AttentionGRUCell( + self.input_size, hidden_size, self.elem_num, use_gru=False) + self.structure_generator = nn.Linear(hidden_size, self.elem_num) + self.loc_type = loc_type + self.in_max_len = in_max_len + + if self.loc_type == 1: + self.loc_generator = nn.Linear(hidden_size, 4) + else: + if self.in_max_len == 640: + self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1) + elif self.in_max_len == 800: + self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1) + else: + self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1) + self.loc_generator = nn.Linear(self.input_size + hidden_size, 4) + + def _char_to_onehot(self, input_char, onehot_dim): + input_ont_hot = F.one_hot(input_char, onehot_dim) + return input_ont_hot + + def forward(self, inputs, targets=None): + # if and else branch are both needed when you want to assign a variable + # if you modify the var in just one branch, then the modification will not work. + fea = inputs[-1] + if len(fea.shape) == 3: + pass + else: + last_shape = int(np.prod(fea.shape[2:])) # gry added + fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape]) + fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) + batch_size = fea.shape[0] + + hidden = paddle.zeros((batch_size, self.hidden_size)) + output_hiddens = [] + if self.training and targets is not None: + structure = targets[0] + for i in range(self.max_elem_length+1): + elem_onehots = self._char_to_onehot( + structure[:, i], onehot_dim=self.elem_num) + (outputs, hidden), alpha = self.structure_attention_cell( + hidden, fea, elem_onehots) + output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) + output = paddle.concat(output_hiddens, axis=1) + structure_probs = self.structure_generator(output) + if self.loc_type == 1: + loc_preds = self.loc_generator(output) + loc_preds = F.sigmoid(loc_preds) + else: + loc_fea = fea.transpose([0, 2, 1]) + loc_fea = self.loc_fea_trans(loc_fea) + loc_fea = loc_fea.transpose([0, 2, 1]) + loc_concat = paddle.concat([output, loc_fea], axis=2) + loc_preds = self.loc_generator(loc_concat) + loc_preds = F.sigmoid(loc_preds) + else: + temp_elem = paddle.zeros(shape=[batch_size], dtype="int32") + structure_probs = None + loc_preds = None + elem_onehots = None + outputs = None + alpha = None + max_elem_length = paddle.to_tensor(self.max_elem_length) + i = 0 + while i < max_elem_length+1: + elem_onehots = self._char_to_onehot( + temp_elem, onehot_dim=self.elem_num) + (outputs, hidden), alpha = self.structure_attention_cell( + hidden, fea, elem_onehots) + output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) + structure_probs_step = self.structure_generator(outputs) + temp_elem = structure_probs_step.argmax(axis=1, dtype="int32") + i += 1 + + output = paddle.concat(output_hiddens, axis=1) + structure_probs = self.structure_generator(output) + structure_probs = F.softmax(structure_probs) + if self.loc_type == 1: + loc_preds = self.loc_generator(output) + loc_preds = F.sigmoid(loc_preds) + else: + loc_fea = fea.transpose([0, 2, 1]) + loc_fea = self.loc_fea_trans(loc_fea) + loc_fea = loc_fea.transpose([0, 2, 1]) + loc_concat = paddle.concat([output, loc_fea], axis=2) + loc_preds = self.loc_generator(loc_concat) + loc_preds = F.sigmoid(loc_preds) + return {'structure_probs':structure_probs, 'loc_preds':loc_preds} + + +class AttentionGRUCell(nn.Layer): + def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): + super(AttentionGRUCell, self).__init__() + self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False) + self.h2h = nn.Linear(hidden_size, hidden_size) + self.score = nn.Linear(hidden_size, 1, bias_attr=False) + self.rnn = nn.GRUCell( + input_size=input_size + num_embeddings, hidden_size=hidden_size) + self.hidden_size = hidden_size + + def forward(self, prev_hidden, batch_H, char_onehots): + batch_H_proj = self.i2h(batch_H) + prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1) + res = paddle.add(batch_H_proj, prev_hidden_proj) + res = paddle.tanh(res) + e = self.score(res) + alpha = F.softmax(e, axis=1) + alpha = paddle.transpose(alpha, [0, 2, 1]) + context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1) + concat_context = paddle.concat([context, char_onehots], 1) + cur_hidden = self.rnn(concat_context, prev_hidden) + return cur_hidden, alpha + + +class AttentionLSTM(nn.Layer): + def __init__(self, in_channels, out_channels, hidden_size, **kwargs): + super(AttentionLSTM, self).__init__() + self.input_size = in_channels + self.hidden_size = hidden_size + self.num_classes = out_channels + + self.attention_cell = AttentionLSTMCell( + in_channels, hidden_size, out_channels, use_gru=False) + self.generator = nn.Linear(hidden_size, out_channels) + + def _char_to_onehot(self, input_char, onehot_dim): + input_ont_hot = F.one_hot(input_char, onehot_dim) + return input_ont_hot + + def forward(self, inputs, targets=None, batch_max_length=25): + batch_size = inputs.shape[0] + num_steps = batch_max_length + + hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros( + (batch_size, self.hidden_size))) + output_hiddens = [] + + if targets is not None: + for i in range(num_steps): + # one-hot vectors for a i-th char + char_onehots = self._char_to_onehot( + targets[:, i], onehot_dim=self.num_classes) + hidden, alpha = self.attention_cell(hidden, inputs, + char_onehots) + + hidden = (hidden[1][0], hidden[1][1]) + output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1)) + output = paddle.concat(output_hiddens, axis=1) + probs = self.generator(output) + + else: + targets = paddle.zeros(shape=[batch_size], dtype="int32") + probs = None + + for i in range(num_steps): + char_onehots = self._char_to_onehot( + targets, onehot_dim=self.num_classes) + hidden, alpha = self.attention_cell(hidden, inputs, + char_onehots) + probs_step = self.generator(hidden[0]) + hidden = (hidden[1][0], hidden[1][1]) + if probs is None: + probs = paddle.unsqueeze(probs_step, axis=1) + else: + probs = paddle.concat( + [probs, paddle.unsqueeze( + probs_step, axis=1)], axis=1) + + next_input = probs_step.argmax(axis=1) + + targets = next_input + + return probs + + +class AttentionLSTMCell(nn.Layer): + def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): + super(AttentionLSTMCell, self).__init__() + self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False) + self.h2h = nn.Linear(hidden_size, hidden_size) + self.score = nn.Linear(hidden_size, 1, bias_attr=False) + if not use_gru: + self.rnn = nn.LSTMCell( + input_size=input_size + num_embeddings, hidden_size=hidden_size) + else: + self.rnn = nn.GRUCell( + input_size=input_size + num_embeddings, hidden_size=hidden_size) + + self.hidden_size = hidden_size + + def forward(self, prev_hidden, batch_H, char_onehots): + batch_H_proj = self.i2h(batch_H) + prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1) + res = paddle.add(batch_H_proj, prev_hidden_proj) + res = paddle.tanh(res) + e = self.score(res) + + alpha = F.softmax(e, axis=1) + alpha = paddle.transpose(alpha, [0, 2, 1]) + context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1) + concat_context = paddle.concat([context, char_onehots], 1) + cur_hidden = self.rnn(concat_context, prev_hidden) + + return cur_hidden, alpha diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py index 37a5cf7863cb386884d82ed88c756c9fc06a541d..e97c4f64bdc9acd6729d67a9c6ff7a7563f6c95e 100644 --- a/ppocr/modeling/necks/__init__.py +++ b/ppocr/modeling/necks/__init__.py @@ -21,7 +21,8 @@ def build_neck(config): from .sast_fpn import SASTFPN from .rnn import SequenceEncoder from .pg_fpn import PGFPN - support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN'] + from .table_fpn import TableFPN + support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN'] module_name = config.pop('name') assert module_name in support_dict, Exception('neck only support {}'.format( diff --git a/ppocr/modeling/necks/db_fpn.py b/ppocr/modeling/necks/db_fpn.py index 710023f30cdda90322b731c4bd3465d0dc06a139..1cf30cedd5b23e8a7ba243726a6d7eea7924750c 100644 --- a/ppocr/modeling/necks/db_fpn.py +++ b/ppocr/modeling/necks/db_fpn.py @@ -32,61 +32,53 @@ class DBFPN(nn.Layer): in_channels=in_channels[0], out_channels=self.out_channels, kernel_size=1, - weight_attr=ParamAttr( - name='conv2d_51.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.in3_conv = nn.Conv2D( in_channels=in_channels[1], out_channels=self.out_channels, kernel_size=1, - weight_attr=ParamAttr( - name='conv2d_50.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.in4_conv = nn.Conv2D( in_channels=in_channels[2], out_channels=self.out_channels, kernel_size=1, - weight_attr=ParamAttr( - name='conv2d_49.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.in5_conv = nn.Conv2D( in_channels=in_channels[3], out_channels=self.out_channels, kernel_size=1, - weight_attr=ParamAttr( - name='conv2d_48.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.p5_conv = nn.Conv2D( in_channels=self.out_channels, out_channels=self.out_channels // 4, kernel_size=3, padding=1, - weight_attr=ParamAttr( - name='conv2d_52.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.p4_conv = nn.Conv2D( in_channels=self.out_channels, out_channels=self.out_channels // 4, kernel_size=3, padding=1, - weight_attr=ParamAttr( - name='conv2d_53.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.p3_conv = nn.Conv2D( in_channels=self.out_channels, out_channels=self.out_channels // 4, kernel_size=3, padding=1, - weight_attr=ParamAttr( - name='conv2d_54.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.p2_conv = nn.Conv2D( in_channels=self.out_channels, out_channels=self.out_channels // 4, kernel_size=3, padding=1, - weight_attr=ParamAttr( - name='conv2d_55.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) def forward(self, x): diff --git a/ppocr/modeling/necks/table_fpn.py b/ppocr/modeling/necks/table_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..734f15af65e4e15a7ddb4004954a61bfa1934246 --- /dev/null +++ b/ppocr/modeling/necks/table_fpn.py @@ -0,0 +1,110 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr + + +class TableFPN(nn.Layer): + def __init__(self, in_channels, out_channels, **kwargs): + super(TableFPN, self).__init__() + self.out_channels = 512 + weight_attr = paddle.nn.initializer.KaimingUniform() + self.in2_conv = nn.Conv2D( + in_channels=in_channels[0], + out_channels=self.out_channels, + kernel_size=1, + weight_attr=ParamAttr(initializer=weight_attr), + bias_attr=False) + self.in3_conv = nn.Conv2D( + in_channels=in_channels[1], + out_channels=self.out_channels, + kernel_size=1, + stride = 1, + weight_attr=ParamAttr(initializer=weight_attr), + bias_attr=False) + self.in4_conv = nn.Conv2D( + in_channels=in_channels[2], + out_channels=self.out_channels, + kernel_size=1, + weight_attr=ParamAttr(initializer=weight_attr), + bias_attr=False) + self.in5_conv = nn.Conv2D( + in_channels=in_channels[3], + out_channels=self.out_channels, + kernel_size=1, + weight_attr=ParamAttr(initializer=weight_attr), + bias_attr=False) + self.p5_conv = nn.Conv2D( + in_channels=self.out_channels, + out_channels=self.out_channels // 4, + kernel_size=3, + padding=1, + weight_attr=ParamAttr(initializer=weight_attr), + bias_attr=False) + self.p4_conv = nn.Conv2D( + in_channels=self.out_channels, + out_channels=self.out_channels // 4, + kernel_size=3, + padding=1, + weight_attr=ParamAttr(initializer=weight_attr), + bias_attr=False) + self.p3_conv = nn.Conv2D( + in_channels=self.out_channels, + out_channels=self.out_channels // 4, + kernel_size=3, + padding=1, + weight_attr=ParamAttr(initializer=weight_attr), + bias_attr=False) + self.p2_conv = nn.Conv2D( + in_channels=self.out_channels, + out_channels=self.out_channels // 4, + kernel_size=3, + padding=1, + weight_attr=ParamAttr(initializer=weight_attr), + bias_attr=False) + self.fuse_conv = nn.Conv2D( + in_channels=self.out_channels * 4, + out_channels=512, + kernel_size=3, + padding=1, + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) + + def forward(self, x): + c2, c3, c4, c5 = x + + in5 = self.in5_conv(c5) + in4 = self.in4_conv(c4) + in3 = self.in3_conv(c3) + in2 = self.in2_conv(c2) + + out4 = in4 + F.upsample( + in5, size=in4.shape[2:4], mode="nearest", align_mode=1) # 1/16 + out3 = in3 + F.upsample( + out4, size=in3.shape[2:4], mode="nearest", align_mode=1) # 1/8 + out2 = in2 + F.upsample( + out3, size=in2.shape[2:4], mode="nearest", align_mode=1) # 1/4 + + p4 = F.upsample(out4, size=in5.shape[2:4], mode="nearest", align_mode=1) + p3 = F.upsample(out3, size=in5.shape[2:4], mode="nearest", align_mode=1) + p2 = F.upsample(out2, size=in5.shape[2:4], mode="nearest", align_mode=1) + fuse = paddle.concat([in5, p4, p3, p2], axis=1) + fuse_conv = self.fuse_conv(fuse) * 0.005 + return [c5 + fuse_conv] diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 042654a19d2d2d2f1363fedbb9ac3530696e6903..2f5bdc3b13135ed69e8af2e28ee0cd8042bf87e6 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -21,18 +21,20 @@ import copy __all__ = ['build_post_process'] +from .db_postprocess import DBPostProcess +from .east_postprocess import EASTPostProcess +from .sast_postprocess import SASTPostProcess +from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \ + TableLabelDecode +from .cls_postprocess import ClsPostProcess +from .pg_postprocess import PGPostProcess -def build_post_process(config, global_config=None): - from .db_postprocess import DBPostProcess - from .east_postprocess import EASTPostProcess - from .sast_postprocess import SASTPostProcess - from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode - from .cls_postprocess import ClsPostProcess - from .pg_postprocess import PGPostProcess +def build_post_process(config, global_config=None): support_dict = [ 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', - 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess' + 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', + 'DistillationCTCLabelDecode', 'TableLabelDecode' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index d353391c9af2b85bd01ba659f541fa1791461f68..8426bcf2b9a71e0293d912e25f1b617fd18c59fc 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -44,16 +44,16 @@ class BaseRecLabelDecode(object): self.character_str = string.printable[:-6] dict_character = list(self.character_str) elif character_type in support_character_type: - self.character_str = "" + self.character_str = [] assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format( character_type) with open(character_dict_path, "rb") as fin: lines = fin.readlines() for line in lines: line = line.decode('utf-8').strip("\n").strip("\r\n") - self.character_str += line + self.character_str.append(line) if use_space_char: - self.character_str += " " + self.character_str.append(" ") dict_character = list(self.character_str) else: @@ -125,6 +125,37 @@ class CTCLabelDecode(BaseRecLabelDecode): return dict_character +class DistillationCTCLabelDecode(CTCLabelDecode): + """ + Convert + Convert between text-label and text-index + """ + + def __init__(self, + character_dict_path=None, + character_type='ch', + use_space_char=False, + model_name=["student"], + key=None, + **kwargs): + super(DistillationCTCLabelDecode, self).__init__( + character_dict_path, character_type, use_space_char) + if not isinstance(model_name, list): + model_name = [model_name] + self.model_name = model_name + + self.key = key + + def __call__(self, preds, label=None, *args, **kwargs): + output = dict() + for name in self.model_name: + pred = preds[name] + if self.key is not None: + pred = pred[self.key] + output[name] = super().__call__(pred, label=label, *args, **kwargs) + return output + + class AttnLabelDecode(BaseRecLabelDecode): """ Convert between text-label and text-index """ @@ -288,3 +319,138 @@ class SRNLabelDecode(BaseRecLabelDecode): assert False, "unsupport type %s in get_beg_end_flag_idx" \ % beg_or_end return idx + + +class TableLabelDecode(object): + """ """ + + def __init__(self, + character_dict_path, + **kwargs): + list_character, list_elem = self.load_char_elem_dict(character_dict_path) + list_character = self.add_special_char(list_character) + list_elem = self.add_special_char(list_elem) + self.dict_character = {} + self.dict_idx_character = {} + for i, char in enumerate(list_character): + self.dict_idx_character[i] = char + self.dict_character[char] = i + self.dict_elem = {} + self.dict_idx_elem = {} + for i, elem in enumerate(list_elem): + self.dict_idx_elem[i] = elem + self.dict_elem[elem] = i + + def load_char_elem_dict(self, character_dict_path): + list_character = [] + list_elem = [] + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + substr = lines[0].decode('utf-8').strip("\n").split("\t") + character_num = int(substr[0]) + elem_num = int(substr[1]) + for cno in range(1, 1 + character_num): + character = lines[cno].decode('utf-8').strip("\n") + list_character.append(character) + for eno in range(1 + character_num, 1 + character_num + elem_num): + elem = lines[eno].decode('utf-8').strip("\n") + list_elem.append(elem) + return list_character, list_elem + + def add_special_char(self, list_character): + self.beg_str = "sos" + self.end_str = "eos" + list_character = [self.beg_str] + list_character + [self.end_str] + return list_character + + def __call__(self, preds): + structure_probs = preds['structure_probs'] + loc_preds = preds['loc_preds'] + if isinstance(structure_probs,paddle.Tensor): + structure_probs = structure_probs.numpy() + if isinstance(loc_preds,paddle.Tensor): + loc_preds = loc_preds.numpy() + structure_idx = structure_probs.argmax(axis=2) + structure_probs = structure_probs.max(axis=2) + structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx, + structure_probs, 'elem') + res_html_code_list = [] + res_loc_list = [] + batch_num = len(structure_str) + for bno in range(batch_num): + res_loc = [] + for sno in range(len(structure_str[bno])): + text = structure_str[bno][sno] + if text in ['', ' 0 and tmp_elem_idx == end_idx: + break + if tmp_elem_idx in ignored_tokens: + continue + + char_list.append(current_dict[tmp_elem_idx]) + elem_pos_list.append(idx) + score_list.append(structure_probs[batch_idx, idx]) + elem_idx_list.append(tmp_elem_idx) + result_list.append(char_list) + result_pos_list.append(elem_pos_list) + result_score_list.append(score_list) + result_elem_idx_list.append(elem_idx_list) + return result_list, result_pos_list, result_score_list, result_elem_idx_list + + def get_ignored_tokens(self, char_or_elem): + beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem) + end_idx = self.get_beg_end_flag_idx("end", char_or_elem) + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end, char_or_elem): + if char_or_elem == "char": + if beg_or_end == "beg": + idx = self.dict_character[self.beg_str] + elif beg_or_end == "end": + idx = self.dict_character[self.end_str] + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \ + % beg_or_end + elif char_or_elem == "elem": + if beg_or_end == "beg": + idx = self.dict_elem[self.beg_str] + elif beg_or_end == "end": + idx = self.dict_elem[self.end_str] + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \ + % beg_or_end + else: + assert False, "Unsupport type %s in char_or_elem" \ + % char_or_elem + return idx diff --git a/ppocr/utils/dict/table_dict.txt b/ppocr/utils/dict/table_dict.txt new file mode 100644 index 0000000000000000000000000000000000000000..2ef028c786cbce6d1e25856c62986d757b31f93b --- /dev/null +++ b/ppocr/utils/dict/table_dict.txt @@ -0,0 +1,277 @@ +← + +☆ +─ +α + + +⋅ +$ +ω +ψ +χ +( +υ +≥ +σ +, +ρ +ε +0 +■ +4 +8 +✗ +b +< +✓ +Ψ +Ω +€ +D +3 +Π +H +║ + +L +Φ +Χ +θ +P +κ +λ +μ +T +ξ +X +β +γ +δ +\ +ζ +η +` +d + +h +f +l +Θ +p +√ +t + +x +Β +Γ +Δ +| +ǂ +ɛ +j +̧ +➢ +⁡ +̌ +′ +« +△ +▲ +# + +' +Ι ++ +¶ +/ +▼ +⇑ +□ +· +7 +▪ +; +? +➔ +∩ +C +÷ +G +⇒ +K + +O +S +С +W +Α +[ +○ +_ +● +‡ +c +z +g + +o + +〈 +〉 +s +⩽ +w +φ +ʹ +{ +» +∣ +̆ +e +ˆ +∈ +τ +◆ +ι +∅ +∆ +∙ +∘ +Ø +ß +✔ +∞ +∑ +− +× +◊ +∗ +∖ +˃ +˂ +∫ +" +i +& +π +↔ +* +∥ +æ +∧ +. +⁄ +ø +Q +∼ +6 +⁎ +: +★ +> +a +B +≈ +F +J +̄ +N +♯ +R +V + +― +Z +♣ +^ +¤ +¥ +§ + +¢ +£ +≦ +­ +≤ +‖ +Λ +© +n +↓ +→ +↑ +r +° +± +v + +♂ +k +♀ +~ +ᅟ +̇ +@ +” +♦ +ł +® +⊕ +„ +! + +% +⇓ +) +- +1 +5 +9 += +А +A +‰ +⋆ +Σ +E +◦ +I +※ +M +m +̨ +⩾ +† + +• +U +Y +
 +] +̸ +2 +‐ +– +‒ +̂ +— +̀ +́ +’ +‘ +⋮ +⋯ +̊ +“ +̈ +≧ +q +u +ı +y + +​ +̃ +} +ν diff --git a/ppocr/utils/dict/table_structure_dict.txt b/ppocr/utils/dict/table_structure_dict.txt new file mode 100644 index 0000000000000000000000000000000000000000..9c4531e5f3b8c498e70d3c2ea0471e5e746a2c30 --- /dev/null +++ b/ppocr/utils/dict/table_structure_dict.txt @@ -0,0 +1,2759 @@ +277 28 1267 1186 + +V +a +r +i +b +l +e + +H +z +d + +t +o +9 +5 +% +C +I + +p + +v +u +* +A +g +( +m +n +) +0 +. +7 +1 +6 +≤ +> +8 +3 +– +2 +G +4 +M +F +T +y +f +s +L +w +c +U +h +D +S +Q +R +x +P +- +E +O +/ +k +, ++ +N +K +q +′ +[ +] +< +≥ + +− + +μ +± +J +j +W +_ +Δ +B +“ +: +Y +α +λ +; + + +? +∼ += +° +# +̊ +̈ +̂ +’ +Z +X +∗ +— +β +' +† +~ +@ +" +γ +↓ +↑ +& +‡ +χ +” +σ +§ +| +¶ +‐ +× +$ +→ +√ +✓ +‘ +\ +∞ +π +• +® +^ +∆ +≧ + + +́ +♀ +♂ +‒ +⁎ +▲ +· +£ +φ +Ψ +ß +△ +☆ +▪ +η +€ +∧ +̃ +Φ +ρ +̄ +δ +‰ +̧ +Ω +♦ +{ +} +̀ +∑ +∫ +ø +κ +ε +¥ +※ +` +ω +Σ +➔ +‖ +Β +̸ +
 +─ +● +⩾ +Χ +Α +⋅ +◆ +★ +■ +ψ +ǂ +□ +ζ +! +Γ +↔ +θ +⁄ +〈 +〉 +― +υ +τ +⋆ +Ø +© +∥ +С +˂ +➢ +ɛ +⁡ +✗ +← +○ +¢ +⩽ +∖ +˃ +­ +≈ +Π +̌ +≦ +∅ +ᅟ + + +∣ +¤ +♯ +̆ +ξ +÷ +▼ + +ι +ν +║ + + +◦ +​ +◊ +∙ +« +» +ł +ı +Θ +∈ +„ +∘ +✔ +̇ +æ +ʹ +ˆ +♣ +⇓ +∩ +⊕ +⇒ +⇑ +̨ +Ι +Λ +⋯ +А +⋮ + + + + + + + + + + colspan="2" + colspan="3" + rowspan="2" + colspan="4" + colspan="6" + rowspan="3" + colspan="9" + colspan="10" + colspan="7" + rowspan="4" + rowspan="5" + rowspan="9" + colspan="8" + rowspan="8" + rowspan="6" + rowspan="7" + rowspan="10" +0 2924682 +1 3405345 +2 2363468 +3 2709165 +4 4078680 +5 3250792 +6 1923159 +7 1617890 +8 1450532 +9 1717624 +10 1477550 +11 1489223 +12 915528 +13 819193 +14 593660 +15 518924 +16 682065 +17 494584 +18 400591 +19 396421 +20 340994 +21 280688 +22 250328 +23 226786 +24 199927 +25 182707 +26 164629 +27 141613 +28 127554 +29 116286 +30 107682 +31 96367 +32 88002 +33 79234 +34 72186 +35 65921 +36 60374 +37 55976 +38 52166 +39 47414 +40 44932 +41 41279 +42 38232 +43 35463 +44 33703 +45 30557 +46 29639 +47 27000 +48 25447 +49 23186 +50 22093 +51 20412 +52 19844 +53 18261 +54 17561 +55 16499 +56 15597 +57 14558 +58 14372 +59 13445 +60 13514 +61 12058 +62 11145 +63 10767 +64 10370 +65 9630 +66 9337 +67 8881 +68 8727 +69 8060 +70 7994 +71 7740 +72 7189 +73 6729 +74 6749 +75 6548 +76 6321 +77 5957 +78 5740 +79 5407 +80 5370 +81 5035 +82 4921 +83 4656 +84 4600 +85 4519 +86 4277 +87 4023 +88 3939 +89 3910 +90 3861 +91 3560 +92 3483 +93 3406 +94 3346 +95 3229 +96 3122 +97 3086 +98 3001 +99 2884 +100 2822 +101 2677 +102 2670 +103 2610 +104 2452 +105 2446 +106 2400 +107 2300 +108 2316 +109 2196 +110 2089 +111 2083 +112 2041 +113 1881 +114 1838 +115 1896 +116 1795 +117 1786 +118 1743 +119 1765 +120 1750 +121 1683 +122 1563 +123 1499 +124 1513 +125 1462 +126 1388 +127 1441 +128 1417 +129 1392 +130 1306 +131 1321 +132 1274 +133 1294 +134 1240 +135 1126 +136 1157 +137 1130 +138 1084 +139 1130 +140 1083 +141 1040 +142 980 +143 1031 +144 974 +145 980 +146 932 +147 898 +148 960 +149 907 +150 852 +151 912 +152 859 +153 847 +154 876 +155 792 +156 791 +157 765 +158 788 +159 787 +160 744 +161 673 +162 683 +163 697 +164 666 +165 680 +166 632 +167 677 +168 657 +169 618 +170 587 +171 585 +172 567 +173 549 +174 562 +175 548 +176 542 +177 539 +178 542 +179 549 +180 547 +181 526 +182 525 +183 514 +184 512 +185 505 +186 515 +187 467 +188 475 +189 458 +190 435 +191 443 +192 427 +193 424 +194 404 +195 389 +196 429 +197 404 +198 386 +199 351 +200 388 +201 408 +202 361 +203 346 +204 324 +205 361 +206 363 +207 364 +208 323 +209 336 +210 342 +211 315 +212 325 +213 328 +214 314 +215 327 +216 320 +217 300 +218 295 +219 315 +220 310 +221 295 +222 275 +223 248 +224 274 +225 232 +226 293 +227 259 +228 286 +229 263 +230 242 +231 214 +232 261 +233 231 +234 211 +235 250 +236 233 +237 206 +238 224 +239 210 +240 233 +241 223 +242 216 +243 222 +244 207 +245 212 +246 196 +247 205 +248 201 +249 202 +250 211 +251 201 +252 215 +253 179 +254 163 +255 179 +256 191 +257 188 +258 196 +259 150 +260 154 +261 176 +262 211 +263 166 +264 171 +265 165 +266 149 +267 182 +268 159 +269 161 +270 164 +271 161 +272 141 +273 151 +274 127 +275 129 +276 142 +277 158 +278 148 +279 135 +280 127 +281 134 +282 138 +283 131 +284 126 +285 125 +286 130 +287 126 +288 135 +289 125 +290 135 +291 131 +292 95 +293 135 +294 106 +295 117 +296 136 +297 128 +298 128 +299 118 +300 109 +301 112 +302 117 +303 108 +304 120 +305 100 +306 95 +307 108 +308 112 +309 77 +310 120 +311 104 +312 109 +313 89 +314 98 +315 82 +316 98 +317 93 +318 77 +319 93 +320 77 +321 98 +322 93 +323 86 +324 89 +325 73 +326 70 +327 71 +328 77 +329 87 +330 77 +331 93 +332 100 +333 83 +334 72 +335 74 +336 69 +337 77 +338 68 +339 78 +340 90 +341 98 +342 75 +343 80 +344 63 +345 71 +346 83 +347 66 +348 71 +349 70 +350 62 +351 62 +352 59 +353 63 +354 62 +355 52 +356 64 +357 64 +358 56 +359 49 +360 57 +361 63 +362 60 +363 68 +364 62 +365 55 +366 54 +367 40 +368 75 +369 70 +370 53 +371 58 +372 57 +373 55 +374 69 +375 57 +376 53 +377 43 +378 45 +379 47 +380 56 +381 51 +382 59 +383 51 +384 43 +385 34 +386 57 +387 49 +388 39 +389 46 +390 48 +391 43 +392 40 +393 54 +394 50 +395 41 +396 43 +397 33 +398 27 +399 49 +400 44 +401 44 +402 38 +403 30 +404 32 +405 37 +406 39 +407 42 +408 53 +409 39 +410 34 +411 31 +412 32 +413 52 +414 27 +415 41 +416 34 +417 36 +418 50 +419 35 +420 32 +421 33 +422 45 +423 35 +424 40 +425 29 +426 41 +427 40 +428 39 +429 32 +430 31 +431 34 +432 29 +433 27 +434 26 +435 22 +436 34 +437 28 +438 30 +439 38 +440 35 +441 36 +442 36 +443 27 +444 24 +445 33 +446 31 +447 25 +448 33 +449 27 +450 32 +451 46 +452 31 +453 35 +454 35 +455 34 +456 26 +457 21 +458 25 +459 26 +460 24 +461 27 +462 33 +463 30 +464 35 +465 21 +466 32 +467 19 +468 27 +469 16 +470 28 +471 26 +472 27 +473 26 +474 25 +475 25 +476 27 +477 20 +478 28 +479 22 +480 23 +481 16 +482 25 +483 27 +484 19 +485 23 +486 19 +487 15 +488 15 +489 23 +490 24 +491 19 +492 20 +493 18 +494 17 +495 30 +496 28 +497 20 +498 29 +499 17 +500 19 +501 21 +502 15 +503 24 +504 15 +505 19 +506 25 +507 16 +508 23 +509 26 +510 21 +511 15 +512 12 +513 16 +514 18 +515 24 +516 26 +517 18 +518 8 +519 25 +520 14 +521 8 +522 24 +523 20 +524 18 +525 15 +526 13 +527 17 +528 18 +529 22 +530 21 +531 9 +532 16 +533 17 +534 13 +535 17 +536 15 +537 13 +538 20 +539 13 +540 19 +541 29 +542 10 +543 8 +544 18 +545 13 +546 9 +547 18 +548 10 +549 18 +550 18 +551 9 +552 9 +553 15 +554 13 +555 15 +556 14 +557 14 +558 18 +559 8 +560 13 +561 9 +562 7 +563 12 +564 6 +565 9 +566 9 +567 18 +568 9 +569 10 +570 13 +571 14 +572 13 +573 21 +574 8 +575 16 +576 12 +577 9 +578 16 +579 17 +580 22 +581 6 +582 14 +583 13 +584 15 +585 11 +586 13 +587 5 +588 12 +589 13 +590 15 +591 13 +592 15 +593 12 +594 7 +595 18 +596 12 +597 13 +598 13 +599 13 +600 12 +601 12 +602 10 +603 11 +604 6 +605 6 +606 2 +607 9 +608 8 +609 12 +610 9 +611 12 +612 13 +613 12 +614 14 +615 9 +616 8 +617 9 +618 14 +619 13 +620 12 +621 6 +622 8 +623 8 +624 8 +625 12 +626 8 +627 7 +628 5 +629 8 +630 12 +631 6 +632 10 +633 10 +634 7 +635 8 +636 9 +637 6 +638 9 +639 4 +640 12 +641 4 +642 3 +643 11 +644 10 +645 6 +646 12 +647 12 +648 4 +649 4 +650 9 +651 8 +652 6 +653 5 +654 14 +655 10 +656 11 +657 8 +658 5 +659 5 +660 9 +661 13 +662 4 +663 5 +664 9 +665 11 +666 12 +667 7 +668 13 +669 2 +670 1 +671 7 +672 7 +673 7 +674 10 +675 9 +676 6 +677 5 +678 7 +679 6 +680 3 +681 3 +682 4 +683 9 +684 8 +685 5 +686 3 +687 11 +688 9 +689 2 +690 6 +691 5 +692 9 +693 5 +694 6 +695 5 +696 9 +697 8 +698 3 +699 7 +700 5 +701 9 +702 8 +703 7 +704 2 +705 3 +706 7 +707 6 +708 6 +709 10 +710 2 +711 10 +712 6 +713 7 +714 5 +715 6 +716 4 +717 6 +718 8 +719 4 +720 6 +721 7 +722 5 +723 7 +724 3 +725 10 +726 10 +727 3 +728 7 +729 7 +730 5 +731 2 +732 1 +733 5 +734 1 +735 5 +736 6 +737 2 +738 2 +739 3 +740 7 +741 2 +742 7 +743 4 +744 5 +745 4 +746 5 +747 3 +748 1 +749 4 +750 4 +751 2 +752 4 +753 6 +754 6 +755 6 +756 3 +757 2 +758 5 +759 5 +760 3 +761 4 +762 2 +763 1 +764 8 +765 3 +766 4 +767 3 +768 1 +769 5 +770 3 +771 3 +772 4 +773 4 +774 1 +775 3 +776 2 +777 2 +778 3 +779 3 +780 1 +781 4 +782 3 +783 4 +784 6 +785 3 +786 5 +787 4 +788 2 +789 4 +790 5 +791 4 +792 6 +794 4 +795 1 +796 1 +797 4 +798 2 +799 3 +800 3 +801 1 +802 5 +803 5 +804 3 +805 3 +806 3 +807 4 +808 4 +809 2 +811 5 +812 4 +813 6 +814 3 +815 2 +816 2 +817 3 +818 5 +819 3 +820 1 +821 1 +822 4 +823 3 +824 4 +825 8 +826 3 +827 5 +828 5 +829 3 +830 6 +831 3 +832 4 +833 8 +834 5 +835 3 +836 3 +837 2 +838 4 +839 2 +840 1 +841 3 +842 2 +843 1 +844 3 +846 4 +847 4 +848 3 +849 3 +850 2 +851 3 +853 1 +854 4 +855 4 +856 2 +857 4 +858 1 +859 2 +860 5 +861 1 +862 1 +863 4 +864 2 +865 2 +867 5 +868 1 +869 4 +870 1 +871 1 +872 1 +873 2 +875 5 +876 3 +877 1 +878 3 +879 3 +880 3 +881 2 +882 1 +883 6 +884 2 +885 2 +886 1 +887 1 +888 3 +889 2 +890 2 +891 3 +892 1 +893 3 +894 1 +895 5 +896 1 +897 3 +899 2 +900 2 +902 1 +903 2 +904 4 +905 4 +906 3 +907 1 +908 1 +909 2 +910 5 +911 2 +912 3 +914 1 +915 1 +916 2 +918 2 +919 2 +920 4 +921 4 +922 1 +923 1 +924 4 +925 5 +926 1 +928 2 +929 1 +930 1 +931 1 +932 1 +933 1 +934 2 +935 1 +936 1 +937 1 +938 2 +939 1 +941 1 +942 4 +944 2 +945 2 +946 2 +947 1 +948 1 +950 1 +951 2 +953 1 +954 2 +955 1 +956 1 +957 2 +958 1 +960 3 +962 4 +963 1 +964 1 +965 3 +966 2 +967 2 +968 1 +969 3 +970 3 +972 1 +974 4 +975 3 +976 3 +977 2 +979 2 +980 1 +981 1 +983 5 +984 1 +985 3 +986 1 +987 2 +988 4 +989 2 +991 2 +992 2 +993 1 +994 1 +996 2 +997 2 +998 1 +999 3 +1000 2 +1001 1 +1002 3 +1003 3 +1004 2 +1005 3 +1006 1 +1007 2 +1009 1 +1011 1 +1013 3 +1014 1 +1016 2 +1017 1 +1018 1 +1019 1 +1020 4 +1021 1 +1022 2 +1025 1 +1026 1 +1027 2 +1028 1 +1030 1 +1031 2 +1032 4 +1034 3 +1035 2 +1036 1 +1038 1 +1039 1 +1040 1 +1041 1 +1042 2 +1043 1 +1044 2 +1045 4 +1048 1 +1050 1 +1051 1 +1052 2 +1054 1 +1055 3 +1056 2 +1057 1 +1059 1 +1061 2 +1063 1 +1064 1 +1065 1 +1066 1 +1067 1 +1068 1 +1069 2 +1074 1 +1075 1 +1077 1 +1078 1 +1079 1 +1082 1 +1085 1 +1088 1 +1090 1 +1091 1 +1092 2 +1094 2 +1097 2 +1098 1 +1099 2 +1101 2 +1102 1 +1104 1 +1105 1 +1107 1 +1109 1 +1111 2 +1112 1 +1114 2 +1115 2 +1116 2 +1117 1 +1118 1 +1119 1 +1120 1 +1122 1 +1123 1 +1127 1 +1128 3 +1132 2 +1138 3 +1142 1 +1145 4 +1150 1 +1153 2 +1154 1 +1158 1 +1159 1 +1163 1 +1165 1 +1169 2 +1174 1 +1176 1 +1177 1 +1178 2 +1179 1 +1180 2 +1181 1 +1182 1 +1183 2 +1185 1 +1187 1 +1191 2 +1193 1 +1195 3 +1196 1 +1201 3 +1203 1 +1206 1 +1210 1 +1213 1 +1214 1 +1215 2 +1218 1 +1220 1 +1221 1 +1225 1 +1226 1 +1233 2 +1241 1 +1243 1 +1249 1 +1250 2 +1251 1 +1254 1 +1255 2 +1260 1 +1268 1 +1270 1 +1273 1 +1274 1 +1277 1 +1284 1 +1287 1 +1291 1 +1292 2 +1294 1 +1295 2 +1297 1 +1298 1 +1301 1 +1307 1 +1308 3 +1311 2 +1313 1 +1316 1 +1321 1 +1324 1 +1325 1 +1330 1 +1333 1 +1334 1 +1338 2 +1340 1 +1341 1 +1342 1 +1343 1 +1345 1 +1355 1 +1357 1 +1360 2 +1375 1 +1376 1 +1380 1 +1383 1 +1387 1 +1389 1 +1393 1 +1394 1 +1396 1 +1398 1 +1410 1 +1414 1 +1419 1 +1425 1 +1434 1 +1435 1 +1438 1 +1439 1 +1447 1 +1455 2 +1460 1 +1461 1 +1463 1 +1466 1 +1470 1 +1473 1 +1478 1 +1480 1 +1483 1 +1484 1 +1485 2 +1492 2 +1499 1 +1509 1 +1512 1 +1513 1 +1523 1 +1524 1 +1525 2 +1529 1 +1539 1 +1544 1 +1568 1 +1584 1 +1591 1 +1598 1 +1600 1 +1604 1 +1614 1 +1617 1 +1621 1 +1622 1 +1626 1 +1638 1 +1648 1 +1658 1 +1661 1 +1679 1 +1682 1 +1693 1 +1700 1 +1705 1 +1707 1 +1722 1 +1728 1 +1758 1 +1762 1 +1763 1 +1775 1 +1776 1 +1801 1 +1810 1 +1812 1 +1827 1 +1834 1 +1846 1 +1847 1 +1848 1 +1851 1 +1862 1 +1866 1 +1877 2 +1884 1 +1888 1 +1903 1 +1912 1 +1925 1 +1938 1 +1955 1 +1998 1 +2054 1 +2058 1 +2065 1 +2069 1 +2076 1 +2089 1 +2104 1 +2111 1 +2133 1 +2138 1 +2156 1 +2204 1 +2212 1 +2237 1 +2246 2 +2298 1 +2304 1 +2360 1 +2400 1 +2481 1 +2544 1 +2586 1 +2622 1 +2666 1 +2682 1 +2725 1 +2920 1 +3997 1 +4019 1 +5211 1 +12 19 +14 1 +16 401 +18 2 +20 421 +22 557 +24 625 +26 50 +28 4481 +30 52 +32 550 +34 5840 +36 4644 +38 87 +40 5794 +41 33 +42 571 +44 11805 +46 4711 +47 7 +48 597 +49 12 +50 678 +51 2 +52 14715 +53 3 +54 7322 +55 3 +56 508 +57 39 +58 3486 +59 11 +60 8974 +61 45 +62 1276 +63 4 +64 15693 +65 15 +66 657 +67 13 +68 6409 +69 10 +70 3188 +71 25 +72 1889 +73 27 +74 10370 +75 9 +76 12432 +77 23 +78 520 +79 15 +80 1534 +81 29 +82 2944 +83 23 +84 12071 +85 36 +86 1502 +87 10 +88 10978 +89 11 +90 889 +91 16 +92 4571 +93 17 +94 7855 +95 21 +96 2271 +97 33 +98 1423 +99 15 +100 11096 +101 21 +102 4082 +103 13 +104 5442 +105 25 +106 2113 +107 26 +108 3779 +109 43 +110 1294 +111 29 +112 7860 +113 29 +114 4965 +115 22 +116 7898 +117 25 +118 1772 +119 28 +120 1149 +121 38 +122 1483 +123 32 +124 10572 +125 25 +126 1147 +127 31 +128 1699 +129 22 +130 5533 +131 22 +132 4669 +133 34 +134 3777 +135 10 +136 5412 +137 21 +138 855 +139 26 +140 2485 +141 46 +142 1970 +143 27 +144 6565 +145 40 +146 933 +147 15 +148 7923 +149 16 +150 735 +151 23 +152 1111 +153 33 +154 3714 +155 27 +156 2445 +157 30 +158 3367 +159 10 +160 4646 +161 27 +162 990 +163 23 +164 5679 +165 25 +166 2186 +167 17 +168 899 +169 32 +170 1034 +171 22 +172 6185 +173 32 +174 2685 +175 17 +176 1354 +177 38 +178 1460 +179 15 +180 3478 +181 20 +182 958 +183 20 +184 6055 +185 23 +186 2180 +187 15 +188 1416 +189 30 +190 1284 +191 22 +192 1341 +193 21 +194 2413 +195 18 +196 4984 +197 13 +198 830 +199 22 +200 1834 +201 19 +202 2238 +203 9 +204 3050 +205 22 +206 616 +207 17 +208 2892 +209 22 +210 711 +211 30 +212 2631 +213 19 +214 3341 +215 21 +216 987 +217 26 +218 823 +219 9 +220 3588 +221 20 +222 692 +223 7 +224 2925 +225 31 +226 1075 +227 16 +228 2909 +229 18 +230 673 +231 20 +232 2215 +233 14 +234 1584 +235 21 +236 1292 +237 29 +238 1647 +239 25 +240 1014 +241 30 +242 1648 +243 19 +244 4465 +245 10 +246 787 +247 11 +248 480 +249 25 +250 842 +251 15 +252 1219 +253 23 +254 1508 +255 8 +256 3525 +257 16 +258 490 +259 12 +260 1678 +261 14 +262 822 +263 16 +264 1729 +265 28 +266 604 +267 11 +268 2572 +269 7 +270 1242 +271 15 +272 725 +273 18 +274 1983 +275 13 +276 1662 +277 19 +278 491 +279 12 +280 1586 +281 14 +282 563 +283 10 +284 2363 +285 10 +286 656 +287 14 +288 725 +289 28 +290 871 +291 9 +292 2606 +293 12 +294 961 +295 9 +296 478 +297 13 +298 1252 +299 10 +300 736 +301 19 +302 466 +303 13 +304 2254 +305 12 +306 486 +307 14 +308 1145 +309 13 +310 955 +311 13 +312 1235 +313 13 +314 931 +315 14 +316 1768 +317 11 +318 330 +319 10 +320 539 +321 23 +322 570 +323 12 +324 1789 +325 13 +326 884 +327 5 +328 1422 +329 14 +330 317 +331 11 +332 509 +333 13 +334 1062 +335 12 +336 577 +337 27 +338 378 +339 10 +340 2313 +341 9 +342 391 +343 13 +344 894 +345 17 +346 664 +347 9 +348 453 +349 6 +350 363 +351 15 +352 1115 +353 13 +354 1054 +355 8 +356 1108 +357 12 +358 354 +359 7 +360 363 +361 16 +362 344 +363 11 +364 1734 +365 12 +366 265 +367 10 +368 969 +369 16 +370 316 +371 12 +372 757 +373 7 +374 563 +375 15 +376 857 +377 9 +378 469 +379 9 +380 385 +381 12 +382 921 +383 15 +384 764 +385 14 +386 246 +387 6 +388 1108 +389 14 +390 230 +391 8 +392 266 +393 11 +394 641 +395 8 +396 719 +397 9 +398 243 +399 4 +400 1108 +401 7 +402 229 +403 7 +404 903 +405 7 +406 257 +407 12 +408 244 +409 3 +410 541 +411 6 +412 744 +413 8 +414 419 +415 8 +416 388 +417 19 +418 470 +419 14 +420 612 +421 6 +422 342 +423 3 +424 1179 +425 3 +426 116 +427 14 +428 207 +429 6 +430 255 +431 4 +432 288 +433 12 +434 343 +435 6 +436 1015 +437 3 +438 538 +439 10 +440 194 +441 6 +442 188 +443 15 +444 524 +445 7 +446 214 +447 7 +448 574 +449 6 +450 214 +451 5 +452 635 +453 9 +454 464 +455 5 +456 205 +457 9 +458 163 +459 2 +460 558 +461 4 +462 171 +463 14 +464 444 +465 11 +466 543 +467 5 +468 388 +469 6 +470 141 +471 4 +472 647 +473 3 +474 210 +475 4 +476 193 +477 7 +478 195 +479 7 +480 443 +481 10 +482 198 +483 3 +484 816 +485 6 +486 128 +487 9 +488 215 +489 9 +490 328 +491 7 +492 158 +493 11 +494 335 +495 8 +496 435 +497 6 +498 174 +499 1 +500 373 +501 5 +502 140 +503 7 +504 330 +505 9 +506 149 +507 5 +508 642 +509 3 +510 179 +511 3 +512 159 +513 8 +514 204 +515 7 +516 306 +517 4 +518 110 +519 5 +520 326 +521 6 +522 305 +523 6 +524 294 +525 7 +526 268 +527 5 +528 149 +529 4 +530 133 +531 2 +532 513 +533 10 +534 116 +535 5 +536 258 +537 4 +538 113 +539 4 +540 138 +541 6 +542 116 +544 485 +545 4 +546 93 +547 9 +548 299 +549 3 +550 256 +551 6 +552 92 +553 3 +554 175 +555 6 +556 253 +557 7 +558 95 +559 2 +560 128 +561 4 +562 206 +563 2 +564 465 +565 3 +566 69 +567 3 +568 157 +569 7 +570 97 +571 8 +572 118 +573 5 +574 130 +575 4 +576 301 +577 6 +578 177 +579 2 +580 397 +581 3 +582 80 +583 1 +584 128 +585 5 +586 52 +587 2 +588 72 +589 1 +590 84 +591 6 +592 323 +593 11 +594 77 +595 5 +596 205 +597 1 +598 244 +599 4 +600 69 +601 3 +602 89 +603 5 +604 254 +605 6 +606 147 +607 3 +608 83 +609 3 +610 77 +611 3 +612 194 +613 1 +614 98 +615 3 +616 243 +617 3 +618 50 +619 8 +620 188 +621 4 +622 67 +623 4 +624 123 +625 2 +626 50 +627 1 +628 239 +629 2 +630 51 +631 4 +632 65 +633 5 +634 188 +636 81 +637 3 +638 46 +639 3 +640 103 +641 1 +642 136 +643 3 +644 188 +645 3 +646 58 +648 122 +649 4 +650 47 +651 2 +652 155 +653 4 +654 71 +655 1 +656 71 +657 3 +658 50 +659 2 +660 177 +661 5 +662 66 +663 2 +664 183 +665 3 +666 50 +667 2 +668 53 +669 2 +670 115 +672 66 +673 2 +674 47 +675 1 +676 197 +677 2 +678 46 +679 3 +680 95 +681 3 +682 46 +683 3 +684 107 +685 1 +686 86 +687 2 +688 158 +689 4 +690 51 +691 1 +692 80 +694 56 +695 4 +696 40 +698 43 +699 3 +700 95 +701 2 +702 51 +703 2 +704 133 +705 1 +706 100 +707 2 +708 121 +709 2 +710 15 +711 3 +712 35 +713 2 +714 20 +715 3 +716 37 +717 2 +718 78 +720 55 +721 1 +722 42 +723 2 +724 218 +725 3 +726 23 +727 2 +728 26 +729 1 +730 64 +731 2 +732 65 +734 24 +735 2 +736 53 +737 1 +738 32 +739 1 +740 60 +742 81 +743 1 +744 77 +745 1 +746 47 +747 1 +748 62 +749 1 +750 19 +751 1 +752 86 +753 3 +754 40 +756 55 +757 2 +758 38 +759 1 +760 101 +761 1 +762 22 +764 67 +765 2 +766 35 +767 1 +768 38 +769 1 +770 22 +771 1 +772 82 +773 1 +774 73 +776 29 +777 1 +778 55 +780 23 +781 1 +782 16 +784 84 +785 3 +786 28 +788 59 +789 1 +790 33 +791 3 +792 24 +794 13 +795 1 +796 110 +797 2 +798 15 +800 22 +801 3 +802 29 +803 1 +804 87 +806 21 +808 29 +810 48 +812 28 +813 1 +814 58 +815 1 +816 48 +817 1 +818 31 +819 1 +820 66 +822 17 +823 2 +824 58 +826 10 +827 2 +828 25 +829 1 +830 29 +831 1 +832 63 +833 1 +834 26 +835 3 +836 52 +837 1 +838 18 +840 27 +841 2 +842 12 +843 1 +844 83 +845 1 +846 7 +847 1 +848 10 +850 26 +852 25 +853 1 +854 15 +856 27 +858 32 +859 1 +860 15 +862 43 +864 32 +865 1 +866 6 +868 39 +870 11 +872 25 +873 1 +874 10 +875 1 +876 20 +877 2 +878 19 +879 1 +880 30 +882 11 +884 53 +886 25 +887 1 +888 28 +890 6 +892 36 +894 10 +896 13 +898 14 +900 31 +902 14 +903 2 +904 43 +906 25 +908 9 +910 11 +911 1 +912 16 +913 1 +914 24 +916 27 +918 6 +920 15 +922 27 +923 1 +924 23 +926 13 +928 42 +929 1 +930 3 +932 27 +934 17 +936 8 +937 1 +938 11 +940 33 +942 4 +943 1 +944 18 +946 15 +948 13 +950 18 +952 12 +954 11 +956 21 +958 10 +960 13 +962 5 +964 32 +966 13 +968 8 +970 8 +971 1 +972 23 +973 2 +974 12 +975 1 +976 22 +978 7 +979 1 +980 14 +982 8 +984 22 +985 1 +986 6 +988 17 +989 1 +990 6 +992 13 +994 19 +996 11 +998 4 +1000 9 +1002 2 +1004 14 +1006 5 +1008 3 +1010 9 +1012 29 +1014 6 +1016 22 +1017 1 +1018 8 +1019 1 +1020 7 +1022 6 +1023 1 +1024 10 +1026 2 +1028 8 +1030 11 +1031 2 +1032 8 +1034 9 +1036 13 +1038 12 +1040 12 +1042 3 +1044 12 +1046 3 +1048 11 +1050 2 +1051 1 +1052 2 +1054 11 +1056 6 +1058 8 +1059 1 +1060 23 +1062 6 +1063 1 +1064 8 +1066 3 +1068 6 +1070 8 +1071 1 +1072 5 +1074 3 +1076 5 +1078 3 +1080 11 +1081 1 +1082 7 +1084 18 +1086 4 +1087 1 +1088 3 +1090 3 +1092 7 +1094 3 +1096 12 +1098 6 +1099 1 +1100 2 +1102 6 +1104 14 +1106 3 +1108 6 +1110 5 +1112 2 +1114 8 +1116 3 +1118 3 +1120 7 +1122 10 +1124 6 +1126 8 +1128 1 +1130 4 +1132 3 +1134 2 +1136 5 +1138 5 +1140 8 +1142 3 +1144 7 +1146 3 +1148 11 +1150 1 +1152 5 +1154 1 +1156 5 +1158 1 +1160 5 +1162 3 +1164 6 +1165 1 +1166 1 +1168 4 +1169 1 +1170 3 +1171 1 +1172 2 +1174 5 +1176 3 +1177 1 +1180 8 +1182 2 +1184 4 +1186 2 +1188 3 +1190 2 +1192 5 +1194 6 +1196 1 +1198 2 +1200 2 +1204 10 +1206 2 +1208 9 +1210 1 +1214 6 +1216 3 +1218 4 +1220 9 +1221 2 +1222 1 +1224 5 +1226 4 +1228 8 +1230 1 +1232 1 +1234 3 +1236 5 +1240 3 +1242 1 +1244 3 +1245 1 +1246 4 +1248 6 +1250 2 +1252 7 +1256 3 +1258 2 +1260 2 +1262 3 +1264 4 +1265 1 +1266 1 +1270 1 +1271 1 +1272 2 +1274 3 +1276 3 +1278 1 +1280 3 +1284 1 +1286 1 +1290 1 +1292 3 +1294 1 +1296 7 +1300 2 +1302 4 +1304 3 +1306 2 +1308 2 +1312 1 +1314 1 +1316 3 +1318 2 +1320 1 +1324 8 +1326 1 +1330 1 +1331 1 +1336 2 +1338 1 +1340 3 +1341 1 +1344 1 +1346 2 +1347 1 +1348 3 +1352 1 +1354 2 +1356 1 +1358 1 +1360 3 +1362 1 +1364 4 +1366 1 +1370 1 +1372 3 +1380 2 +1384 2 +1388 2 +1390 2 +1392 2 +1394 1 +1396 1 +1398 1 +1400 2 +1402 1 +1404 1 +1406 1 +1410 1 +1412 5 +1418 1 +1420 1 +1424 1 +1432 2 +1434 2 +1442 3 +1444 5 +1448 1 +1454 1 +1456 1 +1460 3 +1462 4 +1468 1 +1474 1 +1476 1 +1478 2 +1480 1 +1486 2 +1488 1 +1492 1 +1496 1 +1500 3 +1503 1 +1506 1 +1512 2 +1516 1 +1522 1 +1524 2 +1534 4 +1536 1 +1538 1 +1540 2 +1544 2 +1548 1 +1556 1 +1560 1 +1562 1 +1564 2 +1566 1 +1568 1 +1570 1 +1572 1 +1576 1 +1590 1 +1594 1 +1604 1 +1608 1 +1614 1 +1622 1 +1624 2 +1628 1 +1629 1 +1636 1 +1642 1 +1654 2 +1660 1 +1664 1 +1670 1 +1684 4 +1698 1 +1732 3 +1742 1 +1752 1 +1760 1 +1764 1 +1772 2 +1798 1 +1808 1 +1820 1 +1852 1 +1856 1 +1874 1 +1902 1 +1908 1 +1952 1 +2004 1 +2018 1 +2020 1 +2028 1 +2174 1 +2233 1 +2244 1 +2280 1 +2290 1 +2352 1 +2604 1 +4190 1 diff --git a/ppocr/utils/logging.py b/ppocr/utils/logging.py index 951141db8f39acac612029c8b69f4a29a0ab27ce..11896c37d9285e19a9526caa9c637d7eda7b1979 100644 --- a/ppocr/utils/logging.py +++ b/ppocr/utils/logging.py @@ -22,7 +22,7 @@ logger_initialized = {} @functools.lru_cache() -def get_logger(name='root', log_file=None, log_level=logging.INFO): +def get_logger(name='root', log_file=None, log_level=logging.DEBUG): """Initialize and get a logger by name. If the logger has not been initialized, this method will initialize the logger by adding one or two handlers, otherwise the initialized logger will diff --git a/ppocr/utils/network.py b/ppocr/utils/network.py new file mode 100644 index 0000000000000000000000000000000000000000..453abb693d4c0ed370c1031b677d5bf51661add9 --- /dev/null +++ b/ppocr/utils/network.py @@ -0,0 +1,82 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import tarfile +import requests +from tqdm import tqdm + +from ppocr.utils.logging import get_logger + + +def download_with_progressbar(url, save_path): + logger = get_logger() + response = requests.get(url, stream=True) + total_size_in_bytes = int(response.headers.get('content-length', 0)) + block_size = 1024 # 1 Kibibyte + progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) + with open(save_path, 'wb') as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + progress_bar.close() + if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes: + logger.error("Something went wrong while downloading models") + sys.exit(0) + + +def maybe_download(model_storage_directory, url): + # using custom model + tar_file_name_list = [ + 'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel' + ] + if not os.path.exists( + os.path.join(model_storage_directory, 'inference.pdiparams') + ) or not os.path.exists( + os.path.join(model_storage_directory, 'inference.pdmodel')): + assert url.endswith('.tar'), 'Only supports tar compressed package' + tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) + print('download {} to {}'.format(url, tmp_path)) + os.makedirs(model_storage_directory, exist_ok=True) + download_with_progressbar(url, tmp_path) + with tarfile.open(tmp_path, 'r') as tarObj: + for member in tarObj.getmembers(): + filename = None + for tar_file_name in tar_file_name_list: + if tar_file_name in member.name: + filename = tar_file_name + if filename is None: + continue + file = tarObj.extractfile(member) + with open( + os.path.join(model_storage_directory, filename), + 'wb') as f: + f.write(file.read()) + os.remove(tmp_path) + + +def is_link(s): + return s is not None and s.startswith('http') + + +def confirm_model_dir_url(model_dir, default_model_dir, default_url): + url = default_url + if model_dir is None or is_link(model_dir): + if is_link(model_dir): + url = model_dir + file_name = url.split('/')[-1][:-4] + model_dir = default_model_dir + model_dir = os.path.join(model_dir, file_name) + return model_dir, url diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 3d1c5c356c9510dd701048aee8cbb3e73e8a059a..23f5401bb71a2ef50ff2ff2c3c27275d7e10b3c0 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -23,6 +23,8 @@ import six import paddle +from ppocr.utils.logging import get_logger + __all__ = ['init_model', 'save_model', 'load_dygraph_pretrain'] @@ -42,44 +44,11 @@ def _mkdir_if_not_exist(path, logger): raise OSError('Failed to mkdir {}'.format(path)) -def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False): - if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): - raise ValueError("Model pretrain path {} does not " - "exists.".format(path)) - if load_static_weights: - pre_state_dict = paddle.static.load_program_state(path) - param_state_dict = {} - model_dict = model.state_dict() - for key in model_dict.keys(): - weight_name = model_dict[key].name - weight_name = weight_name.replace('binarize', '').replace( - 'thresh', '') # for DB - if weight_name in pre_state_dict.keys(): - # logger.info('Load weight: {}, shape: {}'.format( - # weight_name, pre_state_dict[weight_name].shape)) - if 'encoder_rnn' in key: - # delete axis which is 1 - pre_state_dict[weight_name] = pre_state_dict[ - weight_name].squeeze() - # change axis - if len(pre_state_dict[weight_name].shape) > 1: - pre_state_dict[weight_name] = pre_state_dict[ - weight_name].transpose((1, 0)) - param_state_dict[key] = pre_state_dict[weight_name] - else: - param_state_dict[key] = model_dict[key] - model.set_state_dict(param_state_dict) - return - - param_state_dict = paddle.load(path + '.pdparams') - model.set_state_dict(param_state_dict) - return - - -def init_model(config, model, logger, optimizer=None, lr_scheduler=None): +def init_model(config, model, optimizer=None, lr_scheduler=None): """ load model from checkpoint or pretrained_model """ + logger = get_logger() global_config = config['Global'] checkpoints = global_config.get('checkpoints') pretrained_model = global_config.get('pretrained_model') @@ -102,18 +71,17 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None): best_model_dict = states_dict.get('best_model_dict', {}) if 'epoch' in states_dict: best_model_dict['start_epoch'] = states_dict['epoch'] + 1 - logger.info("resume from {}".format(checkpoints)) elif pretrained_model: - load_static_weights = global_config.get('load_static_weights', False) if not isinstance(pretrained_model, list): pretrained_model = [pretrained_model] - if not isinstance(load_static_weights, list): - load_static_weights = [load_static_weights] * len(pretrained_model) - for idx, pretrained in enumerate(pretrained_model): - load_static = load_static_weights[idx] - load_dygraph_pretrain( - model, logger, path=pretrained, load_static_weights=load_static) + for pretrained in pretrained_model: + if not (os.path.isdir(pretrained) or + os.path.exists(pretrained + '.pdparams')): + raise ValueError("Model pretrain path {} does not " + "exists.".format(pretrained)) + param_state_dict = paddle.load(pretrained + '.pdparams') + model.set_state_dict(param_state_dict) logger.info("load pretrained model from {}".format( pretrained_model)) else: diff --git a/test1/MANIFEST.in b/test1/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..2961e722b7cebe8e1912be2dd903fcdecb694019 --- /dev/null +++ b/test1/MANIFEST.in @@ -0,0 +1,9 @@ +include LICENSE +include README.md + +recursive-include ppocr/utils *.txt utility.py logging.py network.py +recursive-include ppocr/data/ *.py +recursive-include ppocr/postprocess *.py +recursive-include tools/infer *.py +recursive-include ppstructure *.py + diff --git a/test1/__init__.py b/test1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7055bee443fb86648b80bcb892778a114bc47d71 --- /dev/null +++ b/test1/__init__.py @@ -0,0 +1,17 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .paddlestructure import PaddleStructure, draw_result, to_excel + +__all__ = ['PaddleStructure', 'draw_result', 'to_excel'] diff --git a/test1/api.md b/test1/api.md new file mode 100644 index 0000000000000000000000000000000000000000..6ce2e5904188643839e2a21c137eaa6cf78619c9 --- /dev/null +++ b/test1/api.md @@ -0,0 +1,86 @@ +# PaddleStructure + +install layoutparser +```sh +wget https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl +pip3 install layoutparser-0.0.0-py3-none-any.whl +``` + +## 1. Introduction to pipeline + +PaddleStructure is a toolkit for complex layout text OCR, the process is as follows + +![pipeline](../doc/table/pipeline.png) + +In PaddleStructure, the image will be analyzed by layoutparser first. In the layout analysis, the area in the image will be classified, and the OCR process will be carried out according to the category. + +Currently layoutparser will output five categories: +1. Text +2. Title +3. Figure +4. List +5. Table + +Types 1-4 follow the traditional OCR process, and 5 follow the Table OCR process. + +## 2. LayoutParser + + +## 3. Table OCR + +[doc](table/README.md) + +## 4. Predictive by inference engine + +Use the following commands to complete the inference +```python +python3 table/predict_system.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table +``` +After running, each image will have a directory with the same name under the directory specified in the output field. Each table in the picture will be stored as an excel, and the excel file name will be the coordinates of the table in the image. + +## 5. PaddleStructure whl package introduction + +### 5.1 Use + +5.1.1 Use by code +```python +import os +import cv2 +from paddlestructure import PaddleStructure,draw_result,save_res + +table_engine = PaddleStructure(show_log=True) + +save_folder = './output/table' +img_path = '../doc/table/1.png' +img = cv2.imread(img_path) +result = table_engine(img) +save_res(result, save_folder,os.path.basename(img_path).split('.')[0]) + +for line in result: + print(line) + +from PIL import Image + +font_path = 'path/tp/PaddleOCR/doc/fonts/simfang.ttf' +image = Image.open(img_path).convert('RGB') +im_show = draw_result(image, result,font_path=font_path) +im_show = Image.fromarray(im_show) +im_show.save('result.jpg') +``` + +5.1.2 Use by command line +```bash +paddlestructure --image_dir=../doc/table/1.png +``` + +### Parameter Description +Most of the parameters are consistent with the paddleocr whl package, see [whl package documentation](../doc/doc_ch/whl.md) + +| Parameter | Description | Default | +|------------------------|------------------------------------------------------|------------------| +| output | The path where excel and recognition results are saved | ./output/table | +| structure_max_len | When the table structure model predicts, the long side of the image is resized | 488 | +| structure_model_dir | Table structure inference model path | None | +| structure_char_type | Dictionary path used by table structure model | ../ppocr/utils/dict/table_structure_dict.tx | + + diff --git a/test1/api_ch.md b/test1/api_ch.md new file mode 100644 index 0000000000000000000000000000000000000000..585379e8c6f717733ab436749441be0668b4c6d8 --- /dev/null +++ b/test1/api_ch.md @@ -0,0 +1,86 @@ +# PaddleStructure + +安装layoutparser +```sh +wget https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl +pip3 install layoutparser-0.0.0-py3-none-any.whl +``` + +## 1. pipeline介绍 + +PaddleStructure 是一个用于复杂板式文字OCR的工具包,流程如下 +![pipeline](../doc/table/pipeline.png) + +在PaddleStructure中,图片会先经由layoutparser进行版面分析,在版面分析中,会对图片里的区域进行分类,根据根据类别进行对于的ocr流程。 + +目前layoutparser会输出五个类别: +1. Text +2. Title +3. Figure +4. List +5. Table + +1-4类走传统的OCR流程,5走表格的OCR流程。 + +## 2. LayoutParser + +[文档](layout/README.md) + +## 3. Table OCR + +[文档](table/README_ch.md) + +## 4. 预测引擎推理 + +使用如下命令即可完成预测引擎的推理 +```python +python3 table/predict_system.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table +``` +运行完成后,每张图片会output字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,excel文件名为表格在图片里的坐标。 + +## 5. PaddleStructure whl包介绍 + +### 5.1 使用 + +5.1.1 代码使用 +```python +import os +import cv2 +from paddlestructure import PaddleStructure,draw_result,save_res + +table_engine = PaddleStructure(show_log=True) + +save_folder = './output/table' +img_path = '../doc/table/1.png' +img = cv2.imread(img_path) +result = table_engine(img) +save_res(result, save_folder,os.path.basename(img_path).split('.')[0]) + +for line in result: + print(line) + +from PIL import Image + +font_path = 'path/tp/PaddleOCR/doc/fonts/simfang.ttf' +image = Image.open(img_path).convert('RGB') +im_show = draw_result(image, result,font_path=font_path) +im_show = Image.fromarray(im_show) +im_show.save('result.jpg') +``` + +5.1.2 命令行使用 +```bash +paddlestructure --image_dir=../doc/table/1.png +``` + +### 参数说明 +大部分参数和paddleocr whl包保持一致,见 [whl包文档](../doc/doc_ch/whl.md) + +| 字段 | 说明 | 默认值 | +|------------------------|------------------------------------------------------|------------------| +| output | excel和识别结果保存的地址 | ./output/table | +| table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 | +| table_model_dir | 表格结构模型 inference 模型地址 | None | +| table_char_type | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.tx | + + diff --git a/test1/layout/README.md b/test1/layout/README.md new file mode 100644 index 0000000000000000000000000000000000000000..274a8c63a58543d3769bbd4b11133496e74f405a --- /dev/null +++ b/test1/layout/README.md @@ -0,0 +1,133 @@ +# 版面分析使用说明 + +* [1. 安装whl包](#安装whl包) +* [2. 使用](#使用) +* [3. 后处理](#后处理) +* [4. 指标](#指标) +* [5. 训练版面分析模型](#训练版面分析模型) + + + +## 1. 安装whl包 +```bash +wget https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl +pip install -U layoutparser-0.0.0-py3-none-any.whl +``` + + + +## 2. 使用 + +使用layoutparser识别给定文档的布局: + +```python +import layoutparser as lp +image = cv2.imread("imags/paper-image.jpg") +image = image[..., ::-1] + +# 加载模型 +model = lp.PaddleDetectionLayoutModel(config_path="lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config", + threshold=0.5, + label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}, + enforce_cpu=False, + enable_mkldnn=True) +# 检测 +layout = model.detect(image) + +# 显示结果 +lp.draw_box(image, layout, box_width=3, show_element_type=True) +``` + +下图展示了结果,不同颜色的检测框表示不同的类别,并通过`show_element_type`在框的左上角显示具体类别: + +
+ +
+ +`PaddleDetectionLayoutModel`函数参数说明如下: + +| 参数 | 含义 | 默认值 | 备注 | +| :------------: | :-------------------------: | :---------: | :----------------------------------------------------------: | +| config_path | 模型配置路径 | None | 指定config_path会自动下载模型(仅第一次,之后模型存在,不会再下载) | +| model_path | 模型路径 | None | 本地模型路径,config_path和model_path必须设置一个,不能同时为None | +| threshold | 预测得分的阈值 | 0.5 | \ | +| input_shape | reshape之后图片尺寸 | [3,640,640] | \ | +| batch_size | 测试batch size | 1 | \ | +| label_map | 类别映射表 | None | 设置config_path时,可以为None,根据数据集名称自动获取label_map | +| enforce_cpu | 代码是否使用CPU运行 | False | 设置为False表示使用GPU,True表示强制使用CPU | +| enforce_mkldnn | CPU预测中是否开启MKLDNN加速 | True | \ | +| thread_num | 设置CPU线程数 | 10 | \ | + +目前支持以下几种模型配置和label map,您可以通过修改 `--config_path`和 `--label_map`使用这些模型,从而检测不同类型的内容: + +| dataset | config_path | label_map | +| ------------------------------------------------------------ | ------------------------------------------------------------ | --------------------------------------------------------- | +| [TableBank](https://doc-analysis.github.io/tablebank-page/index.html) word | lp://TableBank/ppyolov2_r50vd_dcn_365e_tableBank_word/config | {0:"Table"} | +| TableBank latex | lp://TableBank/ppyolov2_r50vd_dcn_365e_tableBank_latex/config | {0:"Table"} | +| [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet) | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config | {0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"} | + +* TableBank word和TableBank latex分别在word文档、latex文档数据集训练; +* 下载TableBank数据集同时包含word和latex。 + + + +## 3. 后处理 + +版面分析检测包含多个类别,如果只想获取指定类别(如"Text"类别)的检测框、可以使用下述代码: + +```python +# 首先过滤特定文本类型的区域 +text_blocks = lp.Layout([b for b in layout if b.type=='Text']) +figure_blocks = lp.Layout([b for b in layout if b.type=='Figure']) + +# 因为在图像区域内可能检测到文本区域,所以只需要删除它们 +text_blocks = lp.Layout([b for b in text_blocks \ + if not any(b.is_in(b_fig) for b_fig in figure_blocks)]) + +# 对文本区域排序并分配id +h, w = image.shape[:2] + +left_interval = lp.Interval(0, w/2*1.05, axis='x').put_on_canvas(image) + +left_blocks = text_blocks.filter_by(left_interval, center=True) +left_blocks.sort(key = lambda b:b.coordinates[1]) + +right_blocks = [b for b in text_blocks if b not in left_blocks] +right_blocks.sort(key = lambda b:b.coordinates[1]) + +# 最终合并两个列表,并按顺序添加索引 +text_blocks = lp.Layout([b.set(id = idx) for idx, b in enumerate(left_blocks + right_blocks)]) + +# 显示结果 +lp.draw_box(image, text_blocks, + box_width=3, + show_element_id=True) +``` + +显示只有"Text"类别的结果: + +
+ +
+ + + +## 4. 指标 + +| Dataset | mAP | CPU time cost | GPU time cost | +| --------- | ---- | ------------- | ------------- | +| PubLayNet | 93.6 | 1713.7ms | 66.6ms | +| TableBank | 96.2 | 1968.4ms | 65.1ms | + +**Envrionment:** + +​ **CPU:** Intel(R) Xeon(R) CPU E5-2650 v4 @ 2.20GHz,24core + +​ **GPU:** a single NVIDIA Tesla P40 + + + +## 5. 训练版面分析模型 + +上述模型基于[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection) 训练,如果您想训练自己的版面分析模型,请参考:[train_layoutparser_model](train_layoutparser_model.md) + diff --git a/test1/layout/train_layoutparser_model.md b/test1/layout/train_layoutparser_model.md new file mode 100644 index 0000000000000000000000000000000000000000..0a4554e12d9e565fa8e3de4a83cbd2eb5b515c6e --- /dev/null +++ b/test1/layout/train_layoutparser_model.md @@ -0,0 +1,188 @@ +# 训练版面分析 + +* [1. 安装](#安装) + * [1.1 环境要求](#环境要求) + * [1.2 安装PaddleDetection](#安装PaddleDetection) +* [2. 准备数据](#准备数据) +* [3. 配置文件改动和说明](#配置文件改动和说明) +* [4. PaddleDetection训练](#训练) +* [5. PaddleDetection预测](#预测) +* [6. 预测部署](#预测部署) + * [6.1 模型导出](#模型导出) + * [6.2 layout parser预测](#layout_parser预测) + + + +## 1. 安装 + + + +### 1.1 环境要求 + +- PaddlePaddle 2.1 +- OS 64 bit +- Python 3(3.5.1+/3.6/3.7/3.8/3.9),64 bit +- pip/pip3(9.0.1+), 64 bit +- CUDA >= 10.1 +- cuDNN >= 7.6 + + + +### 1.2 安装PaddleDetection + +```bash +# 克隆PaddleDetection仓库 +cd +git clone https://github.com/PaddlePaddle/PaddleDetection.git + +cd PaddleDetection +# 安装其他依赖 +pip install -r requirements.txt +``` + +更多安装教程,请参考: [Install doc](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/INSTALL_cn.md) + + + +## 2. 准备数据 + +下载 [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet) 数据集: + +```bash +cd PaddleDetection/dataset/ +mkdir publaynet +# 执行命令,下载 +wget -O publaynet.tar.gz https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/publaynet.tar.gz?_ga=2.104193024.1076900768.1622560733-649911202.1622560733 +# 解压 +tar -xvf publaynet.tar.gz +``` + +解压之后PubLayNet目录结构: + +| File or Folder | Description | num | +| :------------- | :----------------------------------------------- | ------- | +| `train/` | Images in the training subset | 335,703 | +| `val/` | Images in the validation subset | 11,245 | +| `test/` | Images in the testing subset | 11,405 | +| `train.json` | Annotations for training images | | +| `val.json` | Annotations for validation images | | +| `LICENSE.txt` | Plaintext version of the CDLA-Permissive license | | +| `README.txt` | Text file with the file names and description | | + +如果使用其它数据集,请参考[准备训练数据](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/PrepareDataSet.md) + + + +## 3. 配置文件改动和说明 + +我们使用 `configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml`配置进行训练,配置文件摘要如下: + +
+ +
+ +从上图看到 `ppyolov2_r50vd_dcn_365e_coco.yml` 配置需要依赖其他的配置文件,在该例子中需要依赖: + +``` +coco_detection.yml:主要说明了训练数据和验证数据的路径 + +runtime.yml:主要说明了公共的运行参数,比如是否使用GPU、每多少个epoch存储checkpoint等 + +optimizer_365e.yml:主要说明了学习率和优化器的配置 + +ppyolov2_r50vd_dcn.yml:主要说明模型和主干网络的情况 + +ppyolov2_reader.yml:主要说明数据读取器配置,如batch size,并发加载子进程数等,同时包含读取后预处理操作,如resize、数据增强等等 +``` + +根据实际情况,修改上述文件,比如数据集路径、batch size等。 + + + +## 4. PaddleDetection训练 + +PaddleDetection提供了单卡/多卡训练模式,满足用户多种训练需求 + +* GPU 单卡训练 + +```bash +export CUDA_VISIBLE_DEVICES=0 #windows和Mac下不需要执行该命令 +python tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml +``` + +* GPU多卡训练 + +```bash +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --eval +``` + +--eval:表示边训练边验证 + +* 模型恢复训练 + +在日常训练过程中,有的用户由于一些原因导致训练中断,用户可以使用-r的命令恢复训练: + +```bash +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --eval -r output/ppyolov2_r50vd_dcn_365e_coco/10000 +``` + +注意:如果遇到 "`Out of memory error`" 问题, 尝试在 `ppyolov2_reader.yml` 文件中调小`batch_size` + + + +## 5. PaddleDetection预测 + +设置参数,使用PaddleDetection预测: + +```bash +export CUDA_VISIBLE_DEVICES=0 +python tools/infer.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --infer_img=images/paper-image.jpg --output_dir=infer_output/ --draw_threshold=0.5 -o weights=output/ppyolov2_r50vd_dcn_365e_coco/model_final --use_vdl=Ture +``` + +`--draw_threshold` 是个可选参数. 根据 [NMS](https://ieeexplore.ieee.org/document/1699659) 的计算,不同阈值会产生不同的结果 `keep_top_k`表示设置输出目标的最大数量,默认值为100,用户可以根据自己的实际情况进行设定。 + + + +## 6. 预测部署 + +在layout parser中使用自己训练好的模型, + + + +### 6.1 模型导出 + +在模型训练过程中保存的模型文件是包含前向预测和反向传播的过程,在实际的工业部署则不需要反向传播,因此需要将模型进行导成部署需要的模型格式。 在PaddleDetection中提供了 `tools/export_model.py`脚本来导出模型。 + +导出模型名称默认是`model.*`,layout parser代码模型名称是`inference.*`, 所以修改[PaddleDetection/ppdet/engine/trainer.py ](https://github.com/PaddlePaddle/PaddleDetection/blob/b87a1ea86fa18ce69e44a17ad1b49c1326f19ff9/ppdet/engine/trainer.py#L512) (点开链接查看详细代码行),将`model`改为`inference`即可。 + +执行导出模型脚本: + +```bash +python tools/export_model.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --output_dir=./inference -o weights=output/ppyolov2_r50vd_dcn_365e_coco/model_final.pdparams +``` + +预测模型会导出到`inference/ppyolov2_r50vd_dcn_365e_coco`目录下,分别为`infer_cfg.yml`(预测不需要), `inference.pdiparams`, `inference.pdiparams.info`,`inference.pdmodel` 。 + +更多模型导出教程,请参考:[EXPORT_MODEL](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/deploy/EXPORT_MODEL.md) + + + +### 6.2 layout_parser预测 + +`model_path`指定训练好的模型路径,使用layout parser进行预测: + +```bash +import layoutparser as lp +model = lp.PaddleDetectionLayoutModel(model_path="inference/ppyolov2_r50vd_dcn_365e_coco", threshold=0.5,label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"},enforce_cpu=True,enable_mkldnn=True) +``` + + + +*** + +更多PaddleDetection训练教程,请参考:[PaddleDetection训练](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/GETTING_STARTED_cn.md) + +*** + diff --git a/test1/paddlestructure.py b/test1/paddlestructure.py new file mode 100644 index 0000000000000000000000000000000000000000..d8199101bb2d97b7bad063c4bec66eeea656c1fa --- /dev/null +++ b/test1/paddlestructure.py @@ -0,0 +1,168 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import sys + +__dir__ = os.path.dirname(__file__) +sys.path.append(__dir__) +sys.path.append(os.path.join(__dir__, '..')) + +import cv2 +import numpy as np +from pathlib import Path + +from ppocr.utils.logging import get_logger +from test1.predict_system import OCRSystem, save_res +from test1.table.predict_table import to_excel +from test1.utility import init_args, draw_result + +logger = get_logger() +from ppocr.utils.utility import check_and_read_gif, get_image_file_list +from ppocr.utils.network import maybe_download, download_with_progressbar, confirm_model_dir_url, is_link + +__all__ = ['PaddleStructure', 'draw_result', 'save_res'] + +VERSION = '2.1' +BASE_DIR = os.path.expanduser("~/.paddlestructure/") + +model_urls = { + 'det': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar', + 'rec': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar', + 'table': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar' + +} + + +def parse_args(mMain=True): + import argparse + parser = init_args() + parser.add_help = mMain + + for action in parser._actions: + if action.dest in ['rec_char_dict_path', 'table_char_dict_path']: + action.default = None + if mMain: + return parser.parse_args() + else: + inference_args_dict = {} + for action in parser._actions: + inference_args_dict[action.dest] = action.default + return argparse.Namespace(**inference_args_dict) + + +class PaddleStructure(OCRSystem): + def __init__(self, **kwargs): + params = parse_args(mMain=False) + params.__dict__.update(**kwargs) + if not params.show_log: + logger.setLevel(logging.INFO) + params.use_angle_cls = False + # init model dir + params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir, + os.path.join(BASE_DIR, VERSION, 'det'), + model_urls['det']) + params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir, + os.path.join(BASE_DIR, VERSION, 'rec'), + model_urls['rec']) + params.table_model_dir, table_url = confirm_model_dir_url(params.table_model_dir, + os.path.join(BASE_DIR, VERSION, 'table'), + model_urls['table']) + # download model + maybe_download(params.det_model_dir, det_url) + maybe_download(params.rec_model_dir, rec_url) + maybe_download(params.table_model_dir, table_url) + + if params.rec_char_dict_path is None: + params.rec_char_type = 'EN' + if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')): + params.rec_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt') + else: + params.rec_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_dict.txt') + if params.table_char_dict_path is None: + if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')): + params.table_char_dict_path = str( + Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt') + else: + params.table_char_dict_path = str( + Path(__file__).parent.parent / 'ppocr/utils/dict/table_structure_dict.txt') + + print(params) + super().__init__(params) + + def __call__(self, img): + if isinstance(img, str): + # download net image + if img.startswith('http'): + download_with_progressbar(img, 'tmp.jpg') + img = 'tmp.jpg' + image_file = img + img, flag = check_and_read_gif(image_file) + if not flag: + with open(image_file, 'rb') as f: + np_arr = np.frombuffer(f.read(), dtype=np.uint8) + img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + if img is None: + logger.error("error in loading image:{}".format(image_file)) + return None + if isinstance(img, np.ndarray) and len(img.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + res = super().__call__(img) + return res + + +def main(): + # for cmd + args = parse_args(mMain=True) + image_dir = args.image_dir + save_folder = args.output + if image_dir.startswith('http'): + download_with_progressbar(image_dir, 'tmp.jpg') + image_file_list = ['tmp.jpg'] + else: + image_file_list = get_image_file_list(args.image_dir) + if len(image_file_list) == 0: + logger.error('no images find in {}'.format(args.image_dir)) + return + + structure_engine = PaddleStructure(**(args.__dict__)) + for img_path in image_file_list: + img_name = os.path.basename(img_path).split('.')[0] + logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10)) + result = structure_engine(img_path) + for item in result: + logger.info(item['res']) + save_res(result, save_folder, img_name) + logger.info('result save to {}'.format(os.path.join(save_folder, img_name))) + +if __name__ == '__main__': + table_engine = PaddleStructure(show_log=True) + + img_path = '../test/test_imgs/PMC1173095_006_00.png' + img = cv2.imread(img_path) + result = table_engine(img) + save_res(result, '/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/output/table', + os.path.basename(img_path).split('.')[0]) + + for line in result: + print(line) + + from PIL import Image + + font_path = '../doc/fonts/simfang.ttf' + image = Image.open(img_path).convert('RGB') + im_show = draw_result(image, result, font_path=font_path) + im_show = Image.fromarray(im_show) + im_show.save('result.jpg') \ No newline at end of file diff --git a/test1/predict_system.py b/test1/predict_system.py new file mode 100644 index 0000000000000000000000000000000000000000..9e99a48cdf033f1cdb2263fc7a655a26a53ded92 --- /dev/null +++ b/test1/predict_system.py @@ -0,0 +1,134 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import subprocess + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' +import cv2 +import numpy as np +import time +import logging + +import layoutparser as lp + +from ppocr.utils.utility import get_image_file_list, check_and_read_gif +from ppocr.utils.logging import get_logger +from tools.infer.predict_system import TextSystem +from test1.table.predict_table import TableSystem, to_excel +from test1.utility import parse_args, draw_result + +logger = get_logger() + + +class OCRSystem(object): + def __init__(self, args): + args.det_limit_type = 'resize_long' + args.drop_score = 0 + if not args.show_log: + logger.setLevel(logging.INFO) + self.text_system = TextSystem(args) + self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer) + self.table_layout = lp.PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config", + threshold=0.5, enable_mkldnn=args.enable_mkldnn, + enforce_cpu=not args.use_gpu, thread_num=args.cpu_threads) + self.use_angle_cls = args.use_angle_cls + self.drop_score = args.drop_score + + def __call__(self, img): + ori_im = img.copy() + layout_res = self.table_layout.detect(img[..., ::-1]) + res_list = [] + for region in layout_res: + x1, y1, x2, y2 = region.coordinates + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + roi_img = ori_im[y1:y2, x1:x2, :] + if region.type == 'Table': + res = self.table_system(roi_img) + else: + filter_boxes, filter_rec_res = self.text_system(roi_img) + filter_boxes = [x + [x1, y1] for x in filter_boxes] + filter_boxes = [x.reshape(-1).tolist() for x in filter_boxes] + + res = (filter_boxes, filter_rec_res) + res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res}) + return res_list + + +def save_res(res, save_folder, img_name): + excel_save_folder = os.path.join(save_folder, img_name) + os.makedirs(excel_save_folder, exist_ok=True) + # save res + for region in res: + if region['type'] == 'Table': + excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox'])) + to_excel(region['res'], excel_path) + elif region['type'] == 'Figure': + pass + else: + with open(os.path.join(excel_save_folder, 'res.txt'), 'a', encoding='utf8') as f: + for box, rec_res in zip(region['res'][0], region['res'][1]): + f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res)) + + +def main(args): + image_file_list = get_image_file_list(args.image_dir) + image_file_list = image_file_list + image_file_list = image_file_list[args.process_id::args.total_process_num] + save_folder = args.output + os.makedirs(save_folder, exist_ok=True) + + structure_sys = OCRSystem(args) + img_num = len(image_file_list) + for i, image_file in enumerate(image_file_list): + logger.info("[{}/{}] {}".format(i, img_num, image_file)) + img, flag = check_and_read_gif(image_file) + img_name = os.path.basename(image_file).split('.')[0] + + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.error("error in loading image:{}".format(image_file)) + continue + starttime = time.time() + res = structure_sys(img) + save_res(res, save_folder, img_name) + draw_img = draw_result(img, res, args.vis_font_path) + cv2.imwrite(os.path.join(save_folder, img_name, 'show.jpg'), draw_img) + logger.info('result save to {}'.format(os.path.join(save_folder, img_name))) + elapse = time.time() - starttime + logger.info("Predict time : {:.3f}s".format(elapse)) + + +if __name__ == "__main__": + args = parse_args() + if args.use_mp: + p_list = [] + total_process_num = args.total_process_num + for process_id in range(total_process_num): + cmd = [sys.executable, "-u"] + sys.argv + [ + "--process_id={}".format(process_id), + "--use_mp={}".format(False) + ] + p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout) + p_list.append(p) + for p in p_list: + p.wait() + else: + main(args) diff --git a/test1/setup.py b/test1/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..0b092c49a4db98def28a7c2942993806b0ffc27c --- /dev/null +++ b/test1/setup.py @@ -0,0 +1,72 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from setuptools import setup +from io import open +import shutil + +with open('../requirements.txt', encoding="utf-8-sig") as f: + requirements = f.readlines() + requirements.append('tqdm') + requirements.append('layoutparser') + requirements.append('iopath') + + +def readme(): + with open('api_ch.md', encoding="utf-8-sig") as f: + README = f.read() + return README + + +shutil.copytree('./table', './test1/table') +shutil.copyfile('./predict_system.py', './test1/predict_system.py') +shutil.copyfile('./utility.py', './test1/utility.py') +shutil.copytree('../ppocr', './ppocr') +shutil.copytree('../tools', './tools') +shutil.copyfile('../LICENSE', './LICENSE') + +setup( + name='paddlestructure', + packages=['paddlestructure'], + package_dir={'paddlestructure': ''}, + include_package_data=True, + entry_points={"console_scripts": ["paddlestructure= paddlestructure.paddlestructure:main"]}, + version='1.0', + install_requires=requirements, + license='Apache License 2.0', + description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices', + long_description=readme(), + long_description_content_type='text/markdown', + url='https://github.com/PaddlePaddle/PaddleOCR', + download_url='https://github.com/PaddlePaddle/PaddleOCR.git', + keywords=[ + 'ocr textdetection textrecognition paddleocr crnn east star-net rosetta ocrlite db chineseocr chinesetextdetection chinesetextrecognition' + ], + classifiers=[ + 'Intended Audience :: Developers', 'Operating System :: OS Independent', + 'Natural Language :: Chinese (Simplified)', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.2', + 'Programming Language :: Python :: 3.3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', 'Topic :: Utilities' + ], ) + +shutil.rmtree('ppocr') +shutil.rmtree('tools') +shutil.rmtree('test1') +os.remove('LICENSE') diff --git a/test1/table/README.md b/test1/table/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1fb00f011137df1e31ccaea44e8a2a98a98bb252 --- /dev/null +++ b/test1/table/README.md @@ -0,0 +1,49 @@ +# Table structure and content prediction + +## 1. pipeline +The ocr of the table mainly contains three models +1. Single line text detection-DB +2. Single line text recognition-CRNN +3. Table structure and cell coordinate prediction-RARE + +The table ocr flow chart is as follows + +![tableocr_pipeline](../../doc/table/tableocr_pipeline.png) + +1. The coordinates of single-line text is detected by DB model, and then sends it to the recognition model to get the recognition result. +2. The table structure and cell coordinates is predicted by RARE model. +3. The recognition result of the cell is combined by the coordinates, recognition result of the single line and the coordinates of the cell. +4. The cell recognition result and the table structure together construct the html string of the table. + +## 2. How to use + + +### 2.1 Train +TBD + +### 2.2 Eval +First cd to the PaddleOCR/ppstructure directory + +The table uses TEDS (Tree-Edit-Distance-based Similarity) as the evaluation metric of the model. Before the model evaluation, the three models in the pipeline need to be exported as inference models (we have provided them), and the gt for evaluation needs to be prepared. Examples of gt are as follows: +```json +{"PMC4289340_004_00.png": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "
", "", "", "
", "", "", "
", "", ""], [[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]], [["", "F", "e", "a", "t", "u", "r", "e", ""], ["", "G", "b", "3", " ", "+", ""], ["", "G", "b", "3", " ", "-", ""], ["", "P", "a", "t", "i", "e", "n", "t", "s", ""], ["6", "2"], ["4", "5"]]]} +``` +In gt json, the key is the image name, the value is the corresponding gt, and gt is a list composed of four items, and each item is +1. HTML string list of table structure +2. The coordinates of each cell (not including the empty text in the cell) +3. The text information in each cell (not including the empty text in the cell) +4. The text information in each cell (including the empty text in the cell) + +Use the following command to evaluate. After the evaluation is completed, the teds indicator will be output. +```python +python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json +``` + + +### 2.3 Inference +First cd to the PaddleOCR/ppstructure directory + +```python +python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table +``` +After running, the excel sheet of each picture will be saved in the directory specified by the output field \ No newline at end of file diff --git a/test1/table/README_ch.md b/test1/table/README_ch.md new file mode 100644 index 0000000000000000000000000000000000000000..03f002f98b3f37a251638d1b1e11812ef703f5fc --- /dev/null +++ b/test1/table/README_ch.md @@ -0,0 +1,49 @@ +# 表格结构和内容预测 + +## 1. pipeline +表格的ocr主要包含三个模型 +1. 单行文本检测-DB +2. 单行文本识别-CRNN +3. 表格结构和cell坐标预测-RARE + +具体流程图如下 + +![tableocr_pipeline](../../doc/table/tableocr_pipeline.png) + +1. 图片由单行文字检测检测模型到单行文字的坐标,然后送入识别模型拿到识别结果。 +2. 图片由表格结构和cell坐标预测模型拿到表格的结构信息和单元格的坐标信息。 +3. 由单行文字的坐标、识别结果和单元格的坐标一起组合出单元格的识别结果。 +4. 单元格的识别结果和表格结构一起构造表格的html字符串。 + +## 2. 使用 + + +### 2.1 训练 +TBD + +### 2.2 评估 +先cd到PaddleOCR/ppstructure目录下 + +表格使用 TEDS(Tree-Edit-Distance-based Similarity) 作为模型的评估指标。在进行模型评估之前,需要将pipeline中的三个模型分别导出为inference模型(我们已经提供好),还需要准备评估的gt, gt示例如下: +```json +{"PMC4289340_004_00.png": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "
", "", "", "
", "", "", "
", "", ""], [[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]], [["", "F", "e", "a", "t", "u", "r", "e", ""], ["", "G", "b", "3", " ", "+", ""], ["", "G", "b", "3", " ", "-", ""], ["", "P", "a", "t", "i", "e", "n", "t", "s", ""], ["6", "2"], ["4", "5"]]]} +``` +json 中,key为图片名,value为对于的gt,gt是一个由四个item组成的list,每个item分别为 +1. 表格结构的html字符串list +2. 每个cell的坐标 (不包括cell里文字为空的) +3. 每个cell里的文字信息 (不包括cell里文字为空的) +4. 每个cell里的文字信息 (包括cell里文字为空的) + +准备完成后使用如下命令进行评估,评估完成后会输出teds指标。 +```python +python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json +``` + + +### 2.3 预测 +先cd到PaddleOCR/ppstructure目录下 + +```python +python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table +``` +运行完成后,每张图片的excel表格会保存到output字段指定的目录下 diff --git a/test1/table/__init__.py b/test1/table/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d11e265597c7c8e39098a228108da3bb954b892 --- /dev/null +++ b/test1/table/__init__.py @@ -0,0 +1,13 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/test1/table/eval_table.py b/test1/table/eval_table.py new file mode 100755 index 0000000000000000000000000000000000000000..dc63e34e2a85657a6487e7abb081854e937cf669 --- /dev/null +++ b/test1/table/eval_table.py @@ -0,0 +1,72 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + +import cv2 +import json +from tqdm import tqdm +from test1.table.table_metric import TEDS +from test1.table.predict_table import TableSystem +from test1.utility import init_args +from ppocr.utils.logging import get_logger + +logger = get_logger() + + +def parse_args(): + parser = init_args() + parser.add_argument("--gt_path", type=str) + return parser.parse_args() + +def main(gt_path, img_root, args): + teds = TEDS(n_jobs=16) + + text_sys = TableSystem(args) + jsons_gt = json.load(open(gt_path)) # gt + pred_htmls = [] + gt_htmls = [] + for img_name in tqdm(jsons_gt): + # read image + img = cv2.imread(os.path.join(img_root,img_name)) + pred_html = text_sys(img) + pred_htmls.append(pred_html) + + gt_structures, gt_bboxes, gt_contents, contents_with_block = jsons_gt[img_name] + gt_html, gt = get_gt_html(gt_structures, contents_with_block) + gt_htmls.append(gt_html) + scores = teds.batch_evaluate_html(gt_htmls, pred_htmls) + logger.info('teds:', sum(scores) / len(scores)) + + +def get_gt_html(gt_structures, contents_with_block): + end_html = [] + td_index = 0 + for tag in gt_structures: + if '' in tag: + if contents_with_block[td_index] != []: + end_html.extend(contents_with_block[td_index]) + end_html.append(tag) + td_index += 1 + else: + end_html.append(tag) + return ''.join(end_html), end_html + + +if __name__ == '__main__': + args = parse_args() + main(args.gt_path,args.image_dir, args) diff --git a/test1/table/matcher.py b/test1/table/matcher.py new file mode 100755 index 0000000000000000000000000000000000000000..c3b56384403f5fd92a8db4b4bb378a6d55e5a76c --- /dev/null +++ b/test1/table/matcher.py @@ -0,0 +1,192 @@ +import json +def distance(box_1, box_2): + x1, y1, x2, y2 = box_1 + x3, y3, x4, y4 = box_2 + dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2) + dis_2 = abs(x3 - x1) + abs(y3 - y1) + dis_3 = abs(x4- x2) + abs(y4 - y2) + return dis + min(dis_2, dis_3) + +def compute_iou(rec1, rec2): + """ + computing IoU + :param rec1: (y0, x0, y1, x1), which reflects + (top, left, bottom, right) + :param rec2: (y0, x0, y1, x1) + :return: scala value of IoU + """ + # computing area of each rectangles + S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) + S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) + + # computing the sum_area + sum_area = S_rec1 + S_rec2 + + # find the each edge of intersect rectangle + left_line = max(rec1[1], rec2[1]) + right_line = min(rec1[3], rec2[3]) + top_line = max(rec1[0], rec2[0]) + bottom_line = min(rec1[2], rec2[2]) + + # judge if there is an intersect + if left_line >= right_line or top_line >= bottom_line: + return 0.0 + else: + intersect = (right_line - left_line) * (bottom_line - top_line) + return (intersect / (sum_area - intersect))*1.0 + + + +def matcher_merge(ocr_bboxes, pred_bboxes): + all_dis = [] + ious = [] + matched = {} + for i, gt_box in enumerate(ocr_bboxes): + distances = [] + for j, pred_box in enumerate(pred_bboxes): + # compute l1 distence and IOU between two boxes + distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) + sorted_distances = distances.copy() + # select nearest cell + sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0])) + if distances.index(sorted_distances[0]) not in matched.keys(): + matched[distances.index(sorted_distances[0])] = [i] + else: + matched[distances.index(sorted_distances[0])].append(i) + return matched#, sum(ious) / len(ious) + +def complex_num(pred_bboxes): + complex_nums = [] + for bbox in pred_bboxes: + distances = [] + temp_ious = [] + for pred_bbox in pred_bboxes: + if bbox != pred_bbox: + distances.append(distance(bbox, pred_bbox)) + temp_ious.append(compute_iou(bbox, pred_bbox)) + complex_nums.append(temp_ious[distances.index(min(distances))]) + return sum(complex_nums) / len(complex_nums) + +def get_rows(pred_bboxes): + pre_bbox = pred_bboxes[0] + res = [] + step = 0 + for i in range(len(pred_bboxes)): + bbox = pred_bboxes[i] + if bbox[1] - pre_bbox[1] > 2 or bbox[0] - pre_bbox[0] < 0: + break + else: + res.append(bbox) + step += 1 + for i in range(step): + pred_bboxes.pop(0) + return res, pred_bboxes +def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上 + ys_1 = [] + ys_2 = [] + for box in pred_bboxes: + ys_1.append(box[1]) + ys_2.append(box[3]) + min_y_1 = sum(ys_1) / len(ys_1) + min_y_2 = sum(ys_2) / len(ys_2) + re_boxes = [] + for box in pred_bboxes: + box[1] = min_y_1 + box[3] = min_y_2 + re_boxes.append(box) + return re_boxes + +def matcher_refine_row(gt_bboxes, pred_bboxes): + before_refine_pred_bboxes = pred_bboxes.copy() + pred_bboxes = [] + while(len(before_refine_pred_bboxes) != 0): + row_bboxes, before_refine_pred_bboxes = get_rows(before_refine_pred_bboxes) + print(row_bboxes) + pred_bboxes.extend(refine_rows(row_bboxes)) + all_dis = [] + ious = [] + matched = {} + for i, gt_box in enumerate(gt_bboxes): + distances = [] + #temp_ious = [] + for j, pred_box in enumerate(pred_bboxes): + distances.append(distance(gt_box, pred_box)) + #temp_ious.append(compute_iou(gt_box, pred_box)) + #all_dis.append(min(distances)) + #ious.append(temp_ious[distances.index(min(distances))]) + if distances.index(min(distances)) not in matched.keys(): + matched[distances.index(min(distances))] = [i] + else: + matched[distances.index(min(distances))].append(i) + return matched#, sum(ious) / len(ious) + + + +#先挑选出一行,再进行匹配 +def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes): + gt_box_index = 0 + delete_gt_bboxes = gt_bboxes.copy() + match_bboxes_ready = [] + matched = {} + while(len(delete_gt_bboxes) != 0): + row_bboxes, delete_gt_bboxes = get_rows(delete_gt_bboxes) + row_bboxes = sorted(row_bboxes, key = lambda key: key[0]) + if len(pred_bboxes_rows) > 0: + match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) + print(row_bboxes) + for i, gt_box in enumerate(row_bboxes): + #print(gt_box) + pred_distances = [] + distances = [] + for pred_bbox in pred_bboxes: + pred_distances.append(distance(gt_box, pred_bbox)) + for j, pred_box in enumerate(match_bboxes_ready): + distances.append(distance(gt_box, pred_box)) + index = pred_distances.index(min(distances)) + #print('index', index) + if index not in matched.keys(): + matched[index] = [gt_box_index] + else: + matched[index].append(gt_box_index) + gt_box_index += 1 + return matched + +def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes): + ''' + gt_bboxes: 排序后 + pred_bboxes: + ''' + pre_bbox = gt_bboxes[0] + matched = {} + match_bboxes_ready = [] + match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) + for i, gt_box in enumerate(gt_bboxes): + + pred_distances = [] + for pred_bbox in pred_bboxes: + pred_distances.append(distance(gt_box, pred_bbox)) + distances = [] + gap_pre = gt_box[1] - pre_bbox[1] + gap_pre_1 = gt_box[0] - pre_bbox[2] + #print(gap_pre, len(pred_bboxes_rows)) + if (gap_pre_1 < 0 and len(pred_bboxes_rows) > 0): + match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) + if len(pred_bboxes_rows) == 1: + match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) + if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) > 0: + match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) + if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) == 0: + break + #print(match_bboxes_ready) + for j, pred_box in enumerate(match_bboxes_ready): + distances.append(distance(gt_box, pred_box)) + index = pred_distances.index(min(distances)) + #print(gt_box, index) + #match_bboxes_ready.pop(distances.index(min(distances))) + print(gt_box, match_bboxes_ready[distances.index(min(distances))]) + if index not in matched.keys(): + matched[index] = [i] + else: + matched[index].append(i) + pre_bbox = gt_box + return matched diff --git a/test1/table/predict_structure.py b/test1/table/predict_structure.py new file mode 100755 index 0000000000000000000000000000000000000000..1070c93ea61ac0efea7e700d00c8144f4139fbd8 --- /dev/null +++ b/test1/table/predict_structure.py @@ -0,0 +1,139 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import numpy as np +import math +import time +import traceback +import paddle + +import tools.infer.utility as utility +from ppocr.data import create_operators, transform +from ppocr.postprocess import build_post_process +from ppocr.utils.logging import get_logger +from ppocr.utils.utility import get_image_file_list, check_and_read_gif +from test1.utility import parse_args + +logger = get_logger() + + +class TableStructurer(object): + def __init__(self, args): + pre_process_list = [{ + 'ResizeTableImage': { + 'max_len': args.table_max_len + } + }, { + 'NormalizeImage': { + 'std': [0.229, 0.224, 0.225], + 'mean': [0.485, 0.456, 0.406], + 'scale': '1./255.', + 'order': 'hwc' + } + }, { + 'PaddingTableImage': None + }, { + 'ToCHWImage': None + }, { + 'KeepKeys': { + 'keep_keys': ['image'] + } + }] + postprocess_params = { + 'name': 'TableLabelDecode', + "character_type": args.table_char_type, + "character_dict_path": args.table_char_dict_path, + } + + self.preprocess_op = create_operators(pre_process_list) + self.postprocess_op = build_post_process(postprocess_params) + self.predictor, self.input_tensor, self.output_tensors, self.config = \ + utility.create_predictor(args, 'table', logger) + + def __call__(self, img): + ori_im = img.copy() + data = {'image': img} + data = transform(data, self.preprocess_op) + img = data[0] + if img is None: + return None, 0 + img = np.expand_dims(img, axis=0) + img = img.copy() + starttime = time.time() + + self.input_tensor.copy_from_cpu(img) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + + preds = {} + preds['structure_probs'] = outputs[1] + preds['loc_preds'] = outputs[0] + + post_result = self.postprocess_op(preds) + + structure_str_list = post_result['structure_str_list'] + res_loc = post_result['res_loc'] + imgh, imgw = ori_im.shape[0:2] + res_loc_final = [] + for rno in range(len(res_loc[0])): + x0, y0, x1, y1 = res_loc[0][rno] + left = max(int(imgw * x0), 0) + top = max(int(imgh * y0), 0) + right = min(int(imgw * x1), imgw - 1) + bottom = min(int(imgh * y1), imgh - 1) + res_loc_final.append([left, top, right, bottom]) + + structure_str_list = structure_str_list[0][:-1] + structure_str_list = ['', '', ''] + structure_str_list + ['
', '', ''] + + elapse = time.time() - starttime + return (structure_str_list, res_loc_final), elapse + + +def main(args): + image_file_list = get_image_file_list(args.image_dir) + table_structurer = TableStructurer(args) + count = 0 + total_time = 0 + for image_file in image_file_list: + img, flag = check_and_read_gif(image_file) + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + structure_res, elapse = table_structurer(img) + + logger.info("result: {}".format(structure_res)) + + if count > 0: + total_time += elapse + count += 1 + logger.info("Predict time of {}: {}".format(image_file, elapse)) + + +if __name__ == "__main__": + main(parse_args()) diff --git a/test1/table/predict_table.py b/test1/table/predict_table.py new file mode 100644 index 0000000000000000000000000000000000000000..b06a4f4d53402ca809f0ab846f83176795ca7217 --- /dev/null +++ b/test1/table/predict_table.py @@ -0,0 +1,221 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import subprocess + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' +import cv2 +import copy +import numpy as np +import time +import tools.infer.predict_rec as predict_rec +import tools.infer.predict_det as predict_det +from ppocr.utils.utility import get_image_file_list, check_and_read_gif +from ppocr.utils.logging import get_logger +from test1.table.matcher import distance, compute_iou +from test1.utility import parse_args +import test1.table.predict_structure as predict_strture + +logger = get_logger() + + +def expand(pix, det_box, shape): + x0, y0, x1, y1 = det_box + # print(shape) + h, w, c = shape + tmp_x0 = x0 - pix + tmp_x1 = x1 + pix + tmp_y0 = y0 - pix + tmp_y1 = y1 + pix + x0_ = tmp_x0 if tmp_x0 >= 0 else 0 + x1_ = tmp_x1 if tmp_x1 <= w else w + y0_ = tmp_y0 if tmp_y0 >= 0 else 0 + y1_ = tmp_y1 if tmp_y1 <= h else h + return x0_, y0_, x1_, y1_ + + +class TableSystem(object): + def __init__(self, args, text_detector=None, text_recognizer=None): + self.text_detector = predict_det.TextDetector(args) if text_detector is None else text_detector + self.text_recognizer = predict_rec.TextRecognizer(args) if text_recognizer is None else text_recognizer + self.table_structurer = predict_strture.TableStructurer(args) + + def __call__(self, img): + ori_im = img.copy() + structure_res, elapse = self.table_structurer(copy.deepcopy(img)) + dt_boxes, elapse = self.text_detector(copy.deepcopy(img)) + dt_boxes = sorted_boxes(dt_boxes) + + r_boxes = [] + for box in dt_boxes: + x_min = box[:, 0].min() - 1 + x_max = box[:, 0].max() + 1 + y_min = box[:, 1].min() - 1 + y_max = box[:, 1].max() + 1 + box = [x_min, y_min, x_max, y_max] + r_boxes.append(box) + dt_boxes = np.array(r_boxes) + + logger.debug("dt_boxes num : {}, elapse : {}".format( + len(dt_boxes), elapse)) + if dt_boxes is None: + return None, None + img_crop_list = [] + + for i in range(len(dt_boxes)): + det_box = dt_boxes[i] + x0, y0, x1, y1 = expand(2, det_box, ori_im.shape) + text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :] + img_crop_list.append(text_rect) + rec_res, elapse = self.text_recognizer(img_crop_list) + logger.debug("rec_res num : {}, elapse : {}".format( + len(rec_res), elapse)) + + pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res) + return pred_html + + def rebuild_table(self, structure_res, dt_boxes, rec_res): + pred_structures, pred_bboxes = structure_res + matched_index = self.match_result(dt_boxes, pred_bboxes) + pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res) + return pred_html, pred + + def match_result(self, dt_boxes, pred_bboxes): + matched = {} + for i, gt_box in enumerate(dt_boxes): + # gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])] + distances = [] + for j, pred_box in enumerate(pred_bboxes): + distances.append( + (distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) # 获取两两cell之间的L1距离和 1- IOU + sorted_distances = distances.copy() + # 根据距离和IOU挑选最"近"的cell + sorted_distances = sorted(sorted_distances, key=lambda item: (item[1], item[0])) + if distances.index(sorted_distances[0]) not in matched.keys(): + matched[distances.index(sorted_distances[0])] = [i] + else: + matched[distances.index(sorted_distances[0])].append(i) + return matched + + def get_pred_html(self, pred_structures, matched_index, ocr_contents): + end_html = [] + td_index = 0 + for tag in pred_structures: + if '' in tag: + if td_index in matched_index.keys(): + b_with = False + if '' in ocr_contents[matched_index[td_index][0]] and len(matched_index[td_index]) > 1: + b_with = True + end_html.extend('') + for i, td_index_index in enumerate(matched_index[td_index]): + content = ocr_contents[td_index_index][0] + if len(matched_index[td_index]) > 1: + if len(content) == 0: + continue + if content[0] == ' ': + content = content[1:] + if '' in content: + content = content[3:] + if '' in content: + content = content[:-4] + if len(content) == 0: + continue + if i != len(matched_index[td_index]) - 1 and ' ' != content[-1]: + content += ' ' + end_html.extend(content) + if b_with: + end_html.extend('') + + end_html.append(tag) + td_index += 1 + else: + end_html.append(tag) + return ''.join(end_html), end_html + + +def sorted_boxes(dt_boxes): + """ + Sort text boxes in order from top to bottom, left to right + args: + dt_boxes(array):detected text boxes with shape [4, 2] + return: + sorted boxes(array) with shape [4, 2] + """ + num_boxes = dt_boxes.shape[0] + sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) + _boxes = list(sorted_boxes) + + for i in range(num_boxes - 1): + if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \ + (_boxes[i + 1][0][0] < _boxes[i][0][0]): + tmp = _boxes[i] + _boxes[i] = _boxes[i + 1] + _boxes[i + 1] = tmp + return _boxes + + +def to_excel(html_table, excel_path): + from tablepyxl import tablepyxl + tablepyxl.document_to_xl(html_table, excel_path) + + +def main(args): + image_file_list = get_image_file_list(args.image_dir) + image_file_list = image_file_list[args.process_id::args.total_process_num] + os.makedirs(args.output, exist_ok=True) + + text_sys = TableSystem(args) + img_num = len(image_file_list) + for i, image_file in enumerate(image_file_list): + logger.info("[{}/{}] {}".format(i, img_num, image_file)) + img, flag = check_and_read_gif(image_file) + excel_path = os.path.join(args.output, os.path.basename(image_file).split('.')[0] + '.xlsx') + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.error("error in loading image:{}".format(image_file)) + continue + starttime = time.time() + pred_html = text_sys(img) + + to_excel(pred_html, excel_path) + logger.info('excel saved to {}'.format(excel_path)) + logger.info(pred_html) + elapse = time.time() - starttime + logger.info("Predict time : {:.3f}s".format(elapse)) + + +if __name__ == "__main__": + args = parse_args() + if args.use_mp: + p_list = [] + total_process_num = args.total_process_num + for process_id in range(total_process_num): + cmd = [sys.executable, "-u"] + sys.argv + [ + "--process_id={}".format(process_id), + "--use_mp={}".format(False) + ] + p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout) + p_list.append(p) + for p in p_list: + p.wait() + else: + main(args) diff --git a/test1/table/table_metric/__init__.py b/test1/table/table_metric/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..de2d307430f68881ece1e41357d3b2f423e07ddd --- /dev/null +++ b/test1/table/table_metric/__init__.py @@ -0,0 +1,16 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ['TEDS'] +from .table_metric import TEDS \ No newline at end of file diff --git a/test1/table/table_metric/parallel.py b/test1/table/table_metric/parallel.py new file mode 100755 index 0000000000000000000000000000000000000000..f7326a1f506ca5fb7b3e97b0d077dc016e7eb7c7 --- /dev/null +++ b/test1/table/table_metric/parallel.py @@ -0,0 +1,51 @@ +from tqdm import tqdm +from concurrent.futures import ProcessPoolExecutor, as_completed + + +def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0): + """ + A parallel version of the map function with a progress bar. + Args: + array (array-like): An array to iterate over. + function (function): A python function to apply to the elements of array + n_jobs (int, default=16): The number of cores to use + use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of + keyword arguments to function + front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job. + Useful for catching bugs + Returns: + [function(array[0]), function(array[1]), ...] + """ + # We run the first few iterations serially to catch bugs + if front_num > 0: + front = [function(**a) if use_kwargs else function(a) + for a in array[:front_num]] + else: + front = [] + # If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging. + if n_jobs == 1: + return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])] + # Assemble the workers + with ProcessPoolExecutor(max_workers=n_jobs) as pool: + # Pass the elements of array into function + if use_kwargs: + futures = [pool.submit(function, **a) for a in array[front_num:]] + else: + futures = [pool.submit(function, a) for a in array[front_num:]] + kwargs = { + 'total': len(futures), + 'unit': 'it', + 'unit_scale': True, + 'leave': True + } + # Print out the progress as tasks complete + for f in tqdm(as_completed(futures), **kwargs): + pass + out = [] + # Get the results from the futures. + for i, future in tqdm(enumerate(futures)): + try: + out.append(future.result()) + except Exception as e: + out.append(e) + return front + out diff --git a/test1/table/table_metric/table_metric.py b/test1/table/table_metric/table_metric.py new file mode 100755 index 0000000000000000000000000000000000000000..9aca98ad785d4614a803fa5a277a6e4a27b3b078 --- /dev/null +++ b/test1/table/table_metric/table_metric.py @@ -0,0 +1,247 @@ +# Copyright 2020 IBM +# Author: peter.zhong@au1.ibm.com +# +# This is free software; you can redistribute it and/or modify +# it under the terms of the Apache 2.0 License. +# +# This software is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Apache 2.0 License for more details. + +import distance +from apted import APTED, Config +from apted.helpers import Tree +from lxml import etree, html +from collections import deque +from .parallel import parallel_process +from tqdm import tqdm + + +class TableTree(Tree): + def __init__(self, tag, colspan=None, rowspan=None, content=None, *children): + self.tag = tag + self.colspan = colspan + self.rowspan = rowspan + self.content = content + self.children = list(children) + + def bracket(self): + """Show tree using brackets notation""" + if self.tag == 'td': + result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \ + (self.tag, self.colspan, self.rowspan, self.content) + else: + result = '"tag": %s' % self.tag + for child in self.children: + result += child.bracket() + return "{{{}}}".format(result) + + +class CustomConfig(Config): + @staticmethod + def maximum(*sequences): + """Get maximum possible value + """ + return max(map(len, sequences)) + + def normalized_distance(self, *sequences): + """Get distance from 0 to 1 + """ + return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) + + def rename(self, node1, node2): + """Compares attributes of trees""" + #print(node1.tag) + if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): + return 1. + if node1.tag == 'td': + if node1.content or node2.content: + #print(node1.content, ) + return self.normalized_distance(node1.content, node2.content) + return 0. + + + +class CustomConfig_del_short(Config): + @staticmethod + def maximum(*sequences): + """Get maximum possible value + """ + return max(map(len, sequences)) + + def normalized_distance(self, *sequences): + """Get distance from 0 to 1 + """ + return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) + + def rename(self, node1, node2): + """Compares attributes of trees""" + if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): + return 1. + if node1.tag == 'td': + if node1.content or node2.content: + #print('before') + #print(node1.content, node2.content) + #print('after') + node1_content = node1.content + node2_content = node2.content + if len(node1_content) < 3: + node1_content = ['####'] + if len(node2_content) < 3: + node2_content = ['####'] + return self.normalized_distance(node1_content, node2_content) + return 0. + +class CustomConfig_del_block(Config): + @staticmethod + def maximum(*sequences): + """Get maximum possible value + """ + return max(map(len, sequences)) + + def normalized_distance(self, *sequences): + """Get distance from 0 to 1 + """ + return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) + + def rename(self, node1, node2): + """Compares attributes of trees""" + if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): + return 1. + if node1.tag == 'td': + if node1.content or node2.content: + + node1_content = node1.content + node2_content = node2.content + while ' ' in node1_content: + print(node1_content.index(' ')) + node1_content.pop(node1_content.index(' ')) + while ' ' in node2_content: + print(node2_content.index(' ')) + node2_content.pop(node2_content.index(' ')) + return self.normalized_distance(node1_content, node2_content) + return 0. + +class TEDS(object): + ''' Tree Edit Distance basead Similarity + ''' + + def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None): + assert isinstance(n_jobs, int) and ( + n_jobs >= 1), 'n_jobs must be an integer greather than 1' + self.structure_only = structure_only + self.n_jobs = n_jobs + self.ignore_nodes = ignore_nodes + self.__tokens__ = [] + + def tokenize(self, node): + ''' Tokenizes table cells + ''' + self.__tokens__.append('<%s>' % node.tag) + if node.text is not None: + self.__tokens__ += list(node.text) + for n in node.getchildren(): + self.tokenize(n) + if node.tag != 'unk': + self.__tokens__.append('' % node.tag) + if node.tag != 'td' and node.tail is not None: + self.__tokens__ += list(node.tail) + + def load_html_tree(self, node, parent=None): + ''' Converts HTML tree to the format required by apted + ''' + global __tokens__ + if node.tag == 'td': + if self.structure_only: + cell = [] + else: + self.__tokens__ = [] + self.tokenize(node) + cell = self.__tokens__[1:-1].copy() + new_node = TableTree(node.tag, + int(node.attrib.get('colspan', '1')), + int(node.attrib.get('rowspan', '1')), + cell, *deque()) + else: + new_node = TableTree(node.tag, None, None, None, *deque()) + if parent is not None: + parent.children.append(new_node) + if node.tag != 'td': + for n in node.getchildren(): + self.load_html_tree(n, new_node) + if parent is None: + return new_node + + def evaluate(self, pred, true): + ''' Computes TEDS score between the prediction and the ground truth of a + given sample + ''' + if (not pred) or (not true): + return 0.0 + parser = html.HTMLParser(remove_comments=True, encoding='utf-8') + pred = html.fromstring(pred, parser=parser) + true = html.fromstring(true, parser=parser) + if pred.xpath('body/table') and true.xpath('body/table'): + pred = pred.xpath('body/table')[0] + true = true.xpath('body/table')[0] + if self.ignore_nodes: + etree.strip_tags(pred, *self.ignore_nodes) + etree.strip_tags(true, *self.ignore_nodes) + n_nodes_pred = len(pred.xpath(".//*")) + n_nodes_true = len(true.xpath(".//*")) + n_nodes = max(n_nodes_pred, n_nodes_true) + tree_pred = self.load_html_tree(pred) + tree_true = self.load_html_tree(true) + distance = APTED(tree_pred, tree_true, + CustomConfig()).compute_edit_distance() + return 1.0 - (float(distance) / n_nodes) + else: + return 0.0 + + def batch_evaluate(self, pred_json, true_json): + ''' Computes TEDS score between the prediction and the ground truth of + a batch of samples + @params pred_json: {'FILENAME': 'HTML CODE', ...} + @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...} + @output: {'FILENAME': 'TEDS SCORE', ...} + ''' + samples = true_json.keys() + if self.n_jobs == 1: + scores = [self.evaluate(pred_json.get( + filename, ''), true_json[filename]['html']) for filename in tqdm(samples)] + else: + inputs = [{'pred': pred_json.get( + filename, ''), 'true': true_json[filename]['html']} for filename in samples] + scores = parallel_process( + inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1) + scores = dict(zip(samples, scores)) + return scores + + def batch_evaluate_html(self, pred_htmls, true_htmls): + ''' Computes TEDS score between the prediction and the ground truth of + a batch of samples + ''' + if self.n_jobs == 1: + scores = [self.evaluate(pred_html, true_html) for ( + pred_html, true_html) in zip(pred_htmls, true_htmls)] + else: + inputs = [{"pred": pred_html, "true": true_html} for( + pred_html, true_html) in zip(pred_htmls, true_htmls)] + + scores = parallel_process( + inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1) + return scores + + +if __name__ == '__main__': + import json + import pprint + with open('sample_pred.json') as fp: + pred_json = json.load(fp) + with open('sample_gt.json') as fp: + true_json = json.load(fp) + teds = TEDS(n_jobs=4) + scores = teds.batch_evaluate(pred_json, true_json) + pp = pprint.PrettyPrinter() + pp.pprint(scores) diff --git a/test1/table/tablepyxl/__init__.py b/test1/table/tablepyxl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dc0085071cf4497b01fc648e7c38f2e8d9d173d0 --- /dev/null +++ b/test1/table/tablepyxl/__init__.py @@ -0,0 +1,13 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/test1/table/tablepyxl/style.py b/test1/table/tablepyxl/style.py new file mode 100644 index 0000000000000000000000000000000000000000..ebd794b1b47d7f9e4f9294dde7330f592d613656 --- /dev/null +++ b/test1/table/tablepyxl/style.py @@ -0,0 +1,283 @@ +# This is where we handle translating css styles into openpyxl styles +# and cascading those from parent to child in the dom. + +from openpyxl.cell import cell +from openpyxl.styles import Font, Alignment, PatternFill, NamedStyle, Border, Side, Color +from openpyxl.styles.fills import FILL_SOLID +from openpyxl.styles.numbers import FORMAT_CURRENCY_USD_SIMPLE, FORMAT_PERCENTAGE +from openpyxl.styles.colors import BLACK + +FORMAT_DATE_MMDDYYYY = 'mm/dd/yyyy' + + +def colormap(color): + """ + Convenience for looking up known colors + """ + cmap = {'black': BLACK} + return cmap.get(color, color) + + +def style_string_to_dict(style): + """ + Convert css style string to a python dictionary + """ + def clean_split(string, delim): + return (s.strip() for s in string.split(delim)) + styles = [clean_split(s, ":") for s in style.split(";") if ":" in s] + return dict(styles) + + +def get_side(style, name): + return {'border_style': style.get('border-{}-style'.format(name)), + 'color': colormap(style.get('border-{}-color'.format(name)))} + +known_styles = {} + + +def style_dict_to_named_style(style_dict, number_format=None): + """ + Change css style (stored in a python dictionary) to openpyxl NamedStyle + """ + + style_and_format_string = str({ + 'style_dict': style_dict, + 'parent': style_dict.parent, + 'number_format': number_format, + }) + + if style_and_format_string not in known_styles: + # Font + font = Font(bold=style_dict.get('font-weight') == 'bold', + color=style_dict.get_color('color', None), + size=style_dict.get('font-size')) + + # Alignment + alignment = Alignment(horizontal=style_dict.get('text-align', 'general'), + vertical=style_dict.get('vertical-align'), + wrap_text=style_dict.get('white-space', 'nowrap') == 'normal') + + # Fill + bg_color = style_dict.get_color('background-color') + fg_color = style_dict.get_color('foreground-color', Color()) + fill_type = style_dict.get('fill-type') + if bg_color and bg_color != 'transparent': + fill = PatternFill(fill_type=fill_type or FILL_SOLID, + start_color=bg_color, + end_color=fg_color) + else: + fill = PatternFill() + + # Border + border = Border(left=Side(**get_side(style_dict, 'left')), + right=Side(**get_side(style_dict, 'right')), + top=Side(**get_side(style_dict, 'top')), + bottom=Side(**get_side(style_dict, 'bottom')), + diagonal=Side(**get_side(style_dict, 'diagonal')), + diagonal_direction=None, + outline=Side(**get_side(style_dict, 'outline')), + vertical=None, + horizontal=None) + + name = 'Style {}'.format(len(known_styles) + 1) + + pyxl_style = NamedStyle(name=name, font=font, fill=fill, alignment=alignment, border=border, + number_format=number_format) + + known_styles[style_and_format_string] = pyxl_style + + return known_styles[style_and_format_string] + + +class StyleDict(dict): + """ + It's like a dictionary, but it looks for items in the parent dictionary + """ + def __init__(self, *args, **kwargs): + self.parent = kwargs.pop('parent', None) + super(StyleDict, self).__init__(*args, **kwargs) + + def __getitem__(self, item): + if item in self: + return super(StyleDict, self).__getitem__(item) + elif self.parent: + return self.parent[item] + else: + raise KeyError('{} not found'.format(item)) + + def __hash__(self): + return hash(tuple([(k, self.get(k)) for k in self._keys()])) + + # Yielding the keys avoids creating unnecessary data structures + # and happily works with both python2 and python3 where the + # .keys() method is a dictionary_view in python3 and a list in python2. + def _keys(self): + yielded = set() + for k in self.keys(): + yielded.add(k) + yield k + if self.parent: + for k in self.parent._keys(): + if k not in yielded: + yielded.add(k) + yield k + + def get(self, k, d=None): + try: + return self[k] + except KeyError: + return d + + def get_color(self, k, d=None): + """ + Strip leading # off colors if necessary + """ + color = self.get(k, d) + if hasattr(color, 'startswith') and color.startswith('#'): + color = color[1:] + if len(color) == 3: # Premailers reduces colors like #00ff00 to #0f0, openpyxl doesn't like that + color = ''.join(2 * c for c in color) + return color + + +class Element(object): + """ + Our base class for representing an html element along with a cascading style. + The element is created along with a parent so that the StyleDict that we store + can point to the parent's StyleDict. + """ + def __init__(self, element, parent=None): + self.element = element + self.number_format = None + parent_style = parent.style_dict if parent else None + self.style_dict = StyleDict(style_string_to_dict(element.get('style', '')), parent=parent_style) + self._style_cache = None + + def style(self): + """ + Turn the css styles for this element into an openpyxl NamedStyle. + """ + if not self._style_cache: + self._style_cache = style_dict_to_named_style(self.style_dict, number_format=self.number_format) + return self._style_cache + + def get_dimension(self, dimension_key): + """ + Extracts the dimension from the style dict of the Element and returns it as a float. + """ + dimension = self.style_dict.get(dimension_key) + if dimension: + if dimension[-2:] in ['px', 'em', 'pt', 'in', 'cm']: + dimension = dimension[:-2] + dimension = float(dimension) + return dimension + + +class Table(Element): + """ + The concrete implementations of Elements are semantically named for the types of elements we are interested in. + This defines a very concrete tree structure for html tables that we expect to deal with. I prefer this compared to + allowing Element to have an arbitrary number of children and dealing with an abstract element tree. + """ + def __init__(self, table): + """ + takes an html table object (from lxml) + """ + super(Table, self).__init__(table) + table_head = table.find('thead') + self.head = TableHead(table_head, parent=self) if table_head is not None else None + table_body = table.find('tbody') + self.body = TableBody(table_body if table_body is not None else table, parent=self) + + +class TableHead(Element): + """ + This class maps to the `` element of the html table. + """ + def __init__(self, head, parent=None): + super(TableHead, self).__init__(head, parent=parent) + self.rows = [TableRow(tr, parent=self) for tr in head.findall('tr')] + + +class TableBody(Element): + """ + This class maps to the `` element of the html table. + """ + def __init__(self, body, parent=None): + super(TableBody, self).__init__(body, parent=parent) + self.rows = [TableRow(tr, parent=self) for tr in body.findall('tr')] + + +class TableRow(Element): + """ + This class maps to the `` element of the html table. + """ + def __init__(self, tr, parent=None): + super(TableRow, self).__init__(tr, parent=parent) + self.cells = [TableCell(cell, parent=self) for cell in tr.findall('th') + tr.findall('td')] + + +def element_to_string(el): + return _element_to_string(el).strip() + + +def _element_to_string(el): + string = '' + + for x in el.iterchildren(): + string += '\n' + _element_to_string(x) + + text = el.text.strip() if el.text else '' + tail = el.tail.strip() if el.tail else '' + + return text + string + '\n' + tail + + +class TableCell(Element): + """ + This class maps to the `` element of the html table. + """ + CELL_TYPES = {'TYPE_STRING', 'TYPE_FORMULA', 'TYPE_NUMERIC', 'TYPE_BOOL', 'TYPE_CURRENCY', 'TYPE_PERCENTAGE', + 'TYPE_NULL', 'TYPE_INLINE', 'TYPE_ERROR', 'TYPE_FORMULA_CACHE_STRING', 'TYPE_INTEGER'} + + def __init__(self, cell, parent=None): + super(TableCell, self).__init__(cell, parent=parent) + self.value = element_to_string(cell) + self.number_format = self.get_number_format() + + def data_type(self): + cell_types = self.CELL_TYPES & set(self.element.get('class', '').split()) + if cell_types: + if 'TYPE_FORMULA' in cell_types: + # Make sure TYPE_FORMULA takes precedence over the other classes in the set. + cell_type = 'TYPE_FORMULA' + elif cell_types & {'TYPE_CURRENCY', 'TYPE_INTEGER', 'TYPE_PERCENTAGE'}: + cell_type = 'TYPE_NUMERIC' + else: + cell_type = cell_types.pop() + else: + cell_type = 'TYPE_STRING' + return getattr(cell, cell_type) + + def get_number_format(self): + if 'TYPE_CURRENCY' in self.element.get('class', '').split(): + return FORMAT_CURRENCY_USD_SIMPLE + if 'TYPE_INTEGER' in self.element.get('class', '').split(): + return '#,##0' + if 'TYPE_PERCENTAGE' in self.element.get('class', '').split(): + return FORMAT_PERCENTAGE + if 'TYPE_DATE' in self.element.get('class', '').split(): + return FORMAT_DATE_MMDDYYYY + if self.data_type() == cell.TYPE_NUMERIC: + try: + int(self.value) + except ValueError: + return '#,##0.##' + else: + return '#,##0' + + def format(self, cell): + cell.style = self.style() + data_type = self.data_type() + if data_type: + cell.data_type = data_type \ No newline at end of file diff --git a/test1/table/tablepyxl/tablepyxl.py b/test1/table/tablepyxl/tablepyxl.py new file mode 100644 index 0000000000000000000000000000000000000000..ba3cc0fc499fccd93ffe3993a99296bc6603ed8a --- /dev/null +++ b/test1/table/tablepyxl/tablepyxl.py @@ -0,0 +1,118 @@ +# Do imports like python3 so our package works for 2 and 3 +from __future__ import absolute_import + +from lxml import html +from openpyxl import Workbook +from openpyxl.utils import get_column_letter +from premailer import Premailer +from tablepyxl.style import Table + + +def string_to_int(s): + if s.isdigit(): + return int(s) + return 0 + + +def get_Tables(doc): + tree = html.fromstring(doc) + comments = tree.xpath('//comment()') + for comment in comments: + comment.drop_tag() + return [Table(table) for table in tree.xpath('//table')] + + +def write_rows(worksheet, elem, row, column=1): + """ + Writes every tr child element of elem to a row in the worksheet + returns the next row after all rows are written + """ + from openpyxl.cell.cell import MergedCell + + initial_column = column + for table_row in elem.rows: + for table_cell in table_row.cells: + cell = worksheet.cell(row=row, column=column) + while isinstance(cell, MergedCell): + column += 1 + cell = worksheet.cell(row=row, column=column) + + colspan = string_to_int(table_cell.element.get("colspan", "1")) + rowspan = string_to_int(table_cell.element.get("rowspan", "1")) + if rowspan > 1 or colspan > 1: + worksheet.merge_cells(start_row=row, start_column=column, + end_row=row + rowspan - 1, end_column=column + colspan - 1) + + cell.value = table_cell.value + table_cell.format(cell) + min_width = table_cell.get_dimension('min-width') + max_width = table_cell.get_dimension('max-width') + + if colspan == 1: + # Initially, when iterating for the first time through the loop, the width of all the cells is None. + # As we start filling in contents, the initial width of the cell (which can be retrieved by: + # worksheet.column_dimensions[get_column_letter(column)].width) is equal to the width of the previous + # cell in the same column (i.e. width of A2 = width of A1) + width = max(worksheet.column_dimensions[get_column_letter(column)].width or 0, len(table_cell.value) + 2) + if max_width and width > max_width: + width = max_width + elif min_width and width < min_width: + width = min_width + worksheet.column_dimensions[get_column_letter(column)].width = width + column += colspan + row += 1 + column = initial_column + return row + + +def table_to_sheet(table, wb): + """ + Takes a table and workbook and writes the table to a new sheet. + The sheet title will be the same as the table attribute name. + """ + ws = wb.create_sheet(title=table.element.get('name')) + insert_table(table, ws, 1, 1) + + +def document_to_workbook(doc, wb=None, base_url=None): + """ + Takes a string representation of an html document and writes one sheet for + every table in the document. + The workbook is returned + """ + if not wb: + wb = Workbook() + wb.remove(wb.active) + + inline_styles_doc = Premailer(doc, base_url=base_url, remove_classes=False).transform() + tables = get_Tables(inline_styles_doc) + + for table in tables: + table_to_sheet(table, wb) + + return wb + + +def document_to_xl(doc, filename, base_url=None): + """ + Takes a string representation of an html document and writes one sheet for + every table in the document. The workbook is written out to a file called filename + """ + wb = document_to_workbook(doc, base_url=base_url) + wb.save(filename) + + +def insert_table(table, worksheet, column, row): + if table.head: + row = write_rows(worksheet, table.head, row, column) + if table.body: + row = write_rows(worksheet, table.body, row, column) + + +def insert_table_at_cell(table, cell): + """ + Inserts a table at the location of an openpyxl Cell object. + """ + ws = cell.parent + column, row = cell.column, cell.row + insert_table(table, ws, column, row) \ No newline at end of file diff --git a/test1/utility.py b/test1/utility.py new file mode 100644 index 0000000000000000000000000000000000000000..f21b287f7d4a044838d6949fae588547ee93ec3e --- /dev/null +++ b/test1/utility.py @@ -0,0 +1,54 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PIL import Image +import numpy as np +from tools.infer.utility import draw_ocr_box_txt, init_args as infer_args + + +def init_args(): + parser = infer_args() + + # params for output + parser.add_argument("--output", type=str, default='./output/table') + # params for table structure + parser.add_argument("--table_max_len", type=int, default=488) + parser.add_argument("--table_model_dir", type=str) + parser.add_argument("--table_char_type", type=str, default='en') + parser.add_argument("--table_char_dict_path", type=str, default="../ppocr/utils/dict/table_structure_dict.txt") + + return parser + + +def parse_args(): + parser = init_args() + return parser.parse_args() + + +def draw_result(image, result, font_path): + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + boxes, txts, scores = [], [], [] + for region in result: + if region['type'] == 'Table': + pass + elif region['type'] == 'Figure': + pass + else: + for box, rec_res in zip(region['res'][0], region['res'][1]): + boxes.append(np.array(box).reshape(-1, 2)) + txts.append(rec_res[0]) + scores.append(rec_res[1]) + im_show = draw_ocr_box_txt(image, boxes, txts, scores, font_path=font_path,drop_score=0) + return im_show \ No newline at end of file diff --git a/tools/eval.py b/tools/eval.py index 9817fa75093dd5127e3d11501ebc0473c9b53365..c1315805b5ff9bf29dee87a21688a145b4662b9a 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -44,12 +44,20 @@ def main(): # build model # for rec algorithm if hasattr(post_process_class, 'character'): - config['Architecture']["Head"]['out_channels'] = len( - getattr(post_process_class, 'character')) + char_num = len(getattr(post_process_class, 'character')) + if config['Architecture']["algorithm"] in ["Distillation", + ]: # distillation model + for key in config['Architecture']["Models"]: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + else: # base rec model + config['Architecture']["Head"]['out_channels'] = char_num + model = build_model(config['Architecture']) use_srn = config['Architecture']['algorithm'] == "SRN" + model_type = config['Architecture']['model_type'] - best_model_dict = init_model(config, model, logger) + best_model_dict = init_model(config, model) if len(best_model_dict): logger.info('metric in ckpt ***************') for k, v in best_model_dict.items(): @@ -60,7 +68,7 @@ def main(): # start eval metric = program.eval(model, valid_dataloader, post_process_class, - eval_class, use_srn) + eval_class, model_type, use_srn) logger.info('metric eval ***************') for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) diff --git a/tools/export_model.py b/tools/export_model.py index bdff89f755d465742f1c2a810f8ae76153a558c6..785aca10e46200bda49bdff2b89ba00cafbe7a20 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -17,7 +17,7 @@ import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) +sys.path.append(os.path.abspath(os.path.join(__dir__, ".."))) import argparse @@ -31,32 +31,12 @@ from ppocr.utils.logging import get_logger from tools.program import load_config, merge_config, ArgsParser -def main(): - FLAGS = ArgsParser().parse_args() - config = load_config(FLAGS.config) - merge_config(FLAGS.opt) - logger = get_logger() - # build post process - - post_process_class = build_post_process(config['PostProcess'], - config['Global']) - - # build model - # for rec algorithm - if hasattr(post_process_class, 'character'): - char_num = len(getattr(post_process_class, 'character')) - config['Architecture']["Head"]['out_channels'] = char_num - model = build_model(config['Architecture']) - init_model(config, model, logger) - model.eval() - - save_path = '{}/inference'.format(config['Global']['save_inference_dir']) - - if config['Architecture']['algorithm'] == "SRN": - max_text_length = config['Architecture']['Head']['max_text_length'] +def export_single_model(model, arch_config, save_path, logger): + if arch_config["algorithm"] == "SRN": + max_text_length = arch_config["Head"]["max_text_length"] other_shape = [ paddle.static.InputSpec( - shape=[None, 1, 64, 256], dtype='float32'), [ + shape=[None, 1, 64, 256], dtype="float32"), [ paddle.static.InputSpec( shape=[None, 256, 1], dtype="int64"), paddle.static.InputSpec( @@ -71,24 +51,67 @@ def main(): model = to_static(model, input_spec=other_shape) else: infer_shape = [3, -1, -1] - if config['Architecture']['model_type'] == "rec": + if arch_config["model_type"] == "rec": infer_shape = [3, 32, -1] # for rec model, H must be 32 - if 'Transform' in config['Architecture'] and config['Architecture'][ - 'Transform'] is not None and config['Architecture'][ - 'Transform']['name'] == 'TPS': + if "Transform" in arch_config and arch_config[ + "Transform"] is not None and arch_config["Transform"][ + "name"] == "TPS": logger.info( - 'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training' + "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training" ) infer_shape[-1] = 100 + elif arch_config["model_type"] == "table": + infer_shape = [3, 488, 488] model = to_static( model, input_spec=[ paddle.static.InputSpec( - shape=[None] + infer_shape, dtype='float32') + shape=[None] + infer_shape, dtype="float32") ]) paddle.jit.save(model, save_path) - logger.info('inference model is saved to {}'.format(save_path)) + logger.info("inference model is saved to {}".format(save_path)) + return + + +def main(): + FLAGS = ArgsParser().parse_args() + config = load_config(FLAGS.config) + merge_config(FLAGS.opt) + logger = get_logger() + # build post process + + post_process_class = build_post_process(config["PostProcess"], + config["Global"]) + + # build model + # for rec algorithm + if hasattr(post_process_class, "character"): + char_num = len(getattr(post_process_class, "character")) + if config["Architecture"]["algorithm"] in ["Distillation", + ]: # distillation model + for key in config["Architecture"]["Models"]: + config["Architecture"]["Models"][key]["Head"][ + "out_channels"] = char_num + else: # base rec model + config["Architecture"]["Head"]["out_channels"] = char_num + model = build_model(config["Architecture"]) + init_model(config, model) + model.eval() + + save_path = config["Global"]["save_inference_dir"] + + arch_config = config["Architecture"] + + if arch_config["algorithm"] in ["Distillation", ]: # distillation model + archs = list(arch_config["Models"].values()) + for idx, name in enumerate(model.model_name_list): + sub_model_save_path = os.path.join(save_path, name, "inference") + export_single_model(model.model_list[idx], archs[idx], + sub_model_save_path, logger) + else: + save_path = os.path.join(save_path, "inference") + export_single_model(model, arch_config, save_path, logger) if __name__ == "__main__": diff --git a/tools/infer/benchmark_utils.py b/tools/infer/benchmark_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1a241d063368d19567e253bf1dada09801d468bc --- /dev/null +++ b/tools/infer/benchmark_utils.py @@ -0,0 +1,232 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import time +import logging + +import paddle +import paddle.inference as paddle_infer + +from pathlib import Path + +CUR_DIR = os.path.dirname(os.path.abspath(__file__)) + + +class PaddleInferBenchmark(object): + def __init__(self, + config, + model_info: dict={}, + data_info: dict={}, + perf_info: dict={}, + resource_info: dict={}, + save_log_path: str="", + **kwargs): + """ + Construct PaddleInferBenchmark Class to format logs. + args: + config(paddle.inference.Config): paddle inference config + model_info(dict): basic model info + {'model_name': 'resnet50' + 'precision': 'fp32'} + data_info(dict): input data info + {'batch_size': 1 + 'shape': '3,224,224' + 'data_num': 1000} + perf_info(dict): performance result + {'preprocess_time_s': 1.0 + 'inference_time_s': 2.0 + 'postprocess_time_s': 1.0 + 'total_time_s': 4.0} + resource_info(dict): + cpu and gpu resources + {'cpu_rss': 100 + 'gpu_rss': 100 + 'gpu_util': 60} + """ + # PaddleInferBenchmark Log Version + self.log_version = 1.0 + + # Paddle Version + self.paddle_version = paddle.__version__ + self.paddle_commit = paddle.__git_commit__ + paddle_infer_info = paddle_infer.get_version() + self.paddle_branch = paddle_infer_info.strip().split(': ')[-1] + + # model info + self.model_info = model_info + + # data info + self.data_info = data_info + + # perf info + self.perf_info = perf_info + + try: + self.model_name = model_info['model_name'] + self.precision = model_info['precision'] + + self.batch_size = data_info['batch_size'] + self.shape = data_info['shape'] + self.data_num = data_info['data_num'] + + self.preprocess_time_s = round(perf_info['preprocess_time_s'], 4) + self.inference_time_s = round(perf_info['inference_time_s'], 4) + self.postprocess_time_s = round(perf_info['postprocess_time_s'], 4) + self.total_time_s = round(perf_info['total_time_s'], 4) + except: + self.print_help() + raise ValueError( + "Set argument wrong, please check input argument and its type") + + # conf info + self.config_status = self.parse_config(config) + self.save_log_path = save_log_path + # mem info + if isinstance(resource_info, dict): + self.cpu_rss_mb = int(resource_info.get('cpu_rss_mb', 0)) + self.gpu_rss_mb = int(resource_info.get('gpu_rss_mb', 0)) + self.gpu_util = round(resource_info.get('gpu_util', 0), 2) + else: + self.cpu_rss_mb = 0 + self.gpu_rss_mb = 0 + self.gpu_util = 0 + + # init benchmark logger + self.benchmark_logger() + + def benchmark_logger(self): + """ + benchmark logger + """ + # Init logger + FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + log_output = f"{self.save_log_path}/{self.model_name}.log" + Path(f"{self.save_log_path}").mkdir(parents=True, exist_ok=True) + logging.basicConfig( + level=logging.INFO, + format=FORMAT, + handlers=[ + logging.FileHandler( + filename=log_output, mode='w'), + logging.StreamHandler(), + ]) + self.logger = logging.getLogger(__name__) + self.logger.info( + f"Paddle Inference benchmark log will be saved to {log_output}") + + def parse_config(self, config) -> dict: + """ + parse paddle predictor config + args: + config(paddle.inference.Config): paddle inference config + return: + config_status(dict): dict style config info + """ + config_status = {} + config_status['runtime_device'] = "gpu" if config.use_gpu() else "cpu" + config_status['ir_optim'] = config.ir_optim() + config_status['enable_tensorrt'] = config.tensorrt_engine_enabled() + config_status['precision'] = self.precision + config_status['enable_mkldnn'] = config.mkldnn_enabled() + config_status[ + 'cpu_math_library_num_threads'] = config.cpu_math_library_num_threads( + ) + return config_status + + def report(self, identifier=None): + """ + print log report + args: + identifier(string): identify log + """ + if identifier: + identifier = f"[{identifier}]" + else: + identifier = "" + + self.logger.info("\n") + self.logger.info( + "---------------------- Paddle info ----------------------") + self.logger.info(f"{identifier} paddle_version: {self.paddle_version}") + self.logger.info(f"{identifier} paddle_commit: {self.paddle_commit}") + self.logger.info(f"{identifier} paddle_branch: {self.paddle_branch}") + self.logger.info(f"{identifier} log_api_version: {self.log_version}") + self.logger.info( + "----------------------- Conf info -----------------------") + self.logger.info( + f"{identifier} runtime_device: {self.config_status['runtime_device']}" + ) + self.logger.info( + f"{identifier} ir_optim: {self.config_status['ir_optim']}") + self.logger.info(f"{identifier} enable_memory_optim: {True}") + self.logger.info( + f"{identifier} enable_tensorrt: {self.config_status['enable_tensorrt']}" + ) + self.logger.info( + f"{identifier} enable_mkldnn: {self.config_status['enable_mkldnn']}") + self.logger.info( + f"{identifier} cpu_math_library_num_threads: {self.config_status['cpu_math_library_num_threads']}" + ) + self.logger.info( + "----------------------- Model info ----------------------") + self.logger.info(f"{identifier} model_name: {self.model_name}") + self.logger.info(f"{identifier} precision: {self.precision}") + self.logger.info( + "----------------------- Data info -----------------------") + self.logger.info(f"{identifier} batch_size: {self.batch_size}") + self.logger.info(f"{identifier} input_shape: {self.shape}") + self.logger.info(f"{identifier} data_num: {self.data_num}") + self.logger.info( + "----------------------- Perf info -----------------------") + self.logger.info( + f"{identifier} cpu_rss(MB): {self.cpu_rss_mb}, gpu_rss(MB): {self.gpu_rss_mb}, gpu_util: {self.gpu_util}%" + ) + self.logger.info( + f"{identifier} total time spent(s): {self.total_time_s}") + self.logger.info( + f"{identifier} preprocess_time(ms): {round(self.preprocess_time_s*1000, 1)}, inference_time(ms): {round(self.inference_time_s*1000, 1)}, postprocess_time(ms): {round(self.postprocess_time_s*1000, 1)}" + ) + + def print_help(self): + """ + print function help + """ + print("""Usage: + ==== Print inference benchmark logs. ==== + config = paddle.inference.Config() + model_info = {'model_name': 'resnet50' + 'precision': 'fp32'} + data_info = {'batch_size': 1 + 'shape': '3,224,224' + 'data_num': 1000} + perf_info = {'preprocess_time_s': 1.0 + 'inference_time_s': 2.0 + 'postprocess_time_s': 1.0 + 'total_time_s': 4.0} + resource_info = {'cpu_rss_mb': 100 + 'gpu_rss_mb': 100 + 'gpu_util': 60} + log = PaddleInferBenchmark(config, model_info, data_info, perf_info, resource_info) + log('Test') + """) + + def __call__(self, identifier=None): + """ + __call__ + args: + identifier(string): identify log + """ + self.report(identifier) diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py index d2592c6c95b0f466ea3ad5b45a35781282c9a492..0037b226df8e1de8edbdb7668e349925a942e8b9 100755 --- a/tools/infer/predict_cls.py +++ b/tools/infer/predict_cls.py @@ -45,9 +45,11 @@ class TextClassifier(object): "label_list": args.label_list, } self.postprocess_op = build_post_process(postprocess_params) - self.predictor, self.input_tensor, self.output_tensors = \ + self.predictor, self.input_tensor, self.output_tensors, _ = \ utility.create_predictor(args, 'cls', logger) + self.cls_times = utility.Timer() + def resize_norm_img(self, img): imgC, imgH, imgW = self.cls_image_shape h = img.shape[0] @@ -83,7 +85,9 @@ class TextClassifier(object): cls_res = [['', 0.0]] * img_num batch_num = self.cls_batch_num elapse = 0 + self.cls_times.total_time.start() for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) norm_img_batch = [] max_wh_ratio = 0 @@ -91,6 +95,7 @@ class TextClassifier(object): h, w = img_list[indices[ino]].shape[0:2] wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) + self.cls_times.preprocess_time.start() for ino in range(beg_img_no, end_img_no): norm_img = self.resize_norm_img(img_list[indices[ino]]) norm_img = norm_img[np.newaxis, :] @@ -98,11 +103,17 @@ class TextClassifier(object): norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = norm_img_batch.copy() starttime = time.time() + self.cls_times.preprocess_time.end() + self.cls_times.inference_time.start() + self.input_tensor.copy_from_cpu(norm_img_batch) self.predictor.run() prob_out = self.output_tensors[0].copy_to_cpu() + self.cls_times.inference_time.end() + self.cls_times.postprocess_time.start() self.predictor.try_shrink_memory() cls_result = self.postprocess_op(prob_out) + self.cls_times.postprocess_time.end() elapse += time.time() - starttime for rno in range(len(cls_result)): label, score = cls_result[rno] @@ -110,6 +121,9 @@ class TextClassifier(object): if '180' in label and score > self.cls_thresh: img_list[indices[beg_img_no + rno]] = cv2.rotate( img_list[indices[beg_img_no + rno]], 1) + self.cls_times.total_time.end() + self.cls_times.img_num += img_num + elapse = self.cls_times.total_time.value() return img_list, cls_res, elapse @@ -141,8 +155,9 @@ def main(args): for ino in range(len(img_list)): logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[ino])) - logger.info("Total predict time for {} images, cost: {:.3f}".format( - len(img_list), predict_time)) + logger.info( + "The predict time about text angle classify module is as follows: ") + text_classifier.cls_times.info(average=False) if __name__ == "__main__": diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 59bb49f90abb198933b91f222febad7a416018e8..1b52e717ca1edbbd9ede31260e47ec8973270d3f 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -31,6 +31,8 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.data import create_operators, transform from ppocr.postprocess import build_post_process +# import tools.infer.benchmark_utils as benchmark_utils + logger = get_logger() @@ -41,7 +43,7 @@ class TextDetector(object): pre_process_list = [{ 'DetResizeForTest': { 'limit_side_len': args.det_limit_side_len, - 'limit_type': args.det_limit_type + 'limit_type': args.det_limit_type, } }, { 'NormalizeImage': { @@ -95,9 +97,8 @@ class TextDetector(object): self.preprocess_op = create_operators(pre_process_list) self.postprocess_op = build_post_process(postprocess_params) - self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor( - args, 'det', logger) # paddle.jit.load(args.det_model_dir) - # self.predictor.eval() + self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor( + args, 'det', logger) def order_points_clockwise(self, pts): """ @@ -155,6 +156,8 @@ class TextDetector(object): def __call__(self, img): ori_im = img.copy() data = {'image': img} + + st = time.time() data = transform(data, self.preprocess_op) img, shape_list = data if img is None: @@ -162,7 +165,6 @@ class TextDetector(object): img = np.expand_dims(img, axis=0) shape_list = np.expand_dims(shape_list, axis=0) img = img.copy() - starttime = time.time() self.input_tensor.copy_from_cpu(img) self.predictor.run() @@ -184,6 +186,7 @@ class TextDetector(object): preds['maps'] = outputs[0] else: raise NotImplementedError + self.predictor.try_shrink_memory() post_result = self.postprocess_op(preds, shape_list) dt_boxes = post_result[0]['points'] @@ -191,8 +194,9 @@ class TextDetector(object): dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) else: dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) - elapse = time.time() - starttime - return dt_boxes, elapse + + et = time.time() + return dt_boxes, et - st if __name__ == "__main__": @@ -202,6 +206,14 @@ if __name__ == "__main__": count = 0 total_time = 0 draw_img_save = "./inference_results" + + if args.warmup: + img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) + for i in range(10): + res = text_detector(img) + + cpu_mem, gpu_mem, gpu_util = 0, 0, 0 + if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) for image_file in image_file_list: @@ -211,10 +223,13 @@ if __name__ == "__main__": if img is None: logger.info("error in loading image:{}".format(image_file)) continue - dt_boxes, elapse = text_detector(img) + st = time.time() + dt_boxes, _ = text_detector(img) + elapse = time.time() - st if count > 0: total_time += elapse count += 1 + logger.info("Predict time of {}: {}".format(image_file, elapse)) src_im = utility.draw_text_det_res(dt_boxes, image_file) img_name_pure = os.path.split(image_file)[-1] @@ -222,5 +237,3 @@ if __name__ == "__main__": "det_res_{}".format(img_name_pure)) cv2.imwrite(img_path, src_im) logger.info("The visualized image saved in {}".format(img_path)) - if count > 1: - logger.info("Avg Time: {}".format(total_time / (count - 1))) diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 24388026b8f395427c93e285ed550446e3aa9b9c..0d847046530c02c9b0591bb4b379fd7ddeac1263 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -28,6 +28,7 @@ import traceback import paddle import tools.infer.utility as utility +import tools.infer.benchmark_utils as benchmark_utils from ppocr.postprocess import build_post_process from ppocr.utils.logging import get_logger from ppocr.utils.utility import get_image_file_list, check_and_read_gif @@ -41,7 +42,6 @@ class TextRecognizer(object): self.character_type = args.rec_char_type self.rec_batch_num = args.rec_batch_num self.rec_algorithm = args.rec_algorithm - self.max_text_length = args.max_text_length postprocess_params = { 'name': 'CTCLabelDecode', "character_type": args.rec_char_type, @@ -63,9 +63,11 @@ class TextRecognizer(object): "use_space_char": args.use_space_char } self.postprocess_op = build_post_process(postprocess_params) - self.predictor, self.input_tensor, self.output_tensors = \ + self.predictor, self.input_tensor, self.output_tensors, self.config = \ utility.create_predictor(args, 'rec', logger) + self.rec_times = utility.Timer() + def resize_norm_img(self, img, max_wh_ratio): imgC, imgH, imgW = self.rec_image_shape assert imgC == img.shape[2] @@ -166,17 +168,15 @@ class TextRecognizer(object): width_list.append(img.shape[1] / float(img.shape[0])) # Sorting can speed up the recognition process indices = np.argsort(np.array(width_list)) - - # rec_res = [] + self.rec_times.total_time.start() rec_res = [['', 0.0]] * img_num batch_num = self.rec_batch_num - elapse = 0 for beg_img_no in range(0, img_num, batch_num): end_img_no = min(img_num, beg_img_no + batch_num) norm_img_batch = [] max_wh_ratio = 0 + self.rec_times.preprocess_time.start() for ino in range(beg_img_no, end_img_no): - # h, w = img_list[ino].shape[0:2] h, w = img_list[indices[ino]].shape[0:2] wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) @@ -187,9 +187,8 @@ class TextRecognizer(object): norm_img = norm_img[np.newaxis, :] norm_img_batch.append(norm_img) else: - norm_img = self.process_image_srn(img_list[indices[ino]], - self.rec_image_shape, 8, - self.max_text_length) + norm_img = self.process_image_srn( + img_list[indices[ino]], self.rec_image_shape, 8, 25) encoder_word_pos_list = [] gsrm_word_pos_list = [] gsrm_slf_attn_bias1_list = [] @@ -203,7 +202,6 @@ class TextRecognizer(object): norm_img_batch = norm_img_batch.copy() if self.rec_algorithm == "SRN": - starttime = time.time() encoder_word_pos_list = np.concatenate(encoder_word_pos_list) gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) gsrm_slf_attn_bias1_list = np.concatenate( @@ -218,19 +216,23 @@ class TextRecognizer(object): gsrm_slf_attn_bias1_list, gsrm_slf_attn_bias2_list, ] + self.rec_times.preprocess_time.end() + self.rec_times.inference_time.start() input_names = self.predictor.get_input_names() for i in range(len(input_names)): input_tensor = self.predictor.get_input_handle(input_names[ i]) input_tensor.copy_from_cpu(inputs[i]) self.predictor.run() + self.rec_times.inference_time.end() outputs = [] for output_tensor in self.output_tensors: output = output_tensor.copy_to_cpu() outputs.append(output) preds = {"predict": outputs[2]} else: - starttime = time.time() + self.rec_times.preprocess_time.end() + self.rec_times.inference_time.start() self.input_tensor.copy_from_cpu(norm_img_batch) self.predictor.run() @@ -239,22 +241,33 @@ class TextRecognizer(object): output = output_tensor.copy_to_cpu() outputs.append(output) preds = outputs[0] - self.predictor.try_shrink_memory() + self.rec_times.inference_time.end() + self.rec_times.postprocess_time.start() rec_result = self.postprocess_op(preds) for rno in range(len(rec_result)): rec_res[indices[beg_img_no + rno]] = rec_result[rno] - elapse += time.time() - starttime - return rec_res, elapse + self.rec_times.postprocess_time.end() + self.rec_times.img_num += int(norm_img_batch.shape[0]) + self.rec_times.total_time.end() + return rec_res, self.rec_times.total_time.value() def main(args): image_file_list = get_image_file_list(args.image_dir) text_recognizer = TextRecognizer(args) - total_run_time = 0.0 - total_images_num = 0 valid_image_file_list = [] img_list = [] - for idx, image_file in enumerate(image_file_list): + + # warmup 10 times + if args.warmup: + img = np.random.uniform(0, 255, [32, 320, 3]).astype(np.uint8) + for i in range(10): + res = text_recognizer([img]) + + cpu_mem, gpu_mem, gpu_util = 0, 0, 0 + count = 0 + + for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) @@ -263,29 +276,54 @@ def main(args): continue valid_image_file_list.append(image_file) img_list.append(img) - if len(img_list) >= args.rec_batch_num or idx == len( - image_file_list) - 1: - try: - rec_res, predict_time = text_recognizer(img_list) - total_run_time += predict_time - except: - logger.info(traceback.format_exc()) - logger.info( - "ERROR!!!! \n" - "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" - "If your model has tps module: " - "TPS does not support variable shape.\n" - "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' " - ) - exit() - for ino in range(len(img_list)): - logger.info("Predicts of {}:{}".format(valid_image_file_list[ - ino], rec_res[ino])) - total_images_num += len(valid_image_file_list) - valid_image_file_list = [] - img_list = [] - logger.info("Total predict time for {} images, cost: {:.3f}".format( - total_images_num, total_run_time)) + try: + rec_res, _ = text_recognizer(img_list) + if args.benchmark: + cm, gm, gu = utility.get_current_memory_mb(0) + cpu_mem += cm + gpu_mem += gm + gpu_util += gu + count += 1 + + except Exception as E: + logger.info(traceback.format_exc()) + logger.info(E) + exit() + for ino in range(len(img_list)): + logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], + rec_res[ino])) + if args.benchmark: + mems = { + 'cpu_rss_mb': cpu_mem / count, + 'gpu_rss_mb': gpu_mem / count, + 'gpu_util': gpu_util * 100 / count + } + else: + mems = None + logger.info("The predict time about recognizer module is as follows: ") + rec_time_dict = text_recognizer.rec_times.report(average=True) + rec_model_name = args.rec_model_dir + + if args.benchmark: + # construct log information + model_info = { + 'model_name': args.rec_model_dir.split('/')[-1], + 'precision': args.precision + } + data_info = { + 'batch_size': args.rec_batch_num, + 'shape': 'dynamic_shape', + 'data_num': rec_time_dict['img_num'] + } + perf_info = { + 'preprocess_time_s': rec_time_dict['preprocess_time'], + 'inference_time_s': rec_time_dict['inference_time'], + 'postprocess_time_s': rec_time_dict['postprocess_time'], + 'total_time_s': rec_time_dict['total_time'] + } + benchmark_log = benchmark_utils.PaddleInferBenchmark( + text_recognizer.config, model_info, data_info, perf_info, mems) + benchmark_log("Rec") if __name__ == "__main__": diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index ba81aff0a940fbee234e59e98f73c62fc7f69f09..c008f9679684e2433859cd104261aeff56b410a2 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -25,6 +25,7 @@ import cv2 import copy import numpy as np import time +import logging from PIL import Image import tools.infer.utility as utility import tools.infer.predict_rec as predict_rec @@ -32,13 +33,16 @@ import tools.infer.predict_det as predict_det import tools.infer.predict_cls as predict_cls from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.logging import get_logger -from tools.infer.utility import draw_ocr_box_txt - +from tools.infer.utility import draw_ocr_box_txt, get_current_memory_mb +import tools.infer.benchmark_utils as benchmark_utils logger = get_logger() class TextSystem(object): def __init__(self, args): + if not args.show_log: + logger.setLevel(logging.INFO) + self.text_detector = predict_det.TextDetector(args) self.text_recognizer = predict_rec.TextRecognizer(args) self.use_angle_cls = args.use_angle_cls @@ -85,10 +89,11 @@ class TextSystem(object): cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno]) logger.info(bno, rec_res[bno]) - def __call__(self, img): + def __call__(self, img, cls=True): ori_im = img.copy() dt_boxes, elapse = self.text_detector(img) - logger.info("dt_boxes num : {}, elapse : {}".format( + + logger.debug("dt_boxes num : {}, elapse : {}".format( len(dt_boxes), elapse)) if dt_boxes is None: return None, None @@ -100,14 +105,14 @@ class TextSystem(object): tmp_box = copy.deepcopy(dt_boxes[bno]) img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop_list.append(img_crop) - if self.use_angle_cls: + if self.use_angle_cls and cls: img_crop_list, angle_list, elapse = self.text_classifier( img_crop_list) - logger.info("cls num : {}, elapse : {}".format( + logger.debug("cls num : {}, elapse : {}".format( len(img_crop_list), elapse)) rec_res, elapse = self.text_recognizer(img_crop_list) - logger.info("rec_res num : {}, elapse : {}".format( + logger.debug("rec_res num : {}, elapse : {}".format( len(rec_res), elapse)) # self.print_draw_crop_rec_res(img_crop_list, rec_res) filter_boxes, filter_rec_res = [], [] @@ -147,7 +152,19 @@ def main(args): is_visualize = True font_path = args.vis_font_path drop_score = args.drop_score - for image_file in image_file_list: + + # warm up 10 times + if args.warmup: + img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) + for i in range(10): + res = text_sys(img) + + total_time = 0 + cpu_mem, gpu_mem, gpu_util = 0, 0, 0 + _st = time.time() + count = 0 + for idx, image_file in enumerate(image_file_list): + img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) @@ -157,8 +174,16 @@ def main(args): starttime = time.time() dt_boxes, rec_res = text_sys(img) elapse = time.time() - starttime - logger.info("Predict time of %s: %.3fs" % (image_file, elapse)) + total_time += elapse + if args.benchmark and idx % 20 == 0: + cm, gm, gu = get_current_memory_mb(0) + cpu_mem += cm + gpu_mem += gm + gpu_util += gu + count += 1 + logger.info( + str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse)) for text, score in rec_res: logger.info("{}, {:.3f}".format(text, score)) @@ -178,12 +203,74 @@ def main(args): draw_img_save = "./inference_results/" if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) + if flag: + image_file = image_file[:-3] + "png" cv2.imwrite( os.path.join(draw_img_save, os.path.basename(image_file)), draw_img[:, :, ::-1]) logger.info("The visualized image saved in {}".format( os.path.join(draw_img_save, os.path.basename(image_file)))) + logger.info("The predict total time is {}".format(time.time() - _st)) + logger.info("\nThe predict total time is {}".format(total_time)) + + img_num = text_sys.text_detector.det_times.img_num + if args.benchmark: + mems = { + 'cpu_rss_mb': cpu_mem / count, + 'gpu_rss_mb': gpu_mem / count, + 'gpu_util': gpu_util * 100 / count + } + else: + mems = None + det_time_dict = text_sys.text_detector.det_times.report(average=True) + rec_time_dict = text_sys.text_recognizer.rec_times.report(average=True) + det_model_name = args.det_model_dir + rec_model_name = args.rec_model_dir + + # construct det log information + model_info = { + 'model_name': args.det_model_dir.split('/')[-1], + 'precision': args.precision + } + data_info = { + 'batch_size': 1, + 'shape': 'dynamic_shape', + 'data_num': det_time_dict['img_num'] + } + perf_info = { + 'preprocess_time_s': det_time_dict['preprocess_time'], + 'inference_time_s': det_time_dict['inference_time'], + 'postprocess_time_s': det_time_dict['postprocess_time'], + 'total_time_s': det_time_dict['total_time'] + } + + benchmark_log = benchmark_utils.PaddleInferBenchmark( + text_sys.text_detector.config, model_info, data_info, perf_info, mems, + args.save_log_path) + benchmark_log("Det") + + # construct rec log information + model_info = { + 'model_name': args.rec_model_dir.split('/')[-1], + 'precision': args.precision + } + data_info = { + 'batch_size': args.rec_batch_num, + 'shape': 'dynamic_shape', + 'data_num': rec_time_dict['img_num'] + } + perf_info = { + 'preprocess_time_s': rec_time_dict['preprocess_time'], + 'inference_time_s': rec_time_dict['inference_time'], + 'postprocess_time_s': rec_time_dict['postprocess_time'], + 'total_time_s': rec_time_dict['total_time'] + } + benchmark_log = benchmark_utils.PaddleInferBenchmark( + text_sys.text_recognizer.config, model_info, data_info, perf_info, mems, + args.save_log_path) + benchmark_log("Rec") + if __name__ == "__main__": args = utility.parse_args() diff --git a/tools/infer/utility.py b/tools/infer/utility.py index b5fe3ba9813b319d5baf7a9cf2ed0cb655e12021..90ac5aa5ba2a33707965159d0486bb42957e3fce 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -21,18 +21,23 @@ import json from PIL import Image, ImageDraw, ImageFont import math from paddle import inference +import time +from ppocr.utils.logging import get_logger +logger = get_logger() + + +def str2bool(v): + return v.lower() in ("true", "t", "1") -def parse_args(): - def str2bool(v): - return v.lower() in ("true", "t", "1") +def init_args(): parser = argparse.ArgumentParser() # params for prediction engine parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--use_tensorrt", type=str2bool, default=False) - parser.add_argument("--use_fp16", type=str2bool, default=False) + parser.add_argument("--precision", type=str, default="fp32") parser.add_argument("--gpu_mem", type=int, default=500) # params for text detector @@ -98,15 +103,97 @@ def parse_args(): parser.add_argument("--cls_thresh", type=float, default=0.9) parser.add_argument("--enable_mkldnn", type=str2bool, default=False) + parser.add_argument("--cpu_threads", type=int, default=10) parser.add_argument("--use_pdserving", type=str2bool, default=False) + parser.add_argument("--warmup", type=str2bool, default=True) + # multi-process parser.add_argument("--use_mp", type=str2bool, default=False) parser.add_argument("--total_process_num", type=int, default=1) parser.add_argument("--process_id", type=int, default=0) + parser.add_argument("--benchmark", type=bool, default=False) + parser.add_argument("--save_log_path", type=str, default="./log_output/") + + parser.add_argument("--show_log", type=str2bool, default=True) + return parser + + +def parse_args(): + parser = init_args() return parser.parse_args() +class Times(object): + def __init__(self): + self.time = 0. + self.st = 0. + self.et = 0. + + def start(self): + self.st = time.time() + + def end(self, accumulative=True): + self.et = time.time() + if accumulative: + self.time += self.et - self.st + else: + self.time = self.et - self.st + + def reset(self): + self.time = 0. + self.st = 0. + self.et = 0. + + def value(self): + return round(self.time, 4) + + +class Timer(Times): + def __init__(self): + super(Timer, self).__init__() + self.total_time = Times() + self.preprocess_time = Times() + self.inference_time = Times() + self.postprocess_time = Times() + self.img_num = 0 + + def info(self, average=False): + logger.info("----------------------- Perf info -----------------------") + logger.info("total_time: {}, img_num: {}".format(self.total_time.value( + ), self.img_num)) + preprocess_time = round(self.preprocess_time.value() / self.img_num, + 4) if average else self.preprocess_time.value() + postprocess_time = round( + self.postprocess_time.value() / self.img_num, + 4) if average else self.postprocess_time.value() + inference_time = round(self.inference_time.value() / self.img_num, + 4) if average else self.inference_time.value() + + average_latency = self.total_time.value() / self.img_num + logger.info("average_latency(ms): {:.2f}, QPS: {:2f}".format( + average_latency * 1000, 1 / average_latency)) + logger.info( + "preprocess_latency(ms): {:.2f}, inference_latency(ms): {:.2f}, postprocess_latency(ms): {:.2f}". + format(preprocess_time * 1000, inference_time * 1000, + postprocess_time * 1000)) + + def report(self, average=False): + dic = {} + dic['preprocess_time'] = round( + self.preprocess_time.value() / self.img_num, + 4) if average else self.preprocess_time.value() + dic['postprocess_time'] = round( + self.postprocess_time.value() / self.img_num, + 4) if average else self.postprocess_time.value() + dic['inference_time'] = round( + self.inference_time.value() / self.img_num, + 4) if average else self.inference_time.value() + dic['img_num'] = self.img_num + dic['total_time'] = round(self.total_time.value(), 4) + return dic + + def create_predictor(args, mode, logger): if mode == "det": model_dir = args.det_model_dir @@ -114,6 +201,8 @@ def create_predictor(args, mode, logger): model_dir = args.cls_model_dir elif mode == 'rec': model_dir = args.rec_model_dir + elif mode == 'table': + model_dir = args.table_model_dir else: model_dir = args.e2e_model_dir @@ -131,30 +220,121 @@ def create_predictor(args, mode, logger): config = inference.Config(model_file_path, params_file_path) + if hasattr(args, 'precision'): + if args.precision == "fp16" and args.use_tensorrt: + precision = inference.PrecisionType.Half + elif args.precision == "int8": + precision = inference.PrecisionType.Int8 + else: + precision = inference.PrecisionType.Float32 + else: + precision = inference.PrecisionType.Float32 + if args.use_gpu: config.enable_use_gpu(args.gpu_mem, 0) if args.use_tensorrt: config.enable_tensorrt_engine( - precision_mode=inference.PrecisionType.Half - if args.use_fp16 else inference.PrecisionType.Float32, - max_batch_size=args.max_batch_size) + precision_mode=inference.PrecisionType.Float32, + max_batch_size=args.max_batch_size, + min_subgraph_size=3) # skip the minmum trt subgraph + if mode == "det" and "mobile" in model_file_path: + min_input_shape = { + "x": [1, 3, 50, 50], + "conv2d_92.tmp_0": [1, 96, 20, 20], + "conv2d_91.tmp_0": [1, 96, 10, 10], + "nearest_interp_v2_1.tmp_0": [1, 96, 10, 10], + "nearest_interp_v2_2.tmp_0": [1, 96, 20, 20], + "nearest_interp_v2_3.tmp_0": [1, 24, 20, 20], + "nearest_interp_v2_4.tmp_0": [1, 24, 20, 20], + "nearest_interp_v2_5.tmp_0": [1, 24, 20, 20], + "elementwise_add_7": [1, 56, 2, 2], + "nearest_interp_v2_0.tmp_0": [1, 96, 2, 2] + } + max_input_shape = { + "x": [1, 3, 2000, 2000], + "conv2d_92.tmp_0": [1, 96, 400, 400], + "conv2d_91.tmp_0": [1, 96, 200, 200], + "nearest_interp_v2_1.tmp_0": [1, 96, 200, 200], + "nearest_interp_v2_2.tmp_0": [1, 96, 400, 400], + "nearest_interp_v2_3.tmp_0": [1, 24, 400, 400], + "nearest_interp_v2_4.tmp_0": [1, 24, 400, 400], + "nearest_interp_v2_5.tmp_0": [1, 24, 400, 400], + "elementwise_add_7": [1, 56, 400, 400], + "nearest_interp_v2_0.tmp_0": [1, 96, 400, 400] + } + opt_input_shape = { + "x": [1, 3, 640, 640], + "conv2d_92.tmp_0": [1, 96, 160, 160], + "conv2d_91.tmp_0": [1, 96, 80, 80], + "nearest_interp_v2_1.tmp_0": [1, 96, 80, 80], + "nearest_interp_v2_2.tmp_0": [1, 96, 160, 160], + "nearest_interp_v2_3.tmp_0": [1, 24, 160, 160], + "nearest_interp_v2_4.tmp_0": [1, 24, 160, 160], + "nearest_interp_v2_5.tmp_0": [1, 24, 160, 160], + "elementwise_add_7": [1, 56, 40, 40], + "nearest_interp_v2_0.tmp_0": [1, 96, 40, 40] + } + if mode == "det" and "server" in model_file_path: + min_input_shape = { + "x": [1, 3, 50, 50], + "conv2d_59.tmp_0": [1, 96, 20, 20], + "nearest_interp_v2_2.tmp_0": [1, 96, 20, 20], + "nearest_interp_v2_3.tmp_0": [1, 24, 20, 20], + "nearest_interp_v2_4.tmp_0": [1, 24, 20, 20], + "nearest_interp_v2_5.tmp_0": [1, 24, 20, 20] + } + max_input_shape = { + "x": [1, 3, 2000, 2000], + "conv2d_59.tmp_0": [1, 96, 400, 400], + "nearest_interp_v2_2.tmp_0": [1, 96, 400, 400], + "nearest_interp_v2_3.tmp_0": [1, 24, 400, 400], + "nearest_interp_v2_4.tmp_0": [1, 24, 400, 400], + "nearest_interp_v2_5.tmp_0": [1, 24, 400, 400] + } + opt_input_shape = { + "x": [1, 3, 640, 640], + "conv2d_59.tmp_0": [1, 96, 160, 160], + "nearest_interp_v2_2.tmp_0": [1, 96, 160, 160], + "nearest_interp_v2_3.tmp_0": [1, 24, 160, 160], + "nearest_interp_v2_4.tmp_0": [1, 24, 160, 160], + "nearest_interp_v2_5.tmp_0": [1, 24, 160, 160] + } + elif mode == "rec": + min_input_shape = {"x": [args.rec_batch_num, 3, 32, 10]} + max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]} + opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]} + elif mode == "cls": + min_input_shape = {"x": [args.rec_batch_num, 3, 48, 10]} + max_input_shape = {"x": [args.rec_batch_num, 3, 48, 2000]} + opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]} + else: + min_input_shape = {"x": [1, 3, 10, 10]} + max_input_shape = {"x": [1, 3, 1000, 1000]} + opt_input_shape = {"x": [1, 3, 500, 500]} + config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, + opt_input_shape) + else: config.disable_gpu() - config.set_cpu_math_library_num_threads(6) + if hasattr(args, "cpu_threads"): + config.set_cpu_math_library_num_threads(args.cpu_threads) + else: + # default cpu threads as 10 + config.set_cpu_math_library_num_threads(10) if args.enable_mkldnn: # cache 10 different shapes for mkldnn to avoid memory leak config.set_mkldnn_cache_capacity(10) config.enable_mkldnn() - # TODO LDOUBLEV: fix mkldnn bug when bach_size > 1 - #config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'}) - args.rec_batch_num = 1 # enable memory optim config.enable_memory_optim() config.disable_glog_info() config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") + if mode == 'table': + config.delete_pass("fc_fuse_pass") # not supported for table config.switch_use_feed_fetch_ops(False) + config.switch_ir_optim(True) # create predictor predictor = inference.create_predictor(config) @@ -166,7 +346,7 @@ def create_predictor(args, mode, logger): for output_name in output_names: output_tensor = predictor.get_output_handle(output_name) output_tensors.append(output_tensor) - return predictor, input_tensor, output_tensors + return predictor, input_tensor, output_tensors, config def draw_e2e_res(dt_boxes, strs, img_path): @@ -210,7 +390,7 @@ def draw_ocr(image, txts=None, scores=None, drop_score=0.5, - font_path="./doc/simfang.ttf"): + font_path="./doc/fonts/simfang.ttf"): """ Visualize the results of OCR detection and recognition args: @@ -417,23 +597,30 @@ def draw_boxes(image, boxes, scores=None, drop_score=0.5): return image +def get_current_memory_mb(gpu_id=None): + """ + It is used to Obtain the memory usage of the CPU and GPU during the running of the program. + And this function Current program is time-consuming. + """ + import pynvml + import psutil + import GPUtil + pid = os.getpid() + p = psutil.Process(pid) + info = p.memory_full_info() + cpu_mem = info.uss / 1024. / 1024. + gpu_mem = 0 + gpu_percent = 0 + if gpu_id is not None: + GPUs = GPUtil.getGPUs() + gpu_load = GPUs[gpu_id].load + gpu_percent = gpu_load + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) + gpu_mem = meminfo.used / 1024. / 1024. + return round(cpu_mem, 4), round(gpu_mem, 4), round(gpu_percent, 4) + + if __name__ == '__main__': - test_img = "./doc/test_v2" - predict_txt = "./doc/predict.txt" - f = open(predict_txt, 'r') - data = f.readlines() - img_path, anno = data[0].strip().split('\t') - img_name = os.path.basename(img_path) - img_path = os.path.join(test_img, img_name) - image = Image.open(img_path) - - data = json.loads(anno) - boxes, txts, scores = [], [], [] - for dic in data: - boxes.append(dic['points']) - txts.append(dic['transcription']) - scores.append(round(dic['scores'], 3)) - - new_img = draw_ocr(image, boxes, txts, scores) - - cv2.imwrite(img_name, new_img) + pass diff --git a/tools/infer_cls.py b/tools/infer_cls.py index 496964826b0b063f9f937c31342932c6cd95502f..a588cab433442695e3bd395da63e35a2052de501 100755 --- a/tools/infer_cls.py +++ b/tools/infer_cls.py @@ -47,7 +47,7 @@ def main(): # build model model = build_model(config['Architecture']) - init_model(config, model, logger) + init_model(config, model) # create data ops transforms = [] diff --git a/tools/infer_det.py b/tools/infer_det.py index 913d617defea18fe881e6fd2212b1df20f7d26d3..a964cd28c934504ce79ea4873d3345295c1266e5 100755 --- a/tools/infer_det.py +++ b/tools/infer_det.py @@ -61,7 +61,7 @@ def main(): # build model model = build_model(config['Architecture']) - init_model(config, model, logger) + init_model(config, model) # build post process post_process_class = build_post_process(config['PostProcess']) @@ -112,4 +112,4 @@ def main(): if __name__ == '__main__': config, device, logger, vdl_writer = program.preprocess() - main() \ No newline at end of file + main() diff --git a/tools/infer_e2e.py b/tools/infer_e2e.py index 9c079f6074f088ef0298cab839f74faefad82abb..1cd468b8e552237af31d985b8b68ddbeecba9c96 100755 --- a/tools/infer_e2e.py +++ b/tools/infer_e2e.py @@ -68,7 +68,7 @@ def main(): # build model model = build_model(config['Architecture']) - init_model(config, model, logger) + init_model(config, model) # build post process post_process_class = build_post_process(config['PostProcess'], diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 2563f5a8197ed39b1b5d44c7cfee32797e760758..09f5a0c767b15c312cdfbe8ed695ea06bdc8cdc4 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -20,6 +20,7 @@ import numpy as np import os import sys +import json __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) @@ -46,12 +47,18 @@ def main(): # build model if hasattr(post_process_class, 'character'): - config['Architecture']["Head"]['out_channels'] = len( - getattr(post_process_class, 'character')) + char_num = len(getattr(post_process_class, 'character')) + if config['Architecture']["algorithm"] in ["Distillation", + ]: # distillation model + for key in config['Architecture']["Models"]: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + else: # base rec model + config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) - init_model(config, model, logger) + init_model(config, model) # create data ops transforms = [] @@ -107,11 +114,23 @@ def main(): else: preds = model(images) post_result = post_process_class(preds) - for rec_reuslt in post_result: - logger.info('\t result: {}'.format(rec_reuslt)) - if len(rec_reuslt) >= 2: - fout.write(file + "\t" + rec_reuslt[0] + "\t" + str( - rec_reuslt[1]) + "\n") + info = None + if isinstance(post_result, dict): + rec_info = dict() + for key in post_result: + if len(post_result[key][0]) >= 2: + rec_info[key] = { + "label": post_result[key][0][0], + "score": post_result[key][0][1], + } + info = json.dumps(rec_info) + else: + if len(post_result[0]) >= 2: + info = post_result[0][0] + "\t" + str(post_result[0][1]) + + if info is not None: + logger.info("\t result: {}".format(info)) + fout.write(file + "\t" + info) logger.info("success!") diff --git a/tools/infer_table.py b/tools/infer_table.py new file mode 100644 index 0000000000000000000000000000000000000000..f743d87540f7fd64157a808db156c9f62a042d9c --- /dev/null +++ b/tools/infer_table.py @@ -0,0 +1,107 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +import os +import sys +import json + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import paddle +from paddle.jit import to_static + +from ppocr.data import create_operators, transform +from ppocr.modeling.architectures import build_model +from ppocr.postprocess import build_post_process +from ppocr.utils.save_load import init_model +from ppocr.utils.utility import get_image_file_list +import tools.program as program +import cv2 + +def main(config, device, logger, vdl_writer): + global_config = config['Global'] + + # build post process + post_process_class = build_post_process(config['PostProcess'], + global_config) + + # build model + if hasattr(post_process_class, 'character'): + config['Architecture']["Head"]['out_channels'] = len( + getattr(post_process_class, 'character')) + + model = build_model(config['Architecture']) + + init_model(config, model, logger) + + # create data ops + transforms = [] + use_padding = False + for op in config['Eval']['dataset']['transforms']: + op_name = list(op)[0] + if 'Label' in op_name: + continue + if op_name == 'KeepKeys': + op[op_name]['keep_keys'] = ['image'] + if op_name == "ResizeTableImage": + use_padding = True + padding_max_len = op['ResizeTableImage']['max_len'] + transforms.append(op) + + global_config['infer_mode'] = True + ops = create_operators(transforms, global_config) + + model.eval() + for file in get_image_file_list(config['Global']['infer_img']): + logger.info("infer_img: {}".format(file)) + with open(file, 'rb') as f: + img = f.read() + data = {'image': img} + batch = transform(data, ops) + images = np.expand_dims(batch[0], axis=0) + images = paddle.to_tensor(images) + preds = model(images) + post_result = post_process_class(preds) + res_html_code = post_result['res_html_code'] + res_loc = post_result['res_loc'] + img = cv2.imread(file) + imgh, imgw = img.shape[0:2] + res_loc_final = [] + for rno in range(len(res_loc[0])): + x0, y0, x1, y1 = res_loc[0][rno] + left = max(int(imgw * x0), 0) + top = max(int(imgh * y0), 0) + right = min(int(imgw * x1), imgw - 1) + bottom = min(int(imgh * y1), imgh - 1) + cv2.rectangle(img, (left, top), (right, bottom), (0, 0, 255), 2) + res_loc_final.append([left, top, right, bottom]) + res_loc_str = json.dumps(res_loc_final) + logger.info("result: {}, {}".format(res_html_code, res_loc_final)) + logger.info("success!") + + +if __name__ == '__main__': + config, device, logger, vdl_writer = program.preprocess() + main(config, device, logger, vdl_writer) + diff --git a/tools/program.py b/tools/program.py index 7e54a2f8c2f1db8881aa476a309c8a8c563fcae5..2d99f2968a3f0c8acc359ed0fbb199650bd7010c 100755 --- a/tools/program.py +++ b/tools/program.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -186,6 +186,7 @@ def train(config, model.train() use_srn = config['Architecture']['algorithm'] == "SRN" + model_type = config['Architecture']['model_type'] if 'start_epoch' in best_model_dict: start_epoch = best_model_dict['start_epoch'] @@ -208,9 +209,9 @@ def train(config, lr = optimizer.get_lr() images = batch[0] if use_srn: - others = batch[-4:] - preds = model(images, others) model_average = True + if use_srn or model_type == 'table': + preds = model(images, data=batch[1:]) else: preds = model(images) loss = loss_class(preds, batch) @@ -232,8 +233,11 @@ def train(config, if cal_metric_during_train: # only rec and cls need batch = [item.numpy() for item in batch] - post_result = post_process_class(preds, batch[1]) - eval_class(post_result, batch) + if model_type == 'table': + eval_class(preds, batch) + else: + post_result = post_process_class(preds, batch[1]) + eval_class(post_result, batch) metric = eval_class.get_metric() train_stats.update(metric) @@ -269,6 +273,7 @@ def train(config, valid_dataloader, post_process_class, eval_class, + model_type, use_srn=use_srn) cur_metric_str = 'cur metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in cur_metric.items()])) @@ -336,7 +341,11 @@ def train(config, return -def eval(model, valid_dataloader, post_process_class, eval_class, +def eval(model, + valid_dataloader, + post_process_class, + eval_class, + model_type, use_srn=False): model.eval() with paddle.no_grad(): @@ -350,19 +359,19 @@ def eval(model, valid_dataloader, post_process_class, eval_class, break images = batch[0] start = time.time() - - if use_srn: - others = batch[-4:] - preds = model(images, others) + if use_srn or model_type == 'table': + preds = model(images, data=batch[1:]) else: preds = model(images) - batch = [item.numpy() for item in batch] # Obtain usable results from post-processing methods - post_result = post_process_class(preds, batch[1]) total_time += time.time() - start # Evaluate the results of the current batch - eval_class(post_result, batch) + if model_type == 'table': + eval_class(preds, batch) + else: + post_result = post_process_class(preds, batch[1]) + eval_class(post_result, batch) pbar.update(1) total_frame += len(images) # Get final metric,eg. acc or hmean @@ -386,7 +395,7 @@ def preprocess(is_train=False): alg = config['Architecture']['algorithm'] assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', - 'CLS', 'PGNet' + 'CLS', 'PGNet', 'Distillation', 'TableAttn' ] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' diff --git a/tools/train.py b/tools/train.py index 47358ca43da46b7eb6a04cd1f7fe284efd7e96f7..b024240b4d5d4973645336c62d3762087ec7bbeb 100755 --- a/tools/train.py +++ b/tools/train.py @@ -72,7 +72,14 @@ def main(config, device, logger, vdl_writer): # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) - config['Architecture']["Head"]['out_channels'] = char_num + if config['Architecture']["algorithm"] in ["Distillation", + ]: # distillation model + for key in config['Architecture']["Models"]: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + else: # base rec model + config['Architecture']["Head"]['out_channels'] = char_num + model = build_model(config['Architecture']) if config['Global']['distributed']: model = paddle.DataParallel(model) @@ -90,7 +97,7 @@ def main(config, device, logger, vdl_writer): # build metric eval_class = build_metric(config['Metric']) # load pretrain model - pre_best_model_dict = init_model(config, model, logger, optimizer) + pre_best_model_dict = init_model(config, model, optimizer) logger.info('train dataloader has {} iters'.format(len(train_dataloader))) if valid_dataloader is not None: