diff --git a/configs/rec/PP-OCRv3/multi_language/arabic_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/multi_language/arabic_PP-OCRv3_rec.yml index 0ad1ab0adc189102ff07094fcda92d4f9ea9c662..8c650bd826d127f25c907f97d20d1a52f67f9203 100644 --- a/configs/rec/PP-OCRv3/multi_language/arabic_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/multi_language/arabic_PP-OCRv3_rec.yml @@ -12,7 +12,7 @@ Global: checkpoints: save_inference_dir: use_visualdl: false - infer_img: doc/imgs_words/ch/word_1.jpg + infer_img: ./doc/imgs_words/arabic/ar_2.jpg character_dict_path: ppocr/utils/dict/arabic_dict.txt max_text_length: &max_text_length 25 infer_mode: false diff --git a/doc/overview_en.png b/doc/overview_en.png deleted file mode 100644 index b44da4e9874d6a2162a8bb05ff1b479875bd65f3..0000000000000000000000000000000000000000 Binary files a/doc/overview_en.png and /dev/null differ diff --git a/doc/ppocr_v3/svtr_tiny.jpg b/doc/ppocr_v3/svtr_tiny.jpg deleted file mode 100644 index 26261047ef253e9802956f4c64449870d10de850..0000000000000000000000000000000000000000 Binary files a/doc/ppocr_v3/svtr_tiny.jpg and /dev/null differ diff --git a/ppocr/modeling/backbones/vqa_layoutlm.py b/ppocr/modeling/backbones/vqa_layoutlm.py index ec64b0327a3c2172c25443c9116fdc98679c2710..8e10ed7b48e9aff344b71e5a04970d1a5dab8a71 100644 --- a/ppocr/modeling/backbones/vqa_layoutlm.py +++ b/ppocr/modeling/backbones/vqa_layoutlm.py @@ -29,14 +29,14 @@ __all__ = ["LayoutXLMForSer", "LayoutLMForSer"] pretrained_model_dict = { LayoutXLMModel: { "base": "layoutxlm-base-uncased", - "vi": "layoutxlm-wo-backbone-base-uncased", + "vi": "vi-layoutxlm-base-uncased", }, LayoutLMModel: { "base": "layoutlm-base-uncased", }, LayoutLMv2Model: { "base": "layoutlmv2-base-uncased", - "vi": "layoutlmv2-wo-backbone-base-uncased", + "vi": "vi-layoutlmv2-base-uncased", }, } diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index fc9fccfb143bf31ec66989e279d0bcc1c9baa5cc..f77631700648e84f28223cb14738e7b4ab679012 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -45,6 +45,27 @@ class BaseRecLabelDecode(object): self.dict[char] = i self.character = dict_character + if 'arabic' in character_dict_path: + self.reverse = True + else: + self.reverse = False + + def pred_reverse(self, pred): + pred_re = [] + c_current = '' + for c in pred: + if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)): + if c_current != '': + pred_re.append(c_current) + pred_re.append(c) + c_current = '' + else: + c_current += c + if c_current != '': + pred_re.append(c_current) + + return ''.join(pred_re[::-1]) + def add_special_char(self, dict_character): return dict_character @@ -73,6 +94,10 @@ class BaseRecLabelDecode(object): conf_list = [0] text = ''.join(char_list) + + if self.reverse: # for arabic rec + text = self.pred_reverse(text) + result_list.append((text, np.mean(conf_list).tolist())) return result_list diff --git a/ppocr/utils/dict/arabic_dict.txt b/ppocr/utils/dict/arabic_dict.txt index e97abf39274df77fbad066ee4635aebc6743140c..916d421c53bad563dfd980c1b64dcce07a3c9d24 100644 --- a/ppocr/utils/dict/arabic_dict.txt +++ b/ppocr/utils/dict/arabic_dict.txt @@ -1,4 +1,3 @@ - ! # $ diff --git a/ppstructure/docs/models_list_en.md b/ppstructure/docs/models_list_en.md index 85531fb753c4e32f0cdc9296ab97a9faebbb0ebd..291d42f995fdd7fabc293a0e4df35c2249945fd2 100644 --- a/ppstructure/docs/models_list_en.md +++ b/ppstructure/docs/models_list_en.md @@ -13,7 +13,7 @@ |model name| description | inference model size |download|dict path| | --- |---------------------------------------------------------------------------------------------------------------------------------------------------------| --- | --- | --- | | picodet_lcnet_x1_0_fgd_layout | The layout analysis English model trained on the PubLayNet dataset based on PicoDet LCNet_x1_0 and FGD . the model can recognition 5 types of areas such as **Text, Title, Table, Picture and List** | 9.7M | [inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_fgd_layout_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_fgd_layout.pdparams) | [PubLayNet dict](../../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt) | -| ppyolov2_r50vd_dcn_365e_publaynet | The layout analysis English model trained on the PubLayNet dataset based on PP-YOLOv2 | 221M | [inference_moel]](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet.tar) / [trained model](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet_pretrained.pdparams) | sme as above | +| ppyolov2_r50vd_dcn_365e_publaynet | The layout analysis English model trained on the PubLayNet dataset based on PP-YOLOv2 | 221M | [inference_moel](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet.tar) / [trained model](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet_pretrained.pdparams) | same as above | | picodet_lcnet_x1_0_fgd_layout_cdla | The layout analysis Chinese model trained on the CDLA dataset, the model can recognition 10 types of areas such as **Table、Figure、Figure caption、Table、Table caption、Header、Footer、Reference、Equation** | 9.7M | [inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_fgd_layout_cdla_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_fgd_layout_cdla.pdparams) | [CDLA dict](../../ppocr/utils/dict/layout_dict/layout_cdla_dict.txt) | | picodet_lcnet_x1_0_fgd_layout_table | The layout analysis model trained on the table dataset, the model can detect tables in Chinese and English documents | 9.7M | [inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_fgd_layout_table_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_fgd_layout_table.pdparams) | [Table dict](../../ppocr/utils/dict/layout_dict/layout_table_dict.txt) | | ppyolov2_r50vd_dcn_365e_tableBank_word | The layout analysis model trained on the TableBank Word dataset based on PP-YOLOv2, the model can detect tables in English documents | 221M | [inference model](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_word.tar) | same as above | diff --git a/ppstructure/docs/quickstart.md b/ppstructure/docs/quickstart.md index b9367cab327a2f6232e34431c12532db03c75389..60642f78b6691c3ac2eeba99680a2af23299ddc9 100644 --- a/ppstructure/docs/quickstart.md +++ b/ppstructure/docs/quickstart.md @@ -48,7 +48,7 @@ pip3 install "paddleocr>=2.6" # 安装 图像方向分类依赖包paddleclas(如不需要图像方向分类功能,可跳过) -pip3 install paddleclas +pip3 install paddleclas>=2.4.3 # 安装 关键信息抽取 依赖包(如不需要KIE功能,可跳过) pip3 install -r ppstructure/kie/requirements.txt diff --git a/ppstructure/docs/quickstart_en.md b/ppstructure/docs/quickstart_en.md index b1df40b267a82fd48853edf607acd43f3a5431c9..e0eec4b38ba57b1bebd0e711093e5dfd4773fdd9 100644 --- a/ppstructure/docs/quickstart_en.md +++ b/ppstructure/docs/quickstart_en.md @@ -50,7 +50,7 @@ For more software version requirements, please refer to the instructions in [Ins pip3 install "paddleocr>=2.6" # Install the image direction classification dependency package paddleclas (if you do not use the image direction classification, you can skip it) -pip3 install paddleclas +pip3 install paddleclas>=2.4.3 # Install the KIE dependency packages (if you do not use the KIE, you can skip it) pip3 install -r kie/requirements.txt diff --git a/ppstructure/layout/README.md b/ppstructure/layout/README.md new file mode 100644 index 0000000000000000000000000000000000000000..01faa7b279c618602cafb8ef7d086753061ea559 --- /dev/null +++ b/ppstructure/layout/README.md @@ -0,0 +1,468 @@ +English | [简体中文](README_ch.md) + +# Layout analysis + +- [1. Introduction](#1-Introduction) +- [2. Install](#2-Install) + - [2.1 Install PaddlePaddle](#21-Install-paddlepaddle) + - [2.2 Install PaddleDetection](#22-Install-paddledetection) +- [3. Data preparation](#3-Data-preparation) + - [3.1 English data set](#31-English-data-set) + - [3.2 More datasets](#32-More-datasets) +- [4. Start training](#4-Start-training) + - [4.1 Train](#41-Train) + - [4.2 FGD Distillation training](#42-FGD-Distillation-training) +- [5. Model evaluation and prediction](#5-Model-evaluation-and-prediction) + - [5.1 Indicator evaluation](#51-Indicator-evaluation) + - [5.2 Test layout analysis results](#52-Test-layout-analysis-results) +- [6 Model export and inference](#6-Model-export-and-inference) + - [6.1 Model export](#61-Model-export) + - [6.2 Model inference](#62-Model-inference) + + +## 1. Introduction + +Layout analysis refers to the regional division of documents in the form of pictures and the positioning of key areas, such as text, title, table, picture, etc. The layout analysis algorithm is based on the lightweight model PP-picodet of [PaddleDetection]( https://github.com/PaddlePaddle/PaddleDetection ) + +
+ +
+ + + +## 2. Install + +### 2.1. Install PaddlePaddle + +- **(1) Install PaddlePaddle** + +```bash +python3 -m pip install --upgrade pip + +# GPU Install +python3 -m pip install "paddlepaddle-gpu>=2.3" -i https://mirror.baidu.com/pypi/simple + +# CPU Install +python3 -m pip install "paddlepaddle>=2.3" -i https://mirror.baidu.com/pypi/simple +``` +For more requirements, please refer to the instructions in the [Install file](https://www.paddlepaddle.org.cn/install/quick)。 + +### 2.2. Install PaddleDetection + +- **(1)Download PaddleDetection Source code** + +```bash +git clone https://github.com/PaddlePaddle/PaddleDetection.git +``` + +- **(2)Install third-party libraries** + +```bash +cd PaddleDetection +python3 -m pip install -r requirements.txt +``` + +## 3. Data preparation + +If you want to experience the prediction process directly, you can skip data preparation and download the pre-training model. + +### 3.1. English data set + +Download document analysis data set [PubLayNet](https://developer.ibm.com/exchanges/data/all/publaynet/)(Dataset 96G),contains 5 classes:`{0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}` + +``` +# Download data +wget https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/publaynet.tar.gz +# Decompress data +tar -xvf publaynet.tar.gz +``` + +Uncompressed **directory structure:** + +``` +|-publaynet + |- test + |- PMC1277013_00004.jpg + |- PMC1291385_00002.jpg + | ... + |- train.json + |- train + |- PMC1291385_00002.jpg + |- PMC1277013_00004.jpg + | ... + |- val.json + |- val + |- PMC538274_00004.jpg + |- PMC539300_00004.jpg + | ... +``` + +**data distribution:** + +| File or Folder | Description | num | +| :------------- | :------------- | ------- | +| `train/` | Training set pictures | 335,703 | +| `val/` | Verification set pictures | 11,245 | +| `test/` | Test set pictures | 11,405 | +| `train.json` | Training set annotation files | - | +| `val.json` | Validation set dimension files | - | + +**Data Annotation** + +The JSON file contains the annotations of all images, and the data is stored in a dictionary nested manner.Contains the following keys: + +- info,represents the dimension file info。 + +- licenses,represents the dimension file licenses。 + +- images,represents the list of image information in the annotation file,each element is the information of an image。The information of one of the images is as follows: + + ``` + { + 'file_name': 'PMC4055390_00006.jpg', # file_name + 'height': 601, # image height + 'width': 792, # image width + 'id': 341427 # image id + } + ``` + +- annotations, represents the list of annotation information of the target object in the annotation file,each element is the annotation information of a target object。The following is the annotation information of one of the target objects: + + ``` + { + + 'segmentation': # Segmentation annotation of objects + 'area': 60518.099043117836, # Area of object + 'iscrowd': 0, # iscrowd + 'image_id': 341427, # image id + 'bbox': [50.58, 490.86, 240.15, 252.16], # bbox [x1,y1,w,h] + 'category_id': 1, # category_id + 'id': 3322348 # image id + } + ``` + +### 3.2. More datasets + +We provide CDLA(Chinese layout analysis), TableBank(Table layout analysis)etc. data set download links,process to the JSON format of the above annotation file,that is, the training can be conducted in the same way。 + +| dataset | 简介 | +| ------------------------------------------------------------ | ------------------------------------------------------------ | +| [cTDaR2019_cTDaR](https://cndplab-founder.github.io/cTDaR2019/) | For form detection (TRACKA) and form identification (TRACKB).Image types include historical data sets (beginning with cTDaR_t0, such as CTDAR_T00872.jpg) and modern data sets (beginning with cTDaR_t1, CTDAR_T10482.jpg). | +| [IIIT-AR-13K](http://cvit.iiit.ac.in/usodi/iiitar13k.php) | Data sets constructed by manually annotating figures or pages from publicly available annual reports, containing 5 categories:table, figure, natural image, logo, and signature. | +| [TableBank](https://github.com/doc-analysis/TableBank) | For table detection and recognition of large datasets, including Word and Latex document formats | +| [CDLA](https://github.com/buptlihang/CDLA) | Chinese document layout analysis data set, for Chinese literature (paper) scenarios, including 10 categories:Table, Figure, Figure caption, Table, Table caption, Header, Footer, Reference, Equation | +| [DocBank](https://github.com/doc-analysis/DocBank) | Large-scale dataset (500K document pages) constructed using weakly supervised methods for document layout analysis, containing 12 categories:Author, Caption, Date, Equation, Figure, Footer, List, Paragraph, Reference, Section, Table, Title | + + +## 4. Start training + +Training scripts, evaluation scripts, and prediction scripts are provided, and the PubLayNet pre-training model is used as an example in this section. + +If you do not want training and directly experience the following process of model evaluation, prediction, motion to static, and inference, you can download the provided pre-trained model (PubLayNet dataset) and skip this part. + +``` +mkdir pretrained_model +cd pretrained_model +# Download PubLayNet pre-training model(Direct experience model evaluates, predicts, and turns static) +wget https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_fgd_layout.pdparams +# Download the PubLaynet inference model(Direct experience model reasoning) +wget https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_fgd_layout_infer.tar +``` + +If the test image is Chinese, the pre-trained model of Chinese CDLA dataset can be downloaded to identify 10 types of document regions:Table, Figure, Figure caption, Table, Table caption, Header, Footer, Reference, Equation,Download the training model and inference model of Model 'picodet_lcnet_x1_0_fgd_layout_cdla' in [layout analysis model](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/ppstructure/docs/models_list.md)。If only the table area in the image is detected, you can download the pre-trained model of the table dataset, and download the training model and inference model of the 'picodet_LCnet_x1_0_FGd_layout_table' model in [Layout Analysis model](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/ppstructure/docs/models_list.md) + +### 4.1. Train + +Train: + +* Modify Profile + +If you want to train your own data set, you need to modify the data configuration and the number of categories in the configuration file. + + +Using 'configs/picodet/legacy_model/application/layout_analysis/picodet_lcnet_x1_0_layout.yml' as an example, the change is as follows: + +```yaml +metric: COCO +# Number of categories +num_classes: 5 + +TrainDataset: + !COCODataSet + # Modify to your own training data directory + image_dir: train + # Modify to your own training data label file + anno_path: train.json + # Modify to your own training data root directory + dataset_dir: /root/publaynet/ + data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd'] + +EvalDataset: + !COCODataSet + # Modify to your own validation data directory + image_dir: val + # Modify to your own validation data label file + anno_path: val.json + # Modify to your own validation data root + dataset_dir: /root/publaynet/ + +TestDataset: + !ImageFolder + # Modify to your own test data label file + anno_path: /root/publaynet/val.json +``` + +* Start training. During training, PP picodet pre training model will be downloaded by default. There is no need to download in advance. + +```bash +# GPU training supports single-card and multi-card training +# The training log is automatically saved to the log directory + +# Single card training +export CUDA_VISIBLE_DEVICES=0 +python3 tools/train.py \ + -c configs/picodet/legacy_model/application/layout_analysis/picodet_lcnet_x1_0_layout.yml \ + --eval + +# Multi-card training, with the -- GPUS parameter specifying the card number +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py \ + -c configs/picodet/legacy_model/application/layout_analysis/picodet_lcnet_x1_0_layout.yml \ + --eval +``` + +**Attention:**If the video memory is out during training, adjust Batch_size in TrainReader and base_LR in LearningRate. The published config is obtained by 8-card training. If the number of GPU cards is changed to 1, then the base_LR needs to be reduced by 8 times. + +After starting training normally, you will see the following log output: + +``` +[08/15 04:02:30] ppdet.utils.checkpoint INFO: Finish loading model weights: /root/.cache/paddle/weights/LCNet_x1_0_pretrained.pdparams +[08/15 04:02:46] ppdet.engine INFO: Epoch: [0] [ 0/1929] learning_rate: 0.040000 loss_vfl: 1.216707 loss_bbox: 1.142163 loss_dfl: 0.544196 loss: 2.903065 eta: 17 days, 13:50:26 batch_cost: 15.7452 data_cost: 2.9112 ips: 1.5243 images/s +[08/15 04:03:19] ppdet.engine INFO: Epoch: [0] [ 20/1929] learning_rate: 0.064000 loss_vfl: 1.180627 loss_bbox: 0.939552 loss_dfl: 0.442436 loss: 2.628206 eta: 2 days, 12:18:53 batch_cost: 1.5770 data_cost: 0.0008 ips: 15.2184 images/s +[08/15 04:03:47] ppdet.engine INFO: Epoch: [0] [ 40/1929] learning_rate: 0.088000 loss_vfl: 0.543321 loss_bbox: 1.071401 loss_dfl: 0.457817 loss: 2.057003 eta: 2 days, 0:07:03 batch_cost: 1.3190 data_cost: 0.0007 ips: 18.1954 images/s +[08/15 04:04:12] ppdet.engine INFO: Epoch: [0] [ 60/1929] learning_rate: 0.112000 loss_vfl: 0.630989 loss_bbox: 0.859183 loss_dfl: 0.384702 loss: 1.883143 eta: 1 day, 19:01:29 batch_cost: 1.2177 data_cost: 0.0006 ips: 19.7087 images/s +``` + +- `--eval` indicates that the best model is saved as `output/picodet_lcnet_x1_0_layout/best_accuracy` by default during the evaluation process 。 + +**Note that the configuration file for prediction / evaluation must be consistent with the training.** + +### 4.2. FGD Distillation Training + +PaddleDetection supports FGD-based [Focal and Global Knowledge Distillation for Detectors]( https://arxiv.org/abs/2111.11837v1) The training process of the target detection model of distillation, FGD distillation is divided into two parts `Focal` and `Global`. `Focal` Distillation separates the foreground and background of the image, allowing the student model to focus on the key pixels of the foreground and background features of the teacher model respectively;` Global`Distillation section reconstructs the relationships between different pixels and transfers them from the teacher to the student to compensate for the global information lost in `Focal`Distillation. + +Change the dataset and modify the data configuration and number of categories in the [TODO] configuration, referring to 4.1. Start training: + +```bash +# Single Card Training +export CUDA_VISIBLE_DEVICES=0 +python3 tools/train.py \ + -c configs/picodet/legacy_model/application/layout_analysis/picodet_lcnet_x1_0_layout.yml \ + --slim_config configs/picodet/legacy_model/application/layout_analysis/picodet_lcnet_x2_5_layout.yml \ + --eval +``` + +- `-c`: Specify the model configuration file. +- `--slim_config`: Specify the compression policy profile. + +## 5. Model evaluation and prediction + +### 5.1. Indicator evaluation + + Model parameters in training are saved by default in `output/picodet_ Lcnet_ X1_ 0_ Under the layout` directory. When evaluating indicators, you need to set `weights` to point to the saved parameter file.Assessment datasets can be accessed via `configs/picodet/legacy_ Model/application/layout_ Analysis/picodet_ Lcnet_ X1_ 0_ Layout. Yml` . Modify `EvalDataset` : `img_dir`,`anno_ Path`and`dataset_dir` setting. + +```bash +# GPU evaluation, weights as weights to be measured +python3 tools/eval.py \ + -c configs/picodet/legacy_model/application/layout_analysis/picodet_lcnet_x1_0_layout.yml \ + -o weights=./output/picodet_lcnet_x1_0_layout/best_model +``` + +The following information will be printed out, such as mAP, AP0.5, etc. + +```py + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.935 + Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.979 + Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.956 + Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.404 + Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.782 + Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.969 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.539 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.938 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.949 + Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.495 + Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.818 + Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.978 +[08/15 07:07:09] ppdet.engine INFO: Total sample number: 11245, averge FPS: 24.405059207157436 +[08/15 07:07:09] ppdet.engine INFO: Best test bbox ap is 0.935. +``` + +If you use the provided pre-training model for evaluation or the FGD distillation training model, replace the `weights` model path and execute the following command for evaluation: + +``` +python3 tools/eval.py \ + -c configs/picodet/legacy_model/application/layout_analysis/picodet_lcnet_x1_0_layout.yml \ + --slim_config configs/picodet/legacy_model/application/layout_analysis/picodet_lcnet_x2_5_layout.yml \ + -o weights=output/picodet_lcnet_x2_5_layout/best_model +``` + +- `-c`: Specify the model configuration file. +- `--slim_config`: Specify the distillation policy profile. +- `-o weights`: Specify the model path trained by the distillation algorithm. + +### 5.2. Test Layout Analysis Results + + +The profile predicted to be used must be consistent with the training, for example, if you pass `python3 tools/train'. Py-c configs/picodet/legacy_ Model/application/layout_ Analysis/picodet_ Lcnet_ X1_ 0_ Layout. Yml` completed the training process for the model. + +With trained PaddleDetection model, you can use the following commands to make model predictions. + +```bash +python3 tools/infer.py \ + -c configs/picodet/legacy_model/application/layout_analysis/picodet_lcnet_x1_0_layout.yml \ + -o weights='output/picodet_lcnet_x1_0_layout/best_model.pdparams' \ + --infer_img='docs/images/layout.jpg' \ + --output_dir=output_dir/ \ + --draw_threshold=0.5 +``` + +- `--infer_img`: Reasoning for a single picture can also be done via `--infer_ Dir`Inform all pictures in the file. +- `--output_dir`: Specify the path to save the visualization results. +- `--draw_threshold`:Specify the NMS threshold for drawing the result box. + +If you use the provided pre-training model for prediction or the FGD distillation training model, change the `weights` model path and execute the following command to make the prediction: + +``` +python3 tools/infer.py \ + -c configs/picodet/legacy_model/application/layout_analysis/picodet_lcnet_x1_0_layout.yml \ + --slim_config configs/picodet/legacy_model/application/layout_analysis/picodet_lcnet_x2_5_layout.yml \ + -o weights='output/picodet_lcnet_x2_5_layout/best_model.pdparams' \ + --infer_img='docs/images/layout.jpg' \ + --output_dir=output_dir/ \ + --draw_threshold=0.5 +``` + + +## 6. Model Export and Inference + + +### 6.1 Model Export + +The inference model (the model saved by `paddle.jit.save`) is generally a solidified model saved after the model training is completed, and is mostly used to give prediction in deployment. + +The model saved during the training process is the checkpoints model, which saves the parameters of the model and is mostly used to resume training. + +Compared with the checkpoints model, the inference model will additionally save the structural information of the model. Therefore, it is easier to deploy because the model structure and model parameters are already solidified in the inference model file, and is suitable for integration with actual systems. + +Layout analysis model to inference model steps are as follows: + +```bash +python3 tools/export_model.py \ + -c configs/picodet/legacy_model/application/layout_analysis/picodet_lcnet_x1_0_layout.yml \ + -o weights=output/picodet_lcnet_x1_0_layout/best_model \ + --output_dir=output_inference/ +``` + +* If no post-export processing is required, specify:`-o export.benchmark=True`(If -o already exists, delete -o here) +* If you do not need to export NMS, specify:`-o export.nms=False` + +After successful conversion, there are three files in the directory: + +``` +output_inference/picodet_lcnet_x1_0_layout/ + ├── model.pdiparams # inference Parameter file for model + ├── model.pdiparams.info # inference Model parameter information, ignorable + └── model.pdmodel # inference Model Structure File for Model +``` + +If you change the `weights` model path using the provided pre-training model to the Inference model, or using the FGD distillation training model, the model to inference model steps are as follows: + +```bash +python3 tools/export_model.py \ + -c configs/picodet/legacy_model/application/layout_analysis/picodet_lcnet_x1_0_layout.yml \ + --slim_config configs/picodet/legacy_model/application/layout_analysis/picodet_lcnet_x2_5_layout.yml \ + -o weights=./output/picodet_lcnet_x2_5_layout/best_model \ + --output_dir=output_inference/ +``` + +### 6.2 Model inference + +Replace model_with the provided inference training model for inference or the FGD distillation training `model_dir`Inference model path, execute the following commands for inference: + +```bash +python3 deploy/python/infer.py \ + --model_dir=output_inference/picodet_lcnet_x1_0_layout/ \ + --image_file=docs/images/layout.jpg \ + --device=CPU +``` + +- --device:Specify the GPU or CPU device + +When model inference is complete, you will see the following log output: + +``` +------------------------------------------ +----------- Model Configuration ----------- +Model Arch: PicoDet +Transform Order: +--transform op: Resize +--transform op: NormalizeImage +--transform op: Permute +--transform op: PadStride +-------------------------------------------- +class_id:0, confidence:0.9921, left_top:[20.18,35.66],right_bottom:[341.58,600.99] +class_id:0, confidence:0.9914, left_top:[19.77,611.42],right_bottom:[341.48,901.82] +class_id:0, confidence:0.9904, left_top:[369.36,375.10],right_bottom:[691.29,600.59] +class_id:0, confidence:0.9835, left_top:[369.60,608.60],right_bottom:[691.38,736.72] +class_id:0, confidence:0.9830, left_top:[369.58,805.38],right_bottom:[690.97,901.80] +class_id:0, confidence:0.9716, left_top:[383.68,271.44],right_bottom:[688.93,335.39] +class_id:0, confidence:0.9452, left_top:[370.82,34.48],right_bottom:[688.10,63.54] +class_id:1, confidence:0.8712, left_top:[370.84,771.03],right_bottom:[519.30,789.13] +class_id:3, confidence:0.9856, left_top:[371.28,67.85],right_bottom:[685.73,267.72] +save result to: output/layout.jpg +Test iter 0 +------------------ Inference Time Info ---------------------- +total_time(ms): 2196.0, img_num: 1 +average latency time(ms): 2196.00, QPS: 0.455373 +preprocess_time(ms): 2172.50, inference_time(ms): 11.90, postprocess_time(ms): 11.60 +``` + +- Model:model structure +- Transform Order:Preprocessing operation +- class_id, confidence, left_top, right_bottom:Indicates category id, confidence level, upper left coordinate, lower right coordinate, respectively +- save result to:Save path of visual layout analysis results, default save to ./output folder +- inference time info:Inference time, where preprocess_time represents the preprocessing time, Inference_time represents the model prediction time, and postprocess_time represents the post-processing time + +The result of visualization layout is shown in the following figure + +
+ +
+ + + +## Citations + +``` +@inproceedings{zhong2019publaynet, + title={PubLayNet: largest dataset ever for document layout analysis}, + author={Zhong, Xu and Tang, Jianbin and Yepes, Antonio Jimeno}, + booktitle={2019 International Conference on Document Analysis and Recognition (ICDAR)}, + year={2019}, + volume={}, + number={}, + pages={1015-1022}, + doi={10.1109/ICDAR.2019.00166}, + ISSN={1520-5363}, + month={Sep.}, + organization={IEEE} +} + +@inproceedings{yang2022focal, + title={Focal and global knowledge distillation for detectors}, + author={Yang, Zhendong and Li, Zhe and Jiang, Xiaohu and Gong, Yuan and Yuan, Zehuan and Zhao, Danpei and Yuan, Chun}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={4643--4652}, + year={2022} +} +``` diff --git a/ppstructure/layout/README_ch.md b/ppstructure/layout/README_ch.md index f8d1978e25d7fb17cfd3fcb363b4ce981e19c8dc..49c10c7e7726a35dadbc936e94c9ab5b55628e82 100644 --- a/ppstructure/layout/README_ch.md +++ b/ppstructure/layout/README_ch.md @@ -1,3 +1,7 @@ +简体中文 | [English](README.md) + +# 版面分析 + - [1. 简介](#1-简介) - [2. 安装](#2-安装) - [2.1 安装PaddlePaddle](#21-安装paddlepaddle) @@ -15,8 +19,6 @@ - [6.1 模型导出](#61-模型导出) - [6.2 模型推理](#62-模型推理) -# 版面分析 - ## 1. 简介 版面分析指的是对图片形式的文档进行区域划分,定位其中的关键区域,如文字、标题、表格、图片等。版面分析算法基于[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection)的轻量模型PP-PicoDet进行开发。 @@ -37,10 +39,10 @@ python3 -m pip install --upgrade pip # GPU安装 -python3 -m pip install "paddlepaddle-gpu>=2.2" -i https://mirror.baidu.com/pypi/simple +python3 -m pip install "paddlepaddle-gpu>=2.3" -i https://mirror.baidu.com/pypi/simple # CPU安装 -python3 -m pip install "paddlepaddle>=2.2" -i https://mirror.baidu.com/pypi/simple +python3 -m pip install "paddlepaddle>=2.3" -i https://mirror.baidu.com/pypi/simple ``` 更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。 diff --git a/ppstructure/recovery/README.md b/ppstructure/recovery/README.md index 59aef707dd67799bb46dc18dc58f883c502c8b86..b1eaae46df87499d11f196d02d17d0690ffd0f16 100644 --- a/ppstructure/recovery/README.md +++ b/ppstructure/recovery/README.md @@ -66,7 +66,7 @@ git clone https://gitee.com/paddlepaddle/PaddleOCR - **(2) Install recovery's `requirements`** -The layout restoration is exported as docx and PDF files, so python-docx and docx2pdf API need to be installed, and fitz and PyMuPDF apis need to be installed to process the input files in pdf format. +The layout restoration is exported as docx and PDF files, so python-docx and docx2pdf API need to be installed, and PyMuPDF api([requires Python >= 3.7](https://pypi.org/project/PyMuPDF/)) need to be installed to process the input files in pdf format. ```bash python3 -m pip install -r ppstructure/recovery/requirements.txt diff --git a/ppstructure/recovery/README_ch.md b/ppstructure/recovery/README_ch.md index ae3b7ed82464f513af585542ef8e92d66f2c8756..cd99f7f725f4a3d275ab920e9fe0125a74a995e5 100644 --- a/ppstructure/recovery/README_ch.md +++ b/ppstructure/recovery/README_ch.md @@ -68,7 +68,7 @@ git clone https://gitee.com/paddlepaddle/PaddleOCR - **(2)安装recovery的`requirements`** -版面恢复导出为docx、pdf文件,所以需要安装python-docx、docx2pdf API,同时处理pdf格式的输入文件,需要安装fitz、PyMuPDF API。 +版面恢复导出为docx、pdf文件,所以需要安装python-docx、docx2pdf API,同时处理pdf格式的输入文件,需要安装PyMuPDF API([要求Python >= 3.7](https://pypi.org/project/PyMuPDF/))。 ```bash python3 -m pip install -r ppstructure/recovery/requirements.txt diff --git a/ppstructure/recovery/requirements.txt b/ppstructure/recovery/requirements.txt index b118a41e516ec20e5807030649943e5f7d848107..25e8cdbb0d58b0a243b176f563c66717d6f4c112 100644 --- a/ppstructure/recovery/requirements.txt +++ b/ppstructure/recovery/requirements.txt @@ -1,5 +1,4 @@ python-docx docx2pdf -fitz -PyMuPDF==1.16.14 +PyMuPDF beautifulsoup4 \ No newline at end of file