diff --git a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml index df429314cd0ec058aa6779a0ff55656f1b211bbf..acf438950a43af3356c7ab0aadf956fdf226814e 100644 --- a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml +++ b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml @@ -14,6 +14,9 @@ Global: use_visualdl: False infer_img: doc/imgs_en/img_10.jpg save_res_path: ./output/det_db/predicts_db.txt + use_amp: False + amp_level: O2 + amp_custom_black_list: ['exp'] Architecture: name: DistillationModel diff --git a/configs/rec/rec_r31_robustscanner.yml b/configs/rec/rec_r31_robustscanner.yml new file mode 100644 index 0000000000000000000000000000000000000000..40d39aee3c42c18085ace035944dba057b923245 --- /dev/null +++ b/configs/rec/rec_r31_robustscanner.yml @@ -0,0 +1,109 @@ +Global: + use_gpu: true + epoch_num: 5 + log_smooth_window: 20 + print_batch_step: 20 + save_model_dir: ./output/rec/rec_r31_robustscanner/ + save_epoch_step: 1 + # evaluation is run every 2000 iterations + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: ./inference/rec_inference + # for data or label process + character_dict_path: ppocr/utils/dict90.txt + max_text_length: &max_text_length 40 + infer_mode: False + use_space_char: False + rm_symbol: True + save_res_path: ./output/rec/predicts_robustscanner.txt + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Piecewise + decay_epochs: [3, 4] + values: [0.001, 0.0001, 0.00001] + regularizer: + name: 'L2' + factor: 0 + +Architecture: + model_type: rec + algorithm: RobustScanner + Transform: + Backbone: + name: ResNet31 + init_type: KaimingNormal + Head: + name: RobustScannerHead + enc_outchannles: 128 + hybrid_dec_rnn_layers: 2 + hybrid_dec_dropout: 0 + position_dec_rnn_layers: 2 + start_idx: 91 + mask: True + padding_idx: 92 + encode_value: False + max_text_length: *max_text_length + +Loss: + name: SARLoss + +PostProcess: + name: SARLabelDecode + +Metric: + name: RecMetric + is_filter: True + + +Train: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/training/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SARLabelEncode: # Class handling label + - RobustScannerRecResizeImg: + image_shape: [3, 48, 48, 160] # h:48 w:[48,160] + width_downsample_ratio: 0.25 + max_text_length: *max_text_length + - KeepKeys: + keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 64 + drop_last: True + num_workers: 8 + use_shared_memory: False + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/evaluation/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SARLabelEncode: # Class handling label + - RobustScannerRecResizeImg: + image_shape: [3, 48, 48, 160] + max_text_length: *max_text_length + width_downsample_ratio: 0.25 + - KeepKeys: + keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 64 + num_workers: 4 + use_shared_memory: False + diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index b96c1a3b2a2dcd4917a7e9d369eda8a1ad118463..b889d0b8ffbc190664b278a50ac867f1e14cbb7d 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -72,6 +72,7 @@ - [x] [ABINet](./algorithm_rec_abinet.md) - [x] [VisionLAN](./algorithm_rec_visionlan.md) - [x] [SPIN](./algorithm_rec_spin.md) +- [x] [RobustScanner](./algorithm_rec_robustscanner.md) 参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: @@ -94,6 +95,7 @@ |ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) | |VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) | |SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon | +|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon | diff --git a/doc/doc_ch/algorithm_rec_robustscanner.md b/doc/doc_ch/algorithm_rec_robustscanner.md new file mode 100644 index 0000000000000000000000000000000000000000..869f9a7c00b617de87ab3c96326e18e536bc18a8 --- /dev/null +++ b/doc/doc_ch/algorithm_rec_robustscanner.md @@ -0,0 +1,113 @@ +# RobustScanner + +- [1. 算法简介](#1) +- [2. 环境配置](#2) +- [3. 模型训练、评估、预测](#3) + - [3.1 训练](#3-1) + - [3.2 评估](#3-2) + - [3.3 预测](#3-3) +- [4. 推理部署](#4) + - [4.1 Python推理](#4-1) + - [4.2 C++推理](#4-2) + - [4.3 Serving服务化部署](#4-3) + - [4.4 更多推理部署](#4-4) +- [5. FAQ](#5) + + +## 1. 算法简介 + +论文信息: +> [RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition](https://arxiv.org/pdf/2007.07542.pdf) +> Xiaoyu Yue, Zhanghui Kuang, Chenhao Lin, Hongbin Sun, Wayne +Zhang +> ECCV, 2020 + +使用MJSynth和SynthText两个合成文字识别数据集训练,在IIIT, SVT, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下: + +|模型|骨干网络|配置文件|Acc|下载链接| +| --- | --- | --- | --- | --- | +|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon| + +注:除了使用MJSynth和SynthText两个文字识别数据集外,还加入了[SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg)数据(提取码:627x),和部分真实数据,具体数据细节可以参考论文。 + + +## 2. 环境配置 +请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。 + + + +## 3. 模型训练、评估、预测 + +请参考[文本识别教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。 + +训练 + +具体地,在完成数据准备后,便可以启动训练,训练命令如下: + +``` +#单卡训练(训练周期长,不建议) +python3 tools/train.py -c configs/rec/rec_r31_robustscanner.yml + +#多卡训练,通过--gpus参数指定卡号 +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r31_robustscanner.yml +``` + +评估 + +``` +# GPU 评估, Global.pretrained_model 为待测权重 +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy +``` + +预测: + +``` +# 预测使用的配置文件必须与训练一致 +python3 tools/infer_rec.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png +``` + + +## 4. 推理部署 + + +### 4.1 Python推理 +首先将RobustScanner文本识别训练过程中保存的模型,转换成inference model。可以使用如下命令进行转换: + +``` +python3 tools/export_model.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r31_robustscanner +``` +RobustScanner文本识别模型推理,可以执行如下命令: + +``` +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r31_robustscanner/" --rec_image_shape="3, 48, 48, 160" --rec_algorithm="RobustScanner" --rec_char_dict_path="ppocr/utils/dict90.txt" --use_space_char=False +``` + + +### 4.2 C++推理 + +由于C++预处理后处理还未支持RobustScanner,所以暂未支持 + + +### 4.3 Serving服务化部署 + +暂不支持 + + +### 4.4 更多推理部署 + +暂不支持 + + +## 5. FAQ + + +## 引用 + +```bibtex +@article{2020RobustScanner, + title={RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition}, + author={Xiaoyu Yue and Zhanghui Kuang and Chenhao Lin and Hongbin Sun and Wayne Zhang}, + journal={ECCV2020}, + year={2020}, +} +``` diff --git a/doc/doc_ch/algorithm_rec_sar.md b/doc/doc_ch/algorithm_rec_sar.md index b8304313994754480a89d708e39149d67f828c0d..cfb1de25390bda8c6ba4be1db9101269873e8b5b 100644 --- a/doc/doc_ch/algorithm_rec_sar.md +++ b/doc/doc_ch/algorithm_rec_sar.md @@ -79,7 +79,7 @@ python3 tools/export_model.py -c configs/rec/rec_r31_sar.yml -o Global.pretraine SAR文本识别模型推理,可以执行如下命令: ``` -python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_sar/" --rec_image_shape="3, 48, 48, 160" --rec_char_type="ch" --rec_algorithm="SAR" --rec_char_dict_path="ppocr/utils/dict90.txt" --max_text_length=30 --use_space_char=False +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_sar/" --rec_image_shape="3, 48, 48, 160" --rec_algorithm="SAR" --rec_char_dict_path="ppocr/utils/dict90.txt" --max_text_length=30 --use_space_char=False ``` diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index dfd8ecda5c306aeb41902caccc2b6079f4f86542..3412ccbf76f6c04b61420a6abd91a55efb383db6 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -70,6 +70,7 @@ Supported text recognition algorithms (Click the link to get the tutorial): - [x] [ABINet](./algorithm_rec_abinet_en.md) - [x] [VisionLAN](./algorithm_rec_visionlan_en.md) - [x] [SPIN](./algorithm_rec_spin_en.md) +- [x] [RobustScanner](./algorithm_rec_robustscanner_en.md) Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow: @@ -92,6 +93,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) | |VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) | |SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon | +|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon | diff --git a/doc/doc_en/algorithm_rec_robustscanner_en.md b/doc/doc_en/algorithm_rec_robustscanner_en.md new file mode 100644 index 0000000000000000000000000000000000000000..a324a6d547a9e448566276234c750ad4497abf9c --- /dev/null +++ b/doc/doc_en/algorithm_rec_robustscanner_en.md @@ -0,0 +1,114 @@ +# RobustScanner + +- [1. Introduction](#1) +- [2. Environment](#2) +- [3. Model Training / Evaluation / Prediction](#3) + - [3.1 Training](#3-1) + - [3.2 Evaluation](#3-2) + - [3.3 Prediction](#3-3) +- [4. Inference and Deployment](#4) + - [4.1 Python Inference](#4-1) + - [4.2 C++ Inference](#4-2) + - [4.3 Serving](#4-3) + - [4.4 More](#4-4) +- [5. FAQ](#5) + + +## 1. Introduction + +Paper: +> [RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition](https://arxiv.org/pdf/2007.07542.pdf) +> Xiaoyu Yue, Zhanghui Kuang, Chenhao Lin, Hongbin Sun, Wayne +Zhang +> ECCV, 2020 + +Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC13, IC15, SVTP, CUTE datasets, the algorithm reproduction effect is as follows: + +|Model|Backbone|config|Acc|Download link| +| --- | --- | --- | --- | --- | +|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon| + +Note:In addition to using the two text recognition datasets MJSynth and SynthText, [SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg) data (extraction code: 627x), and some real data are used in training, the specific data details can refer to the paper. + + +## 2. Environment +Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code. + + + +## 3. Model Training / Evaluation / Prediction + +Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**. + +Training: + +Specifically, after the data preparation is completed, the training can be started. The training command is as follows: + +``` +#Single GPU training (long training period, not recommended) +python3 tools/train.py -c configs/rec/rec_r31_robustscanner.yml + +#Multi GPU training, specify the gpu number through the --gpus parameter +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r31_robustscanner.yml +``` + +Evaluation: + +``` +# GPU evaluation +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy +``` + +Prediction: + +``` +# The configuration file used for prediction must match the training +python3 tools/infer_rec.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png +``` + + +## 4. Inference and Deployment + + +### 4.1 Python Inference +First, the model saved during the RobustScanner text recognition training process is converted into an inference model. you can use the following command to convert: + +``` +python3 tools/export_model.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r31_robustscanner +``` + +For RobustScanner text recognition model inference, the following commands can be executed: + +``` +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r31_robustscanner/" --rec_image_shape="3, 48, 48, 160" --rec_algorithm="RobustScanner" --rec_char_dict_path="ppocr/utils/dict90.txt" --use_space_char=False +``` + + +### 4.2 C++ Inference + +Not supported + + +### 4.3 Serving + +Not supported + + +### 4.4 More + +Not supported + + +## 5. FAQ + + +## Citation + +```bibtex +@article{2020RobustScanner, + title={RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition}, + author={Xiaoyu Yue and Zhanghui Kuang and Chenhao Lin and Hongbin Sun and Wayne Zhang}, + journal={ECCV2020}, + year={2020}, +} +``` diff --git a/doc/doc_en/algorithm_rec_sar_en.md b/doc/doc_en/algorithm_rec_sar_en.md index 24b87c10c3b2839909392bf3de0e0c850112fcdc..5c8319da3bc63dce55b0d5eae749ed4500b9d2f6 100644 --- a/doc/doc_en/algorithm_rec_sar_en.md +++ b/doc/doc_en/algorithm_rec_sar_en.md @@ -79,7 +79,7 @@ python3 tools/export_model.py -c configs/rec/rec_r31_sar.yml -o Global.pretraine For SAR text recognition model inference, the following commands can be executed: ``` -python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_sar/" --rec_image_shape="3, 48, 48, 160" --rec_char_type="ch" --rec_algorithm="SAR" --rec_char_dict_path="ppocr/utils/dict90.txt" --max_text_length=30 --use_space_char=False +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_sar/" --rec_image_shape="3, 48, 48, 160" --rec_algorithm="SAR" --rec_char_dict_path="ppocr/utils/dict90.txt" --max_text_length=30 --use_space_char=False ``` diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index a2332b6c07be63ecfe2fa9003cbe9d0c1b0e8001..102f48fcc19e59d9f8ffb0ad496f54cc64864f7d 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -26,8 +26,7 @@ from .make_pse_gt import MakePseGt from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \ - ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg - + ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg from .ssl_img_aug import SSLRotateResize from .randaugment import RandAugment from .copy_paste import CopyPaste diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 725b4b0617c2f0808c7bf99077e2f62caa3afbf0..a5e0de8496559a40d42641a043848d5d43c98de1 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -414,6 +414,23 @@ class SVTRRecResizeImg(object): data['valid_ratio'] = valid_ratio return data +class RobustScannerRecResizeImg(object): + def __init__(self, image_shape, max_text_length, width_downsample_ratio=0.25, **kwargs): + self.image_shape = image_shape + self.width_downsample_ratio = width_downsample_ratio + self.max_text_length = max_text_length + + def __call__(self, data): + img = data['image'] + norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar( + img, self.image_shape, self.width_downsample_ratio) + word_positons = np.array(range(0, self.max_text_length)).astype('int64') + data['image'] = norm_img + data['resized_shape'] = resize_shape + data['pad_shape'] = pad_shape + data['valid_ratio'] = valid_ratio + data['word_positons'] = word_positons + return data def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): imgC, imgH, imgW_min, imgW_max = image_shape diff --git a/ppocr/modeling/backbones/rec_resnet_31.py b/ppocr/modeling/backbones/rec_resnet_31.py index 965170138d00a53fca720b3b5f535a3dd34272d9..46dc374008b56a20dbd4be257775368e9cbbace4 100644 --- a/ppocr/modeling/backbones/rec_resnet_31.py +++ b/ppocr/modeling/backbones/rec_resnet_31.py @@ -29,27 +29,29 @@ import numpy as np __all__ = ["ResNet31"] - -def conv3x3(in_channel, out_channel, stride=1): +def conv3x3(in_channel, out_channel, stride=1, conv_weight_attr=None): return nn.Conv2D( in_channel, out_channel, kernel_size=3, stride=stride, padding=1, + weight_attr=conv_weight_attr, bias_attr=False) class BasicBlock(nn.Layer): expansion = 1 - def __init__(self, in_channels, channels, stride=1, downsample=False): + def __init__(self, in_channels, channels, stride=1, downsample=False, conv_weight_attr=None, bn_weight_attr=None): super().__init__() - self.conv1 = conv3x3(in_channels, channels, stride) - self.bn1 = nn.BatchNorm2D(channels) + self.conv1 = conv3x3(in_channels, channels, stride, + conv_weight_attr=conv_weight_attr) + self.bn1 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr) self.relu = nn.ReLU() - self.conv2 = conv3x3(channels, channels) - self.bn2 = nn.BatchNorm2D(channels) + self.conv2 = conv3x3(channels, channels, + conv_weight_attr=conv_weight_attr) + self.bn2 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr) self.downsample = downsample if downsample: self.downsample = nn.Sequential( @@ -58,8 +60,9 @@ class BasicBlock(nn.Layer): channels * self.expansion, 1, stride, + weight_attr=conv_weight_attr, bias_attr=False), - nn.BatchNorm2D(channels * self.expansion), ) + nn.BatchNorm2D(channels * self.expansion, weight_attr=bn_weight_attr)) else: self.downsample = nn.Sequential() self.stride = stride @@ -91,6 +94,7 @@ class ResNet31(nn.Layer): channels (list[int]): List of out_channels of Conv2d layer. out_indices (None | Sequence[int]): Indices of output stages. last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage. + init_type (None | str): the config to control the initialization. ''' def __init__(self, @@ -98,7 +102,8 @@ class ResNet31(nn.Layer): layers=[1, 2, 5, 3], channels=[64, 128, 256, 256, 512, 512, 512], out_indices=None, - last_stage_pool=False): + last_stage_pool=False, + init_type=None): super(ResNet31, self).__init__() assert isinstance(in_channels, int) assert isinstance(last_stage_pool, bool) @@ -106,42 +111,55 @@ class ResNet31(nn.Layer): self.out_indices = out_indices self.last_stage_pool = last_stage_pool + conv_weight_attr = None + bn_weight_attr = None + + if init_type is not None: + support_dict = ['KaimingNormal'] + assert init_type in support_dict, Exception( + "resnet31 only support {}".format(support_dict)) + conv_weight_attr = nn.initializer.KaimingNormal() + bn_weight_attr = ParamAttr(initializer=nn.initializer.Uniform(), learning_rate=1) + # conv 1 (Conv Conv) self.conv1_1 = nn.Conv2D( - in_channels, channels[0], kernel_size=3, stride=1, padding=1) - self.bn1_1 = nn.BatchNorm2D(channels[0]) + in_channels, channels[0], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) + self.bn1_1 = nn.BatchNorm2D(channels[0], weight_attr=bn_weight_attr) self.relu1_1 = nn.ReLU() self.conv1_2 = nn.Conv2D( - channels[0], channels[1], kernel_size=3, stride=1, padding=1) - self.bn1_2 = nn.BatchNorm2D(channels[1]) + channels[0], channels[1], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) + self.bn1_2 = nn.BatchNorm2D(channels[1], weight_attr=bn_weight_attr) self.relu1_2 = nn.ReLU() # conv 2 (Max-pooling, Residual block, Conv) self.pool2 = nn.MaxPool2D( kernel_size=2, stride=2, padding=0, ceil_mode=True) - self.block2 = self._make_layer(channels[1], channels[2], layers[0]) + self.block2 = self._make_layer(channels[1], channels[2], layers[0], + conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr) self.conv2 = nn.Conv2D( - channels[2], channels[2], kernel_size=3, stride=1, padding=1) - self.bn2 = nn.BatchNorm2D(channels[2]) + channels[2], channels[2], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) + self.bn2 = nn.BatchNorm2D(channels[2], weight_attr=bn_weight_attr) self.relu2 = nn.ReLU() # conv 3 (Max-pooling, Residual block, Conv) self.pool3 = nn.MaxPool2D( kernel_size=2, stride=2, padding=0, ceil_mode=True) - self.block3 = self._make_layer(channels[2], channels[3], layers[1]) + self.block3 = self._make_layer(channels[2], channels[3], layers[1], + conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr) self.conv3 = nn.Conv2D( - channels[3], channels[3], kernel_size=3, stride=1, padding=1) - self.bn3 = nn.BatchNorm2D(channels[3]) + channels[3], channels[3], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) + self.bn3 = nn.BatchNorm2D(channels[3], weight_attr=bn_weight_attr) self.relu3 = nn.ReLU() # conv 4 (Max-pooling, Residual block, Conv) self.pool4 = nn.MaxPool2D( kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True) - self.block4 = self._make_layer(channels[3], channels[4], layers[2]) + self.block4 = self._make_layer(channels[3], channels[4], layers[2], + conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr) self.conv4 = nn.Conv2D( - channels[4], channels[4], kernel_size=3, stride=1, padding=1) - self.bn4 = nn.BatchNorm2D(channels[4]) + channels[4], channels[4], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) + self.bn4 = nn.BatchNorm2D(channels[4], weight_attr=bn_weight_attr) self.relu4 = nn.ReLU() # conv 5 ((Max-pooling), Residual block, Conv) @@ -149,15 +167,16 @@ class ResNet31(nn.Layer): if self.last_stage_pool: self.pool5 = nn.MaxPool2D( kernel_size=2, stride=2, padding=0, ceil_mode=True) - self.block5 = self._make_layer(channels[4], channels[5], layers[3]) + self.block5 = self._make_layer(channels[4], channels[5], layers[3], + conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr) self.conv5 = nn.Conv2D( - channels[5], channels[5], kernel_size=3, stride=1, padding=1) - self.bn5 = nn.BatchNorm2D(channels[5]) + channels[5], channels[5], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr) + self.bn5 = nn.BatchNorm2D(channels[5], weight_attr=bn_weight_attr) self.relu5 = nn.ReLU() self.out_channels = channels[-1] - def _make_layer(self, input_channels, output_channels, blocks): + def _make_layer(self, input_channels, output_channels, blocks, conv_weight_attr=None, bn_weight_attr=None): layers = [] for _ in range(blocks): downsample = None @@ -168,12 +187,14 @@ class ResNet31(nn.Layer): output_channels, kernel_size=1, stride=1, + weight_attr=conv_weight_attr, bias_attr=False), - nn.BatchNorm2D(output_channels), ) + nn.BatchNorm2D(output_channels, weight_attr=bn_weight_attr)) layers.append( BasicBlock( - input_channels, output_channels, downsample=downsample)) + input_channels, output_channels, downsample=downsample, + conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr)) input_channels = output_channels return nn.Sequential(*layers) diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 3f6ff0c4e0240ff4f241f475e70dc6211106a659..0feda6c6e062fa314d97b8949d8545ed3305c22e 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -35,6 +35,7 @@ def build_head(config): from .rec_multi_head import MultiHead from .rec_spin_att_head import SPINAttentionHead from .rec_abinet_head import ABINetHead + from .rec_robustscanner_head import RobustScannerHead from .rec_visionlan_head import VLHead # cls head @@ -51,7 +52,7 @@ def build_head(config): 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead', - 'VLHead', 'SLAHead' + 'VLHead', 'SLAHead', 'RobustScannerHead' ] #table head diff --git a/ppocr/modeling/heads/rec_robustscanner_head.py b/ppocr/modeling/heads/rec_robustscanner_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7956059ecfe01f27db364d3d748d6af24dad0aac --- /dev/null +++ b/ppocr/modeling/heads/rec_robustscanner_head.py @@ -0,0 +1,709 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code is refer from: +https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/channel_reduction_encoder.py +https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/robust_scanner_decoder.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F + +class BaseDecoder(nn.Layer): + def __init__(self, **kwargs): + super().__init__() + + def forward_train(self, feat, out_enc, targets, img_metas): + raise NotImplementedError + + def forward_test(self, feat, out_enc, img_metas): + raise NotImplementedError + + def forward(self, + feat, + out_enc, + label=None, + valid_ratios=None, + word_positions=None, + train_mode=True): + self.train_mode = train_mode + + if train_mode: + return self.forward_train(feat, out_enc, label, valid_ratios, word_positions) + return self.forward_test(feat, out_enc, valid_ratios, word_positions) + +class ChannelReductionEncoder(nn.Layer): + """Change the channel number with a one by one convoluational layer. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + """ + + def __init__(self, + in_channels, + out_channels, + **kwargs): + super(ChannelReductionEncoder, self).__init__() + + self.layer = nn.Conv2D( + in_channels, out_channels, kernel_size=1, stride=1, padding=0, weight_attr=nn.initializer.XavierNormal()) + + def forward(self, feat): + """ + Args: + feat (Tensor): Image features with the shape of + :math:`(N, C_{in}, H, W)`. + + Returns: + Tensor: A tensor of shape :math:`(N, C_{out}, H, W)`. + """ + return self.layer(feat) + + +def masked_fill(x, mask, value): + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + +class DotProductAttentionLayer(nn.Layer): + + def __init__(self, dim_model=None): + super().__init__() + + self.scale = dim_model**-0.5 if dim_model is not None else 1. + + def forward(self, query, key, value, h, w, valid_ratios=None): + query = paddle.transpose(query, (0, 2, 1)) + logits = paddle.matmul(query, key) * self.scale + n, c, t = logits.shape + # reshape to (n, c, h, w) + logits = paddle.reshape(logits, [n, c, h, w]) + if valid_ratios is not None: + # cal mask of attention weight + for i, valid_ratio in enumerate(valid_ratios): + valid_width = min(w, int(w * valid_ratio + 0.5)) + if valid_width < w: + logits[i, :, :, valid_width:] = float('-inf') + + # reshape to (n, c, h, w) + logits = paddle.reshape(logits, [n, c, t]) + weights = F.softmax(logits, axis=2) + value = paddle.transpose(value, (0, 2, 1)) + glimpse = paddle.matmul(weights, value) + glimpse = paddle.transpose(glimpse, (0, 2, 1)) + return glimpse + +class SequenceAttentionDecoder(BaseDecoder): + """Sequence attention decoder for RobustScanner. + + RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for + Robust Text Recognition `_ + + Args: + num_classes (int): Number of output classes :math:`C`. + rnn_layers (int): Number of RNN layers. + dim_input (int): Dimension :math:`D_i` of input vector ``feat``. + dim_model (int): Dimension :math:`D_m` of the model. Should also be the + same as encoder output vector ``out_enc``. + max_seq_len (int): Maximum output sequence length :math:`T`. + start_idx (int): The index of ``. + mask (bool): Whether to mask input features according to + ``img_meta['valid_ratio']``. + padding_idx (int): The index of ``. + dropout (float): Dropout rate. + return_feature (bool): Return feature or logits as the result. + encode_value (bool): Whether to use the output of encoder ``out_enc`` + as `value` of attention layer. If False, the original feature + ``feat`` will be used. + + Warning: + This decoder will not predict the final class which is assumed to be + ``. Therefore, its output size is always :math:`C - 1`. `` + is also ignored by loss as specified in + :obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`. + """ + + def __init__(self, + num_classes=None, + rnn_layers=2, + dim_input=512, + dim_model=128, + max_seq_len=40, + start_idx=0, + mask=True, + padding_idx=None, + dropout=0, + return_feature=False, + encode_value=False): + super().__init__() + + self.num_classes = num_classes + self.dim_input = dim_input + self.dim_model = dim_model + self.return_feature = return_feature + self.encode_value = encode_value + self.max_seq_len = max_seq_len + self.start_idx = start_idx + self.mask = mask + + self.embedding = nn.Embedding( + self.num_classes, self.dim_model, padding_idx=padding_idx) + + self.sequence_layer = nn.LSTM( + input_size=dim_model, + hidden_size=dim_model, + num_layers=rnn_layers, + time_major=False, + dropout=dropout) + + self.attention_layer = DotProductAttentionLayer() + + self.prediction = None + if not self.return_feature: + pred_num_classes = num_classes - 1 + self.prediction = nn.Linear( + dim_model if encode_value else dim_input, pred_num_classes) + + def forward_train(self, feat, out_enc, targets, valid_ratios): + """ + Args: + feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. + out_enc (Tensor): Encoder output of shape + :math:`(N, D_m, H, W)`. + targets (Tensor): a tensor of shape :math:`(N, T)`. Each element is the index of a + character. + valid_ratios (Tensor): valid length ratio of img. + Returns: + Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if + ``return_feature=False``. Otherwise it would be the hidden feature + before the prediction projection layer, whose shape is + :math:`(N, T, D_m)`. + """ + + tgt_embedding = self.embedding(targets) + + n, c_enc, h, w = out_enc.shape + assert c_enc == self.dim_model + _, c_feat, _, _ = feat.shape + assert c_feat == self.dim_input + _, len_q, c_q = tgt_embedding.shape + assert c_q == self.dim_model + assert len_q <= self.max_seq_len + + query, _ = self.sequence_layer(tgt_embedding) + query = paddle.transpose(query, (0, 2, 1)) + key = paddle.reshape(out_enc, [n, c_enc, h * w]) + if self.encode_value: + value = key + else: + value = paddle.reshape(feat, [n, c_feat, h * w]) + + attn_out = self.attention_layer(query, key, value, h, w, valid_ratios) + attn_out = paddle.transpose(attn_out, (0, 2, 1)) + + if self.return_feature: + return attn_out + + out = self.prediction(attn_out) + + return out + + def forward_test(self, feat, out_enc, valid_ratios): + """ + Args: + feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. + out_enc (Tensor): Encoder output of shape + :math:`(N, D_m, H, W)`. + valid_ratios (Tensor): valid length ratio of img. + + Returns: + Tensor: The output logit sequence tensor of shape + :math:`(N, T, C-1)`. + """ + seq_len = self.max_seq_len + batch_size = feat.shape[0] + + decode_sequence = (paddle.ones((batch_size, seq_len), dtype='int64') * self.start_idx) + + outputs = [] + for i in range(seq_len): + step_out = self.forward_test_step(feat, out_enc, decode_sequence, + i, valid_ratios) + outputs.append(step_out) + max_idx = paddle.argmax(step_out, axis=1, keepdim=False) + if i < seq_len - 1: + decode_sequence[:, i + 1] = max_idx + + outputs = paddle.stack(outputs, 1) + + return outputs + + def forward_test_step(self, feat, out_enc, decode_sequence, current_step, + valid_ratios): + """ + Args: + feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. + out_enc (Tensor): Encoder output of shape + :math:`(N, D_m, H, W)`. + decode_sequence (Tensor): Shape :math:`(N, T)`. The tensor that + stores history decoding result. + current_step (int): Current decoding step. + valid_ratios (Tensor): valid length ratio of img + + Returns: + Tensor: Shape :math:`(N, C-1)`. The logit tensor of predicted + tokens at current time step. + """ + + embed = self.embedding(decode_sequence) + + n, c_enc, h, w = out_enc.shape + assert c_enc == self.dim_model + _, c_feat, _, _ = feat.shape + assert c_feat == self.dim_input + _, _, c_q = embed.shape + assert c_q == self.dim_model + + query, _ = self.sequence_layer(embed) + query = paddle.transpose(query, (0, 2, 1)) + key = paddle.reshape(out_enc, [n, c_enc, h * w]) + if self.encode_value: + value = key + else: + value = paddle.reshape(feat, [n, c_feat, h * w]) + + # [n, c, l] + attn_out = self.attention_layer(query, key, value, h, w, valid_ratios) + out = attn_out[:, :, current_step] + + if self.return_feature: + return out + + out = self.prediction(out) + out = F.softmax(out, dim=-1) + + return out + + +class PositionAwareLayer(nn.Layer): + + def __init__(self, dim_model, rnn_layers=2): + super().__init__() + + self.dim_model = dim_model + + self.rnn = nn.LSTM( + input_size=dim_model, + hidden_size=dim_model, + num_layers=rnn_layers, + time_major=False) + + self.mixer = nn.Sequential( + nn.Conv2D( + dim_model, dim_model, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.Conv2D( + dim_model, dim_model, kernel_size=3, stride=1, padding=1)) + + def forward(self, img_feature): + n, c, h, w = img_feature.shape + rnn_input = paddle.transpose(img_feature, (0, 2, 3, 1)) + rnn_input = paddle.reshape(rnn_input, (n * h, w, c)) + rnn_output, _ = self.rnn(rnn_input) + rnn_output = paddle.reshape(rnn_output, (n, h, w, c)) + rnn_output = paddle.transpose(rnn_output, (0, 3, 1, 2)) + out = self.mixer(rnn_output) + return out + + +class PositionAttentionDecoder(BaseDecoder): + """Position attention decoder for RobustScanner. + + RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for + Robust Text Recognition `_ + + Args: + num_classes (int): Number of output classes :math:`C`. + rnn_layers (int): Number of RNN layers. + dim_input (int): Dimension :math:`D_i` of input vector ``feat``. + dim_model (int): Dimension :math:`D_m` of the model. Should also be the + same as encoder output vector ``out_enc``. + max_seq_len (int): Maximum output sequence length :math:`T`. + mask (bool): Whether to mask input features according to + ``img_meta['valid_ratio']``. + return_feature (bool): Return feature or logits as the result. + encode_value (bool): Whether to use the output of encoder ``out_enc`` + as `value` of attention layer. If False, the original feature + ``feat`` will be used. + + Warning: + This decoder will not predict the final class which is assumed to be + ``. Therefore, its output size is always :math:`C - 1`. `` + is also ignored by loss + + """ + + def __init__(self, + num_classes=None, + rnn_layers=2, + dim_input=512, + dim_model=128, + max_seq_len=40, + mask=True, + return_feature=False, + encode_value=False): + super().__init__() + + self.num_classes = num_classes + self.dim_input = dim_input + self.dim_model = dim_model + self.max_seq_len = max_seq_len + self.return_feature = return_feature + self.encode_value = encode_value + self.mask = mask + + self.embedding = nn.Embedding(self.max_seq_len + 1, self.dim_model) + + self.position_aware_module = PositionAwareLayer( + self.dim_model, rnn_layers) + + self.attention_layer = DotProductAttentionLayer() + + self.prediction = None + if not self.return_feature: + pred_num_classes = num_classes - 1 + self.prediction = nn.Linear( + dim_model if encode_value else dim_input, pred_num_classes) + + def _get_position_index(self, length, batch_size): + position_index_list = [] + for i in range(batch_size): + position_index = paddle.arange(0, end=length, step=1, dtype='int64') + position_index_list.append(position_index) + batch_position_index = paddle.stack(position_index_list, axis=0) + return batch_position_index + + def forward_train(self, feat, out_enc, targets, valid_ratios, position_index): + """ + Args: + feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. + out_enc (Tensor): Encoder output of shape + :math:`(N, D_m, H, W)`. + targets (dict): A dict with the key ``padded_targets``, a + tensor of shape :math:`(N, T)`. Each element is the index of a + character. + valid_ratios (Tensor): valid length ratio of img. + position_index (Tensor): The position of each word. + + Returns: + Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if + ``return_feature=False``. Otherwise it will be the hidden feature + before the prediction projection layer, whose shape is + :math:`(N, T, D_m)`. + """ + n, c_enc, h, w = out_enc.shape + assert c_enc == self.dim_model + _, c_feat, _, _ = feat.shape + assert c_feat == self.dim_input + _, len_q = targets.shape + assert len_q <= self.max_seq_len + + position_out_enc = self.position_aware_module(out_enc) + + query = self.embedding(position_index) + query = paddle.transpose(query, (0, 2, 1)) + key = paddle.reshape(position_out_enc, (n, c_enc, h * w)) + if self.encode_value: + value = paddle.reshape(out_enc,(n, c_enc, h * w)) + else: + value = paddle.reshape(feat,(n, c_feat, h * w)) + + attn_out = self.attention_layer(query, key, value, h, w, valid_ratios) + attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v] + + if self.return_feature: + return attn_out + + return self.prediction(attn_out) + + def forward_test(self, feat, out_enc, valid_ratios, position_index): + """ + Args: + feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. + out_enc (Tensor): Encoder output of shape + :math:`(N, D_m, H, W)`. + valid_ratios (Tensor): valid length ratio of img + position_index (Tensor): The position of each word. + + Returns: + Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if + ``return_feature=False``. Otherwise it would be the hidden feature + before the prediction projection layer, whose shape is + :math:`(N, T, D_m)`. + """ + n, c_enc, h, w = out_enc.shape + assert c_enc == self.dim_model + _, c_feat, _, _ = feat.shape + assert c_feat == self.dim_input + + position_out_enc = self.position_aware_module(out_enc) + + query = self.embedding(position_index) + query = paddle.transpose(query, (0, 2, 1)) + key = paddle.reshape(position_out_enc, (n, c_enc, h * w)) + if self.encode_value: + value = paddle.reshape(out_enc,(n, c_enc, h * w)) + else: + value = paddle.reshape(feat,(n, c_feat, h * w)) + + attn_out = self.attention_layer(query, key, value, h, w, valid_ratios) + attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v] + + if self.return_feature: + return attn_out + + return self.prediction(attn_out) + +class RobustScannerFusionLayer(nn.Layer): + + def __init__(self, dim_model, dim=-1): + super(RobustScannerFusionLayer, self).__init__() + + self.dim_model = dim_model + self.dim = dim + self.linear_layer = nn.Linear(dim_model * 2, dim_model * 2) + + def forward(self, x0, x1): + assert x0.shape == x1.shape + fusion_input = paddle.concat([x0, x1], self.dim) + output = self.linear_layer(fusion_input) + output = F.glu(output, self.dim) + return output + +class RobustScannerDecoder(BaseDecoder): + """Decoder for RobustScanner. + + RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for + Robust Text Recognition `_ + + Args: + num_classes (int): Number of output classes :math:`C`. + dim_input (int): Dimension :math:`D_i` of input vector ``feat``. + dim_model (int): Dimension :math:`D_m` of the model. Should also be the + same as encoder output vector ``out_enc``. + max_seq_len (int): Maximum output sequence length :math:`T`. + start_idx (int): The index of ``. + mask (bool): Whether to mask input features according to + ``img_meta['valid_ratio']``. + padding_idx (int): The index of ``. + encode_value (bool): Whether to use the output of encoder ``out_enc`` + as `value` of attention layer. If False, the original feature + ``feat`` will be used. + + Warning: + This decoder will not predict the final class which is assumed to be + ``. Therefore, its output size is always :math:`C - 1`. `` + is also ignored by loss as specified in + :obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`. + """ + + def __init__(self, + num_classes=None, + dim_input=512, + dim_model=128, + hybrid_decoder_rnn_layers=2, + hybrid_decoder_dropout=0, + position_decoder_rnn_layers=2, + max_seq_len=40, + start_idx=0, + mask=True, + padding_idx=None, + encode_value=False): + super().__init__() + self.num_classes = num_classes + self.dim_input = dim_input + self.dim_model = dim_model + self.max_seq_len = max_seq_len + self.encode_value = encode_value + self.start_idx = start_idx + self.padding_idx = padding_idx + self.mask = mask + + # init hybrid decoder + self.hybrid_decoder = SequenceAttentionDecoder( + num_classes=num_classes, + rnn_layers=hybrid_decoder_rnn_layers, + dim_input=dim_input, + dim_model=dim_model, + max_seq_len=max_seq_len, + start_idx=start_idx, + mask=mask, + padding_idx=padding_idx, + dropout=hybrid_decoder_dropout, + encode_value=encode_value, + return_feature=True + ) + + # init position decoder + self.position_decoder = PositionAttentionDecoder( + num_classes=num_classes, + rnn_layers=position_decoder_rnn_layers, + dim_input=dim_input, + dim_model=dim_model, + max_seq_len=max_seq_len, + mask=mask, + encode_value=encode_value, + return_feature=True + ) + + + self.fusion_module = RobustScannerFusionLayer( + self.dim_model if encode_value else dim_input) + + pred_num_classes = num_classes - 1 + self.prediction = nn.Linear(dim_model if encode_value else dim_input, + pred_num_classes) + + def forward_train(self, feat, out_enc, target, valid_ratios, word_positions): + """ + Args: + feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. + out_enc (Tensor): Encoder output of shape + :math:`(N, D_m, H, W)`. + target (dict): A dict with the key ``padded_targets``, a + tensor of shape :math:`(N, T)`. Each element is the index of a + character. + valid_ratios (Tensor): + word_positions (Tensor): The position of each word. + + Returns: + Tensor: A raw logit tensor of shape :math:`(N, T, C-1)`. + """ + hybrid_glimpse = self.hybrid_decoder.forward_train( + feat, out_enc, target, valid_ratios) + position_glimpse = self.position_decoder.forward_train( + feat, out_enc, target, valid_ratios, word_positions) + + fusion_out = self.fusion_module(hybrid_glimpse, position_glimpse) + + out = self.prediction(fusion_out) + + return out + + def forward_test(self, feat, out_enc, valid_ratios, word_positions): + """ + Args: + feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. + out_enc (Tensor): Encoder output of shape + :math:`(N, D_m, H, W)`. + valid_ratios (Tensor): + word_positions (Tensor): The position of each word. + Returns: + Tensor: The output logit sequence tensor of shape + :math:`(N, T, C-1)`. + """ + seq_len = self.max_seq_len + batch_size = feat.shape[0] + + decode_sequence = (paddle.ones((batch_size, seq_len), dtype='int64') * self.start_idx) + + position_glimpse = self.position_decoder.forward_test( + feat, out_enc, valid_ratios, word_positions) + + outputs = [] + for i in range(seq_len): + hybrid_glimpse_step = self.hybrid_decoder.forward_test_step( + feat, out_enc, decode_sequence, i, valid_ratios) + + fusion_out = self.fusion_module(hybrid_glimpse_step, + position_glimpse[:, i, :]) + + char_out = self.prediction(fusion_out) + char_out = F.softmax(char_out, -1) + outputs.append(char_out) + max_idx = paddle.argmax(char_out, axis=1, keepdim=False) + if i < seq_len - 1: + decode_sequence[:, i + 1] = max_idx + + outputs = paddle.stack(outputs, 1) + + return outputs + +class RobustScannerHead(nn.Layer): + def __init__(self, + out_channels, # 90 + unknown + start + padding + in_channels, + enc_outchannles=128, + hybrid_dec_rnn_layers=2, + hybrid_dec_dropout=0, + position_dec_rnn_layers=2, + start_idx=0, + max_text_length=40, + mask=True, + padding_idx=None, + encode_value=False, + **kwargs): + super(RobustScannerHead, self).__init__() + + # encoder module + self.encoder = ChannelReductionEncoder( + in_channels=in_channels, out_channels=enc_outchannles) + + # decoder module + self.decoder =RobustScannerDecoder( + num_classes=out_channels, + dim_input=in_channels, + dim_model=enc_outchannles, + hybrid_decoder_rnn_layers=hybrid_dec_rnn_layers, + hybrid_decoder_dropout=hybrid_dec_dropout, + position_decoder_rnn_layers=position_dec_rnn_layers, + max_seq_len=max_text_length, + start_idx=start_idx, + mask=mask, + padding_idx=padding_idx, + encode_value=encode_value) + + def forward(self, inputs, targets=None): + ''' + targets: [label, valid_ratio, word_positions] + ''' + out_enc = self.encoder(inputs) + valid_ratios = None + word_positions = targets[-1] + + if len(targets) > 1: + valid_ratios = targets[-2] + + if self.training: + label = targets[0] # label + label = paddle.to_tensor(label, dtype='int64') + final_out = self.decoder( + inputs, out_enc, label, valid_ratios, word_positions) + if not self.training: + final_out = self.decoder( + inputs, + out_enc, + label=None, + valid_ratios=valid_ratios, + word_positions=word_positions, + train_mode=False) + return final_out diff --git a/ppocr/utils/dict/layout_dict/layout_cdla_dict.txt b/ppocr/utils/dict/layout_dict/layout_cdla_dict.txt new file mode 100644 index 0000000000000000000000000000000000000000..8be0f48600a88463d840fffe27eebd63629576ce --- /dev/null +++ b/ppocr/utils/dict/layout_dict/layout_cdla_dict.txt @@ -0,0 +1,10 @@ +text +title +figure +figure_caption +table +table_caption +header +footer +reference +equation \ No newline at end of file diff --git a/ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt b/ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt new file mode 100644 index 0000000000000000000000000000000000000000..ca6acf4eef8d4d5f9ba5a4ced4858a119a4ef983 --- /dev/null +++ b/ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt @@ -0,0 +1,5 @@ +text +title +list +table +figure \ No newline at end of file diff --git a/ppocr/utils/dict/layout_dict/layout_table_dict.txt b/ppocr/utils/dict/layout_dict/layout_table_dict.txt new file mode 100644 index 0000000000000000000000000000000000000000..faea15ea07d7d1a6f77dbd4287bb9fa87165cbb9 --- /dev/null +++ b/ppocr/utils/dict/layout_dict/layout_table_dict.txt @@ -0,0 +1 @@ +table \ No newline at end of file diff --git a/ppocr/utils/utility.py b/ppocr/utils/utility.py index b881fcab20bc5ca076a0002bd72349768c7d881a..18357c8e97bcea8ee321856a87146a4a7b901469 100755 --- a/ppocr/utils/utility.py +++ b/ppocr/utils/utility.py @@ -50,7 +50,7 @@ def get_check_global_params(mode): def _check_image_file(path): - img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'} + img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'pdf'} return any([path.lower().endswith(e) for e in img_end]) @@ -59,7 +59,7 @@ def get_image_file_list(img_file): if img_file is None or not os.path.exists(img_file): raise Exception("not found any img file in {}".format(img_file)) - img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'} + img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'pdf'} if os.path.isfile(img_file) and _check_image_file(img_file): imgs_lists.append(img_file) elif os.path.isdir(img_file): @@ -73,7 +73,7 @@ def get_image_file_list(img_file): return imgs_lists -def check_and_read_gif(img_path): +def check_and_read(img_path): if os.path.basename(img_path)[-3:] in ['gif', 'GIF']: gif = cv2.VideoCapture(img_path) ret, frame = gif.read() @@ -84,8 +84,26 @@ def check_and_read_gif(img_path): if len(frame.shape) == 2 or frame.shape[-1] == 1: frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) imgvalue = frame[:, :, ::-1] - return imgvalue, True - return None, False + return imgvalue, True, False + elif os.path.basename(img_path)[-3:] in ['pdf']: + import fitz + from PIL import Image + imgs = [] + with fitz.open(img_path) as pdf: + for pg in range(0, pdf.pageCount): + page = pdf[pg] + mat = fitz.Matrix(2, 2) + pm = page.getPixmap(matrix=mat, alpha=False) + + # if width or height > 2000 pixels, don't enlarge the image + if pm.width > 2000 or pm.height > 2000: + pm = page.getPixmap(matrix=fitz.Matrix(1, 1), alpha=False) + + img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples) + img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + imgs.append(img) + return imgs, False, True + return None, False, False def load_vqa_bio_label_maps(label_map_path): diff --git a/ppstructure/docs/layout/layout.png b/ppstructure/docs/layout/layout.png new file mode 100644 index 0000000000000000000000000000000000000000..da9640e245e34659771353e328bf97da129bd622 Binary files /dev/null and b/ppstructure/docs/layout/layout.png differ diff --git a/ppstructure/docs/layout/layout_res.jpg b/ppstructure/docs/layout/layout_res.jpg new file mode 100644 index 0000000000000000000000000000000000000000..93b3a8bef3bfc9f5c80a9505239af05d526b45a7 Binary files /dev/null and b/ppstructure/docs/layout/layout_res.jpg differ diff --git a/ppstructure/layout/README.md b/ppstructure/layout/README.md index 3a4f5291763e34c8aec2c5b327d40a459bb4be1e..3762544b834d752a705216ca3f93d326aa1391ad 100644 --- a/ppstructure/layout/README.md +++ b/ppstructure/layout/README.md @@ -1,127 +1,469 @@ -English | [简体中文](README_ch.md) -- [Getting Started](#getting-started) - - [1. Install whl package](#1--install-whl-package) - - [2. Quick Start](#2-quick-start) - - [3. PostProcess](#3-postprocess) - - [4. Results](#4-results) - - [5. Training](#5-training) +- [1. 简介](#1-简介) -# Getting Started +- [2. 安装](#2-安装) + + - [2.1 安装PaddlePaddle](#21-安装paddlepaddle) + - [2.2 安装PaddleDetection](#22-安装paddledetection) + +- [3. 数据准备](#3-数据准备) + + - [3.1 英文数据集](#31-英文数据集) + - [3.2 更多数据集](#32-更多数据集) + +- [4. 开始训练](#4-开始训练) + + - [4.1 启动训练](#41-启动训练) + - [4.2 FGD蒸馏训练](#42-FGD蒸馏训练) + +- [5. 模型评估与预测](#5-模型评估与预测) + + - [5.1 指标评估](#51-指标评估) + - [5.2 测试版面分析结果](#52-测试版面分析结果) + +- [6 模型导出与预测](#6-模型导出与预测) + + - [6.1 模型导出](#61-模型导出) + + - [6.2 模型推理](#62-模型推理) + +# 版面分析 + +## 1. 简介 + +版面分析指的是对图片形式的文档进行区域划分,定位其中的关键区域,如文字、标题、表格、图片等。版面分析算法基于[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection)的轻量模型PP-PicoDet进行开发。 + +
+ +
+ + + +## 2. 安装依赖 + +### 2.1. 安装PaddlePaddle + +- **(1) 安装PaddlePaddle** -## 1. Install whl package ```bash -wget https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl -pip install -U layoutparser-0.0.0-py3-none-any.whl +python3 -m pip install --upgrade pip + +# GPU安装 +python3 -m pip install "paddlepaddle-gpu>=2.2" -i https://mirror.baidu.com/pypi/simple + +# CPU安装 +python3 -m pip install "paddlepaddle>=2.2" -i https://mirror.baidu.com/pypi/simple ``` +更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。 -## 2. Quick Start +### 2.2. 安装PaddleDetection -Use LayoutParser to identify the layout of a document: +- **(1)下载PaddleDetection源码** -```python -import cv2 -import layoutparser as lp -image = cv2.imread("doc/table/layout.jpg") -image = image[..., ::-1] +```bash +git clone https://github.com/PaddlePaddle/PaddleDetection.git +``` -# load model -model = lp.PaddleDetectionLayoutModel(config_path="lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config", - threshold=0.5, - label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}, - enforce_cpu=False, - enable_mkldnn=True) -# detect -layout = model.detect(image) +- **(2)安装其他依赖 ** -# show result -show_img = lp.draw_box(image, layout, box_width=3, show_element_type=True) -show_img.show() +```bash +cd PaddleDetection +python3 -m pip install -r requirements.txt ``` -The following figure shows the result, with different colored detection boxes representing different categories and displaying specific categories in the upper left corner of the box with `show_element_type` +## 3. 数据准备 -
- -
-`PaddleDetectionLayoutModel`parameters are described as follows: +如果希望直接体验预测过程,可以跳过数据准备,下载我们提供的预训练模型。 + +### 3.1. 英文数据集 + +下载文档分析数据集[PubLayNet](https://developer.ibm.com/exchanges/data/all/publaynet/)(数据集96G),包含5个类:`{0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}` -| parameter | description | default | remark | -| :------------: | :------------------------------------------------------: | :---------: | :----------------------------------------------------------: | -| config_path | model config path | None | Specify config_ path will automatically download the model (only for the first time,the model will exist and will not be downloaded again) | -| model_path | model path | None | local model path, config_ path and model_ path must be set to one, cannot be none at the same time | -| threshold | threshold of prediction score | 0.5 | \ | -| input_shape | picture size of reshape | [3,640,640] | \ | -| batch_size | testing batch size | 1 | \ | -| label_map | category mapping table | None | Setting config_ path, it can be none, and the label is automatically obtained according to the dataset name_ map, You need to specify it manually when setting model_path | -| enforce_cpu | whether to use CPU | False | False to use GPU, and True to force the use of CPU | -| enforce_mkldnn | whether mkldnn acceleration is enabled in CPU prediction | True | \ | -| thread_num | the number of CPU threads | 10 | \ | +``` +# 下载数据 +wget https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/publaynet.tar.gz +# 解压数据 +tar -xvf publaynet.tar.gz +``` -The following model configurations and label maps are currently supported, which you can use by modifying '--config_path' and '--label_map' to detect different types of content: +解压之后的**目录结构:** -| dataset | config_path | label_map | -| ------------------------------------------------------------ | ------------------------------------------------------------ | --------------------------------------------------------- | -| [TableBank](https://doc-analysis.github.io/tablebank-page/index.html) word | lp://TableBank/ppyolov2_r50vd_dcn_365e_tableBank_word/config | {0:"Table"} | -| TableBank latex | lp://TableBank/ppyolov2_r50vd_dcn_365e_tableBank_latex/config | {0:"Table"} | -| [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet) | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config | {0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"} | +``` +|-publaynet + |- test + |- PMC1277013_00004.jpg + |- PMC1291385_00002.jpg + | ... + |- train.json + |- train + |- PMC1291385_00002.jpg + |- PMC1277013_00004.jpg + | ... + |- val.json + |- val + |- PMC538274_00004.jpg + |- PMC539300_00004.jpg + | ... +``` -* TableBank word and TableBank latex are trained on datasets of word documents and latex documents respectively; -* Download TableBank dataset contains both word and latex。 +**数据分布:** -## 3. PostProcess +| File or Folder | Description | num | +| :------------- | :------------- | ------- | +| `train/` | 训练集图片 | 335,703 | +| `val/` | 验证集图片 | 11,245 | +| `test/` | 测试集图片 | 11,405 | +| `train.json` | 训练集标注文件 | - | +| `val.json` | 验证集标注文件 | - | -Layout parser contains multiple categories, if you only want to get the detection box for a specific category (such as the "Text" category), you can use the following code: +**标注格式:** -```python -# follow the above code -# filter areas for a specific text type -text_blocks = lp.Layout([b for b in layout if b.type=='Text']) -figure_blocks = lp.Layout([b for b in layout if b.type=='Figure']) +json文件包含所有图像的标注,数据以字典嵌套的方式存放,包含以下key: -# text areas may be detected within the image area, delete these areas -text_blocks = lp.Layout([b for b in text_blocks \ - if not any(b.is_in(b_fig) for b_fig in figure_blocks)]) +- info,表示标注文件info。 -# sort text areas and assign ID -h, w = image.shape[:2] +- licenses,表示标注文件licenses。 -left_interval = lp.Interval(0, w/2*1.05, axis='x').put_on_canvas(image) +- images,表示标注文件中图像信息列表,每个元素是一张图像的信息。如下为其中一张图像的信息: -left_blocks = text_blocks.filter_by(left_interval, center=True) -left_blocks.sort(key = lambda b:b.coordinates[1]) + ``` + { + 'file_name': 'PMC4055390_00006.jpg', # file_name + 'height': 601, # image height + 'width': 792, # image width + 'id': 341427 # image id + } + ``` -right_blocks = [b for b in text_blocks if b not in left_blocks] -right_blocks.sort(key = lambda b:b.coordinates[1]) +- annotations,表示标注文件中目标物体的标注信息列表,每个元素是一个目标物体的标注信息。如下为其中一个目标物体的标注信息: -# the two lists are merged and the indexes are added in order -text_blocks = lp.Layout([b.set(id = idx) for idx, b in enumerate(left_blocks + right_blocks)]) + ``` + { + + 'segmentation': # 物体的分割标注 + 'area': 60518.099043117836, # 物体的区域面积 + 'iscrowd': 0, # iscrowd + 'image_id': 341427, # image id + 'bbox': [50.58, 490.86, 240.15, 252.16], # bbox [x1,y1,w,h] + 'category_id': 1, # category_id + 'id': 3322348 # image id + } + ``` + +### 3.2. 更多数据集 + +我们提供了CDLA(中文版面分析)、TableBank(表格版面分析)等数据集的下连接,处理为上述标注文件json格式,即可以按相同方式进行训练。 + +| dataset | 简介 | +| ------------------------------------------------------------ | ------------------------------------------------------------ | +| [cTDaR2019_cTDaR](https://cndplab-founder.github.io/cTDaR2019/) | 用于表格检测(TRACKA)和表格识别(TRACKB)。图片类型包含历史数据集(以cTDaR_t0开头,如cTDaR_t00872.jpg)和现代数据集(以cTDaR_t1开头,cTDaR_t10482.jpg)。 | +| [IIIT-AR-13K](http://cvit.iiit.ac.in/usodi/iiitar13k.php) | 手动注释公开的年度报告中的图形或页面而构建的数据集,包含5类:table, figure, natural image, logo, and signature | +| [CDLA](https://github.com/buptlihang/CDLA) | 中文文档版面分析数据集,面向中文文献类(论文)场景,包含10类:Table、Figure、Figure caption、Table、Table caption、Header、Footer、Reference、Equation | +| [TableBank](https://github.com/doc-analysis/TableBank) | 用于表格检测和识别大型数据集,包含Word和Latex2种文档格式 | +| [DocBank](https://github.com/doc-analysis/DocBank) | 使用弱监督方法构建的大规模数据集(500K文档页面),用于文档布局分析,包含12类:Author、Caption、Date、Equation、Figure、Footer、List、Paragraph、Reference、Section、Table、Title | + + +## 4. 开始训练 + +提供了训练脚本、评估脚本和预测脚本,本节将以PubLayNet预训练模型为例进行讲解。 + +如果不希望训练,直接体验后面的模型评估、预测、动转静、推理的流程,可以下载提供的预训练模型,并跳过本部分。 -# display result -show_img = lp.draw_box(image, text_blocks, - box_width=3, - show_element_id=True) -show_img.show() +``` +mkdir pretrained_model +cd pretrained_model +# 下载并解压PubLayNet预训练模型 +wget https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_layout.pdparams ``` -Displays results with only the "Text" category: +### 4.1. 启动训练 + +开始训练: + +* 修改配置文件 + +如果你希望训练自己的数据集,需要修改配置文件中的数据配置、类别数。 -
- -
-## 4. Results +以`configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml` 为例,修改的内容如下所示。 -| Dataset | mAP | CPU time cost | GPU time cost | -| --------- | ---- | ------------- | ------------- | -| PubLayNet | 93.6 | 1713.7ms | 66.6ms | -| TableBank | 96.2 | 1968.4ms | 65.1ms | +```yaml +metric: COCO +# 类别数 +num_classes: 5 + +TrainDataset: + !COCODataSet + # 修改为你自己的训练数据目录 + image_dir: train + # 修改为你自己的训练数据标签文件 + anno_path: train.json + # 修改为你自己的训练数据根目录 + dataset_dir: /root/publaynet/ + data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd'] + +EvalDataset: + !COCODataSet + # 修改为你自己的验证数据目录 + image_dir: val + # 修改为你自己的验证数据标签文件 + anno_path: val.json + # 修改为你自己的验证数据根目录 + dataset_dir: /root/publaynet/ + +TestDataset: + !ImageFolder + # 修改为你自己的测试数据标签文件 + anno_path: /root/publaynet/val.json +``` + +* 开始训练,在训练时,会默认下载PP-PicoDet预训练模型,这里无需预先下载。 + +```bash +# GPU训练 支持单卡,多卡训练 +# 训练日志会自动保存到 log 目录中 + +# 单卡训练 +python3 tools/train.py \ + -c configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml \ + --eval + +# 多卡训练,通过--gpus参数指定卡号 +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py \ + -c configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml \ + --eval +``` + +正常启动训练后,会看到以下log输出: + +``` +[08/15 04:02:30] ppdet.utils.checkpoint INFO: Finish loading model weights: /root/.cache/paddle/weights/LCNet_x1_0_pretrained.pdparams +[08/15 04:02:46] ppdet.engine INFO: Epoch: [0] [ 0/1929] learning_rate: 0.040000 loss_vfl: 1.216707 loss_bbox: 1.142163 loss_dfl: 0.544196 loss: 2.903065 eta: 17 days, 13:50:26 batch_cost: 15.7452 data_cost: 2.9112 ips: 1.5243 images/s +[08/15 04:03:19] ppdet.engine INFO: Epoch: [0] [ 20/1929] learning_rate: 0.064000 loss_vfl: 1.180627 loss_bbox: 0.939552 loss_dfl: 0.442436 loss: 2.628206 eta: 2 days, 12:18:53 batch_cost: 1.5770 data_cost: 0.0008 ips: 15.2184 images/s +[08/15 04:03:47] ppdet.engine INFO: Epoch: [0] [ 40/1929] learning_rate: 0.088000 loss_vfl: 0.543321 loss_bbox: 1.071401 loss_dfl: 0.457817 loss: 2.057003 eta: 2 days, 0:07:03 batch_cost: 1.3190 data_cost: 0.0007 ips: 18.1954 images/s +[08/15 04:04:12] ppdet.engine INFO: Epoch: [0] [ 60/1929] learning_rate: 0.112000 loss_vfl: 0.630989 loss_bbox: 0.859183 loss_dfl: 0.384702 loss: 1.883143 eta: 1 day, 19:01:29 batch_cost: 1.2177 data_cost: 0.0006 ips: 19.7087 images/s +``` + +- `--eval`表示训练的同时,进行评估, 评估过程中默认将最佳模型,保存为 `output/picodet_lcnet_x1_0_layout/best_accuracy` 。 + +**注意,预测/评估时的配置文件请务必与训练一致。** + +### 4.2. FGD蒸馏训练 + +PaddleDetection支持了基于FGD([Focal and Global Knowledge Distillation for Detectors](https://arxiv.org/abs/2111.11837v1))蒸馏的目标检测模型训练过程,FGD蒸馏分为两个部分`Focal`和`Global`。`Focal`蒸馏分离图像的前景和背景,让学生模型分别关注教师模型的前景和背景部分特征的关键像素;`Global`蒸馏部分重建不同像素之间的关系并将其从教师转移到学生,以补偿`Focal`蒸馏中丢失的全局信息。 + +更换数据集,修改【TODO】配置中的数据配置、类别数,具体可以参考4.1。启动训练: + +```bash +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py \ + -c configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml \ + --slim_config configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x2_5_layout.yml \ + --eval +``` -**Envrionment:** +- `-c`: 指定模型配置文件。 +- `--slim_config`: 指定压缩策略配置文件。 -​ **CPU:** Intel(R) Xeon(R) CPU E5-2650 v4 @ 2.20GHz,24core +## 5. 模型评估与预测 -​ **GPU:** a single NVIDIA Tesla P40 +### 5.1. 指标评估 -## 5. Training +训练中模型参数默认保存在`output/picodet_lcnet_x1_0_layout`目录下。在评估指标时,需要设置`weights`指向保存的参数文件。评估数据集可以通过 `configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml` 修改`EvalDataset`中的 `image_dir`、`anno_path`和`dataset_dir` 设置。 + +```bash +# GPU 评估, weights 为待测权重 +python3 tools/eval.py \ + -c configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml \ + -o weigths=./output/picodet_lcnet_x1_0_layout/best_model +``` + +会输出以下信息,打印出mAP、AP0.5等信息。 + +```py + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.935 + Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.979 + Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.956 + Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.404 + Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.782 + Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.969 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.539 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.938 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.949 + Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.495 + Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.818 + Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.978 +[08/15 07:07:09] ppdet.engine INFO: Total sample number: 11245, averge FPS: 24.405059207157436 +[08/15 07:07:09] ppdet.engine INFO: Best test bbox ap is 0.935. +``` + +使用FGD蒸馏模型进行评估: + +``` +python3 tools/eval.py \ + -c configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml \ + --slim_config configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x2_5_layout.yml \ + -o weights=output/picodet_lcnet_x2_5_layout/best_model +``` + +- `-c`: 指定模型配置文件。 +- `--slim_config`: 指定蒸馏策略配置文件。 +- `-o weights`: 指定蒸馏算法训好的模型路径。 + +### 5.2. 测试版面分析结果 + + +预测使用的配置文件必须与训练一致,如您通过 `python3 tools/train.py -c configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml` 完成了模型的训练过程。 + +使用 PaddleDetection 训练好的模型,您可以使用如下命令进行中文模型预测。 + + +```bash +python3 tools/infer.py \ + -c configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml \ + -o weights='output/picodet_lcnet_x1_0_layout/best_model.pdparams' \ + --infer_img='docs/images/layout.jpg' \ + --output_dir=output_dir/ \ + --draw_threshold=0.4 +``` + +- `--infer_img`: 推理单张图片,也可以通过`--infer_dir`推理文件中的所有图片。 +- `--output_dir`: 指定可视化结果保存路径。 +- `--draw_threshold`:指定绘制结果框的NMS阈值。 + +预测图片如下所示,图片会存储在`output_dir`路径中。 + +使用FGD蒸馏模型进行测试: + +``` +python3 tools/infer.py \ + -c configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml \ + --slim_config configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x2_5_layout.yml \ + -o weights='output/picodet_lcnet_x2_5_layout/best_model.pdparams' \ + --infer_img='docs/images/layout.jpg' \ + --output_dir=output_dir/ \ + --draw_threshold=0.4 +``` + + + +## 6. 模型导出与预测 + + +### 6.1 模型导出 + +inference 模型(`paddle.jit.save`保存的模型) 一般是模型训练,把模型结构和模型参数保存在文件中的固化模型,多用于预测部署场景。 训练过程中保存的模型是checkpoints模型,保存的只有模型的参数,多用于恢复训练等。 与checkpoints模型相比,inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合于实际系统集成。 + +版面分析模型转inference模型步骤如下: + +```bash +python3 tools/export_model.py \ + -c configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml \ + -o weights=output/picodet_lcnet_x1_0_layout/best_model \ + --output_dir=output_inference/ +``` + +* 如无需导出后处理,请指定:`-o export.benchmark=True`(如果-o已出现过,此处删掉-o) +* 如无需导出NMS,请指定:`-o export.nms=False` + +转换成功后,在目录下有三个文件: + +``` +output_inference/picodet_lcnet_x1_0_layout/ + ├── model.pdiparams # inference模型的参数文件 + ├── model.pdiparams.info # inference模型的参数信息,可忽略 + └── model.pdmodel # inference模型的模型结构文件 +``` + +FGD蒸馏模型转inference模型步骤如下: + +```bash +python3 tools/export_model.py \ + -c configs/picodet/legacy_model/application/publayernet_lcnet_x1_5/picodet_student.yml \ + --slim_config configs/picodet/legacy_model/application/publayernet_lcnet_x1_5/picodet_teacher.yml \ + -o weights=./output/picodet_lcnet_x2_5_layout/best_model \ + --output_dir=output_inference/ +``` + + + +### 6.2 模型推理 + +版面恢复任务进行推理,可以执行如下命令: + +```bash +python3 deploy/python/infer.py \ + --model_dir=output_inference/picodet_lcnet_x1_0_layout/ \ + --image_file=docs/images/layout.jpg \ + --device=CPU +``` + +- --device:指定GPU、CPU设备 + +模型推理完成,会看到以下log输出 + +``` +------------------------------------------ +----------- Model Configuration ----------- +Model Arch: PicoDet +Transform Order: +--transform op: Resize +--transform op: NormalizeImage +--transform op: Permute +--transform op: PadStride +-------------------------------------------- +class_id:0, confidence:0.9921, left_top:[20.18,35.66],right_bottom:[341.58,600.99] +class_id:0, confidence:0.9914, left_top:[19.77,611.42],right_bottom:[341.48,901.82] +class_id:0, confidence:0.9904, left_top:[369.36,375.10],right_bottom:[691.29,600.59] +class_id:0, confidence:0.9835, left_top:[369.60,608.60],right_bottom:[691.38,736.72] +class_id:0, confidence:0.9830, left_top:[369.58,805.38],right_bottom:[690.97,901.80] +class_id:0, confidence:0.9716, left_top:[383.68,271.44],right_bottom:[688.93,335.39] +class_id:0, confidence:0.9452, left_top:[370.82,34.48],right_bottom:[688.10,63.54] +class_id:1, confidence:0.8712, left_top:[370.84,771.03],right_bottom:[519.30,789.13] +class_id:3, confidence:0.9856, left_top:[371.28,67.85],right_bottom:[685.73,267.72] +save result to: output/layout.jpg +Test iter 0 +------------------ Inference Time Info ---------------------- +total_time(ms): 2196.0, img_num: 1 +average latency time(ms): 2196.00, QPS: 0.455373 +preprocess_time(ms): 2172.50, inference_time(ms): 11.90, postprocess_time(ms): 11.60 +``` + +- Model:模型结构 +- Transform Order:预处理操作 +- class_id、confidence、left_top、right_bottom:分别表示类别id、置信度、左上角坐标、右下角坐标 +- save result to:可视化版面分析结果保存路径,默认保存到`./output`文件夹 +- Inference Time Info:推理时间,其中preprocess_time表示预处理耗时,inference_time表示模型预测耗时,postprocess_time表示后处理耗时 + +可视化版面结果如下图所示 + +
+ +
+ + + +## Citations + +``` +@inproceedings{zhong2019publaynet, + title={PubLayNet: largest dataset ever for document layout analysis}, + author={Zhong, Xu and Tang, Jianbin and Yepes, Antonio Jimeno}, + booktitle={2019 International Conference on Document Analysis and Recognition (ICDAR)}, + year={2019}, + volume={}, + number={}, + pages={1015-1022}, + doi={10.1109/ICDAR.2019.00166}, + ISSN={1520-5363}, + month={Sep.}, + organization={IEEE} +} + +@inproceedings{yang2022focal, + title={Focal and global knowledge distillation for detectors}, + author={Yang, Zhendong and Li, Zhe and Jiang, Xiaohu and Gong, Yuan and Yuan, Zehuan and Zhao, Danpei and Yuan, Chun}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={4643--4652}, + year={2022} +} +``` -The above model is based on [PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection). If you want to train your own layout parser model,please refer to:[train_layoutparser_model](train_layoutparser_model.md) diff --git a/ppstructure/layout/README_ch.md b/ppstructure/layout/README_ch.md deleted file mode 100644 index 69419ad1eee3523d498b0d845a72133b619b3787..0000000000000000000000000000000000000000 --- a/ppstructure/layout/README_ch.md +++ /dev/null @@ -1,133 +0,0 @@ -[English](README.md) | 简体中文 - -# 版面分析使用说明 - -- [1. 安装whl包](#1) -- [2. 使用](#2) -- [3. 后处理](#3) -- [4. 指标](#4) -- [5. 训练版面分析模型](#5) - - - -## 1. 安装whl包 -```bash -pip install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl -``` - - -## 2. 使用 - -使用layoutparser识别给定文档的布局: - -```python -import cv2 -import layoutparser as lp -image = cv2.imread("ppstructure/docs/table/layout.jpg") -image = image[..., ::-1] - -# 加载模型 -model = lp.PaddleDetectionLayoutModel(config_path="lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config", - threshold=0.5, - label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}, - enforce_cpu=False, - enable_mkldnn=True) -# 检测 -layout = model.detect(image) - -# 显示结果 -show_img = lp.draw_box(image, layout, box_width=3, show_element_type=True) -show_img.show() -``` - -下图展示了结果,不同颜色的检测框表示不同的类别,并通过`show_element_type`在框的左上角显示具体类别: - -
- -
- -`PaddleDetectionLayoutModel`函数参数说明如下: - -| 参数 | 含义 | 默认值 | 备注 | -| :------------: | :-------------------------: | :---------: | :----------------------------------------------------------: | -| config_path | 模型配置路径 | None | 指定config_path会自动下载模型(仅第一次,之后模型存在,不会再下载) | -| model_path | 模型路径 | None | 本地模型路径,config_path和model_path必须设置一个,不能同时为None | -| threshold | 预测得分的阈值 | 0.5 | \ | -| input_shape | reshape之后图片尺寸 | [3,640,640] | \ | -| batch_size | 测试batch size | 1 | \ | -| label_map | 类别映射表 | None | 设置config_path时,可以为None,根据数据集名称自动获取label_map,设置model_path时需要手动指定 | -| enforce_cpu | 代码是否使用CPU运行 | False | 设置为False表示使用GPU,True表示强制使用CPU | -| enforce_mkldnn | CPU预测中是否开启MKLDNN加速 | True | \ | -| thread_num | 设置CPU线程数 | 10 | \ | - -目前支持以下几种模型配置和label map,您可以通过修改 `--config_path`和 `--label_map`使用这些模型,从而检测不同类型的内容: - -| dataset | config_path | label_map | -| ------------------------------------------------------------ | ------------------------------------------------------------ | --------------------------------------------------------- | -| [TableBank](https://doc-analysis.github.io/tablebank-page/index.html) word | lp://TableBank/ppyolov2_r50vd_dcn_365e_tableBank_word/config | {0:"Table"} | -| TableBank latex | lp://TableBank/ppyolov2_r50vd_dcn_365e_tableBank_latex/config | {0:"Table"} | -| [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet) | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config | {0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"} | - -* TableBank word和TableBank latex分别在word文档、latex文档数据集训练; -* 下载的TableBank数据集里同时包含word和latex。 - - -## 3. 后处理 - -版面分析检测包含多个类别,如果只想获取指定类别(如"Text"类别)的检测框、可以使用下述代码: - -```python -# 接上面代码 -# 首先过滤特定文本类型的区域 -text_blocks = lp.Layout([b for b in layout if b.type=='Text']) -figure_blocks = lp.Layout([b for b in layout if b.type=='Figure']) - -# 因为在图像区域内可能检测到文本区域,所以只需要删除它们 -text_blocks = lp.Layout([b for b in text_blocks \ - if not any(b.is_in(b_fig) for b_fig in figure_blocks)]) - -# 对文本区域排序并分配id -h, w = image.shape[:2] - -left_interval = lp.Interval(0, w/2*1.05, axis='x').put_on_canvas(image) - -left_blocks = text_blocks.filter_by(left_interval, center=True) -left_blocks.sort(key = lambda b:b.coordinates[1]) - -right_blocks = [b for b in text_blocks if b not in left_blocks] -right_blocks.sort(key = lambda b:b.coordinates[1]) - -# 最终合并两个列表,并按顺序添加索引 -text_blocks = lp.Layout([b.set(id = idx) for idx, b in enumerate(left_blocks + right_blocks)]) - -# 显示结果 -show_img = lp.draw_box(image, text_blocks, - box_width=3, - show_element_id=True) -show_img.show() -``` - -显示只有"Text"类别的结果: - -
- -
- - -## 4. 指标 - -| Dataset | mAP | CPU time cost | GPU time cost | -| --------- | ---- | ------------- | ------------- | -| PubLayNet | 93.6 | 1713.7ms | 66.6ms | -| TableBank | 96.2 | 1968.4ms | 65.1ms | - -**Envrionment:** - -​ **CPU:** Intel(R) Xeon(R) CPU E5-2650 v4 @ 2.20GHz,24core - -​ **GPU:** a single NVIDIA Tesla P40 - - -## 5. 训练版面分析模型 - -上述模型基于[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection) 训练,如果您想训练自己的版面分析模型,请参考:[train_layoutparser_model](train_layoutparser_model_ch.md) diff --git a/ppstructure/layout/train_layoutparser_model.md b/ppstructure/layout/train_layoutparser_model.md deleted file mode 100644 index e877c9c0c901e8be8299101daa5ce6248de0a1dc..0000000000000000000000000000000000000000 --- a/ppstructure/layout/train_layoutparser_model.md +++ /dev/null @@ -1,174 +0,0 @@ -English | [简体中文](train_layoutparser_model_ch.md) -- [Training layout-parse](#training-layout-parse) - - [1. Installation](#1--installation) - - [1.1 Requirements](#11-requirements) - - [1.2 Install PaddleDetection](#12-install-paddledetection) - - [2. Data preparation](#2-data-preparation) - - [3. Configuration](#3-configuration) - - [4. Training](#4-training) - - [5. Prediction](#5-prediction) - - [6. Deployment](#6-deployment) - - [6.1 Export model](#61-export-model) - - [6.2 Inference](#62-inference) - -# Training layout-parse - -## 1. Installation - -### 1.1 Requirements - -- PaddlePaddle 2.1 -- OS 64 bit -- Python 3(3.5.1+/3.6/3.7/3.8/3.9),64 bit -- pip/pip3(9.0.1+), 64 bit -- CUDA >= 10.1 -- cuDNN >= 7.6 - -### 1.2 Install PaddleDetection - -```bash -# Clone PaddleDetection repository -cd -git clone https://github.com/PaddlePaddle/PaddleDetection.git - -cd PaddleDetection -# Install other dependencies -pip install -r requirements.txt -``` - -For more installation tutorials, please refer to: [Install doc](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/INSTALL_cn.md) - -## 2. Data preparation - -Download the [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet) dataset - -```bash -cd PaddleDetection/dataset/ -mkdir publaynet -# execute the command,download PubLayNet -wget -O publaynet.tar.gz https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/publaynet.tar.gz?_ga=2.104193024.1076900768.1622560733-649911202.1622560733 -# unpack -tar -xvf publaynet.tar.gz -``` - -PubLayNet directory structure after decompressing : - -| File or Folder | Description | num | -| :------------- | :----------------------------------------------- | ------- | -| `train/` | Images in the training subset | 335,703 | -| `val/` | Images in the validation subset | 11,245 | -| `test/` | Images in the testing subset | 11,405 | -| `train.json` | Annotations for training images | 1 | -| `val.json` | Annotations for validation images | 1 | -| `LICENSE.txt` | Plaintext version of the CDLA-Permissive license | 1 | -| `README.txt` | Text file with the file names and description | 1 | - -For other datasets,please refer to [the PrepareDataSet]((https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/PrepareDataSet.md) ) - -## 3. Configuration - -We use the `configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml` configuration for training,the configuration file is as follows - -```bash -_BASE_: [ - '../datasets/coco_detection.yml', - '../runtime.yml', - './_base_/ppyolov2_r50vd_dcn.yml', - './_base_/optimizer_365e.yml', - './_base_/ppyolov2_reader.yml', -] - -snapshot_epoch: 8 -weights: output/ppyolov2_r50vd_dcn_365e_coco/model_final -``` -The `ppyolov2_r50vd_dcn_365e_coco.yml` configuration depends on other configuration files, in this case: - -- coco_detection.yml:mainly explains the path of training data and verification data - -- runtime.yml:mainly describes the common parameters, such as whether to use the GPU and how many epoch to save model etc. - -- optimizer_365e.yml:mainly explains the learning rate and optimizer configuration - -- ppyolov2_r50vd_dcn.yml:mainly describes the model and the network - -- ppyolov2_reader.yml:mainly describes the configuration of data readers, such as batch size and number of concurrent loading child processes, and also includes post preprocessing, such as resize and data augmention etc. - - -Modify the preceding files, such as the dataset path and batch size etc. - -## 4. Training - -PaddleDetection provides single-card/multi-card training mode to meet various training needs of users: - -* GPU single card training - -```bash -export CUDA_VISIBLE_DEVICES=0 #Don't need to run this command on Windows and Mac -python tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml -``` - -* GPU multi-card training - -```bash -export CUDA_VISIBLE_DEVICES=0,1,2,3 -python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --eval -``` - ---eval: training while verifying - -* Model recovery training - -During the daily training, if training is interrupted due to some reasons, you can use the -r command to resume the training: - -```bash -export CUDA_VISIBLE_DEVICES=0,1,2,3 -python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --eval -r output/ppyolov2_r50vd_dcn_365e_coco/10000 -``` - -Note: If you encounter "`Out of memory error`" , try reducing `batch_size` in the `ppyolov2_reader.yml` file - -## 5. Prediction - -Set parameters and use PaddleDetection to predict: - -```bash -export CUDA_VISIBLE_DEVICES=0 -python tools/infer.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --infer_img=images/paper-image.jpg --output_dir=infer_output/ --draw_threshold=0.5 -o weights=output/ppyolov2_r50vd_dcn_365e_coco/model_final --use_vdl=Ture -``` - -`--draw_threshold` is an optional parameter. According to the calculation of [NMS](https://ieeexplore.ieee.org/document/1699659), different threshold will produce different results, ` keep_top_k ` represent the maximum amount of output target, the default value is 10. You can set different value according to your own actual situation。 - -## 6. Deployment - -Use your trained model in Layout Parser - -### 6.1 Export model - -n the process of model training, the model file saved contains the process of forward prediction and back propagation. In the actual industrial deployment, there is no need for back propagation. Therefore, the model should be translated into the model format required by the deployment. The `tools/export_model.py` script is provided in PaddleDetection to export the model. - -The exported model name defaults to `model.*`, Layout Parser's code model is `inference.*`, So change [PaddleDetection/ppdet/engine/trainer. Py ](https://github.com/PaddlePaddle/PaddleDetection/blob/b87a1ea86fa18ce69e44a17ad1b49c1326f19ff9/ppdet/engine/trainer.py# L512) (click on the link to see the detailed line of code), change 'model' to 'inference'. - -Execute the script to export model: - -```bash -python tools/export_model.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --output_dir=./inference -o weights=output/ppyolov2_r50vd_dcn_365e_coco/model_final.pdparams -``` - -The prediction model is exported to `inference/ppyolov2_r50vd_dcn_365e_coco` ,including:`infer_cfg.yml`(prediction not required), `inference.pdiparams`, `inference.pdiparams.info`,`inference.pdmodel` - -More model export tutorials, please refer to:[EXPORT_MODEL](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/deploy/EXPORT_MODEL.md) - -### 6.2 Inference - -`model_path` represent the trained model path, and layoutparser is used to predict: - -```bash -import layoutparser as lp -model = lp.PaddleDetectionLayoutModel(model_path="inference/ppyolov2_r50vd_dcn_365e_coco", threshold=0.5,label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"},enforce_cpu=True,enable_mkldnn=True) -``` - -*** - -More PaddleDetection training tutorials,please reference:[PaddleDetection Training](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/GETTING_STARTED_cn.md) - -*** diff --git a/ppstructure/layout/train_layoutparser_model_ch.md b/ppstructure/layout/train_layoutparser_model_ch.md deleted file mode 100644 index a89b0f3819b52c79b86d2ada13bac23e3d1656ed..0000000000000000000000000000000000000000 --- a/ppstructure/layout/train_layoutparser_model_ch.md +++ /dev/null @@ -1,176 +0,0 @@ -[English](train_layoutparser_model.md) | 简体中文 -- [训练版面分析](#训练版面分析) - - [1. 安装](#1-安装) - - [1.1 环境要求](#11-环境要求) - - [1.2 安装PaddleDetection](#12-安装paddledetection) - - [2. 准备数据](#2-准备数据) - - [3. 配置文件改动和说明](#3-配置文件改动和说明) - - [4. PaddleDetection训练](#4-paddledetection训练) - - [5. PaddleDetection预测](#5-paddledetection预测) - - [6. 预测部署](#6-预测部署) - - [6.1 模型导出](#61-模型导出) - - [6.2 layout_parser预测](#62-layout_parser预测) - -# 训练版面分析 - -## 1. 安装 - -### 1.1 环境要求 - -- PaddlePaddle 2.1 -- OS 64 bit -- Python 3(3.5.1+/3.6/3.7/3.8/3.9),64 bit -- pip/pip3(9.0.1+), 64 bit -- CUDA >= 10.1 -- cuDNN >= 7.6 - -### 1.2 安装PaddleDetection - -```bash -# 克隆PaddleDetection仓库 -cd -git clone https://github.com/PaddlePaddle/PaddleDetection.git - -cd PaddleDetection -# 安装其他依赖 -pip install -r requirements.txt -``` - -更多安装教程,请参考: [Install doc](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/INSTALL_cn.md) - -## 2. 准备数据 - -下载 [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet) 数据集: - -```bash -cd PaddleDetection/dataset/ -mkdir publaynet -# 执行命令,下载 -wget -O publaynet.tar.gz https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/publaynet.tar.gz?_ga=2.104193024.1076900768.1622560733-649911202.1622560733 -# 解压 -tar -xvf publaynet.tar.gz -``` - -解压之后PubLayNet目录结构: - -| File or Folder | Description | num | -| :------------- | :----------------------------------------------- | ------- | -| `train/` | Images in the training subset | 335,703 | -| `val/` | Images in the validation subset | 11,245 | -| `test/` | Images in the testing subset | 11,405 | -| `train.json` | Annotations for training images | 1 | -| `val.json` | Annotations for validation images | 1 | -| `LICENSE.txt` | Plaintext version of the CDLA-Permissive license | 1 | -| `README.txt` | Text file with the file names and description | 1 | - -如果使用其它数据集,请参考[准备训练数据](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/PrepareDataSet.md) - -## 3. 配置文件改动和说明 - -我们使用 `configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml`配置进行训练,配置文件摘要如下: - -```bash -_BASE_: [ - '../datasets/coco_detection.yml', - '../runtime.yml', - './_base_/ppyolov2_r50vd_dcn.yml', - './_base_/optimizer_365e.yml', - './_base_/ppyolov2_reader.yml', -] - -snapshot_epoch: 8 -weights: output/ppyolov2_r50vd_dcn_365e_coco/model_final -``` -从中可以看到 `ppyolov2_r50vd_dcn_365e_coco.yml` 配置需要依赖其他的配置文件,在该例子中需要依赖: - -- coco_detection.yml:主要说明了训练数据和验证数据的路径 - -- runtime.yml:主要说明了公共的运行参数,比如是否使用GPU、每多少个epoch存储checkpoint等 - -- optimizer_365e.yml:主要说明了学习率和优化器的配置 - -- ppyolov2_r50vd_dcn.yml:主要说明模型和主干网络的情况 - -- ppyolov2_reader.yml:主要说明数据读取器配置,如batch size,并发加载子进程数等,同时包含读取后预处理操作,如resize、数据增强等等 - - -根据实际情况,修改上述文件,比如数据集路径、batch size等。 - -## 4. PaddleDetection训练 - -PaddleDetection提供了单卡/多卡训练模式,满足用户多种训练需求 - -* GPU 单卡训练 - -```bash -export CUDA_VISIBLE_DEVICES=0 #windows和Mac下不需要执行该命令 -python tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml -``` - -* GPU多卡训练 - -```bash -export CUDA_VISIBLE_DEVICES=0,1,2,3 -python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --eval -``` - ---eval:表示边训练边验证 - -* 模型恢复训练 - -在日常训练过程中,有的用户由于一些原因导致训练中断,用户可以使用-r的命令恢复训练: - -```bash -export CUDA_VISIBLE_DEVICES=0,1,2,3 -python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --eval -r output/ppyolov2_r50vd_dcn_365e_coco/10000 -``` - -注意:如果遇到 "`Out of memory error`" 问题, 尝试在 `ppyolov2_reader.yml` 文件中调小`batch_size` - -## 5. PaddleDetection预测 - -设置参数,使用PaddleDetection预测: - -```bash -export CUDA_VISIBLE_DEVICES=0 -python tools/infer.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --infer_img=images/paper-image.jpg --output_dir=infer_output/ --draw_threshold=0.5 -o weights=output/ppyolov2_r50vd_dcn_365e_coco/model_final --use_vdl=Ture -``` - -`--draw_threshold` 是个可选参数. 根据 [NMS](https://ieeexplore.ieee.org/document/1699659) 的计算,不同阈值会产生不同的结果 `keep_top_k`表示设置输出目标的最大数量,默认值为100,用户可以根据自己的实际情况进行设定。 - -## 6. 预测部署 - -在layout parser中使用自己训练好的模型。 - -### 6.1 模型导出 - -在模型训练过程中保存的模型文件是包含前向预测和反向传播的过程,在实际的工业部署则不需要反向传播,因此需要将模型进行导成部署需要的模型格式。 在PaddleDetection中提供了 `tools/export_model.py`脚本来导出模型。 - -导出模型名称默认是`model.*`,layout parser代码模型名称是`inference.*`, 所以修改[PaddleDetection/ppdet/engine/trainer.py ](https://github.com/PaddlePaddle/PaddleDetection/blob/b87a1ea86fa18ce69e44a17ad1b49c1326f19ff9/ppdet/engine/trainer.py#L512) (点开链接查看详细代码行),将`model`改为`inference`即可。 - -执行导出模型脚本: - -```bash -python tools/export_model.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --output_dir=./inference -o weights=output/ppyolov2_r50vd_dcn_365e_coco/model_final.pdparams -``` - -预测模型会导出到`inference/ppyolov2_r50vd_dcn_365e_coco`目录下,分别为`infer_cfg.yml`(预测不需要), `inference.pdiparams`, `inference.pdiparams.info`,`inference.pdmodel` 。 - -更多模型导出教程,请参考:[EXPORT_MODEL](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/deploy/EXPORT_MODEL.md) - -### 6.2 layout_parser预测 - -`model_path`指定训练好的模型路径,使用layout parser进行预测: - -```bash -import layoutparser as lp -model = lp.PaddleDetectionLayoutModel(model_path="inference/ppyolov2_r50vd_dcn_365e_coco", threshold=0.5,label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"},enforce_cpu=True,enable_mkldnn=True) -``` - - - -*** - -更多PaddleDetection训练教程,请参考:[PaddleDetection训练](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/GETTING_STARTED_cn.md) - -*** diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py index 053a8aac00ffe762dd05d7f8030db9aaa32c0f8a..68a84a53e6572d393e98260e1f180fe39645ad2c 100644 --- a/ppstructure/predict_system.py +++ b/ppstructure/predict_system.py @@ -28,13 +28,12 @@ import time import logging from copy import deepcopy -from ppocr.utils.utility import get_image_file_list, check_and_read_gif +from ppocr.utils.utility import get_image_file_list, check_and_read from ppocr.utils.logging import get_logger from tools.infer.predict_system import TextSystem from ppstructure.layout.predict_layout import LayoutPredictor from ppstructure.table.predict_table import TableSystem, to_excel from ppstructure.utility import parse_args, draw_structure_result -from ppstructure.recovery.recovery_to_doc import convert_info_docx logger = get_logger() @@ -78,7 +77,7 @@ class StructureSystem(object): elif self.mode == 'vqa': raise NotImplementedError - def __call__(self, img, return_ocr_result_in_table=False): + def __call__(self, img, img_idx=0, return_ocr_result_in_table=False): time_dict = { 'image_orientation': 0, 'layout': 0, @@ -143,8 +142,8 @@ class StructureSystem(object): time_dict['det'] += ocr_time_dict['det'] time_dict['rec'] += ocr_time_dict['rec'] - # remove style char, - # when using the recognition model trained on the PubtabNet dataset, + # remove style char, + # when using the recognition model trained on the PubtabNet dataset, # it will recognize the text format in the table, such as style_token = [ '', '', '', '', '', @@ -169,7 +168,8 @@ class StructureSystem(object): 'type': region['label'].lower(), 'bbox': [x1, y1, x2, y2], 'img': roi_img, - 'res': res + 'res': res, + 'img_idx': img_idx }) end = time.time() time_dict['all'] = end - start @@ -179,26 +179,29 @@ class StructureSystem(object): return None, None -def save_structure_res(res, save_folder, img_name): +def save_structure_res(res, save_folder, img_name, img_idx=0): excel_save_folder = os.path.join(save_folder, img_name) os.makedirs(excel_save_folder, exist_ok=True) res_cp = deepcopy(res) # save res with open( - os.path.join(excel_save_folder, 'res.txt'), 'w', + os.path.join(excel_save_folder, 'res_{}.txt'.format(img_idx)), + 'w', encoding='utf8') as f: for region in res_cp: roi_img = region.pop('img') f.write('{}\n'.format(json.dumps(region))) - if region['type'] == 'table' and len(region[ + if region['type'].lower() == 'table' and len(region[ 'res']) > 0 and 'html' in region['res']: - excel_path = os.path.join(excel_save_folder, - '{}.xlsx'.format(region['bbox'])) + excel_path = os.path.join( + excel_save_folder, + '{}_{}.xlsx'.format(region['bbox'], img_idx)) to_excel(region['res']['html'], excel_path) - elif region['type'] == 'figure': - img_path = os.path.join(excel_save_folder, - '{}.jpg'.format(region['bbox'])) + elif region['type'].lower() == 'figure': + img_path = os.path.join( + excel_save_folder, + '{}_{}.jpg'.format(region['bbox'], img_idx)) cv2.imwrite(img_path, roi_img) @@ -214,28 +217,75 @@ def main(args): for i, image_file in enumerate(image_file_list): logger.info("[{}/{}] {}".format(i, img_num, image_file)) - img, flag = check_and_read_gif(image_file) + img, flag_gif, flag_pdf = check_and_read(image_file) img_name = os.path.basename(image_file).split('.')[0] - if not flag: + if not flag_gif and not flag_pdf: img = cv2.imread(image_file) - if img is None: - logger.error("error in loading image:{}".format(image_file)) - continue - res, time_dict = structure_sys(img) - if structure_sys.mode == 'structure': - save_structure_res(res, save_folder, img_name) - draw_img = draw_structure_result(img, res, args.vis_font_path) - img_save_path = os.path.join(save_folder, img_name, 'show.jpg') - elif structure_sys.mode == 'vqa': - raise NotImplementedError - # draw_img = draw_ser_results(img, res, args.vis_font_path) - # img_save_path = os.path.join(save_folder, img_name + '.jpg') - cv2.imwrite(img_save_path, draw_img) - logger.info('result save to {}'.format(img_save_path)) - if args.recovery: - convert_info_docx(img, res, save_folder, img_name) + if not flag_pdf: + if img is None: + logger.error("error in loading image:{}".format(image_file)) + continue + res, time_dict = structure_sys(img) + + if structure_sys.mode == 'structure': + save_structure_res(res, save_folder, img_name) + draw_img = draw_structure_result(img, res, args.vis_font_path) + img_save_path = os.path.join(save_folder, img_name, 'show.jpg') + elif structure_sys.mode == 'vqa': + raise NotImplementedError + # draw_img = draw_ser_results(img, res, args.vis_font_path) + # img_save_path = os.path.join(save_folder, img_name + '.jpg') + cv2.imwrite(img_save_path, draw_img) + logger.info('result save to {}'.format(img_save_path)) + if args.recovery: + try: + from ppstructure.recovery.recovery_to_doc import sorted_layout_boxes, convert_info_docx + h, w, _ = img.shape + res = sorted_layout_boxes(res, w) + convert_info_docx(img, res, save_folder, img_name, + args.save_pdf) + except Exception as ex: + logger.error( + "error in layout recovery image:{}, err msg: {}".format( + image_file, ex)) + continue + else: + pdf_imgs = img + all_res = [] + for index, img in enumerate(pdf_imgs): + + res, time_dict = structure_sys(img, index) + if structure_sys.mode == 'structure' and res != []: + save_structure_res(res, save_folder, img_name, index) + draw_img = draw_structure_result(img, res, + args.vis_font_path) + img_save_path = os.path.join(save_folder, img_name, + 'show_{}.jpg'.format(index)) + elif structure_sys.mode == 'vqa': + raise NotImplementedError + # draw_img = draw_ser_results(img, res, args.vis_font_path) + # img_save_path = os.path.join(save_folder, img_name + '.jpg') + if res != []: + cv2.imwrite(img_save_path, draw_img) + logger.info('result save to {}'.format(img_save_path)) + if args.recovery and res != []: + from ppstructure.recovery.recovery_to_doc import sorted_layout_boxes, convert_info_docx + h, w, _ = img.shape + res = sorted_layout_boxes(res, w) + all_res += res + + if args.recovery and all_res != []: + try: + convert_info_docx(img, all_res, save_folder, img_name, + args.save_pdf) + except Exception as ex: + logger.error( + "error in layout recovery image:{}, err msg: {}".format( + image_file, ex)) + continue + logger.info("Predict time : {:.3f}s".format(time_dict['all'])) diff --git a/ppstructure/recovery/README.md b/ppstructure/recovery/README.md index 883dbef3e829dfa213644b610af1ca279dac8641..713d0307dbbd66664db15d19df484af76efea75a 100644 --- a/ppstructure/recovery/README.md +++ b/ppstructure/recovery/README.md @@ -78,9 +78,27 @@ wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar # Download the ultra-lightweight English table inch model and unzip it wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar +# Download the layout model of publaynet dataset and unzip it +wget +https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_layout_infer.tar && tar picodet_lcnet_x1_0_layout_infer.tar cd .. # run -python3 predict_system.py --det_model_dir=inference/en_PP-OCRv3_det_infer --rec_model_dir=inference/en_PP-OCRv3_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --rec_char_dict_path=../ppocr/utils/en_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --output ./output/table --rec_image_shape=3,48,320 --vis_font_path=../doc/fonts/simfang.ttf --recovery=True --image_dir=./docs/table/1.png +python3 predict_system.py \ + --image_dir=./docs/table/1.png \ + --det_model_dir=inference/en_PP-OCRv3_det_infer \ + --rec_model_dir=inference/en_PP-OCRv3_rec_infe \ + --rec_char_dict_path=../ppocr/utils/en_dict.txt \ + --output=../output/ \ + --table_model_dir=inference/ch_ppstructure_mobile_v2.0_SLANet_infer \ + --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ + --table_max_len=488 \ + --layout_model_dir=inference/picodet_lcnet_x1_0_layout_infer \ + --layout_dict_path=../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt \ + --vis_font_path=../doc/fonts/simfang.ttf \ + --recovery=True \ + --save_pdf=False ``` -After running, the docx of each picture will be saved in the directory specified by the output field \ No newline at end of file +After running, the docx of each picture will be saved in the directory specified by the output field + +Recovery table to Word code[table_process.py] reference:https://github.com/pqzx/html2docx.git \ No newline at end of file diff --git a/ppstructure/recovery/README_ch.md b/ppstructure/recovery/README_ch.md index 5a05abffd0399387bc0d22d878e64d03d8894a79..14ca8836a0332a5b0e119be4bf6bcb36fb011d1e 100644 --- a/ppstructure/recovery/README_ch.md +++ b/ppstructure/recovery/README_ch.md @@ -35,21 +35,15 @@ python3 -m pip install --upgrade pip # GPU安装 -python3 -m pip install "paddlepaddle-gpu>=2.2" -i https://mirror.baidu.com/pypi/simple +python3 -m pip install "paddlepaddle-gpu>=2.3" -i https://mirror.baidu.com/pypi/simple # CPU安装 -python3 -m pip install "paddlepaddle>=2.2" -i https://mirror.baidu.com/pypi/simple +python3 -m pip install "paddlepaddle>=2.3" -i https://mirror.baidu.com/pypi/simple ``` 更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。 -* **(2)安装依赖** - -```bash -python3 -m pip install -r ppstructure/recovery/requirements.txt -``` - ### 2.2 安装PaddleOCR @@ -87,11 +81,28 @@ wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar # 下载英文轻量级PP-OCRv3模型的识别模型并解压 wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar # 下载超轻量级英文表格英寸模型并解压 -wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar +wget https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar && tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar +# 下载英文版面分析模型 +wget https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_layout_infer.tar && tar picodet_lcnet_x1_0_layout_infer.tar cd .. + # 执行预测 -python3 predict_system.py --det_model_dir=inference/en_PP-OCRv3_det_infer --rec_model_dir=inference/en_PP-OCRv3_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --rec_char_dict_path=../ppocr/utils/en_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --output ./output/table --rec_image_shape=3,48,320 --vis_font_path=../doc/fonts/simfang.ttf --recovery=True --image_dir=./docs/table/1.png +python3 predict_system.py \ + --image_dir=./docs/table/1.png \ + --det_model_dir=inference/en_PP-OCRv3_det_infer \ + --rec_model_dir=inference/en_PP-OCRv3_rec_infe \ + --rec_char_dict_path=../ppocr/utils/en_dict.txt \ + --output=../output/ \ + --table_model_dir=inference/ch_ppstructure_mobile_v2.0_SLANet_infer \ + --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ + --table_max_len=488 \ + --layout_model_dir=inference/picodet_lcnet_x1_0_layout_infer \ + --layout_dict_path=../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt \ + --vis_font_path=../doc/fonts/simfang.ttf \ + --recovery=True \ + --save_pdf=False ``` -运行完成后,每张图片的docx文档会保存到output字段指定的目录下 +运行完成后,每张图片的docx文档会保存到`output`字段指定的目录下 +表格恢复到Word代码[table_process.py]来自:https://github.com/pqzx/html2docx.git diff --git a/ppstructure/recovery/recovery_to_doc.py b/ppstructure/recovery/recovery_to_doc.py index 5278217d5b983008d357b6b1be3ab1b883a4939d..4401b1f27cf10f8483ee9b2b4a61315ad6aad264 100644 --- a/ppstructure/recovery/recovery_to_doc.py +++ b/ppstructure/recovery/recovery_to_doc.py @@ -22,21 +22,23 @@ from docx import shared from docx.enum.text import WD_ALIGN_PARAGRAPH from docx.enum.section import WD_SECTION from docx.oxml.ns import qn +from docx.enum.table import WD_TABLE_ALIGNMENT + +from table_process import HtmlToDocx from ppocr.utils.logging import get_logger logger = get_logger() -def convert_info_docx(img, res, save_folder, img_name): +def convert_info_docx(img, res, save_folder, img_name, save_pdf): doc = Document() doc.styles['Normal'].font.name = 'Times New Roman' doc.styles['Normal']._element.rPr.rFonts.set(qn('w:eastAsia'), u'宋体') doc.styles['Normal'].font.size = shared.Pt(6.5) - h, w, _ = img.shape - res = sorted_layout_boxes(res, w) flag = 1 for i, region in enumerate(res): + img_idx = region['img_idx'] if flag == 2 and region['layout'] == 'single': section = doc.add_section(WD_SECTION.CONTINUOUS) section._sectPr.xpath('./w:cols')[0].set(qn('w:num'), '1') @@ -46,10 +48,10 @@ def convert_info_docx(img, res, save_folder, img_name): section._sectPr.xpath('./w:cols')[0].set(qn('w:num'), '2') flag = 2 - if region['type'] == 'Figure': + if region['type'].lower() == 'figure': excel_save_folder = os.path.join(save_folder, img_name) img_path = os.path.join(excel_save_folder, - '{}.jpg'.format(region['bbox'])) + '{}_{}.jpg'.format(region['bbox'], img_idx)) paragraph_pic = doc.add_paragraph() paragraph_pic.alignment = WD_ALIGN_PARAGRAPH.CENTER run = paragraph_pic.add_run("") @@ -57,40 +59,38 @@ def convert_info_docx(img, res, save_folder, img_name): run.add_picture(img_path, width=shared.Inches(5)) elif flag == 2: run.add_picture(img_path, width=shared.Inches(2)) - elif region['type'] == 'Title': + elif region['type'].lower() == 'title': doc.add_heading(region['res'][0]['text']) - elif region['type'] == 'Text': + elif region['type'].lower() == 'table': + paragraph = doc.add_paragraph() + new_parser = HtmlToDocx() + new_parser.table_style = 'TableGrid' + table = new_parser.handle_table(html=region['res']['html']) + new_table = deepcopy(table) + new_table.alignment = WD_TABLE_ALIGNMENT.CENTER + paragraph.add_run().element.addnext(new_table._tbl) + + else: paragraph = doc.add_paragraph() paragraph_format = paragraph.paragraph_format for i, line in enumerate(region['res']): if i == 0: paragraph_format.first_line_indent = shared.Inches(0.25) text_run = paragraph.add_run(line['text'] + ' ') - text_run.font.size = shared.Pt(9) - elif region['type'] == 'Table': - pypandoc.convert( - source=region['res']['html'], - format='html', - to='docx', - outputfile='tmp.docx') - tmp_doc = Document('tmp.docx') - paragraph = doc.add_paragraph() - - table = tmp_doc.tables[0] - new_table = deepcopy(table) - new_table.style = doc.styles['Table Grid'] - from docx.enum.table import WD_TABLE_ALIGNMENT - new_table.alignment = WD_TABLE_ALIGNMENT.CENTER - paragraph.add_run().element.addnext(new_table._tbl) - os.remove('tmp.docx') - else: - continue + text_run.font.size = shared.Pt(10) # save to docx docx_path = os.path.join(save_folder, '{}.docx'.format(img_name)) doc.save(docx_path) logger.info('docx save to {}'.format(docx_path)) + # save to pdf + if save_pdf: + pdf = os.path.join(save_folder, '{}.pdf'.format(img_name)) + from docx2pdf import convert + convert(docx_path, pdf_path) + logger.info('pdf save to {}'.format(pdf)) + def sorted_layout_boxes(res, w): """ diff --git a/ppstructure/recovery/requirements.txt b/ppstructure/recovery/requirements.txt index 04187baa2a72d2ac60f0a4e5ce643f882b7255fb..5ba3099d64574954c65ac8169798759dd7c053ac 100644 --- a/ppstructure/recovery/requirements.txt +++ b/ppstructure/recovery/requirements.txt @@ -1,3 +1,5 @@ -opencv-contrib-python==4.4.0.46 pypandoc -python-docx \ No newline at end of file +python-docx +docx2pdf +fitz +PyMuPDF \ No newline at end of file diff --git a/ppstructure/recovery/table_process.py b/ppstructure/recovery/table_process.py new file mode 100644 index 0000000000000000000000000000000000000000..243aaf8933791bf4704964d9665173fe70982f95 --- /dev/null +++ b/ppstructure/recovery/table_process.py @@ -0,0 +1,632 @@ + +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from:https://github.com/pqzx/html2docx/blob/8f6695a778c68befb302e48ac0ed5201ddbd4524/htmldocx/h2d.py + +""" +import re, argparse +import io, os +import urllib.request +from urllib.parse import urlparse +from html.parser import HTMLParser + +import docx, docx.table +from docx import Document +from docx.shared import RGBColor, Pt, Inches +from docx.enum.text import WD_COLOR, WD_ALIGN_PARAGRAPH +from docx.oxml import OxmlElement +from docx.oxml.ns import qn + +from bs4 import BeautifulSoup + +# values in inches +INDENT = 0.25 +LIST_INDENT = 0.5 +MAX_INDENT = 5.5 # To stop indents going off the page + +# Style to use with tables. By default no style is used. +DEFAULT_TABLE_STYLE = None + +# Style to use with paragraphs. By default no style is used. +DEFAULT_PARAGRAPH_STYLE = None + + +def get_filename_from_url(url): + return os.path.basename(urlparse(url).path) + +def is_url(url): + """ + Not to be used for actually validating a url, but in our use case we only + care if it's a url or a file path, and they're pretty distinguishable + """ + parts = urlparse(url) + return all([parts.scheme, parts.netloc, parts.path]) + +def fetch_image(url): + """ + Attempts to fetch an image from a url. + If successful returns a bytes object, else returns None + :return: + """ + try: + with urllib.request.urlopen(url) as response: + # security flaw? + return io.BytesIO(response.read()) + except urllib.error.URLError: + return None + +def remove_last_occurence(ls, x): + ls.pop(len(ls) - ls[::-1].index(x) - 1) + +def remove_whitespace(string, leading=False, trailing=False): + """Remove white space from a string. + Args: + string(str): The string to remove white space from. + leading(bool, optional): Remove leading new lines when True. + trailing(bool, optional): Remove trailing new lines when False. + Returns: + str: The input string with new line characters removed and white space squashed. + Examples: + Single or multiple new line characters are replaced with space. + >>> remove_whitespace("abc\\ndef") + 'abc def' + >>> remove_whitespace("abc\\n\\n\\ndef") + 'abc def' + New line characters surrounded by white space are replaced with a single space. + >>> remove_whitespace("abc \\n \\n \\n def") + 'abc def' + >>> remove_whitespace("abc \\n \\n \\n def") + 'abc def' + Leading and trailing new lines are replaced with a single space. + >>> remove_whitespace("\\nabc") + ' abc' + >>> remove_whitespace(" \\n abc") + ' abc' + >>> remove_whitespace("abc\\n") + 'abc ' + >>> remove_whitespace("abc \\n ") + 'abc ' + Use ``leading=True`` to remove leading new line characters, including any surrounding + white space: + >>> remove_whitespace("\\nabc", leading=True) + 'abc' + >>> remove_whitespace(" \\n abc", leading=True) + 'abc' + Use ``trailing=True`` to remove trailing new line characters, including any surrounding + white space: + >>> remove_whitespace("abc \\n ", trailing=True) + 'abc' + """ + # Remove any leading new line characters along with any surrounding white space + if leading: + string = re.sub(r'^\s*\n+\s*', '', string) + + # Remove any trailing new line characters along with any surrounding white space + if trailing: + string = re.sub(r'\s*\n+\s*$', '', string) + + # Replace new line characters and absorb any surrounding space. + string = re.sub(r'\s*\n\s*', ' ', string) + # TODO need some way to get rid of extra spaces in e.g. text text + return re.sub(r'\s+', ' ', string) + +def delete_paragraph(paragraph): + # https://github.com/python-openxml/python-docx/issues/33#issuecomment-77661907 + p = paragraph._element + p.getparent().remove(p) + p._p = p._element = None + +font_styles = { + 'b': 'bold', + 'strong': 'bold', + 'em': 'italic', + 'i': 'italic', + 'u': 'underline', + 's': 'strike', + 'sup': 'superscript', + 'sub': 'subscript', + 'th': 'bold', +} + +font_names = { + 'code': 'Courier', + 'pre': 'Courier', +} + +styles = { + 'LIST_BULLET': 'List Bullet', + 'LIST_NUMBER': 'List Number', +} + +class HtmlToDocx(HTMLParser): + + def __init__(self): + super().__init__() + self.options = { + 'fix-html': True, + 'images': True, + 'tables': True, + 'styles': True, + } + self.table_row_selectors = [ + 'table > tr', + 'table > thead > tr', + 'table > tbody > tr', + 'table > tfoot > tr' + ] + self.table_style = DEFAULT_TABLE_STYLE + self.paragraph_style = DEFAULT_PARAGRAPH_STYLE + + def set_initial_attrs(self, document=None): + self.tags = { + 'span': [], + 'list': [], + } + if document: + self.doc = document + else: + self.doc = Document() + self.bs = self.options['fix-html'] # whether or not to clean with BeautifulSoup + self.document = self.doc + self.include_tables = True #TODO add this option back in? + self.include_images = self.options['images'] + self.include_styles = self.options['styles'] + self.paragraph = None + self.skip = False + self.skip_tag = None + self.instances_to_skip = 0 + + def copy_settings_from(self, other): + """Copy settings from another instance of HtmlToDocx""" + self.table_style = other.table_style + self.paragraph_style = other.paragraph_style + + def get_cell_html(self, soup): + # Returns string of td element with opening and closing tags removed + # Cannot use find_all as it only finds element tags and does not find text which + # is not inside an element + return ' '.join([str(i) for i in soup.contents]) + + def add_styles_to_paragraph(self, style): + if 'text-align' in style: + align = style['text-align'] + if align == 'center': + self.paragraph.paragraph_format.alignment = WD_ALIGN_PARAGRAPH.CENTER + elif align == 'right': + self.paragraph.paragraph_format.alignment = WD_ALIGN_PARAGRAPH.RIGHT + elif align == 'justify': + self.paragraph.paragraph_format.alignment = WD_ALIGN_PARAGRAPH.JUSTIFY + if 'margin-left' in style: + margin = style['margin-left'] + units = re.sub(r'[0-9]+', '', margin) + margin = int(float(re.sub(r'[a-z]+', '', margin))) + if units == 'px': + self.paragraph.paragraph_format.left_indent = Inches(min(margin // 10 * INDENT, MAX_INDENT)) + # TODO handle non px units + + def add_styles_to_run(self, style): + if 'color' in style: + if 'rgb' in style['color']: + color = re.sub(r'[a-z()]+', '', style['color']) + colors = [int(x) for x in color.split(',')] + elif '#' in style['color']: + color = style['color'].lstrip('#') + colors = tuple(int(color[i:i+2], 16) for i in (0, 2, 4)) + else: + colors = [0, 0, 0] + # TODO map colors to named colors (and extended colors...) + # For now set color to black to prevent crashing + self.run.font.color.rgb = RGBColor(*colors) + + if 'background-color' in style: + if 'rgb' in style['background-color']: + color = color = re.sub(r'[a-z()]+', '', style['background-color']) + colors = [int(x) for x in color.split(',')] + elif '#' in style['background-color']: + color = style['background-color'].lstrip('#') + colors = tuple(int(color[i:i+2], 16) for i in (0, 2, 4)) + else: + colors = [0, 0, 0] + # TODO map colors to named colors (and extended colors...) + # For now set color to black to prevent crashing + self.run.font.highlight_color = WD_COLOR.GRAY_25 #TODO: map colors + + def apply_paragraph_style(self, style=None): + try: + if style: + self.paragraph.style = style + elif self.paragraph_style: + self.paragraph.style = self.paragraph_style + except KeyError as e: + raise ValueError(f"Unable to apply style {self.paragraph_style}.") from e + + def parse_dict_string(self, string, separator=';'): + new_string = string.replace(" ", '').split(separator) + string_dict = dict([x.split(':') for x in new_string if ':' in x]) + return string_dict + + def handle_li(self): + # check list stack to determine style and depth + list_depth = len(self.tags['list']) + if list_depth: + list_type = self.tags['list'][-1] + else: + list_type = 'ul' # assign unordered if no tag + + if list_type == 'ol': + list_style = styles['LIST_NUMBER'] + else: + list_style = styles['LIST_BULLET'] + + self.paragraph = self.doc.add_paragraph(style=list_style) + self.paragraph.paragraph_format.left_indent = Inches(min(list_depth * LIST_INDENT, MAX_INDENT)) + self.paragraph.paragraph_format.line_spacing = 1 + + def add_image_to_cell(self, cell, image): + # python-docx doesn't have method yet for adding images to table cells. For now we use this + paragraph = cell.add_paragraph() + run = paragraph.add_run() + run.add_picture(image) + + def handle_img(self, current_attrs): + if not self.include_images: + self.skip = True + self.skip_tag = 'img' + return + src = current_attrs['src'] + # fetch image + src_is_url = is_url(src) + if src_is_url: + try: + image = fetch_image(src) + except urllib.error.URLError: + image = None + else: + image = src + # add image to doc + if image: + try: + if isinstance(self.doc, docx.document.Document): + self.doc.add_picture(image) + else: + self.add_image_to_cell(self.doc, image) + except FileNotFoundError: + image = None + if not image: + if src_is_url: + self.doc.add_paragraph("" % src) + else: + # avoid exposing filepaths in document + self.doc.add_paragraph("" % get_filename_from_url(src)) + + + def handle_table(self, html): + """ + To handle nested tables, we will parse tables manually as follows: + Get table soup + Create docx table + Iterate over soup and fill docx table with new instances of this parser + Tell HTMLParser to ignore any tags until the corresponding closing table tag + """ + doc = Document() + table_soup = BeautifulSoup(html, 'html.parser') + rows, cols_len = self.get_table_dimensions(table_soup) + table = doc.add_table(len(rows), cols_len) + table.style = doc.styles['Table Grid'] + cell_row = 0 + for index, row in enumerate(rows): + cols = self.get_table_columns(row) + cell_col = 0 + for col in cols: + colspan = int(col.attrs.get('colspan', 1)) + rowspan = int(col.attrs.get('rowspan', 1)) + + cell_html = self.get_cell_html(col) + + if col.name == 'th': + cell_html = "%s" % cell_html + docx_cell = table.cell(cell_row, cell_col) + while docx_cell.text != '': # Skip the merged cell + cell_col += 1 + docx_cell = table.cell(cell_row, cell_col) + + cell_to_merge = table.cell(cell_row + rowspan - 1, cell_col + colspan - 1) + if docx_cell != cell_to_merge: + docx_cell.merge(cell_to_merge) + + child_parser = HtmlToDocx() + child_parser.copy_settings_from(self) + + child_parser.add_html_to_cell(cell_html or ' ', docx_cell) # occupy the position + + cell_col += colspan + cell_row += 1 + + # skip all tags until corresponding closing tag + self.instances_to_skip = len(table_soup.find_all('table')) + self.skip_tag = 'table' + self.skip = True + self.table = None + return table + + def handle_link(self, href, text): + # Link requires a relationship + is_external = href.startswith('http') + rel_id = self.paragraph.part.relate_to( + href, + docx.opc.constants.RELATIONSHIP_TYPE.HYPERLINK, + is_external=True # don't support anchor links for this library yet + ) + + # Create the w:hyperlink tag and add needed values + hyperlink = docx.oxml.shared.OxmlElement('w:hyperlink') + hyperlink.set(docx.oxml.shared.qn('r:id'), rel_id) + + + # Create sub-run + subrun = self.paragraph.add_run() + rPr = docx.oxml.shared.OxmlElement('w:rPr') + + # add default color + c = docx.oxml.shared.OxmlElement('w:color') + c.set(docx.oxml.shared.qn('w:val'), "0000EE") + rPr.append(c) + + # add underline + u = docx.oxml.shared.OxmlElement('w:u') + u.set(docx.oxml.shared.qn('w:val'), 'single') + rPr.append(u) + + subrun._r.append(rPr) + subrun._r.text = text + + # Add subrun to hyperlink + hyperlink.append(subrun._r) + + # Add hyperlink to run + self.paragraph._p.append(hyperlink) + + def handle_starttag(self, tag, attrs): + if self.skip: + return + if tag == 'head': + self.skip = True + self.skip_tag = tag + self.instances_to_skip = 0 + return + elif tag == 'body': + return + + current_attrs = dict(attrs) + + if tag == 'span': + self.tags['span'].append(current_attrs) + return + elif tag == 'ol' or tag == 'ul': + self.tags['list'].append(tag) + return # don't apply styles for now + elif tag == 'br': + self.run.add_break() + return + + self.tags[tag] = current_attrs + if tag in ['p', 'pre']: + self.paragraph = self.doc.add_paragraph() + self.apply_paragraph_style() + + elif tag == 'li': + self.handle_li() + + elif tag == "hr": + + # This implementation was taken from: + # https://github.com/python-openxml/python-docx/issues/105#issuecomment-62806373 + + self.paragraph = self.doc.add_paragraph() + pPr = self.paragraph._p.get_or_add_pPr() + pBdr = OxmlElement('w:pBdr') + pPr.insert_element_before(pBdr, + 'w:shd', 'w:tabs', 'w:suppressAutoHyphens', 'w:kinsoku', 'w:wordWrap', + 'w:overflowPunct', 'w:topLinePunct', 'w:autoSpaceDE', 'w:autoSpaceDN', + 'w:bidi', 'w:adjustRightInd', 'w:snapToGrid', 'w:spacing', 'w:ind', + 'w:contextualSpacing', 'w:mirrorIndents', 'w:suppressOverlap', 'w:jc', + 'w:textDirection', 'w:textAlignment', 'w:textboxTightWrap', + 'w:outlineLvl', 'w:divId', 'w:cnfStyle', 'w:rPr', 'w:sectPr', + 'w:pPrChange' + ) + bottom = OxmlElement('w:bottom') + bottom.set(qn('w:val'), 'single') + bottom.set(qn('w:sz'), '6') + bottom.set(qn('w:space'), '1') + bottom.set(qn('w:color'), 'auto') + pBdr.append(bottom) + + elif re.match('h[1-9]', tag): + if isinstance(self.doc, docx.document.Document): + h_size = int(tag[1]) + self.paragraph = self.doc.add_heading(level=min(h_size, 9)) + else: + self.paragraph = self.doc.add_paragraph() + + elif tag == 'img': + self.handle_img(current_attrs) + return + + elif tag == 'table': + self.handle_table() + return + + # set new run reference point in case of leading line breaks + if tag in ['p', 'li', 'pre']: + self.run = self.paragraph.add_run() + + # add style + if not self.include_styles: + return + if 'style' in current_attrs and self.paragraph: + style = self.parse_dict_string(current_attrs['style']) + self.add_styles_to_paragraph(style) + + def handle_endtag(self, tag): + if self.skip: + if not tag == self.skip_tag: + return + + if self.instances_to_skip > 0: + self.instances_to_skip -= 1 + return + + self.skip = False + self.skip_tag = None + self.paragraph = None + + if tag == 'span': + if self.tags['span']: + self.tags['span'].pop() + return + elif tag == 'ol' or tag == 'ul': + remove_last_occurence(self.tags['list'], tag) + return + elif tag == 'table': + self.table_no += 1 + self.table = None + self.doc = self.document + self.paragraph = None + + if tag in self.tags: + self.tags.pop(tag) + # maybe set relevant reference to None? + + def handle_data(self, data): + if self.skip: + return + + # Only remove white space if we're not in a pre block. + if 'pre' not in self.tags: + # remove leading and trailing whitespace in all instances + data = remove_whitespace(data, True, True) + + if not self.paragraph: + self.paragraph = self.doc.add_paragraph() + self.apply_paragraph_style() + + # There can only be one nested link in a valid html document + # You cannot have interactive content in an A tag, this includes links + # https://html.spec.whatwg.org/#interactive-content + link = self.tags.get('a') + if link: + self.handle_link(link['href'], data) + else: + # If there's a link, dont put the data directly in the run + self.run = self.paragraph.add_run(data) + spans = self.tags['span'] + for span in spans: + if 'style' in span: + style = self.parse_dict_string(span['style']) + self.add_styles_to_run(style) + + # add font style and name + for tag in self.tags: + if tag in font_styles: + font_style = font_styles[tag] + setattr(self.run.font, font_style, True) + + if tag in font_names: + font_name = font_names[tag] + self.run.font.name = font_name + + def ignore_nested_tables(self, tables_soup): + """ + Returns array containing only the highest level tables + Operates on the assumption that bs4 returns child elements immediately after + the parent element in `find_all`. If this changes in the future, this method will need to be updated + :return: + """ + new_tables = [] + nest = 0 + for table in tables_soup: + if nest: + nest -= 1 + continue + new_tables.append(table) + nest = len(table.find_all('table')) + return new_tables + + def get_table_rows(self, table_soup): + # If there's a header, body, footer or direct child tr tags, add row dimensions from there + return table_soup.select(', '.join(self.table_row_selectors), recursive=False) + + def get_table_columns(self, row): + # Get all columns for the specified row tag. + return row.find_all(['th', 'td'], recursive=False) if row else [] + + def get_table_dimensions(self, table_soup): + # Get rows for the table + rows = self.get_table_rows(table_soup) + # Table is either empty or has non-direct children between table and tr tags + # Thus the row dimensions and column dimensions are assumed to be 0 + + cols = self.get_table_columns(rows[0]) if rows else [] + # Add colspan calculation column number + col_count = 0 + for col in cols: + colspan = col.attrs.get('colspan', 1) + col_count += int(colspan) + + # return len(rows), col_count + return rows, col_count + + def get_tables(self): + if not hasattr(self, 'soup'): + self.include_tables = False + return + # find other way to do it, or require this dependency? + self.tables = self.ignore_nested_tables(self.soup.find_all('table')) + self.table_no = 0 + + def run_process(self, html): + if self.bs and BeautifulSoup: + self.soup = BeautifulSoup(html, 'html.parser') + html = str(self.soup) + if self.include_tables: + self.get_tables() + self.feed(html) + + def add_html_to_document(self, html, document): + if not isinstance(html, str): + raise ValueError('First argument needs to be a %s' % str) + elif not isinstance(document, docx.document.Document) and not isinstance(document, docx.table._Cell): + raise ValueError('Second argument needs to be a %s' % docx.document.Document) + self.set_initial_attrs(document) + self.run_process(html) + + def add_html_to_cell(self, html, cell): + self.set_initial_attrs(cell) + self.run_process(html) + + def parse_html_file(self, filename_html, filename_docx=None): + with open(filename_html, 'r') as infile: + html = infile.read() + self.set_initial_attrs() + self.run_process(html) + if not filename_docx: + path, filename = os.path.split(filename_html) + filename_docx = '%s/new_docx_file_%s' % (path, filename) + self.doc.save('%s.docx' % filename_docx) + + def parse_html_string(self, html): + self.set_initial_attrs() + self.run_process(html) + return self.doc \ No newline at end of file diff --git a/ppstructure/utility.py b/ppstructure/utility.py index cda4c063bccbd2aff34cf25768866feb4d68dc2d..2cf20eb53f87a8f8fbe2bdb4c3ead77f40120370 100644 --- a/ppstructure/utility.py +++ b/ppstructure/utility.py @@ -38,7 +38,7 @@ def init_args(): parser.add_argument( "--layout_dict_path", type=str, - default="../ppocr/utils/dict/layout_publaynet_dict.txt") + default="../ppocr/utils/dict/layout_dict/layout_pubalynet_dict.txt") parser.add_argument( "--layout_score_threshold", type=float, @@ -89,6 +89,11 @@ def init_args(): type=bool, default=False, help='Whether to enable layout of recovery') + parser.add_argument( + "--save_pdf", + type=bool, + default=False, + help='Whether to save pdf file') return parser diff --git a/test_tipc/common_func.sh b/test_tipc/common_func.sh index f7d8a1e04adee9d32332eda8cb5913bbaf168481..1bbf829165323b76341461b297b71102462d83af 100644 --- a/test_tipc/common_func.sh +++ b/test_tipc/common_func.sh @@ -58,10 +58,11 @@ function status_check(){ run_command=$2 run_log=$3 model_name=$4 + log_path=$5 if [ $last_status -eq 0 ]; then - echo -e "\033[33m Run successfully with command - ${model_name} - ${run_command}! \033[0m" | tee -a ${run_log} + echo -e "\033[33m Run successfully with command - ${model_name} - ${run_command} - ${log_path} \033[0m" | tee -a ${run_log} else - echo -e "\033[33m Run failed with command - ${model_name} - ${run_command}! \033[0m" | tee -a ${run_log} + echo -e "\033[33m Run failed with command - ${model_name} - ${run_command} - ${log_path} \033[0m" | tee -a ${run_log} fi } diff --git a/test_tipc/configs/det_r18_vd_db_v2_0/train_infer_python.txt b/test_tipc/configs/det_r18_vd_db_v2_0/train_infer_python.txt deleted file mode 100644 index df88c0e5434511fb48deac699e8f67fc535765d3..0000000000000000000000000000000000000000 --- a/test_tipc/configs/det_r18_vd_db_v2_0/train_infer_python.txt +++ /dev/null @@ -1,58 +0,0 @@ -===========================train_params=========================== -model_name:det_r18_db_v2_0 -python:python3.7 -gpu_list:0|0,1 -Global.use_gpu:True|True -Global.auto_cast:null -Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300 -Global.save_model_dir:./output/ -Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_lite_infer=4 -Global.pretrained_model:null -train_model_name:latest -train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ -null:null -## -trainer:norm_train -norm_train:tools/train.py -c configs/det/det_res18_db_v2.0.yml -o -quant_export:null -fpgm_export:null -distill_train:null -null:null -null:null -## -===========================eval_params=========================== -eval:null -null:null -## -===========================infer_params=========================== -Global.save_inference_dir:./output/ -Global.checkpoints: -norm_export:null -quant_export:null -fpgm_export:null -distill_export:null -export1:null -export2:null -## -train_model:null -infer_export:null -infer_quant:False -inference:tools/infer/predict_det.py ---use_gpu:True|False ---enable_mkldnn:False ---cpu_threads:6 ---rec_batch_num:1 ---use_tensorrt:False ---precision:fp32 ---det_model_dir: ---image_dir:./inference/ch_det_data_50/all-sum-510/ ---save_log_path:null ---benchmark:True -null:null -===========================infer_benchmark_params========================== -random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] -===========================train_benchmark_params========================== -batch_size:8|16 -fp_items:fp32|fp16 -epoch:15 ---profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile diff --git a/test_tipc/configs/en_table_structure/train_infer_python.txt b/test_tipc/configs/en_table_structure/train_infer_python.txt index 633b6185d976ac61408283025bd4ba305187317d..3fd5dc9f60a9621026d488e5654cd7e1421e8b65 100644 --- a/test_tipc/configs/en_table_structure/train_infer_python.txt +++ b/test_tipc/configs/en_table_structure/train_infer_python.txt @@ -54,6 +54,6 @@ random_infer_input:[{float32,[3,488,488]}] ===========================train_benchmark_params========================== batch_size:32 fp_items:fp32|fp16 -epoch:1 +epoch:2 --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 diff --git a/test_tipc/configs/layoutxlm_ser/train_infer_python.txt b/test_tipc/configs/layoutxlm_ser/train_infer_python.txt index 34082bc193a2ebd8f4c7a9e7c9ce55dc8dbf8e40..5284ffabe2de4eb8bb000e7fb745ef2846ed6b64 100644 --- a/test_tipc/configs/layoutxlm_ser/train_infer_python.txt +++ b/test_tipc/configs/layoutxlm_ser/train_infer_python.txt @@ -52,7 +52,7 @@ null:null ===========================infer_benchmark_params========================== random_infer_input:[{float32,[3,224,224]}] ===========================train_benchmark_params========================== -batch_size:4 +batch_size:8 fp_items:fp32|fp16 epoch:3 --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile diff --git a/test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml b/test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml new file mode 100644 index 0000000000000000000000000000000000000000..b5466d4478be27d6fd152ee467f7f25731c8dce0 --- /dev/null +++ b/test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml @@ -0,0 +1,111 @@ +Global: + use_gpu: true + epoch_num: 5 + log_smooth_window: 20 + print_batch_step: 20 + save_model_dir: ./output/rec/rec_r31_robustscanner/ + save_epoch_step: 1 + # evaluation is run every 2000 iterations + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: ./inference/rec_inference + # for data or label process + character_dict_path: ppocr/utils/dict90.txt + max_text_length: &max_text_length 40 + infer_mode: False + use_space_char: False + rm_symbol: True + save_res_path: ./output/rec/predicts_robustscanner.txt + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Piecewise + decay_epochs: [3, 4] + values: [0.001, 0.0001, 0.00001] + regularizer: + name: 'L2' + factor: 0 + +Architecture: + model_type: rec + algorithm: RobustScanner + Transform: + Backbone: + name: ResNet31 + init_type: KaimingNormal + Head: + name: RobustScannerHead + enc_outchannles: 128 + hybrid_dec_rnn_layers: 2 + hybrid_dec_dropout: 0 + position_dec_rnn_layers: 2 + start_idx: 91 + mask: True + padding_idx: 92 + encode_value: False + max_text_length: *max_text_length + +Loss: + name: SARLoss + +PostProcess: + name: SARLabelDecode + +Metric: + name: RecMetric + is_filter: True + + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ic15_data/ + label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SARLabelEncode: # Class handling label + - RobustScannerRecResizeImg: + image_shape: [3, 48, 48, 160] # h:48 w:[48,160] + width_downsample_ratio: 0.25 + max_text_length: *max_text_length + - KeepKeys: + keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 16 + drop_last: True + num_workers: 0 + use_shared_memory: False + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ic15_data + label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SARLabelEncode: # Class handling label + - RobustScannerRecResizeImg: + image_shape: [3, 48, 48, 160] + max_text_length: *max_text_length + width_downsample_ratio: 0.25 + - KeepKeys: + keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 16 + num_workers: 0 + use_shared_memory: False + diff --git a/test_tipc/configs/rec_r31_robustscanner/train_infer_python.txt b/test_tipc/configs/rec_r31_robustscanner/train_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..07498c9e81ada9652343b8d8fff0f102d4684380 --- /dev/null +++ b/test_tipc/configs/rec_r31_robustscanner/train_infer_python.txt @@ -0,0 +1,54 @@ +===========================train_params=========================== +model_name:rec_r31_robustscanner +python:python +gpu_list:0|0,1 +Global.use_gpu:True|True +Global.auto_cast:null +Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=5 +Global.save_model_dir:./output/ +Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64 +Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./inference/rec_inference +null:null +## +trainer:norm_train +norm_train:tools/train.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o +null:null +## +===========================infer_params=========================== +Global.save_inference_dir:./output/ +Global.checkpoints: +norm_export:tools/export_model.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +## +train_model:./inference/rec_r31_robustscanner/best_accuracy +infer_export:tools/export_model.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o +infer_quant:False +inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict90.txt --rec_image_shape="3,48,48,160" --use_space_char=False --rec_algorithm="RobustScanner" +--use_gpu:True|False +--enable_mkldnn:True|False +--cpu_threads:1|6 +--rec_batch_num:1|6 +--use_tensorrt:False|False +--precision:fp32|int8 +--rec_model_dir: +--image_dir:./inference/rec_inference +--save_log_path:./test/output/ +--benchmark:True +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,48,160]}] + diff --git a/test_tipc/configs/rec_r45_visionlan/rec_r45_visionlan.yml b/test_tipc/configs/rec_r45_visionlan/rec_r45_visionlan.yml new file mode 100644 index 0000000000000000000000000000000000000000..860e4f53043138e7434d71a816fdf051048be6f7 --- /dev/null +++ b/test_tipc/configs/rec_r45_visionlan/rec_r45_visionlan.yml @@ -0,0 +1,108 @@ +Global: + use_gpu: true + epoch_num: 8 + log_smooth_window: 200 + print_batch_step: 200 + save_model_dir: ./output/rec/r45_visionlan + save_epoch_step: 1 + # evaluation is run every 2000 iterations + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/en/word_2.png + # for data or label process + character_dict_path: + max_text_length: &max_text_length 25 + training_step: &training_step LA + infer_mode: False + use_space_char: False + save_res_path: ./output/rec/predicts_visionlan.txt + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + clip_norm: 20.0 + group_lr: true + training_step: *training_step + lr: + name: Piecewise + decay_epochs: [6] + values: [0.0001, 0.00001] + regularizer: + name: 'L2' + factor: 0 + +Architecture: + model_type: rec + algorithm: VisionLAN + Transform: + Backbone: + name: ResNet45 + strides: [2, 2, 2, 1, 1] + Head: + name: VLHead + n_layers: 3 + n_position: 256 + n_dim: 512 + max_text_length: *max_text_length + training_step: *training_step + +Loss: + name: VLLoss + mode: *training_step + weight_res: 0.5 + weight_mas: 0.5 + +PostProcess: + name: VLLabelDecode + +Metric: + name: RecMetric + is_filter: true + + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ic15_data/ + label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"] + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - ABINetRecAug: + - VLLabelEncode: # Class handling label + - VLRecResizeImg: + image_shape: [3, 64, 256] + - KeepKeys: + keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 220 + drop_last: True + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ic15_data + label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"] + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - VLLabelEncode: # Class handling label + - VLRecResizeImg: + image_shape: [3, 64, 256] + - KeepKeys: + keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 64 + num_workers: 4 + diff --git a/test_tipc/configs/rec_r45_visionlan/train_infer_python.txt b/test_tipc/configs/rec_r45_visionlan/train_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..c08ae7beb6c867bf36283e60dc1e70cfd9ee06a7 --- /dev/null +++ b/test_tipc/configs/rec_r45_visionlan/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:rec_r45_visionlan +python:python3.7 +gpu_list:0|0,1 +Global.use_gpu:True|True +Global.auto_cast:null +Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300 +Global.save_model_dir:./output/ +Train.loader.batch_size_per_card:lite_train_lite_infer=32|whole_train_whole_infer=64 +Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./inference/rec_inference +null:null +## +trainer:norm_train +norm_train:tools/train.py -c test_tipc/configs/rec_r45_visionlan/rec_r45_visionlan.yml -o +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c test_tipc/configs/rec_r45_visionlan/rec_r45_visionlan.yml -o +null:null +## +===========================infer_params=========================== +Global.save_inference_dir:./output/ +Global.checkpoints: +norm_export:tools/export_model.py -c test_tipc/configs/rec_r45_visionlan/rec_r45_visionlan.yml -o +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +## +train_model:./inference/rec_r45_visionlan_train/best_accuracy +infer_export:tools/export_model.py -c test_tipc/configs/rec_r45_visionlan/rec_r45_visionlan.yml -o +infer_quant:False +inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,64,256" --rec_algorithm="VisionLAN" --use_space_char=False +--use_gpu:True|False +--enable_mkldnn:False +--cpu_threads:6 +--rec_batch_num:1|6 +--use_tensorrt:False +--precision:fp32 +--rec_model_dir: +--image_dir:./inference/rec_inference +--save_log_path:./test/output/ +--benchmark:True +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,64,256]}] diff --git a/test_tipc/readme.md b/test_tipc/readme.md index f9e9d89e4198c1ad5fabdf58775c6f7b6d190322..1442ee1c86a7c1319446a0eb22c08287e1ce689a 100644 --- a/test_tipc/readme.md +++ b/test_tipc/readme.md @@ -54,6 +54,7 @@ | NRTR |rec_mtb_nrtr | 识别 | 支持 | 多机多卡
混合精度 | - | - | | SAR |rec_r31_sar | 识别 | 支持 | 多机多卡
混合精度 | - | - | | SPIN |rec_r32_gaspin_bilstm_att | 识别 | 支持 | 多机多卡
混合精度 | - | - | +| RobustScanner |rec_r31_robustscanner | 识别 | 支持 | 多机多卡
混合精度 | - | - | | PGNet |rec_r34_vd_none_none_ctc_v2.0 | 端到端| 支持 | 多机多卡
混合精度 | - | - | | TableMaster |table_structure_tablemaster_train | 表格识别| 支持 | 多机多卡
混合精度 | - | - | diff --git a/test_tipc/test_inference_cpp.sh b/test_tipc/test_inference_cpp.sh index c0c7c18a38a46b00c839757e303049135a508691..aadaa8b0773632885138806861fc851ede503f3d 100644 --- a/test_tipc/test_inference_cpp.sh +++ b/test_tipc/test_inference_cpp.sh @@ -84,7 +84,7 @@ function func_cpp_inference(){ eval $command last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${command}" "${status_log}" "${model_name}" + status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}" done done done @@ -117,7 +117,7 @@ function func_cpp_inference(){ eval $command last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${command}" "${status_log}" "${model_name}" + status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}" done done diff --git a/test_tipc/test_inference_python.sh b/test_tipc/test_inference_python.sh index 2a31a468f0d54d1979e82c8f0da98cac6f4edcec..e9908df1f6049f9d38524dc6598499ddd2b58af8 100644 --- a/test_tipc/test_inference_python.sh +++ b/test_tipc/test_inference_python.sh @@ -88,7 +88,7 @@ function func_inference(){ eval $command last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${command}" "${status_log}" "${model_name}" + status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}" done done done @@ -119,7 +119,7 @@ function func_inference(){ eval $command last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${command}" "${status_log}" "${model_name}" + status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}" done done @@ -146,14 +146,15 @@ if [ ${MODE} = "whole_infer" ]; then for infer_model in ${infer_model_dir_list[*]}; do # run export if [ ${infer_run_exports[Count]} != "null" ];then + _save_log_path="${_log_path}/python_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}_infermodel_${infer_model}.log" save_infer_dir=$(dirname $infer_model) set_export_weight=$(func_set_params "${export_weight}" "${infer_model}") set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}") - export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key}" + export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key} > ${_save_log_path} 2>&1 " echo ${infer_run_exports[Count]} eval $export_cmd status_export=$? - status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" + status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" "${_save_log_path}" else save_infer_dir=${infer_model} fi diff --git a/test_tipc/test_paddle2onnx.sh b/test_tipc/test_paddle2onnx.sh index 78d79d0b8eaac782f98c1e883d091a001443f41a..bace6b2d4684e0ad40ffbd76b37a78ddf1e70722 100644 --- a/test_tipc/test_paddle2onnx.sh +++ b/test_tipc/test_paddle2onnx.sh @@ -66,7 +66,7 @@ function func_paddle2onnx(){ trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_det_log} 2>&1 " eval $trans_model_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_det_log}" # trans rec set_dirname=$(func_set_params "--model_dir" "${rec_infer_model_dir_value}") set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}") @@ -78,7 +78,7 @@ function func_paddle2onnx(){ trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_rec_log} 2>&1 " eval $trans_model_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}" elif [[ ${model_name} =~ "det" ]]; then # trans det set_dirname=$(func_set_params "--model_dir" "${det_infer_model_dir_value}") @@ -91,7 +91,7 @@ function func_paddle2onnx(){ trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_det_log} 2>&1 " eval $trans_model_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_det_log}" elif [[ ${model_name} =~ "rec" ]]; then # trans rec set_dirname=$(func_set_params "--model_dir" "${rec_infer_model_dir_value}") @@ -104,7 +104,7 @@ function func_paddle2onnx(){ trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_rec_log} 2>&1 " eval $trans_model_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}" fi # python inference @@ -127,7 +127,7 @@ function func_paddle2onnx(){ eval $infer_model_cmd last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}" "${_save_log_path}" elif [ ${use_gpu} = "True" ] || [ ${use_gpu} = "gpu" ]; then _save_log_path="${LOG_PATH}/paddle2onnx_infer_gpu.log" set_gpu=$(func_set_params "${use_gpu_key}" "${use_gpu}") @@ -146,7 +146,7 @@ function func_paddle2onnx(){ eval $infer_model_cmd last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}" "${_save_log_path}" else echo "Does not support hardware other than CPU and GPU Currently!" fi @@ -158,4 +158,4 @@ echo "################### run test ###################" export Count=0 IFS="|" -func_paddle2onnx \ No newline at end of file +func_paddle2onnx diff --git a/test_tipc/test_ptq_inference_python.sh b/test_tipc/test_ptq_inference_python.sh index e2939fd5e638ad0f6b4c44422a6fec6459903d1c..caf3d506029ee066aa5abebc25b739439b6e9d75 100644 --- a/test_tipc/test_ptq_inference_python.sh +++ b/test_tipc/test_ptq_inference_python.sh @@ -84,7 +84,7 @@ function func_inference(){ eval $command last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${command}" "${status_log}" "${model_name}" + status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}" done done done @@ -109,7 +109,7 @@ function func_inference(){ eval $command last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${command}" "${status_log}" "${model_name}" + status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}" done done @@ -145,7 +145,7 @@ if [ ${MODE} = "whole_infer" ]; then echo $export_cmd eval $export_cmd status_export=$? - status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" + status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" "${export_log_path}" else save_infer_dir=${infer_model} fi diff --git a/test_tipc/test_serving_infer_cpp.sh b/test_tipc/test_serving_infer_cpp.sh index 0be6a45adf3105f088a96336dddfbe9ac612f19b..10ddecf3fa26805fef7bc6ae10d78ee5e741cd27 100644 --- a/test_tipc/test_serving_infer_cpp.sh +++ b/test_tipc/test_serving_infer_cpp.sh @@ -83,7 +83,7 @@ function func_serving(){ trans_model_cmd="${python_list[0]} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client} > ${trans_rec_log} 2>&1 " eval $trans_model_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}" set_image_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}") python_list=(${python_list}) cd ${serving_dir_value} @@ -95,14 +95,14 @@ function func_serving(){ web_service_cpp_cmd="nohup ${python_list[0]} ${web_service_py} --model ${det_server_value} ${rec_server_value} ${op_key} ${op_value} ${port_key} ${port_value} > ${server_log_path} 2>&1 &" eval $web_service_cpp_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${web_service_cpp_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${web_service_cpp_cmd}" "${status_log}" "${model_name}" "${server_log_path}" sleep 5s _save_log_path="${LOG_PATH}/cpp_client_cpu.log" cpp_client_cmd="${python_list[0]} ${cpp_client_py} ${det_client_value} ${rec_client_value} > ${_save_log_path} 2>&1" eval $cpp_client_cmd last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}" "${_save_log_path}" ps ux | grep -i ${port_value} | awk '{print $2}' | xargs kill -s 9 else server_log_path="${LOG_PATH}/cpp_server_gpu.log" @@ -114,7 +114,7 @@ function func_serving(){ eval $cpp_client_cmd last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}" "${_save_log_path}" ps ux | grep -i ${port_value} | awk '{print $2}' | xargs kill -s 9 fi done diff --git a/test_tipc/test_serving_infer_python.sh b/test_tipc/test_serving_infer_python.sh index 4b7dfcf785a3c8459cce95d55744dbcd4f97027a..c7d305d5d2dcd2ea1bf5a7c3254eea4231d59879 100644 --- a/test_tipc/test_serving_infer_python.sh +++ b/test_tipc/test_serving_infer_python.sh @@ -126,19 +126,19 @@ function func_serving(){ web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &" eval $web_service_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}" elif [[ ${model_name} =~ "det" ]]; then set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}") web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} > ${server_log_path} 2>&1 &" eval $web_service_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}" elif [[ ${model_name} =~ "rec" ]]; then set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}") web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_rec_model_config} > ${server_log_path} 2>&1 &" eval $web_service_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}" fi sleep 2s for pipeline in ${pipeline_py[*]}; do @@ -147,7 +147,7 @@ function func_serving(){ eval $pipeline_cmd last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}" "${_save_log_path}" sleep 2s done ps ux | grep -E 'web_service' | awk '{print $2}' | xargs kill -s 9 @@ -177,19 +177,19 @@ function func_serving(){ web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &" eval $web_service_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}" elif [[ ${model_name} =~ "det" ]]; then set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}") web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} > ${server_log_path} 2>&1 &" eval $web_service_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}" elif [[ ${model_name} =~ "rec" ]]; then set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}") web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_rec_model_config} > ${server_log_path} 2>&1 &" eval $web_service_cmd last_status=${PIPESTATUS[0]} - status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}" fi sleep 2s for pipeline in ${pipeline_py[*]}; do @@ -198,7 +198,7 @@ function func_serving(){ eval $pipeline_cmd last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}" + status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}" "${_save_log_path}" sleep 2s done ps ux | grep -E 'web_service' | awk '{print $2}' | xargs kill -s 9 diff --git a/test_tipc/test_train_inference_python.sh b/test_tipc/test_train_inference_python.sh index 545cdbba2051c8123ef7f70f2aeb4b4b5a57b7c5..e182fa57f060c81af012a5da89b892bde02b4a2b 100644 --- a/test_tipc/test_train_inference_python.sh +++ b/test_tipc/test_train_inference_python.sh @@ -133,7 +133,7 @@ function func_inference(){ eval $command last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${command}" "${status_log}" "${model_name}" + status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}" done done done @@ -164,7 +164,7 @@ function func_inference(){ eval $command last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" - status_check $last_status "${command}" "${status_log}" "${model_name}" + status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}" done done @@ -201,7 +201,7 @@ if [ ${MODE} = "whole_infer" ]; then echo $export_cmd eval $export_cmd status_export=$? - status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" + status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" "${export_log_path}" else save_infer_dir=${infer_model} fi @@ -298,7 +298,7 @@ else # run train eval $cmd eval "cat ${save_log}/train.log >> ${save_log}.log" - status_check $? "${cmd}" "${status_log}" "${model_name}" + status_check $? "${cmd}" "${status_log}" "${model_name}" "${save_log}.log" set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${train_model_name}") @@ -309,7 +309,7 @@ else eval_log_path="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}_nodes_${nodes}_eval.log" eval_cmd="${python} ${eval_py} ${set_eval_pretrain} ${set_use_gpu} ${set_eval_params1} > ${eval_log_path} 2>&1 " eval $eval_cmd - status_check $? "${eval_cmd}" "${status_log}" "${model_name}" + status_check $? "${eval_cmd}" "${status_log}" "${model_name}" "${eval_log_path}" fi # run export model if [ ${run_export} != "null" ]; then @@ -320,7 +320,7 @@ else set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_path}") export_cmd="${python} ${run_export} ${set_export_weight} ${set_save_infer_key} > ${export_log_path} 2>&1 " eval $export_cmd - status_check $? "${export_cmd}" "${status_log}" "${model_name}" + status_check $? "${export_cmd}" "${status_log}" "${model_name}" "${export_log_path}" #run inference eval $env diff --git a/tools/eval.py b/tools/eval.py index 2fc53488efa2c4c475d31af47f69b3560e6cc69a..38d72d178db45a4787ddc09c865afba9222f385a 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -73,7 +73,7 @@ def main(): config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) - extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN"] + extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner"] extra_input = False if config['Architecture']['algorithm'] == 'Distillation': for key in config['Architecture']["Models"]: diff --git a/tools/export_model.py b/tools/export_model.py index c6763374a634dfca125f64b63d7c85716f68f142..193988cc1b62a6c4536a8d2ec640e3e5fc81a79c 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -58,6 +58,8 @@ def export_single_model(model, other_shape = [ paddle.static.InputSpec( shape=[None, 3, 48, 160], dtype="float32"), + [paddle.static.InputSpec( + shape=[None], dtype="float32")] ] model = to_static(model, input_spec=other_shape) elif arch_config["algorithm"] == "SVTR": @@ -109,6 +111,22 @@ def export_single_model(model, shape=[None, 3, 64, 256], dtype="float32"), ] model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "RobustScanner": + max_text_length = arch_config["Head"]["max_text_length"] + other_shape = [ + paddle.static.InputSpec( + shape=[None, 3, 48, 160], dtype="float32"), + + [ + paddle.static.InputSpec( + shape=[None, ], + dtype="float32"), + paddle.static.InputSpec( + shape=[None, max_text_length], + dtype="int64") + ] + ] + model = to_static(model, input_spec=other_shape) elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]: input_spec = [ paddle.static.InputSpec( @@ -128,7 +146,7 @@ def export_single_model(model, else: infer_shape = [3, -1, -1] if arch_config["model_type"] == "rec": - infer_shape = [3, 48, -1] # for rec model, H must be 32 + infer_shape = [3, 32, -1] # for rec model, H must be 32 if "Transform" in arch_config and arch_config[ "Transform"] is not None and arch_config["Transform"][ "name"] == "TPS": @@ -234,4 +252,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 449f69ed6a22cb12af8a6a2ef8f2eedc1aca087c..53dab6f26d8b84a224360f2fa6fe5f411eea751f 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -68,7 +68,7 @@ class TextRecognizer(object): 'name': 'SARLabelDecode', "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char - } + } elif self.rec_algorithm == "VisionLAN": postprocess_params = { 'name': 'VLLabelDecode', @@ -93,6 +93,13 @@ class TextRecognizer(object): "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char } + elif self.rec_algorithm == "RobustScanner": + postprocess_params = { + 'name': 'SARLabelDecode', + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char, + "rm_symbol": True + } self.postprocess_op = build_post_process(postprocess_params) self.predictor, self.input_tensor, self.output_tensors, self.config = \ utility.create_predictor(args, 'rec', logger) @@ -390,6 +397,18 @@ class TextRecognizer(object): img_list[indices[ino]], self.rec_image_shape) norm_img = norm_img[np.newaxis, :] norm_img_batch.append(norm_img) + elif self.rec_algorithm == "RobustScanner": + norm_img, _, _, valid_ratio = self.resize_norm_img_sar( + img_list[indices[ino]], self.rec_image_shape, width_downsample_ratio=0.25) + norm_img = norm_img[np.newaxis, :] + valid_ratio = np.expand_dims(valid_ratio, axis=0) + valid_ratios = [] + valid_ratios.append(valid_ratio) + norm_img_batch.append(norm_img) + word_positions_list = [] + word_positions = np.array(range(0, 40)).astype('int64') + word_positions = np.expand_dims(word_positions, axis=0) + word_positions_list.append(word_positions) else: norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) @@ -437,10 +456,40 @@ class TextRecognizer(object): preds = {"predict": outputs[2]} elif self.rec_algorithm == "SAR": valid_ratios = np.concatenate(valid_ratios) + inputs = [ + norm_img_batch, + np.array( + [valid_ratios], dtype=np.float32), + ] + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + outputs = self.predictor.run(self.output_tensors, + input_dict) + preds = outputs[0] + else: + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle( + input_names[i]) + input_tensor.copy_from_cpu(inputs[i]) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if self.benchmark: + self.autolog.times.stamp() + preds = outputs[0] + elif self.rec_algorithm == "RobustScanner": + valid_ratios = np.concatenate(valid_ratios) + word_positions_list = np.concatenate(word_positions_list) inputs = [ norm_img_batch, valid_ratios, + word_positions_list ] + if self.use_onnx: input_dict = {} input_dict[self.input_tensor.name] = norm_img_batch diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 81d0196ccd6b86741e73524d9321618f3f5cc34b..1eebc73f31e6b48a473c20d907ca401ad919fe0b 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -231,89 +231,10 @@ def create_predictor(args, mode, logger): ) config.enable_tuned_tensorrt_dynamic_shape( args.shape_info_filename, True) - - use_dynamic_shape = True - if mode == "det": - min_input_shape = { - "x": [1, 3, 50, 50], - "conv2d_92.tmp_0": [1, 120, 20, 20], - "conv2d_91.tmp_0": [1, 24, 10, 10], - "conv2d_59.tmp_0": [1, 96, 20, 20], - "nearest_interp_v2_1.tmp_0": [1, 256, 10, 10], - "nearest_interp_v2_2.tmp_0": [1, 256, 20, 20], - "conv2d_124.tmp_0": [1, 256, 20, 20], - "nearest_interp_v2_3.tmp_0": [1, 64, 20, 20], - "nearest_interp_v2_4.tmp_0": [1, 64, 20, 20], - "nearest_interp_v2_5.tmp_0": [1, 64, 20, 20], - "elementwise_add_7": [1, 56, 2, 2], - "nearest_interp_v2_0.tmp_0": [1, 256, 2, 2] - } - max_input_shape = { - "x": [1, 3, 1536, 1536], - "conv2d_92.tmp_0": [1, 120, 400, 400], - "conv2d_91.tmp_0": [1, 24, 200, 200], - "conv2d_59.tmp_0": [1, 96, 400, 400], - "nearest_interp_v2_1.tmp_0": [1, 256, 200, 200], - "conv2d_124.tmp_0": [1, 256, 400, 400], - "nearest_interp_v2_2.tmp_0": [1, 256, 400, 400], - "nearest_interp_v2_3.tmp_0": [1, 64, 400, 400], - "nearest_interp_v2_4.tmp_0": [1, 64, 400, 400], - "nearest_interp_v2_5.tmp_0": [1, 64, 400, 400], - "elementwise_add_7": [1, 56, 400, 400], - "nearest_interp_v2_0.tmp_0": [1, 256, 400, 400] - } - opt_input_shape = { - "x": [1, 3, 640, 640], - "conv2d_92.tmp_0": [1, 120, 160, 160], - "conv2d_91.tmp_0": [1, 24, 80, 80], - "conv2d_59.tmp_0": [1, 96, 160, 160], - "nearest_interp_v2_1.tmp_0": [1, 256, 80, 80], - "nearest_interp_v2_2.tmp_0": [1, 256, 160, 160], - "conv2d_124.tmp_0": [1, 256, 160, 160], - "nearest_interp_v2_3.tmp_0": [1, 64, 160, 160], - "nearest_interp_v2_4.tmp_0": [1, 64, 160, 160], - "nearest_interp_v2_5.tmp_0": [1, 64, 160, 160], - "elementwise_add_7": [1, 56, 40, 40], - "nearest_interp_v2_0.tmp_0": [1, 256, 40, 40] - } - min_pact_shape = { - "nearest_interp_v2_26.tmp_0": [1, 256, 20, 20], - "nearest_interp_v2_27.tmp_0": [1, 64, 20, 20], - "nearest_interp_v2_28.tmp_0": [1, 64, 20, 20], - "nearest_interp_v2_29.tmp_0": [1, 64, 20, 20] - } - max_pact_shape = { - "nearest_interp_v2_26.tmp_0": [1, 256, 400, 400], - "nearest_interp_v2_27.tmp_0": [1, 64, 400, 400], - "nearest_interp_v2_28.tmp_0": [1, 64, 400, 400], - "nearest_interp_v2_29.tmp_0": [1, 64, 400, 400] - } - opt_pact_shape = { - "nearest_interp_v2_26.tmp_0": [1, 256, 160, 160], - "nearest_interp_v2_27.tmp_0": [1, 64, 160, 160], - "nearest_interp_v2_28.tmp_0": [1, 64, 160, 160], - "nearest_interp_v2_29.tmp_0": [1, 64, 160, 160] - } - min_input_shape.update(min_pact_shape) - max_input_shape.update(max_pact_shape) - opt_input_shape.update(opt_pact_shape) - elif mode == "rec": - if args.rec_algorithm not in ["CRNN", "SVTR_LCNet"]: - use_dynamic_shape = False - imgH = int(args.rec_image_shape.split(',')[-2]) - min_input_shape = {"x": [1, 3, imgH, 10]} - max_input_shape = {"x": [args.rec_batch_num, 3, imgH, 2304]} - opt_input_shape = {"x": [args.rec_batch_num, 3, imgH, 320]} - config.exp_disable_tensorrt_ops(["transpose2"]) - elif mode == "cls": - min_input_shape = {"x": [1, 3, 48, 10]} - max_input_shape = {"x": [args.rec_batch_num, 3, 48, 1024]} - opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]} else: - use_dynamic_shape = False - if use_dynamic_shape: - config.set_trt_dynamic_shape_info( - min_input_shape, max_input_shape, opt_input_shape) + logger.info( + f"when using tensorrt, dynamic shape is a suggested option, you can use '--shape_info_filename=shape.txt' for offline dygnamic shape tuning" + ) elif args.use_xpu: config.enable_xpu(10 * 1024 * 1024) diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 182694e6cda12ead0e263bb94a7d6483a6f7f212..14b14544eb11e9fb0a0c2cdf92aff9d7cb4b5ba7 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -96,6 +96,8 @@ def main(): ] elif config['Architecture']['algorithm'] == "SAR": op[op_name]['keep_keys'] = ['image', 'valid_ratio'] + elif config['Architecture']['algorithm'] == "RobustScanner": + op[op_name]['keep_keys'] = ['image', 'valid_ratio', 'word_positons'] else: op[op_name]['keep_keys'] = ['image'] transforms.append(op) @@ -131,12 +133,20 @@ def main(): if config['Architecture']['algorithm'] == "SAR": valid_ratio = np.expand_dims(batch[-1], axis=0) img_metas = [paddle.to_tensor(valid_ratio)] + if config['Architecture']['algorithm'] == "RobustScanner": + valid_ratio = np.expand_dims(batch[1], axis=0) + word_positons = np.expand_dims(batch[2], axis=0) + img_metas = [paddle.to_tensor(valid_ratio), + paddle.to_tensor(word_positons), + ] images = np.expand_dims(batch[0], axis=0) images = paddle.to_tensor(images) if config['Architecture']['algorithm'] == "SRN": preds = model(images, others) elif config['Architecture']['algorithm'] == "SAR": preds = model(images, img_metas) + elif config['Architecture']['algorithm'] == "RobustScanner": + preds = model(images, img_metas) else: preds = model(images) post_result = post_process_class(preds) diff --git a/tools/infer_sr.py b/tools/infer_sr.py index 0bc2f6aaa7c4400676268ec64d37e721af0f99c2..df4334f3427e57b9062dd819aa16c110fd771e8c 100755 --- a/tools/infer_sr.py +++ b/tools/infer_sr.py @@ -63,14 +63,14 @@ def main(): elif op_name in ['SRResize']: op[op_name]['infer_mode'] = True elif op_name == 'KeepKeys': - op[op_name]['keep_keys'] = ['imge_lr'] + op[op_name]['keep_keys'] = ['img_lr'] transforms.append(op) global_config['infer_mode'] = True ops = create_operators(transforms, global_config) - save_res_path = config['Global'].get('save_res_path', "./infer_result") - if not os.path.exists(os.path.dirname(save_res_path)): - os.makedirs(os.path.dirname(save_res_path)) + save_visual_path = config['Global'].get('save_visual', "infer_result/") + if not os.path.exists(os.path.dirname(save_visual_path)): + os.makedirs(os.path.dirname(save_visual_path)) model.eval() for file in get_image_file_list(config['Global']['infer_img']): @@ -87,7 +87,7 @@ def main(): fm_sr = (sr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8) fm_lr = (lr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8) img_name_pure = os.path.split(file)[-1] - cv2.imwrite("infer_result/sr_{}".format(img_name_pure), + cv2.imwrite("{}/sr_{}".format(save_visual_path, img_name_pure), fm_sr[:, :, ::-1]) logger.info("The visualized image saved in infer_result/sr_{}".format( img_name_pure)) diff --git a/tools/program.py b/tools/program.py index b450bc5a3abf0be500b42712d72c81c190412f34..5a4d3ea4d2ec6832e6735d15096d46fbb62f86dd 100755 --- a/tools/program.py +++ b/tools/program.py @@ -162,18 +162,18 @@ def to_float32(preds): for k in preds: if isinstance(preds[k], dict) or isinstance(preds[k], list): preds[k] = to_float32(preds[k]) - else: - preds[k] = paddle.to_tensor(preds[k], dtype='float32') + elif isinstance(preds[k], paddle.Tensor): + preds[k] = preds[k].astype(paddle.float32) elif isinstance(preds, list): for k in range(len(preds)): if isinstance(preds[k], dict): preds[k] = to_float32(preds[k]) elif isinstance(preds[k], list): preds[k] = to_float32(preds[k]) - else: - preds[k] = paddle.to_tensor(preds[k], dtype='float32') - else: - preds = paddle.to_tensor(preds, dtype='float32') + elif isinstance(preds[k], paddle.Tensor): + preds[k] = preds[k].astype(paddle.float32) + elif isinstance(preds, paddle.Tensor): + preds = preds.astype(paddle.float32) return preds @@ -190,7 +190,8 @@ def train(config, pre_best_model_dict, logger, log_writer=None, - scaler=None): + scaler=None, + amp_level='O2'): cal_metric_during_train = config['Global'].get('cal_metric_during_train', False) calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1) @@ -230,7 +231,8 @@ def train(config, use_srn = config['Architecture']['algorithm'] == "SRN" extra_input_models = [ - "SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN" + "SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN", + "RobustScanner" ] extra_input = False if config['Architecture']['algorithm'] == 'Distillation': @@ -276,7 +278,8 @@ def train(config, model_average = True # use amp if scaler: - with paddle.amp.auto_cast(level='O2'): + custom_black_list = config['Global'].get('amp_custom_black_list',[]) + with paddle.amp.auto_cast(level=amp_level, custom_black_list=custom_black_list): if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) elif model_type in ["kie", 'vqa']: @@ -502,18 +505,9 @@ def eval(model, preds = model(batch) sr_img = preds["sr_img"] lr_img = preds["lr_img"] - - for i in (range(sr_img.shape[0])): - fm_sr = (sr_img[i].numpy() * 255).transpose( - 1, 2, 0).astype(np.uint8) - fm_lr = (lr_img[i].numpy() * 255).transpose( - 1, 2, 0).astype(np.uint8) - cv2.imwrite("output/images/{}_{}_sr.jpg".format( - sum_images, i), fm_sr) - cv2.imwrite("output/images/{}_{}_lr.jpg".format( - sum_images, i), fm_lr) else: preds = model(images) + preds = to_float32(preds) else: if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) @@ -523,16 +517,6 @@ def eval(model, preds = model(batch) sr_img = preds["sr_img"] lr_img = preds["lr_img"] - - for i in (range(sr_img.shape[0])): - fm_sr = (sr_img[i].numpy() * 255).transpose( - 1, 2, 0).astype(np.uint8) - fm_lr = (lr_img[i].numpy() * 255).transpose( - 1, 2, 0).astype(np.uint8) - cv2.imwrite("output/images/{}_{}_sr.jpg".format( - sum_images, i), fm_sr) - cv2.imwrite("output/images/{}_{}_lr.jpg".format( - sum_images, i), fm_lr) else: preds = model(images) @@ -653,7 +637,7 @@ def preprocess(is_train=False): 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN', - 'Gestalt', 'SLANet' + 'Gestalt', 'SLANet', 'RobustScanner' ] if use_xpu: diff --git a/tools/train.py b/tools/train.py index 0c881ecae8daf78860829b1419178358c2209f25..5f310938f3ae3488281b47ccdb436697595b5578 100755 --- a/tools/train.py +++ b/tools/train.py @@ -147,6 +147,7 @@ def main(config, device, logger, vdl_writer): len(valid_dataloader))) use_amp = config["Global"].get("use_amp", False) + amp_level = config["Global"].get("amp_level", 'O2') if use_amp: AMP_RELATED_FLAGS_SETTING = { 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, @@ -159,8 +160,9 @@ def main(config, device, logger, vdl_writer): scaler = paddle.amp.GradScaler( init_loss_scaling=scale_loss, use_dynamic_loss_scaling=use_dynamic_loss_scaling) - model, optimizer = paddle.amp.decorate( - models=model, optimizers=optimizer, level='O2', master_weight=True) + if amp_level == "O2": + model, optimizer = paddle.amp.decorate( + models=model, optimizers=optimizer, level=amp_level, master_weight=True) else: scaler = None @@ -169,7 +171,7 @@ def main(config, device, logger, vdl_writer): # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, - eval_class, pre_best_model_dict, logger, vdl_writer, scaler) + eval_class, pre_best_model_dict, logger, vdl_writer, scaler,amp_level) def test_reader(config, device, logger):