diff --git a/configs/rec/rec_d28_can.yml b/configs/rec/rec_d28_can.yml
new file mode 100644
index 0000000000000000000000000000000000000000..7c3b0fd3d60368d196837826c252301fb5f3b59e
--- /dev/null
+++ b/configs/rec/rec_d28_can.yml
@@ -0,0 +1,122 @@
+Global:
+ use_gpu: True
+ epoch_num: 240
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/rec/can/
+ save_epoch_step: 1
+ # evaluation is run every 1105 iterations (1 epoch)(batch_size = 8)
+ eval_batch_step: [0, 1105]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/datasets/crohme_demo/hme_00.jpg
+ # for data or label process
+ character_dict_path: ppocr/utils/dict/latex_symbol_dict.txt
+ max_text_length: 36
+ infer_mode: False
+ use_space_char: False
+ save_res_path: ./output/rec/predicts_can.txt
+
+Optimizer:
+ name: Momentum
+ momentum: 0.9
+ clip_norm_global: 100.0
+ lr:
+ name: TwoStepCosine
+ learning_rate: 0.01
+ warmup_epoch: 1
+ weight_decay: 0.0001
+
+Architecture:
+ model_type: rec
+ algorithm: CAN
+ in_channels: 1
+ Transform:
+ Backbone:
+ name: DenseNet
+ growthRate: 24
+ reduction: 0.5
+ bottleneck: True
+ use_dropout: True
+ input_channel: 1
+ Head:
+ name: CANHead
+ in_channel: 684
+ out_channel: 111
+ max_text_length: 36
+ ratio: 16
+ attdecoder:
+ is_train: True
+ input_size: 256
+ hidden_size: 256
+ encoder_out_channel: 684
+ dropout: True
+ dropout_ratio: 0.5
+ word_num: 111
+ counting_decoder_out_channel: 111
+ attention:
+ attention_dim: 512
+ word_conv_kernel: 1
+
+Loss:
+ name: CANLoss
+
+PostProcess:
+ name: CANLabelDecode
+
+Metric:
+ name: CANMetric
+ main_indicator: exp_rate
+
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/CROHME/training/images/
+ label_file_list: ["./train_data/CROHME/training/labels.txt"]
+ transforms:
+ - DecodeImage:
+ channel_first: False
+ - NormalizeImage:
+ mean: [0,0,0]
+ std: [1,1,1]
+ order: 'hwc'
+ - GrayImageChannelFormat:
+ inverse: True
+ - CANLabelEncode:
+ lower: False
+ - KeepKeys:
+ keep_keys: ['image', 'label']
+ loader:
+ shuffle: True
+ batch_size_per_card: 8
+ drop_last: False
+ num_workers: 4
+ collate_fn: DyMaskCollator
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/CROHME/evaluation/images/
+ label_file_list: ["./train_data/CROHME/evaluation/labels.txt"]
+ transforms:
+ - DecodeImage:
+ channel_first: False
+ - NormalizeImage:
+ mean: [0,0,0]
+ std: [1,1,1]
+ order: 'hwc'
+ - GrayImageChannelFormat:
+ inverse: True
+ - CANLabelEncode:
+ lower: False
+ - KeepKeys:
+ keep_keys: ['image', 'label']
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 1
+ num_workers: 4
+ collate_fn: DyMaskCollator
diff --git a/doc/datasets/crohme_demo/hme_00.jpg b/doc/datasets/crohme_demo/hme_00.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..66ff27db266b5d4fa05d8acd95ba881bb8a1aec0
Binary files /dev/null and b/doc/datasets/crohme_demo/hme_00.jpg differ
diff --git a/doc/datasets/crohme_demo/hme_01.jpg b/doc/datasets/crohme_demo/hme_01.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..68b7f09fc2f330ee523ded27a14486b3c92763cb
Binary files /dev/null and b/doc/datasets/crohme_demo/hme_01.jpg differ
diff --git a/doc/datasets/crohme_demo/hme_02.jpg b/doc/datasets/crohme_demo/hme_02.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ecc760f5382bfe3d94de6141379f6a5a196e8430
Binary files /dev/null and b/doc/datasets/crohme_demo/hme_02.jpg differ
diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md
index 235763d8a85e74173cb0a244833732223d7fe887..44c1e117ec0cdea33f3c2b74286eb58eb83e67a3 100755
--- a/doc/doc_ch/algorithm_overview.md
+++ b/doc/doc_ch/algorithm_overview.md
@@ -102,7 +102,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
-|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
+|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [训练模型](https://paddleocr.bj.bcebos.com/VisionLAN/rec_r45_visionlan_train.tar) |
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) |
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl_att_train.tar) |
diff --git a/doc/doc_ch/algorithm_rec_can.md b/doc/doc_ch/algorithm_rec_can.md
new file mode 100644
index 0000000000000000000000000000000000000000..4f266cb33b800b446b88b507f3710d9c96db00a1
--- /dev/null
+++ b/doc/doc_ch/algorithm_rec_can.md
@@ -0,0 +1,174 @@
+# 手写数学公式识别算法-CAN
+
+- [1. 算法简介](#1)
+- [2. 环境配置](#2)
+- [3. 模型训练、评估、预测](#3)
+ - [3.1 训练](#3-1)
+ - [3.2 评估](#3-2)
+ - [3.3 预测](#3-3)
+- [4. 推理部署](#4)
+ - [4.1 Python推理](#4-1)
+ - [4.2 C++推理](#4-2)
+ - [4.3 Serving服务化部署](#4-3)
+ - [4.4 更多推理部署](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. 算法简介
+
+论文信息:
+> [When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition](https://arxiv.org/abs/2207.11463)
+> Bohan Li, Ye Yuan, Dingkang Liang, Xiao Liu, Zhilong Ji, Jinfeng Bai, Wenyu Liu, Xiang Bai
+> ECCV, 2022
+
+
+
+`CAN`使用CROHME手写公式数据集进行训练,在对应测试集上的精度如下:
+
+|模型 |骨干网络|配置文件|ExpRate|下载链接|
+| ----- | ----- | ----- | ----- | ----- |
+|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72|[训练模型](https://paddleocr.bj.bcebos.com/contribution/can_train.tar)|
+
+
+## 2. 环境配置
+请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
+
+
+
+## 3. 模型训练、评估、预测
+
+
+### 3.1 模型训练
+
+请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`CAN`识别模型时需要**更换配置文件**为`CAN`的[配置文件](../../configs/rec/rec_d28_can.yml)。
+
+#### 启动训练
+
+
+具体地,在完成数据准备后,便可以启动训练,训练命令如下:
+```shell
+#单卡训练(训练周期长,不建议)
+python3 tools/train.py -c configs/rec/rec_d28_can.yml
+
+#多卡训练,通过--gpus参数指定卡号
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_d28_can.yml
+```
+
+**注意:**
+- 我们提供的数据集,即[`CROHME数据集`](https://paddleocr.bj.bcebos.com/dataset/CROHME.tar)将手写公式存储为黑底白字的格式,若您自行准备的数据集与之相反,即以白底黑字模式存储,请在训练时做出如下修改
+```
+python3 tools/train.py -c configs/rec/rec_d28_can.yml
+-o Train.dataset.transforms.GrayImageChannelFormat.inverse=False
+```
+- 默认每训练1个epoch(1105次iteration)进行1次评估,若您更改训练的batch_size,或更换数据集,请在训练时作出如下修改
+```
+python3 tools/train.py -c configs/rec/rec_d28_can.yml
+-o Global.eval_batch_step=[0, {length_of_dataset//batch_size}]
+```
+
+#
+
+### 3.2 评估
+
+可下载已训练完成的[模型文件](https://paddleocr.bj.bcebos.com/contribution/can_train.tar),使用如下命令进行评估:
+
+```shell
+# 注意将pretrained_model的路径设置为本地路径。若使用自行训练保存的模型,请注意修改路径和文件名为{path/to/weights}/{model_name}。
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_d28_can.yml -o Global.pretrained_model=./rec_d28_can_train/CAN
+```
+
+
+### 3.3 预测
+
+使用如下命令进行单张图片预测:
+```shell
+# 注意将pretrained_model的路径设置为本地路径。
+python3 tools/infer_rec.py -c configs/rec/rec_d28_can.yml -o Architecture.Head.attdecoder.is_train=False Global.infer_img='./doc/datasets/crohme_demo/hme_00.jpg' Global.pretrained_model=./rec_d28_can_train/CAN
+
+# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/datasets/crohme_demo/'。
+```
+
+
+
+## 4. 推理部署
+
+
+### 4.1 Python推理
+首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/contribution/can_train.tar) ),可以使用如下命令进行转换:
+
+```shell
+# 注意将pretrained_model的路径设置为本地路径。
+python3 tools/export_model.py -c configs/rec/rec_d28_can.yml -o Global.pretrained_model=./rec_d28_can_train/CAN Global.save_inference_dir=./inference/rec_d28_can/ Architecture.Head.attdecoder.is_train=False
+
+# 目前的静态图模型默认的输出长度最大为36,如果您需要预测更长的序列,请在导出模型时指定其输出序列为合适的值,例如 Architecture.Head.max_text_length=72
+```
+**注意:**
+- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
+
+转换成功后,在目录下有三个文件:
+```
+/inference/rec_d28_can/
+ ├── inference.pdiparams # 识别inference模型的参数文件
+ ├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
+ └── inference.pdmodel # 识别inference模型的program文件
+```
+
+执行如下命令进行模型推理:
+
+```shell
+python3 tools/infer/predict_rec.py --image_dir="./doc/datasets/crohme_demo/hme_00.jpg" --rec_algorithm="CAN" --rec_batch_num=1 --rec_model_dir="./inference/rec_d28_can/" --rec_char_dict_path="./ppocr/utils/dict/latex_symbol_dict.txt"
+
+# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/datasets/crohme_demo/'。
+
+# 如果您需要在白底黑字的图片上进行预测,请设置 --rec_image_inverse=False
+```
+
+
+
+执行命令后,上面图像的预测结果(识别的文本)会打印到屏幕上,示例如下:
+```shell
+Predicts of ./doc/imgs_hme/hme_00.jpg:['x _ { k } x x _ { k } + y _ { k } y x _ { k }', []]
+```
+
+
+**注意**:
+
+- 需要注意预测图像为**黑底白字**,即手写公式部分为白色,背景为黑色的图片。
+- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
+- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中CAN的预处理为您的预处理方法。
+
+
+
+### 4.2 C++推理部署
+
+由于C++预处理后处理还未支持CAN,所以暂未支持
+
+
+### 4.3 Serving服务化部署
+
+暂不支持
+
+
+### 4.4 更多推理部署
+
+暂不支持
+
+
+## 5. FAQ
+
+1. CROHME数据集来自于[CAN源repo](https://github.com/LBH1024/CAN) 。
+
+## 引用
+
+```bibtex
+@misc{https://doi.org/10.48550/arxiv.2207.11463,
+ doi = {10.48550/ARXIV.2207.11463},
+ url = {https://arxiv.org/abs/2207.11463},
+ author = {Li, Bohan and Yuan, Ye and Liang, Dingkang and Liu, Xiao and Ji, Zhilong and Bai, Jinfeng and Liu, Wenyu and Bai, Xiang},
+ keywords = {Computer Vision and Pattern Recognition (cs.CV), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences},
+ title = {When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition},
+ publisher = {arXiv},
+ year = {2022},
+ copyright = {arXiv.org perpetual, non-exclusive license}
+}
+```
diff --git a/doc/doc_ch/algorithm_rec_visionlan.md b/doc/doc_ch/algorithm_rec_visionlan.md
index df039491d49e192349d57b44cc448c57e4211098..b4474c29f8596197fb536f07fa96b9926e5b20f4 100644
--- a/doc/doc_ch/algorithm_rec_visionlan.md
+++ b/doc/doc_ch/algorithm_rec_visionlan.md
@@ -27,7 +27,7 @@
|模型|骨干网络|配置文件|Acc|下载链接|
| --- | --- | --- | --- | --- |
-|VisionLAN|ResNet45|[rec_r45_visionlan.yml](../../configs/rec/rec_r45_visionlan.yml)|90.3%|[预训练、训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar)|
+|VisionLAN|ResNet45|[rec_r45_visionlan.yml](../../configs/rec/rec_r45_visionlan.yml)|90.3%|[预训练、训练模型](https://paddleocr.bj.bcebos.com/VisionLAN/rec_r45_visionlan_train.tar)|
## 2. 环境配置
@@ -80,7 +80,7 @@ python3 tools/infer_rec.py -c configs/rec/rec_r45_visionlan.yml -o Global.infer_
### 4.1 Python推理
-首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar)),可以使用如下命令进行转换:
+首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/VisionLAN/rec_r45_visionlan_train.tar)),可以使用如下命令进行转换:
```shell
# 注意将pretrained_model的路径设置为本地路径。
@@ -139,7 +139,7 @@ Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.9999493)
## 5. FAQ
1. MJSynth和SynthText两种数据集来自于[VisionLAN源repo](https://github.com/wangyuxin87/VisionLAN) 。
-2. 我们使用VisionLAN作者提供的预训练模型进行finetune训练。
+2. 我们使用VisionLAN作者提供的预训练模型进行finetune训练,预训练模型配套字典为'ppocr/utils/ic15_dict.txt'。
## 引用
diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md
index ff84b9a6822f72ff1cd225892c9febcab8a2ae75..2614226e001b84d7316c9497de1a74bd548a64f6 100755
--- a/doc/doc_en/algorithm_overview_en.md
+++ b/doc/doc_en/algorithm_overview_en.md
@@ -99,7 +99,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
-|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
+|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [trained model](https://paddleocr.bj.bcebos.com/VisionLAN/rec_r45_visionlan_train.tar) |
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) |
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl_att_train.tar) |
diff --git a/doc/doc_en/algorithm_rec_can_en.md b/doc/doc_en/algorithm_rec_can_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..da6c9c6096fa7170b108012165b7c69862671e1a
--- /dev/null
+++ b/doc/doc_en/algorithm_rec_can_en.md
@@ -0,0 +1,119 @@
+# CAN
+
+- [1. Introduction](#1)
+- [2. Environment](#2)
+- [3. Model Training / Evaluation / Prediction](#3)
+ - [3.1 Training](#3-1)
+ - [3.2 Evaluation](#3-2)
+ - [3.3 Prediction](#3-3)
+- [4. Inference and Deployment](#4)
+ - [4.1 Python Inference](#4-1)
+ - [4.2 C++ Inference](#4-2)
+ - [4.3 Serving](#4-3)
+ - [4.4 More](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. Introduction
+
+Paper:
+> [When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition](https://arxiv.org/abs/2207.11463)
+> Bohan Li, Ye Yuan, Dingkang Liang, Xiao Liu, Zhilong Ji, Jinfeng Bai, Wenyu Liu, Xiang Bai
+> ECCV, 2022
+
+Using CROHME handwrittem mathematical expression recognition datasets for training, and evaluating on its test sets, the algorithm reproduction effect is as follows:
+
+|Model|Backbone|config|exprate|Download link|
+| --- | --- | --- | --- | --- |
+|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72|[trained model](https://paddleocr.bj.bcebos.com/contribution/can_train.tar)|
+
+
+## 2. Environment
+Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
+
+
+
+## 3. Model Training / Evaluation / Prediction
+
+Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
+
+Training:
+
+Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
+
+```
+#Single GPU training (long training period, not recommended)
+python3 tools/train.py -c configs/rec/rec_d28_can.yml
+
+#Multi GPU training, specify the gpu number through the --gpus parameter
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_d28_can.yml
+```
+
+Evaluation:
+
+```
+# GPU evaluation
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_d28_can.yml -o Global.pretrained_model=./rec_d28_can_train/CAN
+```
+
+Prediction:
+
+```
+# The configuration file used for prediction must match the training
+python3 tools/infer_rec.py -c configs/rec/rec_d28_can.yml -o Architecture.Head.attdecoder.is_train=False Global.infer_img='./doc/crohme_demo/hme_00.jpg' Global.pretrained_model=./rec_d28_can_train/CAN
+```
+
+
+## 4. Inference and Deployment
+
+
+### 4.1 Python Inference
+First, the model saved during the CAN handwritten mathematical expression recognition training process is converted into an inference model. you can use the following command to convert:
+
+```
+python3 tools/export_model.py -c configs/rec/rec_d28_can.yml -o Global.save_inference_dir=./inference/rec_d28_can/ Architecture.Head.attdecoder.is_train=False
+
+# The default output max length of the model is 36. If you need to predict a longer sequence, please specify its output sequence as an appropriate value when exporting the model, as: Architecture.Head.max_ text_ length=72
+```
+
+For CAN handwritten mathematical expression recognition model inference, the following commands can be executed:
+
+```
+python3 tools/infer/predict_rec.py --image_dir="./doc/crohme_demo/hme_00.jpg" --rec_algorithm="CAN" --rec_batch_num=1 --rec_model_dir="./inference/rec_d28_can/" --rec_char_dict_path="./ppocr/utils/dict/latex_symbol_dict.txt"
+
+# If you need to predict on a picture with black characters on a white background, please set: -- rec_ image_ inverse=False
+```
+
+
+### 4.2 C++ Inference
+
+Not supported
+
+
+### 4.3 Serving
+
+Not supported
+
+
+### 4.4 More
+
+Not supported
+
+
+## 5. FAQ
+
+
+## Citation
+
+```bibtex
+@misc{https://doi.org/10.48550/arxiv.2207.11463,
+ doi = {10.48550/ARXIV.2207.11463},
+ url = {https://arxiv.org/abs/2207.11463},
+ author = {Li, Bohan and Yuan, Ye and Liang, Dingkang and Liu, Xiao and Ji, Zhilong and Bai, Jinfeng and Liu, Wenyu and Bai, Xiang},
+ keywords = {Computer Vision and Pattern Recognition (cs.CV), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences},
+ title = {When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition},
+ publisher = {arXiv},
+ year = {2022},
+ copyright = {arXiv.org perpetual, non-exclusive license}
+}
+```
diff --git a/doc/doc_en/algorithm_rec_visionlan_en.md b/doc/doc_en/algorithm_rec_visionlan_en.md
index 70c2ccc470af0a03485d9d234e86e384c087617f..f67aa3c622d706a387075b37bd9e493740574cdd 100644
--- a/doc/doc_en/algorithm_rec_visionlan_en.md
+++ b/doc/doc_en/algorithm_rec_visionlan_en.md
@@ -25,7 +25,7 @@ Using MJSynth and SynthText two text recognition datasets for training, and eval
|Model|Backbone|config|Acc|Download link|
| --- | --- | --- | --- | --- |
-|VisionLAN|ResNet45|[rec_r45_visionlan.yml](../../configs/rec/rec_r45_visionlan.yml)|90.3%|[预训练、训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar)|
+|VisionLAN|ResNet45|[rec_r45_visionlan.yml](../../configs/rec/rec_r45_visionlan.yml)|90.3%|[预训练、训练模型](https://paddleocr.bj.bcebos.com/VisionLAN/rec_r45_visionlan_train.tar)|
## 2. Environment
@@ -68,7 +68,7 @@ python3 tools/infer_rec.py -c configs/rec/rec_r45_visionlan.yml -o Global.infer_
### 4.1 Python Inference
-First, the model saved during the VisionLAN text recognition training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar)) ), you can use the following command to convert:
+First, the model saved during the VisionLAN text recognition training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/VisionLAN/rec_r45_visionlan_train.tar)) ), you can use the following command to convert:
```
python3 tools/export_model.py -c configs/rec/rec_r45_visionlan.yml -o Global.pretrained_model=./rec_r45_visionlan_train/best_accuracy Global.save_inference_dir=./inference/rec_r45_visionlan/
@@ -120,7 +120,7 @@ Not supported
## 5. FAQ
1. Note that the MJSynth and SynthText datasets come from [VisionLAN repo](https://github.com/wangyuxin87/VisionLAN).
-2. We use the pre-trained model provided by the VisionLAN authors for finetune training.
+2. We use the pre-trained model provided by the VisionLAN authors for finetune training. The dictionary for the pre-trained model is 'ppocr/utils/ic15_dict.txt'.
## Citation
diff --git a/paddleocr.py b/paddleocr.py
index 6b4de93e9b7ce6b34b610d9bec2de31ab56dda6f..44308a823ed8edca0e979fd8d83414cec337ab9b 100644
--- a/paddleocr.py
+++ b/paddleocr.py
@@ -663,6 +663,16 @@ def main():
if not flag_gif and not flag_pdf:
img = cv2.imread(img_path)
+ if args.recovery and args.use_pdf2docx_api and flag_pdf:
+ from pdf2docx.converter import Converter
+ docx_file = os.path.join(args.output,
+ '{}.docx'.format(img_name))
+ cv = Converter(img_path)
+ cv.convert(docx_file)
+ cv.close()
+ logger.info('docx save to {}'.format(docx_file))
+ continue
+
if not flag_pdf:
if img is None:
logger.error("error in loading image:{}".format(img_path))
diff --git a/ppocr/data/collate_fn.py b/ppocr/data/collate_fn.py
index 0da6060f042a0e60cdf211d8bc13aede32d5930a..067b2158aca183c68c3a09999483c059bb10eb14 100644
--- a/ppocr/data/collate_fn.py
+++ b/ppocr/data/collate_fn.py
@@ -70,3 +70,49 @@ class SSLRotateCollate(object):
def __call__(self, batch):
output = [np.concatenate(d, axis=0) for d in zip(*batch)]
return output
+
+
+class DyMaskCollator(object):
+ """
+ batch: [
+ image [batch_size, channel, maxHinbatch, maxWinbatch]
+ image_mask [batch_size, channel, maxHinbatch, maxWinbatch]
+ label [batch_size, maxLabelLen]
+ label_mask [batch_size, maxLabelLen]
+ ...
+ ]
+ """
+
+ def __call__(self, batch):
+ max_width, max_height, max_length = 0, 0, 0
+ bs, channel = len(batch), batch[0][0].shape[0]
+ proper_items = []
+ for item in batch:
+ if item[0].shape[1] * max_width > 1600 * 320 or item[0].shape[
+ 2] * max_height > 1600 * 320:
+ continue
+ max_height = item[0].shape[1] if item[0].shape[
+ 1] > max_height else max_height
+ max_width = item[0].shape[2] if item[0].shape[
+ 2] > max_width else max_width
+ max_length = len(item[1]) if len(item[
+ 1]) > max_length else max_length
+ proper_items.append(item)
+
+ images, image_masks = np.zeros(
+ (len(proper_items), channel, max_height, max_width),
+ dtype='float32'), np.zeros(
+ (len(proper_items), 1, max_height, max_width), dtype='float32')
+ labels, label_masks = np.zeros(
+ (len(proper_items), max_length), dtype='int64'), np.zeros(
+ (len(proper_items), max_length), dtype='int64')
+
+ for i in range(len(proper_items)):
+ _, h, w = proper_items[i][0].shape
+ images[i][:, :h, :w] = proper_items[i][0]
+ image_masks[i][:, :h, :w] = 1
+ l = len(proper_items[i][1])
+ labels[i][:l] = proper_items[i][1]
+ label_masks[i][:l] = 1
+
+ return images, image_masks, labels, label_masks
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index 2a2ac2decd1abf4daf2c5325a8f69fc26f4fc0ef..63c5d6aa7851422e21a567dfe938c417793ca7ea 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -1400,8 +1400,6 @@ class VLLabelEncode(BaseRecLabelEncode):
**kwargs):
super(VLLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char, lower)
- self.character = self.character[10:] + self.character[
- 1:10] + [self.character[0]]
self.dict = {}
for i, char in enumerate(self.character):
self.dict[char] = i
@@ -1476,4 +1474,33 @@ class CTLabelEncode(object):
data['polys'] = boxes
data['texts'] = txts
- return data
\ No newline at end of file
+ return data
+
+
+class CANLabelEncode(BaseRecLabelEncode):
+ def __init__(self,
+ character_dict_path,
+ max_text_length=100,
+ use_space_char=False,
+ lower=True,
+ **kwargs):
+ super(CANLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char, lower)
+
+ def encode(self, text_seq):
+ text_seq_encoded = []
+ for text in text_seq:
+ if text not in self.character:
+ continue
+ text_seq_encoded.append(self.dict.get(text))
+ if len(text_seq_encoded) == 0:
+ return None
+ return text_seq_encoded
+
+ def __call__(self, data):
+ label = data['label']
+ if isinstance(label, str):
+ label = label.strip().split()
+ label.append(self.end_str)
+ data['label'] = self.encode(label)
+ return data
diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py
index 5e84b1aac9c54d8a8283468af6826ca917ba0384..4ff2d29ed32df906c42b28f97a81b20f716cb0fd 100644
--- a/ppocr/data/imaug/operators.py
+++ b/ppocr/data/imaug/operators.py
@@ -498,3 +498,27 @@ class ResizeNormalize(object):
img_numpy = np.array(img).astype("float32")
img_numpy = img_numpy.transpose((2, 0, 1)) / 255
return img_numpy
+
+
+class GrayImageChannelFormat(object):
+ """
+ format gray scale image's channel: (3,h,w) -> (1,h,w)
+ Args:
+ inverse: inverse gray image
+ """
+
+ def __init__(self, inverse=False, **kwargs):
+ self.inverse = inverse
+
+ def __call__(self, data):
+ img = data['image']
+ img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ img_expanded = np.expand_dims(img_single_channel, 0)
+
+ if self.inverse:
+ data['image'] = np.abs(img_expanded - 1)
+ else:
+ data['image'] = img_expanded
+
+ data['src_image'] = img
+ return data
\ No newline at end of file
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
old mode 100755
new mode 100644
index cfa9d5fad5811dd3e9f77c28f4c58a95cd15afe8..c7142e3e5e73e25764dde4631a47be939905e3be
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -40,6 +40,7 @@ from .rec_multi_loss import MultiLoss
from .rec_vl_loss import VLLoss
from .rec_spin_att_loss import SPINAttentionLoss
from .rec_rfl_loss import RFLLoss
+from .rec_can_loss import CANLoss
# cls loss
from .cls_loss import ClsLoss
@@ -72,7 +73,7 @@ def build_loss(config):
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
- 'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'TelescopeLoss'
+ 'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss', 'TelescopeLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/rec_can_loss.py b/ppocr/losses/rec_can_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..227e17f5e1ef1ff398b112b19dfd05b0b1fb7ab1
--- /dev/null
+++ b/ppocr/losses/rec_can_loss.py
@@ -0,0 +1,79 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/LBH1024/CAN/models/can.py
+"""
+
+import paddle
+import paddle.nn as nn
+import numpy as np
+
+
+class CANLoss(nn.Layer):
+ '''
+ CANLoss is consist of two part:
+ word_average_loss: average accuracy of the symbol
+ counting_loss: counting loss of every symbol
+ '''
+
+ def __init__(self):
+ super(CANLoss, self).__init__()
+
+ self.use_label_mask = False
+ self.out_channel = 111
+ self.cross = nn.CrossEntropyLoss(
+ reduction='none') if self.use_label_mask else nn.CrossEntropyLoss()
+ self.counting_loss = nn.SmoothL1Loss(reduction='mean')
+ self.ratio = 16
+
+ def forward(self, preds, batch):
+ word_probs = preds[0]
+ counting_preds = preds[1]
+ counting_preds1 = preds[2]
+ counting_preds2 = preds[3]
+ labels = batch[2]
+ labels_mask = batch[3]
+ counting_labels = gen_counting_label(labels, self.out_channel, True)
+ counting_loss = self.counting_loss(counting_preds1, counting_labels) + self.counting_loss(counting_preds2, counting_labels) \
+ + self.counting_loss(counting_preds, counting_labels)
+
+ word_loss = self.cross(
+ paddle.reshape(word_probs, [-1, word_probs.shape[-1]]),
+ paddle.reshape(labels, [-1]))
+ word_average_loss = paddle.sum(
+ paddle.reshape(word_loss * labels_mask, [-1])) / (
+ paddle.sum(labels_mask) + 1e-10
+ ) if self.use_label_mask else word_loss
+ loss = word_average_loss + counting_loss
+ return {'loss': loss}
+
+
+def gen_counting_label(labels, channel, tag):
+ b, t = labels.shape
+ counting_labels = np.zeros([b, channel])
+
+ if tag:
+ ignore = [0, 1, 107, 108, 109, 110]
+ else:
+ ignore = []
+ for i in range(b):
+ for j in range(t):
+ k = labels[i][j]
+ if k in ignore:
+ continue
+ else:
+ counting_labels[i][k] += 1
+ counting_labels = paddle.to_tensor(counting_labels, dtype='float32')
+ return counting_labels
diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py
index 20aea8b5995a49306d427bc427048c9df8d0923d..5e840a194adc2683e92c308f232dc869df34de8e 100644
--- a/ppocr/metrics/__init__.py
+++ b/ppocr/metrics/__init__.py
@@ -22,7 +22,7 @@ import copy
__all__ = ["build_metric"]
from .det_metric import DetMetric, DetFCEMetric
-from .rec_metric import RecMetric, CNTMetric
+from .rec_metric import RecMetric, CNTMetric, CANMetric
from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric
@@ -38,7 +38,7 @@ def build_metric(config):
support_dict = [
"DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric",
"DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
- 'VQAReTokenMetric', 'SRMetric', 'CTMetric', 'CNTMetric'
+ 'VQAReTokenMetric', 'SRMetric', 'CTMetric', 'CNTMetric', 'CANMetric'
]
config = copy.deepcopy(config)
diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py
index 4758e71d0930261044841a6a820308a04391fc0b..305b913c72da5842b6654f1fc9b27e6e2b46b436 100644
--- a/ppocr/metrics/rec_metric.py
+++ b/ppocr/metrics/rec_metric.py
@@ -13,6 +13,9 @@
# limitations under the License.
from rapidfuzz.distance import Levenshtein
+from difflib import SequenceMatcher
+
+import numpy as np
import string
@@ -106,3 +109,71 @@ class CNTMetric(object):
def reset(self):
self.correct_num = 0
self.all_num = 0
+
+
+class CANMetric(object):
+ def __init__(self, main_indicator='exp_rate', **kwargs):
+ self.main_indicator = main_indicator
+ self.word_right = []
+ self.exp_right = []
+ self.word_total_length = 0
+ self.exp_total_num = 0
+ self.word_rate = 0
+ self.exp_rate = 0
+ self.reset()
+ self.epoch_reset()
+
+ def __call__(self, preds, batch, **kwargs):
+ for k, v in kwargs.items():
+ epoch_reset = v
+ if epoch_reset:
+ self.epoch_reset()
+ word_probs = preds
+ word_label, word_label_mask = batch
+ line_right = 0
+ if word_probs is not None:
+ word_pred = word_probs.argmax(2)
+ word_pred = word_pred.cpu().detach().numpy()
+ word_scores = [
+ SequenceMatcher(
+ None,
+ s1[:int(np.sum(s3))],
+ s2[:int(np.sum(s3))],
+ autojunk=False).ratio() * (
+ len(s1[:int(np.sum(s3))]) + len(s2[:int(np.sum(s3))])) /
+ len(s1[:int(np.sum(s3))]) / 2
+ for s1, s2, s3 in zip(word_label, word_pred, word_label_mask)
+ ]
+ batch_size = len(word_scores)
+ for i in range(batch_size):
+ if word_scores[i] == 1:
+ line_right += 1
+ self.word_rate = np.mean(word_scores) #float
+ self.exp_rate = line_right / batch_size #float
+ exp_length, word_length = word_label.shape[:2]
+ self.word_right.append(self.word_rate * word_length)
+ self.exp_right.append(self.exp_rate * exp_length)
+ self.word_total_length = self.word_total_length + word_length
+ self.exp_total_num = self.exp_total_num + exp_length
+
+ def get_metric(self):
+ """
+ return {
+ 'word_rate': 0,
+ "exp_rate": 0,
+ }
+ """
+ cur_word_rate = sum(self.word_right) / self.word_total_length
+ cur_exp_rate = sum(self.exp_right) / self.exp_total_num
+ self.reset()
+ return {'word_rate': cur_word_rate, "exp_rate": cur_exp_rate}
+
+ def reset(self):
+ self.word_rate = 0
+ self.exp_rate = 0
+
+ def epoch_reset(self):
+ self.word_right = []
+ self.exp_right = []
+ self.word_total_length = 0
+ self.exp_total_num = 0
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index 84892fa9c7fd61838e984690b17931f367ab0585..e2c2e9c4a4ed526b36d512d824ae8a8a701c17bc 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -43,10 +43,12 @@ def build_backbone(config, model_type):
from .rec_svtrnet import SVTRNet
from .rec_vitstr import ViTSTR
from .rec_resnet_rfl import ResNetRFL
+ from .rec_densenet import DenseNet
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
- 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL'
+ 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL',
+ 'DenseNet'
]
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
diff --git a/ppocr/modeling/backbones/rec_densenet.py b/ppocr/modeling/backbones/rec_densenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..65c5fa4f245f9825ce8c728db487e8888b5bc3c6
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_densenet.py
@@ -0,0 +1,146 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/LBH1024/CAN/models/densenet.py
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class Bottleneck(nn.Layer):
+ def __init__(self, nChannels, growthRate, use_dropout):
+ super(Bottleneck, self).__init__()
+ interChannels = 4 * growthRate
+ self.bn1 = nn.BatchNorm2D(interChannels)
+ self.conv1 = nn.Conv2D(
+ nChannels, interChannels, kernel_size=1,
+ bias_attr=None) # Xavier initialization
+ self.bn2 = nn.BatchNorm2D(growthRate)
+ self.conv2 = nn.Conv2D(
+ interChannels, growthRate, kernel_size=3, padding=1,
+ bias_attr=None) # Xavier initialization
+ self.use_dropout = use_dropout
+ self.dropout = nn.Dropout(p=0.2)
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ if self.use_dropout:
+ out = self.dropout(out)
+ out = F.relu(self.bn2(self.conv2(out)))
+ if self.use_dropout:
+ out = self.dropout(out)
+ out = paddle.concat([x, out], 1)
+ return out
+
+
+class SingleLayer(nn.Layer):
+ def __init__(self, nChannels, growthRate, use_dropout):
+ super(SingleLayer, self).__init__()
+ self.bn1 = nn.BatchNorm2D(nChannels)
+ self.conv1 = nn.Conv2D(
+ nChannels, growthRate, kernel_size=3, padding=1, bias_attr=False)
+
+ self.use_dropout = use_dropout
+ self.dropout = nn.Dropout(p=0.2)
+
+ def forward(self, x):
+ out = self.conv1(F.relu(x))
+ if self.use_dropout:
+ out = self.dropout(out)
+
+ out = paddle.concat([x, out], 1)
+ return out
+
+
+class Transition(nn.Layer):
+ def __init__(self, nChannels, out_channels, use_dropout):
+ super(Transition, self).__init__()
+ self.bn1 = nn.BatchNorm2D(out_channels)
+ self.conv1 = nn.Conv2D(
+ nChannels, out_channels, kernel_size=1, bias_attr=False)
+ self.use_dropout = use_dropout
+ self.dropout = nn.Dropout(p=0.2)
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ if self.use_dropout:
+ out = self.dropout(out)
+ out = F.avg_pool2d(out, 2, ceil_mode=True, exclusive=False)
+ return out
+
+
+class DenseNet(nn.Layer):
+ def __init__(self, growthRate, reduction, bottleneck, use_dropout,
+ input_channel, **kwargs):
+ super(DenseNet, self).__init__()
+
+ nDenseBlocks = 16
+ nChannels = 2 * growthRate
+
+ self.conv1 = nn.Conv2D(
+ input_channel,
+ nChannels,
+ kernel_size=7,
+ padding=3,
+ stride=2,
+ bias_attr=False)
+ self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks,
+ bottleneck, use_dropout)
+ nChannels += nDenseBlocks * growthRate
+ out_channels = int(math.floor(nChannels * reduction))
+ self.trans1 = Transition(nChannels, out_channels, use_dropout)
+
+ nChannels = out_channels
+ self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks,
+ bottleneck, use_dropout)
+ nChannels += nDenseBlocks * growthRate
+ out_channels = int(math.floor(nChannels * reduction))
+ self.trans2 = Transition(nChannels, out_channels, use_dropout)
+
+ nChannels = out_channels
+ self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks,
+ bottleneck, use_dropout)
+ self.out_channels = out_channels
+
+ def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck,
+ use_dropout):
+ layers = []
+ for i in range(int(nDenseBlocks)):
+ if bottleneck:
+ layers.append(Bottleneck(nChannels, growthRate, use_dropout))
+ else:
+ layers.append(SingleLayer(nChannels, growthRate, use_dropout))
+ nChannels += growthRate
+ return nn.Sequential(*layers)
+
+ def forward(self, inputs):
+ x, x_m, y = inputs
+ out = self.conv1(x)
+ out = F.relu(out)
+ out = F.max_pool2d(out, 2, ceil_mode=True)
+ out = self.dense1(out)
+ out = self.trans1(out)
+ out = self.dense2(out)
+ out = self.trans2(out)
+ out = self.dense3(out)
+ return out, x_m, y
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index 63002140c5be4bd7e32b56995c6410ecc8a0fa36..65afaf84f4453f2d4199371576ac71bb93a1e6d5 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -24,7 +24,6 @@ def build_head(config):
from .det_fce_head import FCEHead
from .e2e_pg_head import PGHead
from .det_ct_head import CT_Head
- from .det_drrg_head import DRRGHead
# rec head
from .rec_ctc_head import CTCHead
@@ -40,6 +39,7 @@ def build_head(config):
from .rec_robustscanner_head import RobustScannerHead
from .rec_visionlan_head import VLHead
from .rec_rfl_head import RFLHead
+ from .rec_can_head import CANHead
# cls head
from .cls_head import ClsHead
@@ -56,9 +56,13 @@ def build_head(config):
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead',
- 'DRRGHead'
+ 'DRRGHead', 'CANHead'
]
+ if config['name'] == 'DRRGHead':
+ from .det_drrg_head import DRRGHead
+ support_dict.append('DRRGHead')
+
#table head
module_name = config.pop('name')
diff --git a/ppocr/modeling/heads/rec_can_head.py b/ppocr/modeling/heads/rec_can_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..732dbfe2db080b5e5da6c4656d7bc9de92bbc6e0
--- /dev/null
+++ b/ppocr/modeling/heads/rec_can_head.py
@@ -0,0 +1,319 @@
+# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/LBH1024/CAN/models/can.py
+https://github.com/LBH1024/CAN/models/counting.py
+https://github.com/LBH1024/CAN/models/decoder.py
+https://github.com/LBH1024/CAN/models/attention.py
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle.nn as nn
+import paddle
+import math
+'''
+Counting Module
+'''
+
+
+class ChannelAtt(nn.Layer):
+ def __init__(self, channel, reduction):
+ super(ChannelAtt, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2D(1)
+
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction),
+ nn.ReLU(), nn.Linear(channel // reduction, channel), nn.Sigmoid())
+
+ def forward(self, x):
+ b, c, _, _ = x.shape
+ y = paddle.reshape(self.avg_pool(x), [b, c])
+ y = paddle.reshape(self.fc(y), [b, c, 1, 1])
+ return x * y
+
+
+class CountingDecoder(nn.Layer):
+ def __init__(self, in_channel, out_channel, kernel_size):
+ super(CountingDecoder, self).__init__()
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+
+ self.trans_layer = nn.Sequential(
+ nn.Conv2D(
+ self.in_channel,
+ 512,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ bias_attr=False),
+ nn.BatchNorm2D(512))
+
+ self.channel_att = ChannelAtt(512, 16)
+
+ self.pred_layer = nn.Sequential(
+ nn.Conv2D(
+ 512, self.out_channel, kernel_size=1, bias_attr=False),
+ nn.Sigmoid())
+
+ def forward(self, x, mask):
+ b, _, h, w = x.shape
+ x = self.trans_layer(x)
+ x = self.channel_att(x)
+ x = self.pred_layer(x)
+
+ if mask is not None:
+ x = x * mask
+ x = paddle.reshape(x, [b, self.out_channel, -1])
+ x1 = paddle.sum(x, axis=-1)
+
+ return x1, paddle.reshape(x, [b, self.out_channel, h, w])
+
+
+'''
+Attention Decoder
+'''
+
+
+class PositionEmbeddingSine(nn.Layer):
+ def __init__(self,
+ num_pos_feats=64,
+ temperature=10000,
+ normalize=False,
+ scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, x, mask):
+ y_embed = paddle.cumsum(mask, 1, dtype='float32')
+ x_embed = paddle.cumsum(mask, 2, dtype='float32')
+
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+ dim_t = paddle.arange(self.num_pos_feats, dtype='float32')
+ dim_d = paddle.expand(paddle.to_tensor(2), dim_t.shape)
+ dim_t = self.temperature**(2 * (dim_t / dim_d).astype('int64') /
+ self.num_pos_feats)
+
+ pos_x = paddle.unsqueeze(x_embed, [3]) / dim_t
+ pos_y = paddle.unsqueeze(y_embed, [3]) / dim_t
+
+ pos_x = paddle.flatten(
+ paddle.stack(
+ [
+ paddle.sin(pos_x[:, :, :, 0::2]),
+ paddle.cos(pos_x[:, :, :, 1::2])
+ ],
+ axis=4),
+ 3)
+ pos_y = paddle.flatten(
+ paddle.stack(
+ [
+ paddle.sin(pos_y[:, :, :, 0::2]),
+ paddle.cos(pos_y[:, :, :, 1::2])
+ ],
+ axis=4),
+ 3)
+
+ pos = paddle.transpose(
+ paddle.concat(
+ [pos_y, pos_x], axis=3), [0, 3, 1, 2])
+
+ return pos
+
+
+class AttDecoder(nn.Layer):
+ def __init__(self, ratio, is_train, input_size, hidden_size,
+ encoder_out_channel, dropout, dropout_ratio, word_num,
+ counting_decoder_out_channel, attention):
+ super(AttDecoder, self).__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.out_channel = encoder_out_channel
+ self.attention_dim = attention['attention_dim']
+ self.dropout_prob = dropout
+ self.ratio = ratio
+ self.word_num = word_num
+
+ self.counting_num = counting_decoder_out_channel
+ self.is_train = is_train
+
+ self.init_weight = nn.Linear(self.out_channel, self.hidden_size)
+ self.embedding = nn.Embedding(self.word_num, self.input_size)
+ self.word_input_gru = nn.GRUCell(self.input_size, self.hidden_size)
+ self.word_attention = Attention(hidden_size, attention['attention_dim'])
+
+ self.encoder_feature_conv = nn.Conv2D(
+ self.out_channel,
+ self.attention_dim,
+ kernel_size=attention['word_conv_kernel'],
+ padding=attention['word_conv_kernel'] // 2)
+
+ self.word_state_weight = nn.Linear(self.hidden_size, self.hidden_size)
+ self.word_embedding_weight = nn.Linear(self.input_size,
+ self.hidden_size)
+ self.word_context_weight = nn.Linear(self.out_channel, self.hidden_size)
+ self.counting_context_weight = nn.Linear(self.counting_num,
+ self.hidden_size)
+ self.word_convert = nn.Linear(self.hidden_size, self.word_num)
+
+ if dropout:
+ self.dropout = nn.Dropout(dropout_ratio)
+
+ def forward(self, cnn_features, labels, counting_preds, images_mask):
+ if self.is_train:
+ _, num_steps = labels.shape
+ else:
+ num_steps = 36
+
+ batch_size, _, height, width = cnn_features.shape
+ images_mask = images_mask[:, :, ::self.ratio, ::self.ratio]
+
+ word_probs = paddle.zeros((batch_size, num_steps, self.word_num))
+ word_alpha_sum = paddle.zeros((batch_size, 1, height, width))
+
+ hidden = self.init_hidden(cnn_features, images_mask)
+ counting_context_weighted = self.counting_context_weight(counting_preds)
+ cnn_features_trans = self.encoder_feature_conv(cnn_features)
+
+ position_embedding = PositionEmbeddingSine(256, normalize=True)
+ pos = position_embedding(cnn_features_trans, images_mask[:, 0, :, :])
+
+ cnn_features_trans = cnn_features_trans + pos
+
+ word = paddle.ones([batch_size, 1], dtype='int64') # init word as sos
+ word = word.squeeze(axis=1)
+ for i in range(num_steps):
+ word_embedding = self.embedding(word)
+ _, hidden = self.word_input_gru(word_embedding, hidden)
+ word_context_vec, _, word_alpha_sum = self.word_attention(
+ cnn_features, cnn_features_trans, hidden, word_alpha_sum,
+ images_mask)
+
+ current_state = self.word_state_weight(hidden)
+ word_weighted_embedding = self.word_embedding_weight(word_embedding)
+ word_context_weighted = self.word_context_weight(word_context_vec)
+
+ if self.dropout_prob:
+ word_out_state = self.dropout(
+ current_state + word_weighted_embedding +
+ word_context_weighted + counting_context_weighted)
+ else:
+ word_out_state = current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted
+
+ word_prob = self.word_convert(word_out_state)
+ word_probs[:, i] = word_prob
+
+ if self.is_train:
+ word = labels[:, i]
+ else:
+ word = word_prob.argmax(1)
+ word = paddle.multiply(
+ word, labels[:, i]
+ ) # labels are oneslike tensor in infer/predict mode
+
+ return word_probs
+
+ def init_hidden(self, features, feature_mask):
+ average = paddle.sum(paddle.sum(features * feature_mask, axis=-1),
+ axis=-1) / paddle.sum(
+ (paddle.sum(feature_mask, axis=-1)), axis=-1)
+ average = self.init_weight(average)
+ return paddle.tanh(average)
+
+
+'''
+Attention Module
+'''
+
+
+class Attention(nn.Layer):
+ def __init__(self, hidden_size, attention_dim):
+ super(Attention, self).__init__()
+ self.hidden = hidden_size
+ self.attention_dim = attention_dim
+ self.hidden_weight = nn.Linear(self.hidden, self.attention_dim)
+ self.attention_conv = nn.Conv2D(
+ 1, 512, kernel_size=11, padding=5, bias_attr=False)
+ self.attention_weight = nn.Linear(
+ 512, self.attention_dim, bias_attr=False)
+ self.alpha_convert = nn.Linear(self.attention_dim, 1)
+
+ def forward(self,
+ cnn_features,
+ cnn_features_trans,
+ hidden,
+ alpha_sum,
+ image_mask=None):
+ query = self.hidden_weight(hidden)
+ alpha_sum_trans = self.attention_conv(alpha_sum)
+ coverage_alpha = self.attention_weight(
+ paddle.transpose(alpha_sum_trans, [0, 2, 3, 1]))
+ alpha_score = paddle.tanh(
+ paddle.unsqueeze(query, [1, 2]) + coverage_alpha + paddle.transpose(
+ cnn_features_trans, [0, 2, 3, 1]))
+ energy = self.alpha_convert(alpha_score)
+ energy = energy - energy.max()
+ energy_exp = paddle.exp(paddle.squeeze(energy, -1))
+
+ if image_mask is not None:
+ energy_exp = energy_exp * paddle.squeeze(image_mask, 1)
+ alpha = energy_exp / (paddle.unsqueeze(
+ paddle.sum(paddle.sum(energy_exp, -1), -1), [1, 2]) + 1e-10)
+ alpha_sum = paddle.unsqueeze(alpha, 1) + alpha_sum
+ context_vector = paddle.sum(
+ paddle.sum((paddle.unsqueeze(alpha, 1) * cnn_features), -1), -1)
+
+ return context_vector, alpha, alpha_sum
+
+
+class CANHead(nn.Layer):
+ def __init__(self, in_channel, out_channel, ratio, attdecoder, **kwargs):
+ super(CANHead, self).__init__()
+
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+
+ self.counting_decoder1 = CountingDecoder(self.in_channel,
+ self.out_channel, 3) # mscm
+ self.counting_decoder2 = CountingDecoder(self.in_channel,
+ self.out_channel, 5)
+
+ self.decoder = AttDecoder(ratio, **attdecoder)
+
+ self.ratio = ratio
+
+ def forward(self, inputs, targets=None):
+ cnn_features, images_mask, labels = inputs
+
+ counting_mask = images_mask[:, :, ::self.ratio, ::self.ratio]
+ counting_preds1, _ = self.counting_decoder1(cnn_features, counting_mask)
+ counting_preds2, _ = self.counting_decoder2(cnn_features, counting_mask)
+ counting_preds = (counting_preds1 + counting_preds2) / 2
+
+ word_probs = self.decoder(cnn_features, labels, counting_preds,
+ images_mask)
+ return word_probs, counting_preds, counting_preds1, counting_preds2
diff --git a/ppocr/optimizer/learning_rate.py b/ppocr/optimizer/learning_rate.py
index 7d45109b4857871f52764c64d6d32e5322fc7c57..be52a918458d64f0ae15b52ebf511e5068184f59 100644
--- a/ppocr/optimizer/learning_rate.py
+++ b/ppocr/optimizer/learning_rate.py
@@ -18,7 +18,7 @@ from __future__ import print_function
from __future__ import unicode_literals
from paddle.optimizer import lr
-from .lr_scheduler import CyclicalCosineDecay, OneCycleDecay
+from .lr_scheduler import CyclicalCosineDecay, OneCycleDecay, TwoStepCosineDecay
class Linear(object):
@@ -386,3 +386,44 @@ class MultiStepDecay(object):
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
return learning_rate
+
+
+class TwoStepCosine(object):
+ """
+ Cosine learning rate decay
+ lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)
+ Args:
+ lr(float): initial learning rate
+ step_each_epoch(int): steps each epoch
+ epochs(int): total training epochs
+ last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
+ """
+
+ def __init__(self,
+ learning_rate,
+ step_each_epoch,
+ epochs,
+ warmup_epoch=0,
+ last_epoch=-1,
+ **kwargs):
+ super(TwoStepCosine, self).__init__()
+ self.learning_rate = learning_rate
+ self.T_max1 = step_each_epoch * 200
+ self.T_max2 = step_each_epoch * epochs
+ self.last_epoch = last_epoch
+ self.warmup_epoch = round(warmup_epoch * step_each_epoch)
+
+ def __call__(self):
+ learning_rate = TwoStepCosineDecay(
+ learning_rate=self.learning_rate,
+ T_max1=self.T_max1,
+ T_max2=self.T_max2,
+ last_epoch=self.last_epoch)
+ if self.warmup_epoch > 0:
+ learning_rate = lr.LinearWarmup(
+ learning_rate=learning_rate,
+ warmup_steps=self.warmup_epoch,
+ start_lr=0.0,
+ end_lr=self.learning_rate,
+ last_epoch=self.last_epoch)
+ return learning_rate
diff --git a/ppocr/optimizer/lr_scheduler.py b/ppocr/optimizer/lr_scheduler.py
index f62f1f3b0adbd8df0e03a66faa4565f2f7df28bc..cd09367e2ab8a649e3c375698f5b182eb5c3ff7a 100644
--- a/ppocr/optimizer/lr_scheduler.py
+++ b/ppocr/optimizer/lr_scheduler.py
@@ -160,3 +160,63 @@ class OneCycleDecay(LRScheduler):
start_step = phase['end_step']
return computed_lr
+
+
+class TwoStepCosineDecay(LRScheduler):
+ def __init__(self,
+ learning_rate,
+ T_max1,
+ T_max2,
+ eta_min=0,
+ last_epoch=-1,
+ verbose=False):
+ if not isinstance(T_max1, int):
+ raise TypeError(
+ "The type of 'T_max1' in 'CosineAnnealingDecay' must be 'int', but received %s."
+ % type(T_max1))
+ if not isinstance(T_max2, int):
+ raise TypeError(
+ "The type of 'T_max2' in 'CosineAnnealingDecay' must be 'int', but received %s."
+ % type(T_max2))
+ if not isinstance(eta_min, (float, int)):
+ raise TypeError(
+ "The type of 'eta_min' in 'CosineAnnealingDecay' must be 'float, int', but received %s."
+ % type(eta_min))
+ assert T_max1 > 0 and isinstance(
+ T_max1, int), " 'T_max1' must be a positive integer."
+ assert T_max2 > 0 and isinstance(
+ T_max2, int), " 'T_max1' must be a positive integer."
+ self.T_max1 = T_max1
+ self.T_max2 = T_max2
+ self.eta_min = float(eta_min)
+ super(TwoStepCosineDecay, self).__init__(learning_rate, last_epoch,
+ verbose)
+
+ def get_lr(self):
+
+ if self.last_epoch <= self.T_max1:
+ if self.last_epoch == 0:
+ return self.base_lr
+ elif (self.last_epoch - 1 - self.T_max1) % (2 * self.T_max1) == 0:
+ return self.last_lr + (self.base_lr - self.eta_min) * (
+ 1 - math.cos(math.pi / self.T_max1)) / 2
+
+ return (1 + math.cos(math.pi * self.last_epoch / self.T_max1)) / (
+ 1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max1)) * (
+ self.last_lr - self.eta_min) + self.eta_min
+ else:
+ if (self.last_epoch - 1 - self.T_max2) % (2 * self.T_max2) == 0:
+ return self.last_lr + (self.base_lr - self.eta_min) * (
+ 1 - math.cos(math.pi / self.T_max2)) / 2
+
+ return (1 + math.cos(math.pi * self.last_epoch / self.T_max2)) / (
+ 1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max2)) * (
+ self.last_lr - self.eta_min) + self.eta_min
+
+ def _get_closed_form_lr(self):
+ if self.last_epoch <= self.T_max1:
+ return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
+ math.pi * self.last_epoch / self.T_max1)) / 2
+ else:
+ return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
+ math.pi * self.last_epoch / self.T_max2)) / 2
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index 3a09030b25461029d9160699dc591eaedab9e0db..36a3152f2f2d68ed0884bd415844d209d850f5ca 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -37,6 +37,7 @@ from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
from .picodet_postprocess import PicoDetPostProcess
from .ct_postprocess import CTPostProcess
from .drrg_postprocess import DRRGPostprocess
+from .rec_postprocess import CANLabelDecode
def build_post_process(config, global_config=None):
@@ -51,7 +52,7 @@ def build_post_process(config, global_config=None):
'TableMasterLabelDecode', 'SPINLabelDecode',
'DistillationSerPostProcess', 'DistillationRePostProcess',
'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess',
- 'RFLLabelDecode', 'DRRGPostprocess'
+ 'RFLLabelDecode', 'DRRGPostprocess', 'CANLabelDecode'
]
if config['name'] == 'PSEPostProcess':
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index 59b5254e480e7c52aca4ce648c379a280683db4f..fbf8b93e3d11121c99ce5b2dcbf2149e15453d4a 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -26,6 +26,7 @@ class BaseRecLabelDecode(object):
self.end_str = "eos"
self.reverse = False
self.character_str = []
+
if character_dict_path is None:
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
@@ -805,8 +806,6 @@ class VLLabelDecode(BaseRecLabelDecode):
super(VLLabelDecode, self).__init__(character_dict_path, use_space_char)
self.max_text_length = kwargs.get('max_text_length', 25)
self.nclass = len(self.character) + 1
- self.character = self.character[10:] + self.character[
- 1:10] + [self.character[0]]
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
@@ -897,3 +896,36 @@ class VLLabelDecode(BaseRecLabelDecode):
return text
label = self.decode(label)
return text, label
+
+
+class CANLabelDecode(BaseRecLabelDecode):
+ """ Convert between latex-symbol and symbol-index """
+
+ def __init__(self, character_dict_path=None, use_space_char=False,
+ **kwargs):
+ super(CANLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def decode(self, text_index, preds_prob=None):
+ result_list = []
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ seq_end = text_index[batch_idx].argmin(0)
+ idx_list = text_index[batch_idx][:seq_end].tolist()
+ symbol_list = [self.character[idx] for idx in idx_list]
+ probs = []
+ if preds_prob is not None:
+ probs = preds_prob[batch_idx][:len(symbol_list)].tolist()
+
+ result_list.append([' '.join(symbol_list), probs])
+ return result_list
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ pred_prob, _, _, _ = preds
+ preds_idx = pred_prob.argmax(axis=2)
+
+ text = self.decode(preds_idx)
+ if label is None:
+ return text
+ label = self.decode(label)
+ return text, label
diff --git a/ppocr/utils/dict/latex_symbol_dict.txt b/ppocr/utils/dict/latex_symbol_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b43f1fa8b904e3107eb450f6d7332aec6b5b81e2
--- /dev/null
+++ b/ppocr/utils/dict/latex_symbol_dict.txt
@@ -0,0 +1,111 @@
+eos
+sos
+!
+'
+(
+)
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+<
+=
+>
+A
+B
+C
+E
+F
+G
+H
+I
+L
+M
+N
+P
+R
+S
+T
+V
+X
+Y
+[
+\Delta
+\alpha
+\beta
+\cdot
+\cdots
+\cos
+\div
+\exists
+\forall
+\frac
+\gamma
+\geq
+\in
+\infty
+\int
+\lambda
+\ldots
+\leq
+\lim
+\log
+\mu
+\neq
+\phi
+\pi
+\pm
+\prime
+\rightarrow
+\sigma
+\sin
+\sqrt
+\sum
+\tan
+\theta
+\times
+]
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+\{
+|
+\}
+{
+}
+^
+_
\ No newline at end of file
diff --git a/ppstructure/docs/quickstart.md b/ppstructure/docs/quickstart.md
index 74a3ff1aeee83622e6b3f1937c31f13896fda039..6fbd31c3c19b9d5bb8d6045efaac76628c18a3d9 100644
--- a/ppstructure/docs/quickstart.md
+++ b/ppstructure/docs/quickstart.md
@@ -97,6 +97,19 @@ paddleocr --image_dir=ppstructure/docs/table/table.jpg --type=structure --layout
#### 2.1.6 版面恢复
+版面恢复分为2种方法,详细介绍请参考:[版面恢复教程](../recovery/README_ch.md):
+
+- PDF解析
+- OCR技术
+
+通过PDF解析(只支持pdf格式的输入):
+
+```bash
+paddleocr --image_dir=ppstructure/recovery/UnrealText.pdf --type=structure --recovery=true --use_pdf2docx_api=true
+```
+
+通过OCR技术:
+
```bash
# 中文测试图
paddleocr --image_dir=ppstructure/docs/table/1.png --type=structure --recovery=true
diff --git a/ppstructure/docs/quickstart_en.md b/ppstructure/docs/quickstart_en.md
index e6b1419cbf2a58aca3567e174f30341a26d88634..446f9d2ee387a169cbfeb067de9d1a0aa0ff7584 100644
--- a/ppstructure/docs/quickstart_en.md
+++ b/ppstructure/docs/quickstart_en.md
@@ -98,7 +98,21 @@ Key information extraction does not currently support use by the whl package. Fo
#### 2.1.6 layout recovery
+
+Two layout recovery methods are provided, For detailed usage tutorials, please refer to: [Layout Recovery](../recovery/README.md).
+
+- PDF parse
+- OCR
+
+Recovery by using PDF parse (only support pdf as input):
+
+```bash
+paddleocr --image_dir=ppstructure/recovery/UnrealText.pdf --type=structure --recovery=true --use_pdf2docx_api=true
```
+
+Recovery by using OCR:
+
+```bash
paddleocr --image_dir=ppstructure/docs/table/1.png --type=structure --recovery=true --lang='en'
```
diff --git a/ppstructure/kie/requirements.txt b/ppstructure/kie/requirements.txt
index 11fa98da1bff7a1863d8a077ca73435d15072523..6cfcba764190fd46f98b76c27e93db6f4fa36c45 100644
--- a/ppstructure/kie/requirements.txt
+++ b/ppstructure/kie/requirements.txt
@@ -4,4 +4,4 @@ seqeval
pypandoc
attrdict
python_docx
-https://paddleocr.bj.bcebos.com/ppstructure/whl/paddlenlp-2.3.0.dev0-py3-none-any.whl
+paddlenlp>=2.4.1
diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py
index 417002d1ef58471268071f96868617a4c9c52056..bb061c998f6f8b16c06f9ee94299af0f59c53eb2 100644
--- a/ppstructure/predict_system.py
+++ b/ppstructure/predict_system.py
@@ -216,16 +216,26 @@ def main(args):
image_file_list = image_file_list
image_file_list = image_file_list[args.process_id::args.total_process_num]
- structure_sys = StructureSystem(args)
+ if not args.use_pdf2docx_api:
+ structure_sys = StructureSystem(args)
+ save_folder = os.path.join(args.output, structure_sys.mode)
+ os.makedirs(save_folder, exist_ok=True)
img_num = len(image_file_list)
- save_folder = os.path.join(args.output, structure_sys.mode)
- os.makedirs(save_folder, exist_ok=True)
for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag_gif, flag_pdf = check_and_read(image_file)
img_name = os.path.basename(image_file).split('.')[0]
+ if args.recovery and args.use_pdf2docx_api and flag_pdf:
+ from pdf2docx.converter import Converter
+ docx_file = os.path.join(args.output, '{}.docx'.format(img_name))
+ cv = Converter(image_file)
+ cv.convert(docx_file)
+ cv.close()
+ logger.info('docx save to {}'.format(docx_file))
+ continue
+
if not flag_gif and not flag_pdf:
img = cv2.imread(image_file)
diff --git a/ppstructure/recovery/README.md b/ppstructure/recovery/README.md
index 0e06c65475b67bcdfc119069fa6f6076322c0e99..46a348c8e5d4cf3e43c4287ee5b37030426c1524 100644
--- a/ppstructure/recovery/README.md
+++ b/ppstructure/recovery/README.md
@@ -6,18 +6,39 @@ English | [简体中文](README_ch.md)
- [2. Install](#2)
- [2.1 Install PaddlePaddle](#2.1)
- [2.2 Install PaddleOCR](#2.2)
-- [3. Quick Start](#3)
- - [3.1 Download models](#3.1)
- - [3.2 Layout recovery](#3.2)
-- [4. More](#4)
+- [3. Quick Start using standard PDF parse](#3)
+- [4. Quick Start using image format PDF parse ](#4)
+ - [4.1 Download models](#4.1)
+ - [4.2 Layout recovery](#4.2)
+- [5. More](#5)
## 1. Introduction
-Layout recovery means that after OCR recognition, the content is still arranged like the original document pictures, and the paragraphs are output to word document in the same order.
+The layout recovery module is used to restore the image or pdf to an
+editable Word file consistent with the original image layout.
-Layout recovery combines [layout analysis](../layout/README.md)、[table recognition](../table/README.md) to better recover images, tables, titles, etc. supports input files in PDF and document image formats in Chinese and English. The following figure shows the effect of restoring the layout of English and Chinese documents:
+Two layout recovery methods are provided, you can choose by PDF format:
+
+- **Standard PDF parse(the input is standard PDF)**: Python based PDF to word library [pdf2docx] (https://github.com/dothinking/pdf2docx) is optimized, the method extracts data from PDF with PyMuPDF, then parse layout with rule, finally, generate docx with python-docx.
+
+- **Image format PDF parse(the input can be standard PDF or image format PDF)**: Layout recovery combines [layout analysis](../layout/README.md)、[table recognition](../table/README.md) to better recover images, tables, titles, etc. supports input files in PDF and document image formats in Chinese and English.
+
+The input formats and application scenarios of the two methods are as follows:
+
+| method | input formats | application scenarios/problem |
+| :-----: | :----------: | :----------------------------------------------------------: |
+| Standard PDF parse | pdf | Advantages: Better recovery for non-paper documents, each page remains on the same page after restoration
Disadvantages: English characters in some Chinese documents are garbled, some contents are still beyond the current page, the whole page content is restored to the table format, and the recovery effect of some pictures is not good |
+| Image format PDF parse( | pdf、picture | Advantages: More suitable for paper document content recovery, OCR recognition effect is more good
Disadvantages: Currently, the recovery is based on rules, the effect of content typesetting (spacing, fonts, etc.) need to be further improved, and the effect of layout recovery depends on layout analysis |
+
+The following figure shows the effect of restoring the layout of documents by using PDF parse:
+
+