diff --git a/README.md b/README.md index 1ce6fabb55e231a7748bf2baeefc2d2fb4ffb2e5..a5c52bdb356fa765931fa7046aca5e4ef72caa95 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,7 @@ For a new language request, please refer to [Guideline for new language_requests - [Quick Inference Based on PIP](./doc/doc_en/whl_en.md) - [Python Inference](./doc/doc_en/inference_en.md) - [C++ Inference](./deploy/cpp_infer/readme_en.md) - - [Serving](./deploy/hubserving/readme_en.md) + - [Serving](./deploy/pdserving/README.md) - [Mobile](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/deploy/lite/readme_en.md) - [Benchmark](./doc/doc_en/benchmark_en.md) - Data Annotation and Synthesis diff --git a/README_ch.md b/README_ch.md index 9fe853f92510e540e35b7452f6ab00e5898ddb55..0430fe759f62155ad97d73db06445bbfe551c181 100755 --- a/README_ch.md +++ b/README_ch.md @@ -8,9 +8,10 @@ PaddleOCR同时支持动态图与静态图两种编程范式 - 静态图版本:develop分支 **近期更新** +- 【预告】 PaddleOCR研发团队对最新发版内容技术深入解读,4月13日晚上19:00,[直播地址](https://live.bilibili.com/21689802) +- 2021.4.8 release 2.1版本,新增AAAI 2021论文[端到端识别算法PGNet](./doc/doc_ch/pgnet.md)开源,[多语言模型](./doc/doc_ch/multi_languages.md)支持种类增加到80+。 - 2021.2.1 [FAQ](./doc/doc_ch/FAQ.md)新增5个高频问题,总数162个,每周一都会更新,欢迎大家持续关注。 -- 2021.1.26,28,29 PaddleOCR官方研发团队带来技术深入解读三日直播课,1月26日、28日、29日晚上19:30,[直播地址](https://live.bilibili.com/21689802) -- 2021.1.21 更新多语言识别模型,目前支持语种超过27种,[多语言模型下载](./doc/doc_ch/models_list.md),包括中文简体、中文繁体、英文、法文、德文、韩文、日文、意大利文、西班牙文、葡萄牙文、俄罗斯文、阿拉伯文等,后续计划可以参考[多语言研发计划](https://github.com/PaddlePaddle/PaddleOCR/issues/1048) +- 2021.1.21 更新多语言识别模型,目前支持语种超过27种,包括中文简体、中文繁体、英文、法文、德文、韩文、日文、意大利文、西班牙文、葡萄牙文、俄罗斯文、阿拉伯文等,后续计划可以参考[多语言研发计划](https://github.com/PaddlePaddle/PaddleOCR/issues/1048) - 2020.12.15 更新数据合成工具[Style-Text](./StyleText/README_ch.md),可以批量合成大量与目标场景类似的图像,在多个场景验证,效果明显提升。 - 2020.11.25 更新半自动标注工具[PPOCRLabel](./PPOCRLabel/README_ch.md),辅助开发者高效完成标注任务,输出格式与PP-OCR训练任务完美衔接。 - 2020.9.22 更新PP-OCR技术文章,https://arxiv.org/abs/2009.09941 @@ -74,11 +75,13 @@ PaddleOCR同时支持动态图与静态图两种编程范式 ## 文档教程 - [快速安装](./doc/doc_ch/installation.md) - [中文OCR模型快速使用](./doc/doc_ch/quickstart.md) +- [多语言OCR模型快速使用](./doc/doc_ch/multi_languages.md) - [代码组织结构](./doc/doc_ch/tree.md) - 算法介绍 - [文本检测](./doc/doc_ch/algorithm_overview.md) - [文本识别](./doc/doc_ch/algorithm_overview.md) - [PP-OCR Pipline](#PP-OCR) + - [端到端PGNet算法](./doc/doc_ch/pgnet.md) - 模型训练/评估 - [文本检测](./doc/doc_ch/detection.md) - [文本识别](./doc/doc_ch/recognition.md) @@ -88,7 +91,7 @@ PaddleOCR同时支持动态图与静态图两种编程范式 - [基于pip安装whl包快速推理](./doc/doc_ch/whl.md) - [基于Python脚本预测引擎推理](./doc/doc_ch/inference.md) - [基于C++预测引擎推理](./deploy/cpp_infer/readme.md) - - [服务化部署](./deploy/hubserving/readme.md) + - [服务化部署](./deploy/pdserving/README_CN.md) - [端侧部署](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/deploy/lite/readme.md) - [Benchmark](./doc/doc_ch/benchmark.md) - 数据集 diff --git a/configs/det/det_r50_vd_sast_icdar15.yml b/configs/det/det_r50_vd_sast_icdar15.yml index c24cae90132c68d662e9edb7a7975e358fb40d9c..c90327b22b9c73111c997e84cfdd47d0721ee5b9 100755 --- a/configs/det/det_r50_vd_sast_icdar15.yml +++ b/configs/det/det_r50_vd_sast_icdar15.yml @@ -14,12 +14,13 @@ Global: load_static_weights: True cal_metric_during_train: False pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/ - checkpoints: + checkpoints: save_inference_dir: use_visualdl: False - infer_img: + infer_img: save_res_path: ./output/sast_r50_vd_ic15/predicts_sast.txt + Architecture: model_type: det algorithm: SAST diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml new file mode 100644 index 0000000000000000000000000000000000000000..e4d868f98b5847fa064e14f87a69932806791320 --- /dev/null +++ b/configs/e2e/e2e_r50_vd_pg.yml @@ -0,0 +1,116 @@ +Global: + use_gpu: True + epoch_num: 600 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/pgnet_r50_vd_totaltext/ + save_epoch_step: 10 + # evaluation is run every 0 iterationss after the 1000th iteration + eval_batch_step: [ 0, 1000 ] + # 1. If pretrained_model is saved in static mode, such as classification pretrained model + # from static branch, load_static_weights must be set as True. + # 2. If you want to finetune the pretrained models we provide in the docs, + # you should set load_static_weights as False. + load_static_weights: False + cal_metric_during_train: False + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: + valid_set: totaltext # two mode: totaltext valid curved words, partvgg valid non-curved words + save_res_path: ./output/pgnet_r50_vd_totaltext/predicts_pgnet.txt + character_dict_path: ppocr/utils/ic15_dict.txt + character_type: EN + max_text_length: 50 # the max length in seq + max_text_nums: 30 # the max seq nums in a pic + tcl_len: 64 + +Architecture: + model_type: e2e + algorithm: PGNet + Transform: + Backbone: + name: ResNet + layers: 50 + Neck: + name: PGFPN + Head: + name: PGHead + +Loss: + name: PGLoss + tcl_bs: 64 + max_text_length: 50 # the same as Global: max_text_length + max_text_nums: 30 # the same as Global:max_text_nums + pad_num: 36 # the length of dict for pad + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + learning_rate: 0.001 + regularizer: + name: 'L2' + factor: 0 + + +PostProcess: + name: PGPostProcess + score_thresh: 0.5 + mode: fast # fast or slow two ways +Metric: + name: E2EMetric + gt_mat_dir: # the dir of gt_mat + character_dict_path: ppocr/utils/ic15_dict.txt + main_indicator: f_score_e2e + +Train: + dataset: + name: PGDataSet + label_file_list: [.././train_data/total_text/train/] + ratio_list: [1.0] + data_format: icdar #two data format: icdar/textnet + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - PGProcessTrain: + batch_size: 14 # same as loader: batch_size_per_card + min_crop_size: 24 + min_text_size: 4 + max_text_size: 512 + - KeepKeys: + keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order + loader: + shuffle: True + drop_last: True + batch_size_per_card: 14 + num_workers: 16 + +Eval: + dataset: + name: PGDataSet + data_dir: ./train_data/ + label_file_list: [./train_data/total_text/test/] + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - E2ELabelEncode: + - E2EResizeForTest: + max_side_len: 768 + - NormalizeImage: + scale: 1./255. + mean: [ 0.485, 0.456, 0.406 ] + std: [ 0.229, 0.224, 0.225 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags', 'img_id'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 2 \ No newline at end of file diff --git a/deploy/hubserving/readme.md b/deploy/hubserving/readme.md index ce55b0f0da42b706cef30ab9b7a4c06f02e7c8eb..88f335812de191d860e46f7317bb303a20df8b41 100755 --- a/deploy/hubserving/readme.md +++ b/deploy/hubserving/readme.md @@ -2,7 +2,7 @@ PaddleOCR提供2种服务部署方式: - 基于PaddleHub Serving的部署:代码路径为"`./deploy/hubserving`",按照本教程使用; -- (coming soon)基于PaddleServing的部署:代码路径为"`./deploy/pdserving`",使用方法参考[文档](../../deploy/pdserving/readme.md)。 +- 基于PaddleServing的部署:代码路径为"`./deploy/pdserving`",使用方法参考[文档](../../deploy/pdserving/README_CN.md)。 # 基于PaddleHub Serving的服务部署 diff --git a/deploy/hubserving/readme_en.md b/deploy/hubserving/readme_en.md index 95223ffd82f8264d56158e8e7917c983b07f679d..c948fed1eefe9f5f83f63a82699cdac3548fad52 100755 --- a/deploy/hubserving/readme_en.md +++ b/deploy/hubserving/readme_en.md @@ -2,7 +2,7 @@ English | [简体中文](readme.md) PaddleOCR provides 2 service deployment methods: - Based on **PaddleHub Serving**: Code path is "`./deploy/hubserving`". Please follow this tutorial. -- (coming soon)Based on **PaddleServing**: Code path is "`./deploy/pdserving`". Please refer to the [tutorial](../../deploy/pdserving/readme.md) for usage. +- Based on **PaddleServing**: Code path is "`./deploy/pdserving`". Please refer to the [tutorial](../../deploy/pdserving/README.md) for usage. # Service deployment based on PaddleHub Serving diff --git a/deploy/pdserving/README.md b/deploy/pdserving/README.md new file mode 100644 index 0000000000000000000000000000000000000000..88426ba9c508a4020af0a6203010d683cb73eba9 --- /dev/null +++ b/deploy/pdserving/README.md @@ -0,0 +1,158 @@ +# OCR Pipeline WebService + +(English|[简体中文](./README_CN.md)) + +PaddleOCR provides two service deployment methods: +- Based on **PaddleHub Serving**: Code path is "`./deploy/hubserving`". Please refer to the [tutorial](../../deploy/hubserving/readme_en.md) +- Based on **PaddleServing**: Code path is "`./deploy/pdserving`". Please follow this tutorial. + +# Service deployment based on PaddleServing + +This document will introduce how to use the [PaddleServing](https://github.com/PaddlePaddle/Serving/blob/develop/README.md) to deploy the PPOCR dynamic graph model as a pipeline online service. + +Some Key Features of Paddle Serving: +- Integrate with Paddle training pipeline seamlessly, most paddle models can be deployed with one line command. +- Industrial serving features supported, such as models management, online loading, online A/B testing etc. +- Highly concurrent and efficient communication between clients and servers supported. + +The introduction and tutorial of Paddle Serving service deployment framework reference [document](https://github.com/PaddlePaddle/Serving/blob/develop/README.md). + + +## Contents +- [Environmental preparation](#environmental-preparation) +- [Model conversion](#model-conversion) +- [Paddle Serving pipeline deployment](#paddle-serving-pipeline-deployment) +- [FAQ](#faq) + + +## Environmental preparation + +PaddleOCR operating environment and Paddle Serving operating environment are needed. + +1. Please prepare PaddleOCR operating environment reference [link](../../doc/doc_ch/installation.md). + +2. The steps of PaddleServing operating environment prepare are as follows: + + Install serving which used to start the service + ``` + pip3 install paddle-serving-server==0.5.0 # for CPU + pip3 install paddle-serving-server-gpu==0.5.0 # for GPU + # Other GPU environments need to confirm the environment and then choose to execute the following commands + pip3 install paddle-serving-server-gpu==0.5.0.post9 # GPU with CUDA9.0 + pip3 install paddle-serving-server-gpu==0.5.0.post10 # GPU with CUDA10.0 + pip3 install paddle-serving-server-gpu==0.5.0.post101 # GPU with CUDA10.1 + TensorRT6 + pip3 install paddle-serving-server-gpu==0.5.0.post11 # GPU with CUDA10.1 + TensorRT7 + ``` + +3. Install the client to send requests to the service + ``` + pip3 install paddle-serving-client==0.5.0 # for CPU + + pip3 install paddle-serving-client-gpu==0.5.0 # for GPU + ``` + +4. Install serving-app + ``` + pip3 install paddle-serving-app==0.3.0 + # fix local_predict to support load dynamic model + # find the install directoory of paddle_serving_app + vim /usr/local/lib/python3.7/site-packages/paddle_serving_app/local_predict.py + # replace line 85 of local_predict.py config = AnalysisConfig(model_path) with: + if os.path.exists(os.path.join(model_path, "__params__")): + config = AnalysisConfig(os.path.join(model_path, "__model__"), os.path.join(model_path, "__params__")) + else: + config = AnalysisConfig(model_path) + ``` + + **note:** If you want to install the latest version of PaddleServing, refer to [link](https://github.com/PaddlePaddle/Serving/blob/develop/doc/LATEST_PACKAGES.md). + + + +## Model conversion +When using PaddleServing for service deployment, you need to convert the saved inference model into a serving model that is easy to deploy. + +Firstly, download the [inference model](https://github.com/PaddlePaddle/PaddleOCR#pp-ocr-20-series-model-listupdate-on-dec-15) of PPOCR +``` +# Download and unzip the OCR text detection model +wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_ppocr_server_v2.0_det_infer.tar +# Download and unzip the OCR text recognition model +wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar + +``` +Then, you can use installed paddle_serving_client tool to convert inference model to server model. +``` +# Detection model conversion +python3 -m paddle_serving_client.convert --dirname ./ch_ppocr_server_v2.0_det_infer/ \ + --model_filename inference.pdmodel \ + --params_filename inference.pdiparams \ + --serving_server ./ppocr_det_server_2.0_serving/ \ + --serving_client ./ppocr_det_server_2.0_client/ + +# Recognition model conversion +python3 -m paddle_serving_client.convert --dirname ./ch_ppocr_server_v2.0_rec_infer/ \ + --model_filename inference.pdmodel \ + --params_filename inference.pdiparams \ + --serving_server ./ppocr_rec_server_2.0_serving/ \ + --serving_client ./ppocr_rec_server_2.0_client/ + +``` + +After the detection model is converted, there will be additional folders of `ppocr_det_server_2.0_serving` and `ppocr_det_server_2.0_client` in the current folder, with the following format: +``` +|- ppocr_det_server_2.0_serving/ + |- __model__ + |- __params__ + |- serving_server_conf.prototxt + |- serving_server_conf.stream.prototxt + +|- ppocr_det_server_2.0_client + |- serving_client_conf.prototxt + |- serving_client_conf.stream.prototxt + +``` +The recognition model is the same. + + +## Paddle Serving pipeline deployment + +1. Download the PaddleOCR code, if you have already downloaded it, you can skip this step. + ``` + git clone https://github.com/PaddlePaddle/PaddleOCR + + # Enter the working directory + cd PaddleOCR/deploy/pdserver/ + ``` + + The pdserver directory contains the code to start the pipeline service and send prediction requests, including: + ``` + __init__.py + config.yml # Start the service configuration file + ocr_reader.py # OCR model pre-processing and post-processing code implementation + pipeline_http_client.py # Script to send pipeline prediction request + web_service.py # Start the script of the pipeline server + ``` + +2. Run the following command to start the service. + ``` + # Start the service and save the running log in log.txt + python3 web_service.py &>log.txt & + ``` + After the service is successfully started, a log similar to the following will be printed in log.txt + ![](./imgs/start_server.png) + +3. Send service request + ``` + python3 pipeline_http_client.py + ``` + After successfully running, the predicted result of the model will be printed in the cmd window. An example of the result is: + ![](./imgs/results.png) + + +## FAQ +**Q1**: No result return after sending the request. + +**A1**: Do not set the proxy when starting the service and sending the request. You can close the proxy before starting the service and before sending the request. The command to close the proxy is: +``` +unset https_proxy +unset http_proxy +``` diff --git a/deploy/pdserving/README_CN.md b/deploy/pdserving/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..3e3f1bde0e824fe6133a1c169b9b03e614904c26 --- /dev/null +++ b/deploy/pdserving/README_CN.md @@ -0,0 +1,160 @@ +# PPOCR 服务化部署 + +([English](./README.md)|简体中文) + +PaddleOCR提供2种服务部署方式: +- 基于PaddleHub Serving的部署:代码路径为"`./deploy/hubserving`",使用方法参考[文档](../../deploy/hubserving/readme.md); +- 基于PaddleServing的部署:代码路径为"`./deploy/pdserving`",按照本教程使用。 + +# 基于PaddleServing的服务部署 + +本文档将介绍如何使用[PaddleServing](https://github.com/PaddlePaddle/Serving/blob/develop/README_CN.md)工具部署PPOCR +动态图模型的pipeline在线服务。 + +相比较于hubserving部署,PaddleServing具备以下优点: +- 支持客户端和服务端之间高并发和高效通信 +- 支持 工业级的服务能力 例如模型管理,在线加载,在线A/B测试等 +- 支持 多种编程语言 开发客户端,例如C++, Python和Java + +更多有关PaddleServing服务化部署框架介绍和使用教程参考[文档](https://github.com/PaddlePaddle/Serving/blob/develop/README_CN.md)。 + +## 目录 +- [环境准备](#环境准备) +- [模型转换](#模型转换) +- [Paddle Serving pipeline部署](#部署) +- [FAQ](#FAQ) + + +## 环境准备 + +需要准备PaddleOCR的运行环境和Paddle Serving的运行环境。 + +- 准备PaddleOCR的运行环境参考[链接](../../doc/doc_ch/installation.md) + +- 准备PaddleServing的运行环境,步骤如下 + +1. 安装serving,用于启动服务 + ``` + pip3 install paddle-serving-server==0.5.0 # for CPU + pip3 install paddle-serving-server-gpu==0.5.0 # for GPU + # 其他GPU环境需要确认环境再选择执行如下命令 + pip3 install paddle-serving-server-gpu==0.5.0.post9 # GPU with CUDA9.0 + pip3 install paddle-serving-server-gpu==0.5.0.post10 # GPU with CUDA10.0 + pip3 install paddle-serving-server-gpu==0.5.0.post101 # GPU with CUDA10.1 + TensorRT6 + pip3 install paddle-serving-server-gpu==0.5.0.post11 # GPU with CUDA10.1 + TensorRT7 + ``` + +2. 安装client,用于向服务发送请求 + ``` + pip3 install paddle-serving-client==0.5.0 # for CPU + + pip3 install paddle-serving-client-gpu==0.5.0 # for GPU + ``` + +3. 安装serving-app + ``` + pip3 install paddle-serving-app==0.3.0 + ``` + **note:** 安装0.3.0版本的serving-app后,为了能加载动态图模型,需要修改serving_app的源码,具体为: + ``` + # 找到paddle_serving_app的安装目录,找到并编辑local_predict.py文件 + vim /usr/local/lib/python3.7/site-packages/paddle_serving_app/local_predict.py + # 将local_predict.py 的第85行 config = AnalysisConfig(model_path) 替换为: + if os.path.exists(os.path.join(model_path, "__params__")): + config = AnalysisConfig(os.path.join(model_path, "__model__"), os.path.join(model_path, "__params__")) + else: + config = AnalysisConfig(model_path) + ``` + + **Note:** 如果要安装最新版本的PaddleServing参考[链接](https://github.com/PaddlePaddle/Serving/blob/develop/doc/LATEST_PACKAGES.md)。 + + +## 模型转换 + +使用PaddleServing做服务化部署时,需要将保存的inference模型转换为serving易于部署的模型。 + +首先,下载PPOCR的[inference模型](https://github.com/PaddlePaddle/PaddleOCR#pp-ocr-20-series-model-listupdate-on-dec-15) +``` +# 下载并解压 OCR 文本检测模型 +wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_ppocr_server_v2.0_det_infer.tar +# 下载并解压 OCR 文本识别模型 +wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar +``` + +接下来,用安装的paddle_serving_client把下载的inference模型转换成易于server部署的模型格式。 + +``` +# 转换检测模型 +python3 -m paddle_serving_client.convert --dirname ./ch_ppocr_server_v2.0_det_infer/ \ + --model_filename inference.pdmodel \ + --params_filename inference.pdiparams \ + --serving_server ./ppocr_det_server_2.0_serving/ \ + --serving_client ./ppocr_det_server_2.0_client/ + +# 转换识别模型 +python3 -m paddle_serving_client.convert --dirname ./ch_ppocr_server_v2.0_rec_infer/ \ + --model_filename inference.pdmodel \ + --params_filename inference.pdiparams \ + --serving_server ./ppocr_rec_server_2.0_serving/ \ + --serving_client ./ppocr_rec_server_2.0_client/ +``` + +检测模型转换完成后,会在当前文件夹多出`ppocr_det_server_2.0_serving` 和`ppocr_det_server_2.0_client`的文件夹,具备如下格式: +``` +|- ppocr_det_server_2.0_serving/ + |- __model__ + |- __params__ + |- serving_server_conf.prototxt + |- serving_server_conf.stream.prototxt + +|- ppocr_det_server_2.0_client + |- serving_client_conf.prototxt + |- serving_client_conf.stream.prototxt + +``` +识别模型同理。 + + +## Paddle Serving pipeline部署 + +1. 下载PaddleOCR代码,若已下载可跳过此步骤 + ``` + git clone https://github.com/PaddlePaddle/PaddleOCR + + # 进入到工作目录 + cd PaddleOCR/deploy/pdserver/ + ``` + pdserver目录包含启动pipeline服务和发送预测请求的代码,包括: + ``` + __init__.py + config.yml # 启动服务的配置文件 + ocr_reader.py # OCR模型预处理和后处理的代码实现 + pipeline_http_client.py # 发送pipeline预测请求的脚本 + web_service.py # 启动pipeline服务端的脚本 + ``` + +2. 启动服务可运行如下命令: + ``` + # 启动服务,运行日志保存在log.txt + python3 web_service.py &>log.txt & + ``` + 成功启动服务后,log.txt中会打印类似如下日志 + ![](./imgs/start_server.png) + +3. 发送服务请求: + ``` + python3 pipeline_http_client.py + ``` + 成功运行后,模型预测的结果会打印在cmd窗口中,结果示例为: + ![](./imgs/results.png) + + + +## FAQ +**Q1**: 发送请求后没有结果返回或者提示输出解码报错 + +**A1**: 启动服务和发送请求时不要设置代理,可以在启动服务前和发送请求前关闭代理,关闭代理的命令是: +``` +unset https_proxy +unset http_proxy +``` diff --git a/deploy/pdserving/__init__.py b/deploy/pdserving/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..185a92b8d94d3426d616c0624f0f2ee04339349e --- /dev/null +++ b/deploy/pdserving/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/deploy/pdserving/config.yml b/deploy/pdserving/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..aef735dbfab5b314f9209a7cc91e7fd5b6fc615c --- /dev/null +++ b/deploy/pdserving/config.yml @@ -0,0 +1,71 @@ +#rpc端口, rpc_port和http_port不允许同时为空。当rpc_port为空且http_port不为空时,会自动将rpc_port设置为http_port+1 +rpc_port: 18090 + +#http端口, rpc_port和http_port不允许同时为空。当rpc_port可用且http_port为空时,不自动生成http_port +http_port: 9999 + +#worker_num, 最大并发数。当build_dag_each_worker=True时, 框架会创建worker_num个进程,每个进程内构建grpcSever和DAG +##当build_dag_each_worker=False时,框架会设置主线程grpc线程池的max_workers=worker_num +worker_num: 20 + +#build_dag_each_worker, False,框架在进程内创建一条DAG;True,框架会每个进程内创建多个独立的DAG +build_dag_each_worker: false + +dag: + #op资源类型, True, 为线程模型;False,为进程模型 + is_thread_op: False + + #重试次数 + retry: 1 + + #使用性能分析, True,生成Timeline性能数据,对性能有一定影响;False为不使用 + use_profile: False + + tracer: + interval_s: 10 +op: + det: + #并发数,is_thread_op=True时,为线程并发;否则为进程并发 + concurrency: 4 + + #当op配置没有server_endpoints时,从local_service_conf读取本地服务配置 + local_service_conf: + #client类型,包括brpc, grpc和local_predictor.local_predictor不启动Serving服务,进程内预测 + client_type: local_predictor + + #det模型路径 + model_config: /paddle/serving/models/det_serving_server/ #ocr_det_model + + #Fetch结果列表,以client_config中fetch_var的alias_name为准 + fetch_list: ["save_infer_model/scale_0.tmp_1"] + + #计算硬件ID,当devices为""或不写时为CPU预测;当devices为"0", "0,1,2"时为GPU预测,表示使用的GPU卡 + devices: "2" + + ir_optim: True + rec: + #并发数,is_thread_op=True时,为线程并发;否则为进程并发 + concurrency: 1 + + #超时时间, 单位ms + timeout: -1 + + #Serving交互重试次数,默认不重试 + retry: 1 + + #当op配置没有server_endpoints时,从local_service_conf读取本地服务配置 + local_service_conf: + + #client类型,包括brpc, grpc和local_predictor。local_predictor不启动Serving服务,进程内预测 + client_type: local_predictor + + #rec模型路径 + model_config: /paddle/serving/models/rec_serving_server/ #ocr_rec_model + + #Fetch结果列表,以client_config中fetch_var的alias_name为准 + fetch_list: ["save_infer_model/scale_0.tmp_1"] #["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] + + #计算硬件ID,当devices为""或不写时为CPU预测;当devices为"0", "0,1,2"时为GPU预测,表示使用的GPU卡 + devices: "2" + + ir_optim: True diff --git a/deploy/imgs/cpp_infer_pred_12.png b/deploy/pdserving/imgs/cpp_infer_pred_12.png similarity index 100% rename from deploy/imgs/cpp_infer_pred_12.png rename to deploy/pdserving/imgs/cpp_infer_pred_12.png diff --git a/deploy/imgs/demo.png b/deploy/pdserving/imgs/demo.png similarity index 100% rename from deploy/imgs/demo.png rename to deploy/pdserving/imgs/demo.png diff --git a/deploy/pdserving/imgs/results.png b/deploy/pdserving/imgs/results.png new file mode 100644 index 0000000000000000000000000000000000000000..35322bf9462c859bb2d158dd1d50ab52f35bad41 Binary files /dev/null and b/deploy/pdserving/imgs/results.png differ diff --git a/deploy/pdserving/imgs/start_server.png b/deploy/pdserving/imgs/start_server.png new file mode 100644 index 0000000000000000000000000000000000000000..60e19ccaed6bd4382f5342101a54d8957879bd60 Binary files /dev/null and b/deploy/pdserving/imgs/start_server.png differ diff --git a/deploy/pdserving/ocr_reader.py b/deploy/pdserving/ocr_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..95110706af13662de11ef0f668558d0dd3abcf52 --- /dev/null +++ b/deploy/pdserving/ocr_reader.py @@ -0,0 +1,438 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import copy +import numpy as np +import math +import re +import sys +import argparse +import string +from copy import deepcopy +import paddle + + +class DetResizeForTest(object): + def __init__(self, **kwargs): + super(DetResizeForTest, self).__init__() + self.resize_type = 0 + if 'image_shape' in kwargs: + self.image_shape = kwargs['image_shape'] + self.resize_type = 1 + elif 'limit_side_len' in kwargs: + self.limit_side_len = kwargs['limit_side_len'] + self.limit_type = kwargs.get('limit_type', 'min') + elif 'resize_long' in kwargs: + self.resize_type = 2 + self.resize_long = kwargs.get('resize_long', 960) + else: + self.limit_side_len = 736 + self.limit_type = 'min' + + def __call__(self, data): + img = deepcopy(data) + src_h, src_w, _ = img.shape + + if self.resize_type == 0: + img, [ratio_h, ratio_w] = self.resize_image_type0(img) + elif self.resize_type == 2: + img, [ratio_h, ratio_w] = self.resize_image_type2(img) + else: + img, [ratio_h, ratio_w] = self.resize_image_type1(img) + + return img + + def resize_image_type1(self, img): + resize_h, resize_w = self.image_shape + ori_h, ori_w = img.shape[:2] # (h, w, c) + ratio_h = float(resize_h) / ori_h + ratio_w = float(resize_w) / ori_w + img = cv2.resize(img, (int(resize_w), int(resize_h))) + return img, [ratio_h, ratio_w] + + def resize_image_type0(self, img): + """ + resize image to a size multiple of 32 which is required by the network + args: + img(array): array with shape [h, w, c] + return(tuple): + img, (ratio_h, ratio_w) + """ + limit_side_len = self.limit_side_len + h, w, _ = img.shape + + # limit the max side + if self.limit_type == 'max': + if max(h, w) > limit_side_len: + if h > w: + ratio = float(limit_side_len) / h + else: + ratio = float(limit_side_len) / w + else: + ratio = 1. + else: + if min(h, w) < limit_side_len: + if h < w: + ratio = float(limit_side_len) / h + else: + ratio = float(limit_side_len) / w + else: + ratio = 1. + resize_h = int(h * ratio) + resize_w = int(w * ratio) + + resize_h = int(round(resize_h / 32) * 32) + resize_w = int(round(resize_w / 32) * 32) + + try: + if int(resize_w) <= 0 or int(resize_h) <= 0: + return None, (None, None) + img = cv2.resize(img, (int(resize_w), int(resize_h))) + except: + print(img.shape, resize_w, resize_h) + sys.exit(0) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + # return img, np.array([h, w]) + return img, [ratio_h, ratio_w] + + def resize_image_type2(self, img): + h, w, _ = img.shape + + resize_w = w + resize_h = h + + # Fix the longer side + if resize_h > resize_w: + ratio = float(self.resize_long) / resize_h + else: + ratio = float(self.resize_long) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + img = cv2.resize(img, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + + return img, [ratio_h, ratio_w] + + +class BaseRecLabelDecode(object): + """ Convert between text-label and text-index """ + + def __init__(self, config): + support_character_type = [ + 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean', + 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc', + 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr', + 'ne', 'EN' + ] + character_type = config['character_type'] + character_dict_path = config['character_dict_path'] + use_space_char = True + assert character_type in support_character_type, "Only {} are supported now but get {}".format( + support_character_type, character_type) + + self.beg_str = "sos" + self.end_str = "eos" + + if character_type == "en": + self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" + dict_character = list(self.character_str) + elif character_type == "EN_symbol": + # same with ASTER setting (use 94 char). + self.character_str = string.printable[:-6] + dict_character = list(self.character_str) + elif character_type in support_character_type: + self.character_str = "" + assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format( + character_type) + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip("\n").strip("\r\n") + self.character_str += line + if use_space_char: + self.character_str += " " + dict_character = list(self.character_str) + + else: + raise NotImplementedError + self.character_type = character_type + dict_character = self.add_special_char(dict_character) + self.dict = {} + for i, char in enumerate(dict_character): + self.dict[char] = i + self.character = dict_character + + def add_special_char(self, dict_character): + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + ignored_tokens = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] in ignored_tokens: + continue + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + result_list.append((text, np.mean(conf_list))) + return result_list + + def get_ignored_tokens(self): + return [0] # for ctc blank + + +class CTCLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__( + self, + config, + #character_dict_path=None, + #character_type='ch', + #use_space_char=False, + **kwargs): + super(CTCLabelDecode, self).__init__(config) + + def __call__(self, preds, label=None, *args, **kwargs): + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) + if label is None: + return text + label = self.decode(label) + return text, label + + def add_special_char(self, dict_character): + dict_character = ['blank'] + dict_character + return dict_character + + +class CharacterOps(object): + """ Convert between text-label and text-index """ + + def __init__(self, config): + self.character_type = config['character_type'] + self.loss_type = config['loss_type'] + if self.character_type == "en": + self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" + dict_character = list(self.character_str) + elif self.character_type == "ch": + character_dict_path = config['character_dict_path'] + self.character_str = "" + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip("\n").strip("\r\n") + self.character_str += line + dict_character = list(self.character_str) + elif self.character_type == "en_sensitive": + # same with ASTER setting (use 94 char). + self.character_str = string.printable[:-6] + dict_character = list(self.character_str) + else: + self.character_str = None + assert self.character_str is not None, \ + "Nonsupport type of the character: {}".format(self.character_str) + self.beg_str = "sos" + self.end_str = "eos" + if self.loss_type == "attention": + dict_character = [self.beg_str, self.end_str] + dict_character + self.dict = {} + for i, char in enumerate(dict_character): + self.dict[char] = i + self.character = dict_character + + def encode(self, text): + """convert text-label into text-index. + input: + text: text labels of each image. [batch_size] + + output: + text: concatenated text index for CTCLoss. + [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] + length: length of each text. [batch_size] + """ + if self.character_type == "en": + text = text.lower() + + text_list = [] + for char in text: + if char not in self.dict: + continue + text_list.append(self.dict[char]) + text = np.array(text_list) + return text + + def decode(self, text_index, is_remove_duplicate=False): + """ convert text-index into text-label. """ + char_list = [] + char_num = self.get_char_num() + + if self.loss_type == "attention": + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + ignored_tokens = [beg_idx, end_idx] + else: + ignored_tokens = [char_num] + + for idx in range(len(text_index)): + if text_index[idx] in ignored_tokens: + continue + if is_remove_duplicate: + if idx > 0 and text_index[idx - 1] == text_index[idx]: + continue + char_list.append(self.character[text_index[idx]]) + text = ''.join(char_list) + return text + + def get_char_num(self): + return len(self.character) + + def get_beg_end_flag_idx(self, beg_or_end): + if self.loss_type == "attention": + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx"\ + % beg_or_end + return idx + else: + err = "error in get_beg_end_flag_idx when using the loss %s"\ + % (self.loss_type) + assert False, err + + +class OCRReader(object): + def __init__(self, + algorithm="CRNN", + image_shape=[3, 32, 320], + char_type="ch", + batch_num=1, + char_dict_path="./ppocr_keys_v1.txt"): + self.rec_image_shape = image_shape + self.character_type = char_type + self.rec_batch_num = batch_num + char_ops_params = {} + char_ops_params["character_type"] = char_type + char_ops_params["character_dict_path"] = char_dict_path + char_ops_params['loss_type'] = 'ctc' + self.char_ops = CharacterOps(char_ops_params) + self.label_ops = CTCLabelDecode(char_ops_params) + + def resize_norm_img(self, img, max_wh_ratio): + imgC, imgH, imgW = self.rec_image_shape + if self.character_type == "ch": + imgW = int(32 * max_wh_ratio) + h = img.shape[0] + w = img.shape[1] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + def preprocess(self, img_list): + img_num = len(img_list) + norm_img_batch = [] + max_wh_ratio = 0 + for ino in range(img_num): + h, w = img_list[ino].shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + + for ino in range(img_num): + norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + norm_img_batch = np.concatenate(norm_img_batch) + norm_img_batch = norm_img_batch.copy() + + return norm_img_batch[0] + + def postprocess_old(self, outputs, with_score=False): + rec_res = [] + rec_idx_lod = outputs["ctc_greedy_decoder_0.tmp_0.lod"] + rec_idx_batch = outputs["ctc_greedy_decoder_0.tmp_0"] + if with_score: + predict_lod = outputs["softmax_0.tmp_0.lod"] + for rno in range(len(rec_idx_lod) - 1): + beg = rec_idx_lod[rno] + end = rec_idx_lod[rno + 1] + if isinstance(rec_idx_batch, list): + rec_idx_tmp = [x[0] for x in rec_idx_batch[beg:end]] + else: #nd array + rec_idx_tmp = rec_idx_batch[beg:end, 0] + preds_text = self.char_ops.decode(rec_idx_tmp) + if with_score: + beg = predict_lod[rno] + end = predict_lod[rno + 1] + if isinstance(outputs["softmax_0.tmp_0"], list): + outputs["softmax_0.tmp_0"] = np.array(outputs[ + "softmax_0.tmp_0"]).astype(np.float32) + probs = outputs["softmax_0.tmp_0"][beg:end, :] + ind = np.argmax(probs, axis=1) + blank = probs.shape[1] + valid_ind = np.where(ind != (blank - 1))[0] + score = np.mean(probs[valid_ind, ind[valid_ind]]) + rec_res.append([preds_text, score]) + else: + rec_res.append([preds_text]) + return rec_res + + def postprocess(self, outputs, with_score=False): + preds = outputs["save_infer_model/scale_0.tmp_1"] + try: + preds = preds.numpy() + except: + pass + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.label_ops.decode( + preds_idx, preds_prob, is_remove_duplicate=True) + return text diff --git a/deploy/pdserving/pipeline_http_client.py b/deploy/pdserving/pipeline_http_client.py new file mode 100644 index 0000000000000000000000000000000000000000..88c4a81ea8bbed80d37b5fbfea6bf01b38f9613a --- /dev/null +++ b/deploy/pdserving/pipeline_http_client.py @@ -0,0 +1,40 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import requests +import json +import base64 +import os + + +def cv2_to_base64(image): + return base64.b64encode(image).decode('utf8') + + +url = "http://127.0.0.1:9999/ocr/prediction" +test_img_dir = "../doc/imgs/" +for idx, img_file in enumerate(os.listdir(test_img_dir)): + with open(os.path.join(test_img_dir, img_file), 'rb') as file: + image_data1 = file.read() + + image = cv2_to_base64(image_data1) + + for i in range(1): + data = {"key": ["image"], "value": [image]} + r = requests.post(url=url, data=json.dumps(data)) + print(r.json()) + +test_img_dir = "../doc/imgs/" +print("==> total number of test imgs: ", len(os.listdir(test_img_dir))) diff --git a/deploy/pdserving/pipeline_rpc_client.py b/deploy/pdserving/pipeline_rpc_client.py new file mode 100644 index 0000000000000000000000000000000000000000..7471f7ed6c1254d550bcf2c19f6ee7c610a2e20e --- /dev/null +++ b/deploy/pdserving/pipeline_rpc_client.py @@ -0,0 +1,42 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + from paddle_serving_server_gpu.pipeline import PipelineClient +except ImportError: + from paddle_serving_server.pipeline import PipelineClient +import numpy as np +import requests +import json +import cv2 +import base64 +import os + +client = PipelineClient() +client.connect(['127.0.0.1:18090']) + + +def cv2_to_base64(image): + return base64.b64encode(image).decode('utf8') + + +test_img_dir = "imgs/" +for img_file in os.listdir(test_img_dir): + with open(os.path.join(test_img_dir, img_file), 'rb') as file: + image_data = file.read() + image = cv2_to_base64(image_data) + +for i in range(1): + ret = client.predict(feed_dict={"image": image}, fetch=["res"]) + print(ret) + #print(ret) diff --git a/deploy/pdserving/web_service.py b/deploy/pdserving/web_service.py new file mode 100644 index 0000000000000000000000000000000000000000..b47ef65d09dd7aad0e4d00ca852a5c32161ad45b --- /dev/null +++ b/deploy/pdserving/web_service.py @@ -0,0 +1,127 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + from paddle_serving_server_gpu.web_service import WebService, Op +except ImportError: + from paddle_serving_server.web_service import WebService, Op + +import logging +import numpy as np +import cv2 +import base64 +# from paddle_serving_app.reader import OCRReader +from ocr_reader import OCRReader, DetResizeForTest +from paddle_serving_app.reader import Sequential, ResizeByFactor +from paddle_serving_app.reader import Div, Normalize, Transpose +from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes + +_LOGGER = logging.getLogger() + + +class DetOp(Op): + def init_op(self): + self.det_preprocess = Sequential([ + DetResizeForTest(), Div(255), + Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose( + (2, 0, 1)) + ]) + self.filter_func = FilterBoxes(10, 10) + self.post_func = DBPostProcess({ + "thresh": 0.3, + "box_thresh": 0.5, + "max_candidates": 1000, + "unclip_ratio": 1.5, + "min_size": 3 + }) + + def preprocess(self, input_dicts, data_id, log_id): + (_, input_dict), = input_dicts.items() + data = base64.b64decode(input_dict["image"].encode('utf8')) + data = np.fromstring(data, np.uint8) + # Note: class variables(self.var) can only be used in process op mode + im = cv2.imdecode(data, cv2.IMREAD_COLOR) + self.im = im + self.ori_h, self.ori_w, _ = im.shape + + det_img = self.det_preprocess(self.im) + _, self.new_h, self.new_w = det_img.shape + print("det image shape", det_img.shape) + return {"x": det_img[np.newaxis, :].copy()}, False, None, "" + + def postprocess(self, input_dicts, fetch_dict, log_id): + print("input_dicts: ", input_dicts) + det_out = fetch_dict["save_infer_model/scale_0.tmp_1"] + ratio_list = [ + float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w + ] + dt_boxes_list = self.post_func(det_out, [ratio_list]) + dt_boxes = self.filter_func(dt_boxes_list[0], [self.ori_h, self.ori_w]) + out_dict = {"dt_boxes": dt_boxes, "image": self.im} + + print("out dict", out_dict["dt_boxes"]) + return out_dict, None, "" + + +class RecOp(Op): + def init_op(self): + self.ocr_reader = OCRReader( + char_dict_path="../../ppocr/utils/ppocr_keys_v1.txt") + + self.get_rotate_crop_image = GetRotateCropImage() + self.sorted_boxes = SortedBoxes() + + def preprocess(self, input_dicts, data_id, log_id): + (_, input_dict), = input_dicts.items() + im = input_dict["image"] + dt_boxes = input_dict["dt_boxes"] + dt_boxes = self.sorted_boxes(dt_boxes) + feed_list = [] + img_list = [] + max_wh_ratio = 0 + for i, dtbox in enumerate(dt_boxes): + boximg = self.get_rotate_crop_image(im, dt_boxes[i]) + img_list.append(boximg) + h, w = boximg.shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + _, w, h = self.ocr_reader.resize_norm_img(img_list[0], + max_wh_ratio).shape + + imgs = np.zeros((len(img_list), 3, w, h)).astype('float32') + for id, img in enumerate(img_list): + norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio) + imgs[id] = norm_img + print("rec image shape", imgs.shape) + feed = {"x": imgs.copy()} + return feed, False, None, "" + + def postprocess(self, input_dicts, fetch_dict, log_id): + rec_res = self.ocr_reader.postprocess(fetch_dict, with_score=True) + res_lst = [] + for res in rec_res: + res_lst.append(res[0]) + res = {"res": str(res_lst)} + return res, None, "" + + +class OcrService(WebService): + def get_pipeline_response(self, read_op): + det_op = DetOp(name="det", input_ops=[read_op]) + rec_op = RecOp(name="rec", input_ops=[det_op]) + return rec_op + + +uci_service = OcrService(name="ocr") +uci_service.prepare_pipeline_config("config.yml") +uci_service.run_service() diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index c8fc280d80056395bbc841a973004b06844b1214..19d7a69c7fb08a8e7fb36c3043aa211de19b9295 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -28,7 +28,9 @@ PaddleOCR开源的文本检测算法列表: | --- | --- | --- | --- | --- | --- | |SAST|ResNet50_vd|89.63%|78.44%|83.66%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)| -**说明:** SAST模型训练额外加入了icdar2013、icdar2017、COCO-Text、ArT等公开数据集进行调优。PaddleOCR用到的经过整理格式的英文公开数据集下载:[百度云地址](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (提取码: 2bpi) +**说明:** SAST模型训练额外加入了icdar2013、icdar2017、COCO-Text、ArT等公开数据集进行调优。PaddleOCR用到的经过整理格式的英文公开数据集下载: +* [百度云地址](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (提取码: 2bpi) +* [Google Drive下载地址](https://drive.google.com/drive/folders/1ll2-XEVyCQLpJjawLDiRlvo_i4BqHCJe?usp=sharing) PaddleOCR文本检测算法的训练和使用请参考文档教程中[模型训练/评估中的文本检测部分](./detection.md)。 diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md index 7968b355ea936d465b3c173c0fcdb3e08f12f16e..f0f7401538a9f8940f671fdcc170aca6c003040d 100755 --- a/doc/doc_ch/inference.md +++ b/doc/doc_ch/inference.md @@ -12,7 +12,7 @@ inference 模型(`paddle.jit.save`保存的模型) - [一、训练模型转inference模型](#训练模型转inference模型) - [检测模型转inference模型](#检测模型转inference模型) - [识别模型转inference模型](#识别模型转inference模型) - - [方向分类模型转inference模型](#方向分类模型转inference模型) + - [方向分类模型转inference模型](#方向分类模型转inference模型) - [二、文本检测模型推理](#文本检测模型推理) - [1. 超轻量中文检测模型推理](#超轻量中文检测模型推理) diff --git a/doc/doc_ch/multi_languages.md b/doc/doc_ch/multi_languages.md new file mode 100644 index 0000000000000000000000000000000000000000..eec09535e7242f62cdebda97cee11086e45c1096 --- /dev/null +++ b/doc/doc_ch/multi_languages.md @@ -0,0 +1,313 @@ +# 多语言模型 + +**近期更新** + +- 2021.4.9 支持**80种**语言的检测和识别 +- 2021.4.9 支持**轻量高精度**英文模型检测识别 + +PaddleOCR 旨在打造一套丰富、领先、且实用的OCR工具库,不仅提供了通用场景下的中英文模型,也提供了专门在英文场景下训练的模型, +和覆盖[80个语言](#语种缩写)的小语种模型。 + +其中英文模型支持,大小写字母和常见标点的检测识别,并优化了空格字符的识别: + +
+ +
+ +小语种模型覆盖了拉丁语系、阿拉伯语系、中文繁体、韩语、日语等等: + +
+ + +
+ + +本文档将简要介绍小语种模型的使用方法。 + +- [1 安装](#安装) + - [1.1 paddle 安装](#paddle安装) + - [1.2 paddleocr package 安装](#paddleocr_package_安装) + +- [2 快速使用](#快速使用) + - [2.1 命令行运行](#命令行运行) + - [2.1.1 整图预测](#bash_检测+识别) + - [2.1.2 识别预测](#bash_识别) + - [2.1.3 检测预测](#bash_检测) + - [2.2 python 脚本运行](#python_脚本运行) + - [2.2.1 整图预测](#python_检测+识别) + - [2.2.2 识别预测](#python_识别) + - [2.2.3 检测预测](#python_检测) +- [3 自定义训练](#自定义训练) +- [4 支持语种及缩写](#语种缩写) + + +## 1 安装 + + +### 1.1 paddle 安装 +``` +# cpu +pip install paddlepaddle + +# gpu +pip instll paddlepaddle-gpu +``` + + +### 1.2 paddleocr package 安装 + + +pip 安装 +``` +pip install "paddleocr>=2.0.4" # 推荐使用2.0.4版本 +``` +本地构建并安装 +``` +python3 setup.py bdist_wheel +pip3 install dist/paddleocr-x.x.x-py3-none-any.whl # x.x.x是paddleocr的版本号 +``` + + +## 2 快速使用 + + +### 2.1 命令行运行 + +查看帮助信息 + +``` +paddleocr -h +``` + +* 整图预测(检测+识别) + +Paddleocr目前支持80个语种,可以通过修改--lang参数进行切换,具体支持的[语种](#语种缩写)可查看表格。 + +``` bash + +paddleocr --image_dir doc/imgs/japan_2.jpg --lang=japan +``` + +
+ +
+ + +结果是一个list,每个item包含了文本框,文字和识别置信度 +```text +[[[671.0, 60.0], [847.0, 63.0], [847.0, 104.0], [671.0, 102.0]], ('もちもち', 0.9993342)] +[[[394.0, 82.0], [536.0, 77.0], [538.0, 127.0], [396.0, 132.0]], ('天然の', 0.9919842)] +[[[880.0, 89.0], [1014.0, 93.0], [1013.0, 127.0], [879.0, 124.0]], ('とろっと', 0.9976762)] +[[[1067.0, 101.0], [1294.0, 101.0], [1294.0, 138.0], [1067.0, 138.0]], ('後味のよい', 0.9988712)] +...... +``` + +* 识别预测 + +```bash +paddleocr --image_dir doc/imgs_words/japan/1.jpg --det false --lang=japan +``` + +![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.0/doc/imgs_words/japan/1.jpg) + +结果是一个tuple,返回识别结果和识别置信度 + +```text +('したがって', 0.99965394) +``` + +* 检测预测 + +``` +paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --rec false +``` + +结果是一个list,每个item只包含文本框 + +``` +[[26.0, 457.0], [137.0, 457.0], [137.0, 477.0], [26.0, 477.0]] +[[25.0, 425.0], [372.0, 425.0], [372.0, 448.0], [25.0, 448.0]] +[[128.0, 397.0], [273.0, 397.0], [273.0, 414.0], [128.0, 414.0]] +...... +``` + + +### 2.2 python 脚本运行 + +ppocr 也支持在python脚本中运行,便于嵌入到您自己的代码中: + +* 整图预测(检测+识别) + +``` +from paddleocr import PaddleOCR, draw_ocr + +# 同样也是通过修改 lang 参数切换语种 +ocr = PaddleOCR(lang="korean") # 首次执行会自动下载模型文件 +img_path = 'doc/imgs/korean_1.jpg ' +result = ocr.ocr(img_path) +# 打印检测框和识别结果 +for line in result: + print(line) + +# 可视化 +from PIL import Image +image = Image.open(img_path).convert('RGB') +boxes = [line[0] for line in result] +txts = [line[1][0] for line in result] +scores = [line[1][1] for line in result] +im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/korean.ttf') +im_show = Image.fromarray(im_show) +im_show.save('result.jpg') +``` + +结果可视化: + +
+ +
+ +* 识别预测 + +``` +from paddleocr import PaddleOCR +ocr = PaddleOCR(lang="german") +img_path = 'PaddleOCR/doc/imgs_words/german/1.jpg' +result = ocr.ocr(img_path, det=False, cls=True) +for line in result: + print(line) +``` + + +![](../imgs_words/german/1.jpg) + +结果是一个tuple,只包含识别结果和识别置信度 + +``` +('leider auch jetzt', 0.97538936) +``` + +* 检测预测 + +```python +from paddleocr import PaddleOCR, draw_ocr +ocr = PaddleOCR() # need to run only once to download and load model into memory +img_path = 'PaddleOCR/doc/imgs_en/img_12.jpg' +result = ocr.ocr(img_path, rec=False) +for line in result: + print(line) + +# 显示结果 +from PIL import Image + +image = Image.open(img_path).convert('RGB') +im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf') +im_show = Image.fromarray(im_show) +im_show.save('result.jpg') +``` +结果是一个list,每个item只包含文本框 +```bash +[[26.0, 457.0], [137.0, 457.0], [137.0, 477.0], [26.0, 477.0]] +[[25.0, 425.0], [372.0, 425.0], [372.0, 448.0], [25.0, 448.0]] +[[128.0, 397.0], [273.0, 397.0], [273.0, 414.0], [128.0, 414.0]] +...... +``` + +结果可视化 : + +
+ +
+ +ppocr 还支持方向分类, 更多使用方式请参考:[whl包使用说明](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_ch/whl.md)。 + + +## 3 自定义训练 + +ppocr 支持使用自己的数据进行自定义训练或finetune, 其中识别模型可以参考 [法语配置文件](../../configs/rec/multi_language/rec_french_lite_train.yml) +修改训练数据路径、字典等参数。 + +具体数据准备、训练过程可参考:[文本检测](../doc_ch/detection.md)、[文本识别](../doc_ch/recognition.md),更多功能如预测部署、 +数据标注等功能可以阅读完整的[文档教程](../../README_ch.md)。 + + +## 4 支持语种及缩写 + +| 语种 | 描述 | 缩写 | +| --- | --- | --- | +|中文|chinese and english|ch| +|英文|english|en| +|法文|french|fr| +|德文|german|german| +|日文|japan|japan| +|韩文|korean|korean| +|中文繁体|chinese traditional |ch_tra| +|意大利文| Italian |it| +|西班牙文|Spanish |es| +|葡萄牙文| Portuguese|pt| +|俄罗斯文|Russia|ru| +|阿拉伯文|Arabic|ar| +|印地文|Hindi|hi| +|维吾尔|Uyghur|ug| +|波斯文|Persian|fa| +|乌尔都文|Urdu|ur| +|塞尔维亚文(latin)| Serbian(latin) |rs_latin| +|欧西坦文|Occitan |oc| +|马拉地文|Marathi|mr| +|尼泊尔文|Nepali|ne| +|塞尔维亚文(cyrillic)|Serbian(cyrillic)|rs_cyrillic| +|保加利亚文|Bulgarian |bg| +|乌克兰文|Ukranian|uk| +|白俄罗斯文|Belarusian|be| +|泰卢固文|Telugu |te| +|卡纳达文|Kannada |kn| +|泰米尔文|Tamil |ta| +|南非荷兰文 |Afrikaans |af| +|阿塞拜疆文 |Azerbaijani |az| +|波斯尼亚文|Bosnian|bs| +|捷克文|Czech|cs| +|威尔士文 |Welsh |cy| +|丹麦文 |Danish|da| +|爱沙尼亚文 |Estonian |et| +|爱尔兰文 |Irish |ga| +|克罗地亚文|Croatian |hr| +|匈牙利文|Hungarian |hu| +|印尼文|Indonesian|id| +|冰岛文 |Icelandic|is| +|库尔德文 |Kurdish|ku| +|立陶宛文|Lithuanian |lt| +|拉脱维亚文 |Latvian |lv| +|毛利文|Maori|mi| +|马来文 |Malay|ms| +|马耳他文 |Maltese |mt| +|荷兰文 |Dutch |nl| +|挪威文 |Norwegian |no| +|波兰文|Polish |pl| +| 罗马尼亚文|Romanian |ro| +| 斯洛伐克文|Slovak |sk| +| 斯洛文尼亚文|Slovenian |sl| +| 阿尔巴尼亚文|Albanian |sq| +| 瑞典文|Swedish |sv| +| 西瓦希里文|Swahili |sw| +| 塔加洛文|Tagalog |tl| +| 土耳其文|Turkish |tr| +| 乌兹别克文|Uzbek |uz| +| 越南文|Vietnamese |vi| +| 蒙古文|Mongolian |mn| +| 阿巴扎文|Abaza |abq| +| 阿迪赫文|Adyghe |ady| +| 卡巴丹文|Kabardian |kbd| +| 阿瓦尔文|Avar |ava| +| 达尔瓦文|Dargwa |dar| +| 因古什文|Ingush |inh| +| 拉克文|Lak |lbe| +| 莱兹甘文|Lezghian |lez| +|塔巴萨兰文 |Tabassaran |tab| +| 比尔哈文|Bihari |bh| +| 迈蒂利文|Maithili |mai| +| 昂加文|Angika |ang| +| 孟加拉文|Bhojpuri |bho| +| 摩揭陀文 |Magahi |mah| +| 那格浦尔文|Nagpur |sck| +| 尼瓦尔文|Newari |new| +| 保加利亚文 |Goan Konkani|gom| +| 沙特阿拉伯文|Saudi Arabia|sa| diff --git a/doc/doc_ch/pgnet.md b/doc/doc_ch/pgnet.md new file mode 100644 index 0000000000000000000000000000000000000000..265853860854317ab00f40b1f447edfad47dc557 --- /dev/null +++ b/doc/doc_ch/pgnet.md @@ -0,0 +1,187 @@ +# 端对端OCR算法-PGNet +- [一、简介](#简介) +- [二、环境配置](#环境配置) +- [三、快速使用](#快速使用) +- [四、模型训练、评估、推理](#模型训练、评估、推理) + + +## 一、简介 +OCR算法可以分为两阶段算法和端对端的算法。二阶段OCR算法一般分为两个部分,文本检测和文本识别算法,文件检测算法从图像中得到文本行的检测框,然后识别算法去识别文本框中的内容。而端对端OCR算法可以在一个算法中完成文字检测和文字识别,其基本思想是设计一个同时具有检测单元和识别模块的模型,共享其中两者的CNN特征,并联合训练。由于一个算法即可完成文字识别,端对端模型更小,速度更快。 + +### PGNet算法介绍 +近些年来,端对端OCR算法得到了良好的发展,包括MaskTextSpotter系列、TextSnake、TextDragon、PGNet系列等算法。在这些算法中,PGNet算法具备其他算法不具备的优势,包括: +- 设计PGNet loss指导训练,不需要字符级别的标注 +- 不需要NMS和ROI相关操作,加速预测 +- 提出预测文本行内的阅读顺序模块; +- 提出基于图的修正模块(GRM)来进一步提高模型识别性能 +- 精度更高,预测速度更快 + +PGNet算法细节详见[论文](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf) ,算法原理图如下所示: +![](../pgnet_framework.png) +输入图像经过特征提取送入四个分支,分别是:文本边缘偏移量预测TBO模块,文本中心线预测TCL模块,文本方向偏移量预测TDO模块,以及文本字符分类图预测TCC模块。 +其中TBO以及TCL的输出经过后处理后可以得到文本的检测结果,TCL、TDO、TCC负责文本识别。 + +其检测识别效果图如下: + +![](../imgs_results/e2e_res_img293_pgnet.png) +![](../imgs_results/e2e_res_img295_pgnet.png) + +### 性能指标 + +测试集: Total Text + +测试环境: NVIDIA Tesla V100-SXM2-16GB + +|PGNetA|det_precision|det_recall|det_f_score|e2e_precision|e2e_recall|e2e_f_score|FPS|下载| +| --- | --- | --- | --- | --- | --- | --- | --- | --- | +|Paper|85.30|86.80|86.1|-|-|61.7|38.20 (size=640)|-| +|Ours|87.03|82.48|84.69|61.71|58.43|60.03|48.73 (size=768)|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar)| + +*note:PaddleOCR里的PGNet实现针对预测速度做了优化,在精度下降可接受范围内,可以显著提升端对端预测速度* + + + + +## 二、环境配置 +请先参考[快速安装](./installation.md)配置PaddleOCR运行环境。 + + +## 三、快速使用 +### inference模型下载 +本节以训练好的端到端模型为例,快速使用模型预测,首先下载训练好的端到端inference模型[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/e2e_server_pgnetA_infer.tar) +``` +mkdir inference && cd inference +# 下载英文端到端模型并解压 +wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/e2e_server_pgnetA_infer.tar && tar xf e2e_server_pgnetA_infer.tar +``` +* windows 环境下如果没有安装wget,下载模型时可将链接复制到浏览器中下载,并解压放置在相应目录下 + +解压完毕后应有如下文件结构: +``` +├── e2e_server_pgnetA_infer +│ ├── inference.pdiparams +│ ├── inference.pdiparams.info +│ └── inference.pdmodel +``` +### 单张图像或者图像集合预测 +```bash +# 预测image_dir指定的单张图像 +python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True + +# 预测image_dir指定的图像集合 +python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True + +# 如果想使用CPU进行预测,需设置use_gpu参数为False +python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True --use_gpu=False +``` +### 可视化结果 +可视化文本检测结果默认保存到./inference_results文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下: +![](../imgs_results/e2e_res_img623_pgnet.jpg) + + +## 四、模型训练、评估、推理 +本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。 + +### 准备数据 +下载解压[totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md) 数据集到PaddleOCR/train_data/目录,数据集组织结构: +``` +/PaddleOCR/train_data/total_text/train/ + |- rgb/ # total_text数据集的训练数据 + |- gt_0.png + | ... + |- total_text.txt # total_text数据集的训练标注 +``` + +total_text.txt标注文件格式如下,文件名和标注信息中间用"\t"分隔: +``` +" 图像文件名 json.dumps编码的图像标注信息" +rgb/gt_0.png [{"transcription": "EST", "points": [[1004.0,689.0],[1019.0,698.0],[1034.0,708.0],[1049.0,718.0],[1064.0,728.0],[1079.0,738.0],[1095.0,748.0],[1094.0,774.0],[1079.0,765.0],[1065.0,756.0],[1050.0,747.0],[1036.0,738.0],[1021.0,729.0],[1007.0,721.0]]}, {...}] +``` +json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 `points` 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。 +`transcription` 表示当前文本框的文字,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。** +如果您想在其他数据集上训练,可以按照上述形式构建标注文件。 + +### 启动训练 + +PGNet训练分为两个步骤:step1: 在合成数据上训练,得到预训练模型,此时模型精度依然较低;step2: 加载预训练模型,在totaltext数据集上训练;为快速训练,我们直接提供了step1的预训练模型。 +```shell +cd PaddleOCR/ +下载step1 预训练模型 +wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/train_step1.tar +可以得到以下的文件格式 +./pretrain_models/train_step1/ + └─ best_accuracy.pdopt + └─ best_accuracy.states + └─ best_accuracy.pdparams +``` +*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false* + +```shell +# 单机单卡训练 e2e 模型 +python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./pretrain_models/train_step1/best_accuracy Global.load_static_weights=False +# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./pretrain_models/train_step1/best_accuracy Global.load_static_weights=False +``` + +上述指令中,通过-c 选择训练使用configs/e2e/e2e_r50_vd_pg.yml配置文件。 +有关配置文件的详细解释,请参考[链接](./config.md)。 + +您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001 +```shell +python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Optimizer.base_lr=0.0001 +``` + +#### 断点训练 +如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径: +```shell +python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints=./your/trained/model +``` + +**注意**:`Global.checkpoints`的优先级高于`Global.pretrain_weights`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrain_weights`指定的模型。 + +PaddleOCR计算三个OCR端到端相关的指标,分别是:Precision、Recall、Hmean。 + +运行如下代码,根据配置文件`e2e_r50_vd_pg.yml`中`save_res_path`指定的测试集检测结果文件,计算评估指标。 + +评估时设置后处理参数`max_side_len=768`,使用不同数据集、不同模型训练,可调整参数进行优化 +训练中模型参数默认保存在`Global.save_model_dir`目录下。在评估指标时,需要设置`Global.checkpoints`指向保存的参数文件。 +```shell +python3 tools/eval.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" +``` + +### 模型预测 +测试单张图像的端到端识别效果 +```shell +python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/e2e_pgnet/best_accuracy" Global.load_static_weights=false +``` + +测试文件夹下所有图像的端到端识别效果 +```shell +python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/e2e_pgnet/best_accuracy" Global.load_static_weights=false +``` + +### 预测推理 +#### (1). 四边形文本检测模型(ICDAR2015) +首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,以英文数据集训练的模型为例[模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar) ,可以使用如下命令进行转换: +``` +wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar && tar xf en_server_pgnetA.tar +python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e +``` +**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令: +``` +python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=False +``` +可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下: + +![](../imgs_results/e2e_res_img_10_pgnet.jpg) + +#### (2). 弯曲文本检测模型(Total-Text) +对于弯曲文本样例 + +**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令: +``` +python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True +``` +可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下: + +![](../imgs_results/e2e_res_img623_pgnet.jpg) diff --git a/doc/doc_ch/whl.md b/doc/doc_ch/whl.md index 032d7ae642ad7b2a10253dbfd6310bd5299fb5a0..2e93c487c2f2071c7c89c753cf86eef61ce20805 100644 --- a/doc/doc_ch/whl.md +++ b/doc/doc_ch/whl.md @@ -36,7 +36,7 @@ image = Image.open(img_path).convert('RGB') boxes = [line[0] for line in result] txts = [line[1][0] for line in result] scores = [line[1][1] for line in result] -im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') +im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf') im_show = Image.fromarray(im_show) im_show.save('result.jpg') ``` @@ -69,7 +69,7 @@ image = Image.open(img_path).convert('RGB') boxes = [line[0] for line in result] txts = [line[1][0] for line in result] scores = [line[1][1] for line in result] -im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') +im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf') im_show = Image.fromarray(im_show) im_show.save('result.jpg') ``` @@ -114,7 +114,7 @@ for line in result: from PIL import Image image = Image.open(img_path).convert('RGB') -im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/PaddleOCR/doc/simfang.ttf') +im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf') im_show = Image.fromarray(im_show) im_show.save('result.jpg') ``` @@ -253,7 +253,7 @@ image = Image.open(img_path).convert('RGB') boxes = [line[0] for line in result] txts = [line[1][0] for line in result] scores = [line[1][1] for line in result] -im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') +im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf') im_show = Image.fromarray(im_show) im_show.save('result.jpg') ``` @@ -285,7 +285,7 @@ image = Image.open(img_path).convert('RGB') boxes = [line[0] for line in result] txts = [line[1][0] for line in result] scores = [line[1][1] for line in result] -im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') +im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf') im_show = Image.fromarray(im_show) im_show.save('result.jpg') ``` @@ -314,7 +314,7 @@ image = Image.open(img_path).convert('RGB') boxes = [line[0] for line in result] txts = [line[1][0] for line in result] scores = [line[1][1] for line in result] -im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') +im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf') im_show = Image.fromarray(im_show) im_show.save('result.jpg') ``` diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index 77b9642e3b880547b1df6620d931689982db6d29..d70f99bb5c5b0bdcb7d39209dfc9a77c56918260 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -31,7 +31,9 @@ On Total-Text dataset, the text detection result is as follows: | --- | --- | --- | --- | --- | --- | |SAST|ResNet50_vd|89.63%|78.44%|83.66%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)| -**Note:** Additional data, like icdar2013, icdar2017, COCO-Text, ArT, was added to the model training of SAST. Download English public dataset in organized format used by PaddleOCR from [Baidu Drive](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (download code: 2bpi). +**Note:** Additional data, like icdar2013, icdar2017, COCO-Text, ArT, was added to the model training of SAST. Download English public dataset in organized format used by PaddleOCR from: +* [Baidu Drive](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (download code: 2bpi). +* [Google Drive](https://drive.google.com/drive/folders/1ll2-XEVyCQLpJjawLDiRlvo_i4BqHCJe?usp=sharing) For the training guide and use of PaddleOCR text detection algorithms, please refer to the document [Text detection model training/evaluation/prediction](./detection_en.md) diff --git a/doc/doc_en/multi_languages_en.md b/doc/doc_en/multi_languages_en.md new file mode 100644 index 0000000000000000000000000000000000000000..3f4f02266f4708efb7fdd3e9211678af7028c9df --- /dev/null +++ b/doc/doc_en/multi_languages_en.md @@ -0,0 +1,305 @@ +# Multi-language model + +**Recent Update** + +- 2021.4.9 supports the detection and recognition of 80 languages +- 2021.4.9 supports **lightweight high-precision** English model detection and recognition + +PaddleOCR aims to create a rich, leading, and practical OCR tool library, which not only provides +Chinese and English models in general scenarios, but also provides models specifically trained +in English scenarios. And multilingual models covering [80 languages](#language_abbreviations). + +Among them, the English model supports the detection and recognition of uppercase and lowercase +letters and common punctuation, and the recognition of space characters is optimized: + +
+ +
+ +The multilingual models cover Latin, Arabic, Traditional Chinese, Korean, Japanese, etc.: + +
+ + +
+ +This document will briefly introduce how to use the multilingual model. + +- [1 Installation](#Install) + - [1.1 paddle installation](#paddleinstallation) + - [1.2 paddleocr package installation](#paddleocr_package_install) + +- [2 Quick Use](#Quick_Use) + - [2.1 Command line operation](#Command_line_operation) + - [2.1.1 Prediction of the whole image](#bash_detection+recognition) + - [2.1.2 Recognition](#bash_Recognition) + - [2.1.3 Detection](#bash_detection) + - [2.2 python script running](#python_Script_running) + - [2.2.1 Whole image prediction](#python_detection+recognition) + - [2.2.2 Recognition](#python_Recognition) + - [2.2.3 Detection](#python_detection) +- [3 Custom Training](#Custom_Training) +- [4 Supported languages and abbreviations](#language_abbreviations) + + +## 1 Installation + + +### 1.1 paddle installation +``` +# cpu +pip install paddlepaddle + +# gpu +pip instll paddlepaddle-gpu +``` + + +### 1.2 paddleocr package installation + + +pip install +``` +pip install "paddleocr>=2.0.4" # 2.0.4 version is recommended +``` +Build and install locally +``` +python3 setup.py bdist_wheel +pip3 install dist/paddleocr-x.x.x-py3-none-any.whl # x.x.x is the version number of paddleocr +``` + + +## 2 Quick use + + +### 2.1 Command line operation + +View help information + +``` +paddleocr -h +``` + +* Whole image prediction (detection + recognition) + +Paddleocr currently supports 80 languages, which can be switched by modifying the --lang parameter. +The specific supported [language] (#language_abbreviations) can be viewed in the table. + +``` bash + +paddleocr --image_dir doc/imgs/japan_2.jpg --lang=japan +``` +![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.1/doc/imgs/japan_2.jpg) + +The result is a list, each item contains a text box, text and recognition confidence +```text +[[[671.0, 60.0], [847.0, 63.0], [847.0, 104.0], [671.0, 102.0]], ('もちもち', 0.9993342)] +[[[394.0, 82.0], [536.0, 77.0], [538.0, 127.0], [396.0, 132.0]], ('自然の', 0.9919842)] +[[[880.0, 89.0], [1014.0, 93.0], [1013.0, 127.0], [879.0, 124.0]], ('とろっと', 0.9976762)] +[[[1067.0, 101.0], [1294.0, 101.0], [1294.0, 138.0], [1067.0, 138.0]], ('后味のよい', 0.9988712)] +...... +``` + +* Recognition + +```bash +paddleocr --image_dir doc/imgs_words/japan/1.jpg --det false --lang=japan +``` + +![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.1/doc/imgs_words/japan/1.jpg) + +The result is a tuple, which returns the recognition result and recognition confidence + +```text +('したがって', 0.99965394) +``` + +* Detection + +``` +paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --rec false +``` + +The result is a list, each item contains only text boxes + +``` +[[26.0, 457.0], [137.0, 457.0], [137.0, 477.0], [26.0, 477.0]] +[[25.0, 425.0], [372.0, 425.0], [372.0, 448.0], [25.0, 448.0]] +[[128.0, 397.0], [273.0, 397.0], [273.0, 414.0], [128.0, 414.0]] +...... +``` + + +### 2.2 python script running + +ppocr also supports running in python scripts for easy embedding in your own code: + +* Whole image prediction (detection + recognition) + +``` +from paddleocr import PaddleOCR, draw_ocr + +# Also switch the language by modifying the lang parameter +ocr = PaddleOCR(lang="korean") # The model file will be downloaded automatically when executed for the first time +img_path ='doc/imgs/korean_1.jpg' +result = ocr.ocr(img_path) +# Print detection frame and recognition result +for line in result: + print(line) + +# Visualization +from PIL import Image +image = Image.open(img_path).convert('RGB') +boxes = [line[0] for line in result] +txts = [line[1][0] for line in result] +scores = [line[1][1] for line in result] +im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/korean.ttf') +im_show = Image.fromarray(im_show) +im_show.save('result.jpg') +``` + +Visualization of results: +![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.1/doc/imgs_results/korean.jpg) + + +* Recognition + +``` +from paddleocr import PaddleOCR +ocr = PaddleOCR(lang="german") +img_path ='PaddleOCR/doc/imgs_words/german/1.jpg' +result = ocr.ocr(img_path, det=False, cls=True) +for line in result: + print(line) +``` + +![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.1/doc/imgs_words/german/1.jpg) + +The result is a tuple, which only contains the recognition result and recognition confidence + +``` +('leider auch jetzt', 0.97538936) +``` + +* Detection + +```python +from paddleocr import PaddleOCR, draw_ocr +ocr = PaddleOCR() # need to run only once to download and load model into memory +img_path ='PaddleOCR/doc/imgs_en/img_12.jpg' +result = ocr.ocr(img_path, rec=False) +for line in result: + print(line) + +# show result +from PIL import Image + +image = Image.open(img_path).convert('RGB') +im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf') +im_show = Image.fromarray(im_show) +im_show.save('result.jpg') +``` +The result is a list, each item contains only text boxes +```bash +[[26.0, 457.0], [137.0, 457.0], [137.0, 477.0], [26.0, 477.0]] +[[25.0, 425.0], [372.0, 425.0], [372.0, 448.0], [25.0, 448.0]] +[[128.0, 397.0], [273.0, 397.0], [273.0, 414.0], [128.0, 414.0]] +...... +``` + +Visualization of results: +![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.1/doc/imgs_results/whl/12_det.jpg) + +ppocr also supports direction classification. For more usage methods, please refer to: [whl package instructions](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_ch/whl.md). + + +## 3 Custom training + +ppocr supports using your own data for custom training or finetune, where the recognition model can refer to [French configuration file](../../configs/rec/multi_language/rec_french_lite_train.yml) +Modify the training data path, dictionary and other parameters. + +For specific data preparation and training process, please refer to: [Text Detection](../doc_en/detection_en.md), [Text Recognition](../doc_en/recognition_en.md), more functions such as predictive deployment, +For functions such as data annotation, you can read the complete [Document Tutorial](../../README.md). + + +## 4 Support languages and abbreviations + +| Language | Abbreviation | +| --- | --- | +|chinese and english|ch| +|english|en| +|french|fr| +|german|german| +|japan|japan| +|korean|korean| +|chinese traditional |ch_tra| +| Italian |it| +|Spanish |es| +| Portuguese|pt| +|Russia|ru| +|Arabic|ar| +|Hindi|hi| +|Uyghur|ug| +|Persian|fa| +|Urdu|ur| +| Serbian(latin) |rs_latin| +|Occitan |oc| +|Marathi|mr| +|Nepali|ne| +|Serbian(cyrillic)|rs_cyrillic| +|Bulgarian |bg| +|Ukranian|uk| +|Belarusian|be| +|Telugu |te| +|Kannada |kn| +|Tamil |ta| +|Afrikaans |af| +|Azerbaijani |az| +|Bosnian|bs| +|Czech|cs| +|Welsh |cy| +|Danish|da| +|Estonian |et| +|Irish |ga| +|Croatian |hr| +|Hungarian |hu| +|Indonesian|id| +|Icelandic|is| +|Kurdish|ku| +|Lithuanian |lt| + |Latvian |lv| +|Maori|mi| +|Malay|ms| +|Maltese |mt| +|Dutch |nl| +|Norwegian |no| +|Polish |pl| +|Romanian |ro| +|Slovak |sk| +|Slovenian |sl| +|Albanian |sq| +|Swedish |sv| +|Swahili |sw| +|Tagalog |tl| +|Turkish |tr| +|Uzbek |uz| +|Vietnamese |vi| +|Mongolian |mn| +|Abaza |abq| +|Adyghe |ady| +|Kabardian |kbd| +|Avar |ava| +|Dargwa |dar| +|Ingush |inh| +|Lak |lbe| +|Lezghian |lez| +|Tabassaran |tab| +|Bihari |bh| +|Maithili |mai| +|Angika |ang| +|Bhojpuri |bho| +|Magahi |mah| +|Nagpur |sck| +|Newari |new| +|Goan Konkani|gom| +|Saudi Arabia|sa| diff --git a/doc/doc_en/pgnet_en.md b/doc/doc_en/pgnet_en.md new file mode 100644 index 0000000000000000000000000000000000000000..2ab1116ce7085e2e322b4be45ee5628c247040ea --- /dev/null +++ b/doc/doc_en/pgnet_en.md @@ -0,0 +1,185 @@ +# End-to-end OCR Algorithm-PGNet +- [1. Brief Introduction](#Brief_Introduction) +- [2. Environment Configuration](#Environment_Configuration) +- [3. Quick Use](#Quick_Use) +- [4. Model Training,Evaluation And Inference](#Model_Training_Evaluation_And_Inference) + + +## 1. Brief Introduction +OCR algorithm can be divided into two-stage algorithm and end-to-end algorithm. The two-stage OCR algorithm is generally divided into two parts, text detection and text recognition algorithm. The text detection algorithm gets the detection box of the text line from the image, and then the recognition algorithm identifies the content of the text box. The end-to-end OCR algorithm can complete text detection and recognition in one algorithm. Its basic idea is to design a model with both detection unit and recognition module, share the CNN features of both and train them together. Because one algorithm can complete character recognition, the end-to-end model is smaller and faster. +### Introduction Of PGNet Algorithm +In recent years, the end-to-end OCR algorithm has been well developed, including MaskTextSpotter series, TextSnake, TextDragon, PGNet series and so on. Among these algorithms, PGNet algorithm has the advantages that other algorithms do not +- Pgnet loss is designed to guide training, and no character-level annotations is needed +- NMS and ROI related operations are not needed, It can accelerate the prediction +- The reading order prediction module is proposed +- A graph based modification module (GRM) is proposed to further improve the performance of model recognition +- Higher accuracy and faster prediction speed + +For details of PGNet algorithm, please refer to [paper](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf) ,The schematic diagram of the algorithm is as follows: +![](../pgnet_framework.png) +After feature extraction, the input image is sent to four branches: TBO module for text edge offset prediction, TCL module for text centerline prediction, TDO module for text direction offset prediction, and TCC module for text character classification graph prediction. +The output of TBO and TCL can get text detection results after post-processing, and TCL, TDO and TCC are responsible for text recognition. + +The results of detection and recognition are as follows: +![](../imgs_results/e2e_res_img293_pgnet.png) +![](../imgs_results/e2e_res_img295_pgnet.png) +### Performance +####Test set: Total Text + +####Test environment: NVIDIA Tesla V100-SXM2-16GB +|PGNetA|det_precision|det_recall|det_f_score|e2e_precision|e2e_recall|e2e_f_score|FPS|download| +| --- | --- | --- | --- | --- | --- | --- | --- | --- | +|Paper|85.30|86.80|86.1|-|-|61.7|38.20 (size=640)|-| +|Ours|87.03|82.48|84.69|61.71|58.43|60.03|48.73 (size=768)|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar)| + +*note:PGNet in PaddleOCR optimizes the prediction speed, and can significantly improve the end-to-end prediction speed within the acceptable range of accuracy reduction* + + +## 2. Environment Configuration +Please refer to [Quick Installation](./installation_en.md) Configure the PaddleOCR running environment. + + +## 3. Quick Use +### inference model download +This section takes the trained end-to-end model as an example to quickly use the model prediction. First, download the trained end-to-end inference model [download address](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/e2e_server_pgnetA_infer.tar) +``` +mkdir inference && cd inference +# Download the English end-to-end model and unzip it +wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/e2e_server_pgnetA_infer.tar && tar xf e2e_server_pgnetA_infer.tar +``` +* In Windows environment, if 'wget' is not installed, the link can be copied to the browser when downloading the model, and decompressed and placed in the corresponding directory + +After decompression, there should be the following file structure: +``` +├── e2e_server_pgnetA_infer +│ ├── inference.pdiparams +│ ├── inference.pdiparams.info +│ └── inference.pdmodel +``` +### Single image or image set prediction +```bash +# Prediction single image specified by image_dir +python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True + +# Prediction the collection of images specified by image_dir +python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True + +# If you want to use CPU for prediction, you need to set use_gpu parameter is false +python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True --use_gpu=False +``` +### Visualization results +The visualized end-to-end results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'e2e_res'. Examples of results are as follows: +![](../imgs_results/e2e_res_img623_pgnet.jpg) + + +## 4. Model Training,Evaluation And Inference +This section takes the totaltext dataset as an example to introduce the training, evaluation and testing of the end-to-end model in PaddleOCR. + +### Data Preparation +Download and unzip [totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md) dataset to PaddleOCR/train_data/, dataset organization structure is as follow: +``` +/PaddleOCR/train_data/total_text/train/ + |- rgb/ # total_text training data of dataset + |- gt_0.png + | ... + |- total_text.txt # total_text training annotation of dataset +``` + +total_text.txt: the format of dimension file is as follows,the file name and annotation information are separated by "\t": +``` +" Image file name Image annotation information encoded by json.dumps" +rgb/gt_0.png [{"transcription": "EST", "points": [[1004.0,689.0],[1019.0,698.0],[1034.0,708.0],[1049.0,718.0],[1064.0,728.0],[1079.0,738.0],[1095.0,748.0],[1094.0,774.0],[1079.0,765.0],[1065.0,756.0],[1050.0,747.0],[1036.0,738.0],[1021.0,729.0],[1007.0,721.0]]}, {...}] +``` +The image annotation after **json.dumps()** encoding is a list containing multiple dictionaries. + +The `points` in the dictionary represent the coordinates (x, y) of the four points of the text box, arranged clockwise from the point at the upper left corner. + +`transcription` represents the text of the current text box. **When its content is "###" it means that the text box is invalid and will be skipped during training.** + +If you want to train PaddleOCR on other datasets, please build the annotation file according to the above format. + + +### Start Training + +PGNet training is divided into two steps: Step 1: training on the synthetic data to get the pretrain_model, and the accuracy of the model is still low; step 2: loading the pretrain_model and training on the totaltext data set; for fast training, we directly provide the pre training model of step 1[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/train_step1.tar). +```shell +cd PaddleOCR/ +download step1 pretrain_models +wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/train_step1.tar +You can get the following file format +./pretrain_models/train_step1/ + └─ best_accuracy.pdopt + └─ best_accuracy.states + └─ best_accuracy.pdparams +``` +*If CPU version installed, please set the parameter `use_gpu` to `false` in the configuration.* + +```shell +# single GPU training +python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./pretrain_models/train_step1/best_accuracy Global.load_static_weights=False +# multi-GPU training +# Set the GPU ID used by the '--gpus' parameter. +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./pretrain_models/train_step1/best_accuracy Global.load_static_weights=False +``` + +In the above instruction, use `-c` to select the training to use the `configs/e2e/e2e_r50_vd_pg.yml` configuration file. +For a detailed explanation of the configuration file, please refer to [config](./config_en.md). + +You can also use `-o` to change the training parameters without modifying the yml file. For example, adjust the training learning rate to 0.0001 +```shell +python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Optimizer.base_lr=0.0001 +``` + +#### Load trained model and continue training +If you expect to load trained model and continue the training again, you can specify the parameter `Global.checkpoints` as the model path to be loaded. +```shell +python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints=./your/trained/model +``` + +**Note**: The priority of `Global.checkpoints` is higher than that of `Global.pretrain_weights`, that is, when two parameters are specified at the same time, the model specified by `Global.checkpoints` will be loaded first. If the model path specified by `Global.checkpoints` is wrong, the one specified by `Global.pretrain_weights` will be loaded. + +PaddleOCR calculates three indicators for evaluating performance of OCR end-to-end task: Precision, Recall, and Hmean. + + +Run the following code to calculate the evaluation indicators. The result will be saved in the test result file specified by `save_res_path` in the configuration file `e2e_r50_vd_pg.yml` +When evaluating, set post-processing parameters `max_side_len=768`. If you use different datasets, different models for training. +The model parameters during training are saved in the `Global.save_model_dir` directory by default. When evaluating indicators, you need to set `Global.checkpoints` to point to the saved parameter file. +```shell +python3 tools/eval.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" +``` + +### Model Test +Test the end-to-end result on a single image: +```shell +python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/e2e_pgnet/best_accuracy" Global.load_static_weights=false +``` + +Test the end-to-end result on all images in the folder: +```shell +python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/e2e_pgnet/best_accuracy" Global.load_static_weights=false +``` + +### Model inference +#### (1).Quadrangle text detection model (ICDAR2015) +First, convert the model saved in the PGNet end-to-end training process into an inference model. In the first stage of training based on composite dataset, the model of English data set training is taken as an example[model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar), you can use the following command to convert: +``` +wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar && tar xf en_server_pgnetA.tar +python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e +``` +**For PGNet quadrangle end-to-end model inference, you need to set the parameter `--e2e_algorithm="PGNet"`**, run the following command: +``` +python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=False +``` +The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'e2e_res'. Examples of results are as follows: + +![](../imgs_results/e2e_res_img_10_pgnet.jpg) + +#### (2). Curved text detection model (Total-Text) +For the curved text example, we use the same model as the quadrilateral +**For PGNet end-to-end curved text detection model inference, you need to set the parameter `--e2e_algorithm="PGNet"` and `--e2e_pgnet_polygon=True`**, run the following command: +``` +python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True +``` +The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'e2e_res'. Examples of results are as follows: + +![](../imgs_results/e2e_res_img623_pgnet.jpg) diff --git a/doc/doc_en/whl_en.md b/doc/doc_en/whl_en.md index ae4d34923535779d89594ee7f6fa4259f38ba497..69abf085556f466853798077bb116b3986582bcc 100644 --- a/doc/doc_en/whl_en.md +++ b/doc/doc_en/whl_en.md @@ -35,7 +35,7 @@ image = Image.open(img_path).convert('RGB') boxes = [line[0] for line in result] txts = [line[1][0] for line in result] scores = [line[1][1] for line in result] -im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') +im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf') im_show = Image.fromarray(im_show) im_show.save('result.jpg') ``` @@ -69,7 +69,7 @@ image = Image.open(img_path).convert('RGB') boxes = [line[0] for line in result] txts = [line[1][0] for line in result] scores = [line[1][1] for line in result] -im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') +im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf') im_show = Image.fromarray(im_show) im_show.save('result.jpg') ``` @@ -116,7 +116,7 @@ for line in result: from PIL import Image image = Image.open(img_path).convert('RGB') -im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/PaddleOCR/doc/simfang.ttf') +im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf') im_show = Image.fromarray(im_show) im_show.save('result.jpg') ``` @@ -262,7 +262,7 @@ image = Image.open(img_path).convert('RGB') boxes = [line[0] for line in result] txts = [line[1][0] for line in result] scores = [line[1][1] for line in result] -im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') +im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf') im_show = Image.fromarray(im_show) im_show.save('result.jpg') ``` @@ -292,7 +292,7 @@ image = Image.open(img_path).convert('RGB') boxes = [line[0] for line in result] txts = [line[1][0] for line in result] scores = [line[1][1] for line in result] -im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') +im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf') im_show = Image.fromarray(im_show) im_show.save('result.jpg') ``` @@ -320,7 +320,7 @@ image = Image.open(img_path).convert('RGB') boxes = [line[0] for line in result] txts = [line[1][0] for line in result] scores = [line[1][1] for line in result] -im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') +im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf') im_show = Image.fromarray(im_show) im_show.save('result.jpg') ``` diff --git a/doc/imgs_results/e2e_res_img293_pgnet.png b/doc/imgs_results/e2e_res_img293_pgnet.png new file mode 100644 index 0000000000000000000000000000000000000000..232f8293adbe00c46b1ea14e6206c3b0eb31ca8f Binary files /dev/null and b/doc/imgs_results/e2e_res_img293_pgnet.png differ diff --git a/doc/imgs_results/e2e_res_img295_pgnet.png b/doc/imgs_results/e2e_res_img295_pgnet.png new file mode 100644 index 0000000000000000000000000000000000000000..69337e3adfe444429cf513343ef12f7697b6889f Binary files /dev/null and b/doc/imgs_results/e2e_res_img295_pgnet.png differ diff --git a/doc/imgs_results/e2e_res_img623_pgnet.jpg b/doc/imgs_results/e2e_res_img623_pgnet.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b45dc05f7bfae05bbaa338f59397c2458b80638b Binary files /dev/null and b/doc/imgs_results/e2e_res_img623_pgnet.jpg differ diff --git a/doc/imgs_results/e2e_res_img_10_pgnet.jpg b/doc/imgs_results/e2e_res_img_10_pgnet.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a0962993f81628da04f0838aed6240599f9eaec2 Binary files /dev/null and b/doc/imgs_results/e2e_res_img_10_pgnet.jpg differ diff --git a/doc/imgs_results/multi_lang/en_1.jpg b/doc/imgs_results/multi_lang/en_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2dc84d3f04610effa7888e81848b015443e8091f Binary files /dev/null and b/doc/imgs_results/multi_lang/en_1.jpg differ diff --git a/doc/imgs_results/multi_lang/en_2.jpg b/doc/imgs_results/multi_lang/en_2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..455ec98ed337c6af9c45f2861b8078b31abb4054 Binary files /dev/null and b/doc/imgs_results/multi_lang/en_2.jpg differ diff --git a/doc/imgs_results/multi_lang/en_3.jpg b/doc/imgs_results/multi_lang/en_3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..36eb063d78466a9000306dab11507e0a43664568 Binary files /dev/null and b/doc/imgs_results/multi_lang/en_3.jpg differ diff --git a/doc/imgs_results/multi_lang/french_0.jpg b/doc/imgs_results/multi_lang/french_0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3c2abe6304b93e3025dd19b75980c548f70bd3c7 Binary files /dev/null and b/doc/imgs_results/multi_lang/french_0.jpg differ diff --git a/doc/imgs_results/multi_lang/japan_2.jpg b/doc/imgs_results/multi_lang/japan_2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7038ba2effab289b2895ea8d0fddaa365149ff80 Binary files /dev/null and b/doc/imgs_results/multi_lang/japan_2.jpg differ diff --git a/doc/imgs_results/whl/12_det.jpg b/doc/imgs_results/whl/12_det.jpg index 1d5ccf2a6b5d3fa9516560e0cb2646ad6b917da6..71627f0b8db8fdc6e1bf0c4601f0311160d3164d 100644 Binary files a/doc/imgs_results/whl/12_det.jpg and b/doc/imgs_results/whl/12_det.jpg differ diff --git a/doc/pgnet_framework.png b/doc/pgnet_framework.png new file mode 100644 index 0000000000000000000000000000000000000000..88fbca3947cdd2f88c3b8bb90b8f9b05356cbab1 Binary files /dev/null and b/doc/pgnet_framework.png differ diff --git a/paddleocr.py b/paddleocr.py index 7c126261eff1168a1888d72f71fb284e347f9ec9..47e1267ac40effbe8b4ab80723c66eb5378be179 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -66,6 +66,46 @@ model_urls = { 'url': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar', 'dict_path': './ppocr/utils/dict/japan_dict.txt' + }, + 'chinese_cht': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt' + }, + 'ta': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/ta_dict.txt' + }, + 'te': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/te_dict.txt' + }, + 'ka': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/ka_dict.txt' + }, + 'latin': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/latin_dict.txt' + }, + 'arabic': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/arabic_dict.txt' + }, + 'cyrillic': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/cyrillic_dict.txt' + }, + 'devanagari': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/devanagari_dict.txt' } }, 'cls': @@ -233,10 +273,35 @@ class PaddleOCR(predict_system.TextSystem): postprocess_params.__dict__.update(**kwargs) self.use_angle_cls = postprocess_params.use_angle_cls lang = postprocess_params.lang + latin_lang = [ + 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'en', 'es', 'et', 'fr', + 'ga', 'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', + 'ms', 'mt', 'nl', 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', + 'sk', 'sl', 'sq', 'sv', 'sw', 'tl', 'tr', 'uz', 'vi' + ] + arabic_lang = ['ar', 'fa', 'ug', 'ur'] + cyrillic_lang = [ + 'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', + 'ava', 'dar', 'inh', 'che', 'lbe', 'lez', 'tab' + ] + devanagari_lang = [ + 'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', + 'gom', 'sa', 'bgc' + ] + if lang in latin_lang: + lang = "latin" + elif lang in arabic_lang: + lang = "arabic" + elif lang in cyrillic_lang: + lang = "cyrillic" + elif lang in devanagari_lang: + lang = "devanagari" assert lang in model_urls[ 'rec'], 'param lang must in {}, but got {}'.format( model_urls['rec'].keys(), lang) + use_inner_dict = False if postprocess_params.rec_char_dict_path is None: + use_inner_dict = True postprocess_params.rec_char_dict_path = model_urls['rec'][lang][ 'dict_path'] @@ -263,9 +328,9 @@ class PaddleOCR(predict_system.TextSystem): if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL: logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL)) sys.exit(0) - - postprocess_params.rec_char_dict_path = str( - Path(__file__).parent / postprocess_params.rec_char_dict_path) + if use_inner_dict: + postprocess_params.rec_char_dict_path = str( + Path(__file__).parent / postprocess_params.rec_char_dict_path) # init det_model and rec_model super().__init__(postprocess_params) @@ -282,8 +347,13 @@ class PaddleOCR(predict_system.TextSystem): if isinstance(img, list) and det == True: logger.error('When input a list of images, det must be false') exit(0) + if cls == False: + self.use_angle_cls = False + elif cls == True and self.use_angle_cls == False: + logger.warning( + 'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process' + ) - self.use_angle_cls = cls if isinstance(img, str): # download net image if img.startswith('http'): diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py index 7cb50d7a62aa3f24811e517768e0635ac7b7321a..728b8317f54687ee76b519cba18f4d7807493821 100644 --- a/ppocr/data/__init__.py +++ b/ppocr/data/__init__.py @@ -34,6 +34,7 @@ import paddle.distributed as dist from ppocr.data.imaug import transform, create_operators from ppocr.data.simple_dataset import SimpleDataSet from ppocr.data.lmdb_dataset import LMDBDataSet +from ppocr.data.pgnet_dataset import PGDataSet __all__ = ['build_dataloader', 'transform', 'create_operators'] @@ -54,7 +55,7 @@ signal.signal(signal.SIGTERM, term_mp) def build_dataloader(config, mode, device, logger, seed=None): config = copy.deepcopy(config) - support_dict = ['SimpleDataSet', 'LMDBDataSet'] + support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet'] module_name = config[mode]['dataset']['name'] assert module_name in support_dict, Exception( 'DataSet only support {}'.format(support_dict)) @@ -72,14 +73,14 @@ def build_dataloader(config, mode, device, logger, seed=None): else: use_shared_memory = True if mode == "Train": - #Distribute data to multiple cards + # Distribute data to multiple cards batch_sampler = DistributedBatchSampler( dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) else: - #Distribute data to single card + # Distribute data to single card batch_sampler = BatchSampler( dataset=dataset, batch_size=batch_size, diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 250ac75e7683df2353d9fad02ef42b9e133681d3..a808fd586b0676751da1ee31d379179b026fd51d 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -28,6 +28,7 @@ from .label_ops import * from .east_process import * from .sast_process import * +from .pg_process import * def transform(data, ops=None): diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 7a32d870bfc7f532896ce6b11aac5508a6369993..cbb110090cfff3ebee4b30b009f88fc9aaba1617 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -187,6 +187,32 @@ class CTCLabelEncode(BaseRecLabelEncode): return dict_character +class E2ELabelEncode(BaseRecLabelEncode): + def __init__(self, + max_text_length, + character_dict_path=None, + character_type='EN', + use_space_char=False, + **kwargs): + super(E2ELabelEncode, + self).__init__(max_text_length, character_dict_path, + character_type, use_space_char) + self.pad_num = len(self.dict) # the length to pad + + def __call__(self, data): + texts = data['strs'] + temp_texts = [] + for text in texts: + text = text.lower() + text = self.encode(text) + if text is None: + return None + text = text + [self.pad_num] * (self.max_text_len - len(text)) + temp_texts.append(text) + data['strs'] = np.array(temp_texts) + return data + + class AttnLabelEncode(BaseRecLabelEncode): """ Convert between text-label and text-index """ diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index eacfdf3b243af5b9051ad726ced6edacddff45ed..9c48b09647527cf718113ea1b5df152ff7befa04 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -197,7 +197,6 @@ class DetResizeForTest(object): sys.exit(0) ratio_h = resize_h / float(h) ratio_w = resize_w / float(w) - # return img, np.array([h, w]) return img, [ratio_h, ratio_w] def resize_image_type2(self, img): @@ -206,7 +205,6 @@ class DetResizeForTest(object): resize_w = w resize_h = h - # Fix the longer side if resize_h > resize_w: ratio = float(self.resize_long) / resize_h else: @@ -223,3 +221,72 @@ class DetResizeForTest(object): ratio_w = resize_w / float(w) return img, [ratio_h, ratio_w] + + +class E2EResizeForTest(object): + def __init__(self, **kwargs): + super(E2EResizeForTest, self).__init__() + self.max_side_len = kwargs['max_side_len'] + self.valid_set = kwargs['valid_set'] + + def __call__(self, data): + img = data['image'] + src_h, src_w, _ = img.shape + if self.valid_set == 'totaltext': + im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext( + img, max_side_len=self.max_side_len) + else: + im_resized, (ratio_h, ratio_w) = self.resize_image( + img, max_side_len=self.max_side_len) + data['image'] = im_resized + data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) + return data + + def resize_image_for_totaltext(self, im, max_side_len=512): + + h, w, _ = im.shape + resize_w = w + resize_h = h + ratio = 1.25 + if h * ratio > max_side_len: + ratio = float(max_side_len) / resize_h + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + return im, (ratio_h, ratio_w) + + def resize_image(self, im, max_side_len=512): + """ + resize image to a size multiple of max_stride which is required by the network + :param im: the resized image + :param max_side_len: limit of max image size to avoid out of memory in gpu + :return: the resized image and the resize ratio + """ + h, w, _ = im.shape + + resize_w = w + resize_h = h + + # Fix the longer side + if resize_h > resize_w: + ratio = float(max_side_len) / resize_h + else: + ratio = float(max_side_len) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + + return im, (ratio_h, ratio_w) diff --git a/ppocr/data/imaug/pg_process.py b/ppocr/data/imaug/pg_process.py new file mode 100644 index 0000000000000000000000000000000000000000..0c9439d7a274af27ca8d296d5e737bafdec3bd1f --- /dev/null +++ b/ppocr/data/imaug/pg_process.py @@ -0,0 +1,906 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import cv2 +import numpy as np + +__all__ = ['PGProcessTrain'] + + +class PGProcessTrain(object): + def __init__(self, + character_dict_path, + max_text_length, + max_text_nums, + tcl_len, + batch_size=14, + min_crop_size=24, + min_text_size=4, + max_text_size=512, + **kwargs): + self.tcl_len = tcl_len + self.max_text_length = max_text_length + self.max_text_nums = max_text_nums + self.batch_size = batch_size + self.min_crop_size = min_crop_size + self.min_text_size = min_text_size + self.max_text_size = max_text_size + self.Lexicon_Table = self.get_dict(character_dict_path) + self.pad_num = len(self.Lexicon_Table) + self.img_id = 0 + + def get_dict(self, character_dict_path): + character_str = "" + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip("\n").strip("\r\n") + character_str += line + dict_character = list(character_str) + return dict_character + + def quad_area(self, poly): + """ + compute area of a polygon + :param poly: + :return: + """ + edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]), + (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]), + (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]), + (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])] + return np.sum(edge) / 2. + + def gen_quad_from_poly(self, poly): + """ + Generate min area quad from poly. + """ + point_num = poly.shape[0] + min_area_quad = np.zeros((4, 2), dtype=np.float32) + rect = cv2.minAreaRect(poly.astype( + np.int32)) # (center (x,y), (width, height), angle of rotation) + box = np.array(cv2.boxPoints(rect)) + + first_point_idx = 0 + min_dist = 1e4 + for i in range(4): + dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \ + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \ + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \ + np.linalg.norm(box[(i + 3) % 4] - poly[-1]) + if dist < min_dist: + min_dist = dist + first_point_idx = i + for i in range(4): + min_area_quad[i] = box[(first_point_idx + i) % 4] + + return min_area_quad + + def check_and_validate_polys(self, polys, tags, xxx_todo_changeme): + """ + check so that the text poly is in the same direction, + and also filter some invalid polygons + :param polys: + :param tags: + :return: + """ + (h, w) = xxx_todo_changeme + if polys.shape[0] == 0: + return polys, np.array([]), np.array([]) + polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1) + polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1) + + validated_polys = [] + validated_tags = [] + hv_tags = [] + for poly, tag in zip(polys, tags): + quad = self.gen_quad_from_poly(poly) + p_area = self.quad_area(quad) + if abs(p_area) < 1: + print('invalid poly') + continue + if p_area > 0: + if tag == False: + print('poly in wrong direction') + tag = True # reversed cases should be ignore + poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, + 1), :] + quad = quad[(0, 3, 2, 1), :] + + len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] - + quad[2]) + len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - + quad[2]) + hv_tag = 1 + + if len_w * 2.0 < len_h: + hv_tag = 0 + + validated_polys.append(poly) + validated_tags.append(tag) + hv_tags.append(hv_tag) + return np.array(validated_polys), np.array(validated_tags), np.array( + hv_tags) + + def crop_area(self, + im, + polys, + tags, + hv_tags, + txts, + crop_background=False, + max_tries=25): + """ + make random crop from the input image + :param im: + :param polys: [b,4,2] + :param tags: + :param crop_background: + :param max_tries: 50 -> 25 + :return: + """ + h, w, _ = im.shape + pad_h = h // 10 + pad_w = w // 10 + h_array = np.zeros((h + pad_h * 2), dtype=np.int32) + w_array = np.zeros((w + pad_w * 2), dtype=np.int32) + for poly in polys: + poly = np.round(poly, decimals=0).astype(np.int32) + minx = np.min(poly[:, 0]) + maxx = np.max(poly[:, 0]) + w_array[minx + pad_w:maxx + pad_w] = 1 + miny = np.min(poly[:, 1]) + maxy = np.max(poly[:, 1]) + h_array[miny + pad_h:maxy + pad_h] = 1 + # ensure the cropped area not across a text + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + if len(h_axis) == 0 or len(w_axis) == 0: + return im, polys, tags, hv_tags, txts + for i in range(max_tries): + xx = np.random.choice(w_axis, size=2) + xmin = np.min(xx) - pad_w + xmax = np.max(xx) - pad_w + xmin = np.clip(xmin, 0, w - 1) + xmax = np.clip(xmax, 0, w - 1) + yy = np.random.choice(h_axis, size=2) + ymin = np.min(yy) - pad_h + ymax = np.max(yy) - pad_h + ymin = np.clip(ymin, 0, h - 1) + ymax = np.clip(ymax, 0, h - 1) + if xmax - xmin < self.min_crop_size or \ + ymax - ymin < self.min_crop_size: + continue + if polys.shape[0] != 0: + poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \ + & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax) + selected_polys = np.where( + np.sum(poly_axis_in_area, axis=1) == 4)[0] + else: + selected_polys = [] + if len(selected_polys) == 0: + # no text in this area + if crop_background: + txts_tmp = [] + for selected_poly in selected_polys: + txts_tmp.append(txts[selected_poly]) + txts = txts_tmp + return im[ymin: ymax + 1, xmin: xmax + 1, :], \ + polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts + else: + continue + im = im[ymin:ymax + 1, xmin:xmax + 1, :] + polys = polys[selected_polys] + tags = tags[selected_polys] + hv_tags = hv_tags[selected_polys] + txts_tmp = [] + for selected_poly in selected_polys: + txts_tmp.append(txts[selected_poly]) + txts = txts_tmp + polys[:, :, 0] -= xmin + polys[:, :, 1] -= ymin + return im, polys, tags, hv_tags, txts + + return im, polys, tags, hv_tags, txts + + def fit_and_gather_tcl_points_v2(self, + min_area_quad, + poly, + max_h, + max_w, + fixed_point_num=64, + img_id=0, + reference_height=3): + """ + Find the center point of poly as key_points, then fit and gather. + """ + key_point_xys = [] + point_num = poly.shape[0] + for idx in range(point_num // 2): + center_point = (poly[idx] + poly[point_num - 1 - idx]) / 2.0 + key_point_xys.append(center_point) + + tmp_image = np.zeros( + shape=( + max_h, + max_w, ), dtype='float32') + cv2.polylines(tmp_image, [np.array(key_point_xys).astype('int32')], + False, 1.0) + ys, xs = np.where(tmp_image > 0) + xy_text = np.array(list(zip(xs, ys)), dtype='float32') + + left_center_pt = ( + (min_area_quad[0] - min_area_quad[1]) / 2.0).reshape(1, 2) + right_center_pt = ( + (min_area_quad[1] - min_area_quad[2]) / 2.0).reshape(1, 2) + proj_unit_vec = (right_center_pt - left_center_pt) / ( + np.linalg.norm(right_center_pt - left_center_pt) + 1e-6) + proj_unit_vec_tile = np.tile(proj_unit_vec, + (xy_text.shape[0], 1)) # (n, 2) + left_center_pt_tile = np.tile(left_center_pt, + (xy_text.shape[0], 1)) # (n, 2) + xy_text_to_left_center = xy_text - left_center_pt_tile + proj_value = np.sum(xy_text_to_left_center * proj_unit_vec_tile, axis=1) + xy_text = xy_text[np.argsort(proj_value)] + + # convert to np and keep the num of point not greater then fixed_point_num + pos_info = np.array(xy_text).reshape(-1, 2)[:, ::-1] # xy-> yx + point_num = len(pos_info) + if point_num > fixed_point_num: + keep_ids = [ + int((point_num * 1.0 / fixed_point_num) * x) + for x in range(fixed_point_num) + ] + pos_info = pos_info[keep_ids, :] + + keep = int(min(len(pos_info), fixed_point_num)) + if np.random.rand() < 0.2 and reference_height >= 3: + dl = (np.random.rand(keep) - 0.5) * reference_height * 0.3 + random_float = np.array([1, 0]).reshape([1, 2]) * dl.reshape( + [keep, 1]) + pos_info += random_float + pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1) + pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1) + + # padding to fixed length + pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32) + pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id + pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32) + pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32) + pos_m[:keep] = 1.0 + return pos_l, pos_m + + def generate_direction_map(self, poly_quads, n_char, direction_map): + """ + """ + width_list = [] + height_list = [] + for quad in poly_quads: + quad_w = (np.linalg.norm(quad[0] - quad[1]) + + np.linalg.norm(quad[2] - quad[3])) / 2.0 + quad_h = (np.linalg.norm(quad[0] - quad[3]) + + np.linalg.norm(quad[2] - quad[1])) / 2.0 + width_list.append(quad_w) + height_list.append(quad_h) + norm_width = max(sum(width_list) / n_char, 1.0) + average_height = max(sum(height_list) / len(height_list), 1.0) + k = 1 + for quad in poly_quads: + direct_vector_full = ( + (quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0 + direct_vector = direct_vector_full / ( + np.linalg.norm(direct_vector_full) + 1e-6) * norm_width + direction_label = tuple( + map(float, + [direct_vector[0], direct_vector[1], 1.0 / average_height])) + cv2.fillPoly(direction_map, + quad.round().astype(np.int32)[np.newaxis, :, :], + direction_label) + k += 1 + return direction_map + + def calculate_average_height(self, poly_quads): + """ + """ + height_list = [] + for quad in poly_quads: + quad_h = (np.linalg.norm(quad[0] - quad[3]) + + np.linalg.norm(quad[2] - quad[1])) / 2.0 + height_list.append(quad_h) + average_height = max(sum(height_list) / len(height_list), 1.0) + return average_height + + def generate_tcl_ctc_label(self, + h, + w, + polys, + tags, + text_strs, + ds_ratio, + tcl_ratio=0.3, + shrink_ratio_of_width=0.15): + """ + Generate polygon. + """ + score_map_big = np.zeros( + ( + h, + w, ), dtype=np.float32) + h, w = int(h * ds_ratio), int(w * ds_ratio) + polys = polys * ds_ratio + + score_map = np.zeros( + ( + h, + w, ), dtype=np.float32) + score_label_map = np.zeros( + ( + h, + w, ), dtype=np.float32) + tbo_map = np.zeros((h, w, 5), dtype=np.float32) + training_mask = np.ones( + ( + h, + w, ), dtype=np.float32) + direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape( + [1, 1, 3]).astype(np.float32) + + label_idx = 0 + score_label_map_text_label_list = [] + pos_list, pos_mask, label_list = [], [], [] + for poly_idx, poly_tag in enumerate(zip(polys, tags)): + poly = poly_tag[0] + tag = poly_tag[1] + + # generate min_area_quad + min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly) + min_area_quad_h = 0.5 * ( + np.linalg.norm(min_area_quad[0] - min_area_quad[3]) + + np.linalg.norm(min_area_quad[1] - min_area_quad[2])) + min_area_quad_w = 0.5 * ( + np.linalg.norm(min_area_quad[0] - min_area_quad[1]) + + np.linalg.norm(min_area_quad[2] - min_area_quad[3])) + + if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \ + or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio: + continue + + if tag: + cv2.fillPoly(training_mask, + poly.astype(np.int32)[np.newaxis, :, :], 0.15) + else: + text_label = text_strs[poly_idx] + text_label = self.prepare_text_label(text_label, + self.Lexicon_Table) + + text_label_index_list = [[self.Lexicon_Table.index(c_)] + for c_ in text_label + if c_ in self.Lexicon_Table] + if len(text_label_index_list) < 1: + continue + + tcl_poly = self.poly2tcl(poly, tcl_ratio) + tcl_quads = self.poly2quads(tcl_poly) + poly_quads = self.poly2quads(poly) + + stcl_quads, quad_index = self.shrink_poly_along_width( + tcl_quads, + shrink_ratio_of_width=shrink_ratio_of_width, + expand_height_ratio=1.0 / tcl_ratio) + + cv2.fillPoly(score_map, + np.round(stcl_quads).astype(np.int32), 1.0) + cv2.fillPoly(score_map_big, + np.round(stcl_quads / ds_ratio).astype(np.int32), + 1.0) + + for idx, quad in enumerate(stcl_quads): + quad_mask = np.zeros((h, w), dtype=np.float32) + quad_mask = cv2.fillPoly( + quad_mask, + np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0) + tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]], + quad_mask, tbo_map) + + # score label map and score_label_map_text_label_list for refine + if label_idx == 0: + text_pos_list_ = [[len(self.Lexicon_Table)], ] + score_label_map_text_label_list.append(text_pos_list_) + + label_idx += 1 + cv2.fillPoly(score_label_map, + np.round(poly_quads).astype(np.int32), label_idx) + score_label_map_text_label_list.append(text_label_index_list) + + # direction info, fix-me + n_char = len(text_label_index_list) + direction_map = self.generate_direction_map(poly_quads, n_char, + direction_map) + + # pos info + average_shrink_height = self.calculate_average_height( + stcl_quads) + pos_l, pos_m = self.fit_and_gather_tcl_points_v2( + min_area_quad, + poly, + max_h=h, + max_w=w, + fixed_point_num=64, + img_id=self.img_id, + reference_height=average_shrink_height) + + label_l = text_label_index_list + if len(text_label_index_list) < 2: + continue + + pos_list.append(pos_l) + pos_mask.append(pos_m) + label_list.append(label_l) + + # use big score_map for smooth tcl lines + score_map_big_resized = cv2.resize( + score_map_big, dsize=None, fx=ds_ratio, fy=ds_ratio) + score_map = np.array(score_map_big_resized > 1e-3, dtype='float32') + + return score_map, score_label_map, tbo_map, direction_map, training_mask, \ + pos_list, pos_mask, label_list, score_label_map_text_label_list + + def adjust_point(self, poly): + """ + adjust point order. + """ + point_num = poly.shape[0] + if point_num == 4: + len_1 = np.linalg.norm(poly[0] - poly[1]) + len_2 = np.linalg.norm(poly[1] - poly[2]) + len_3 = np.linalg.norm(poly[2] - poly[3]) + len_4 = np.linalg.norm(poly[3] - poly[0]) + + if (len_1 + len_3) * 1.5 < (len_2 + len_4): + poly = poly[[1, 2, 3, 0], :] + + elif point_num > 4: + vector_1 = poly[0] - poly[1] + vector_2 = poly[1] - poly[2] + cos_theta = np.dot(vector_1, vector_2) / ( + np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6) + theta = np.arccos(np.round(cos_theta, decimals=4)) + + if abs(theta) > (70 / 180 * math.pi): + index = list(range(1, point_num)) + [0] + poly = poly[np.array(index), :] + return poly + + def gen_min_area_quad_from_poly(self, poly): + """ + Generate min area quad from poly. + """ + point_num = poly.shape[0] + min_area_quad = np.zeros((4, 2), dtype=np.float32) + if point_num == 4: + min_area_quad = poly + center_point = np.sum(poly, axis=0) / 4 + else: + rect = cv2.minAreaRect(poly.astype( + np.int32)) # (center (x,y), (width, height), angle of rotation) + center_point = rect[0] + box = np.array(cv2.boxPoints(rect)) + + first_point_idx = 0 + min_dist = 1e4 + for i in range(4): + dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \ + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \ + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \ + np.linalg.norm(box[(i + 3) % 4] - poly[-1]) + if dist < min_dist: + min_dist = dist + first_point_idx = i + + for i in range(4): + min_area_quad[i] = box[(first_point_idx + i) % 4] + + return min_area_quad, center_point + + def shrink_quad_along_width(self, + quad, + begin_width_ratio=0., + end_width_ratio=1.): + """ + Generate shrink_quad_along_width. + """ + ratio_pair = np.array( + [[begin_width_ratio], [end_width_ratio]], dtype=np.float32) + p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair + p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair + return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) + + def shrink_poly_along_width(self, + quads, + shrink_ratio_of_width, + expand_height_ratio=1.0): + """ + shrink poly with given length. + """ + upper_edge_list = [] + + def get_cut_info(edge_len_list, cut_len): + for idx, edge_len in enumerate(edge_len_list): + cut_len -= edge_len + if cut_len <= 0.000001: + ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx] + return idx, ratio + + for quad in quads: + upper_edge_len = np.linalg.norm(quad[0] - quad[1]) + upper_edge_list.append(upper_edge_len) + + # length of left edge and right edge. + left_length = np.linalg.norm(quads[0][0] - quads[0][ + 3]) * expand_height_ratio + right_length = np.linalg.norm(quads[-1][1] - quads[-1][ + 2]) * expand_height_ratio + + shrink_length = min(left_length, right_length, + sum(upper_edge_list)) * shrink_ratio_of_width + # shrinking length + upper_len_left = shrink_length + upper_len_right = sum(upper_edge_list) - shrink_length + + left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left) + left_quad = self.shrink_quad_along_width( + quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1) + right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right) + right_quad = self.shrink_quad_along_width( + quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio) + + out_quad_list = [] + if left_idx == right_idx: + out_quad_list.append( + [left_quad[0], right_quad[1], right_quad[2], left_quad[3]]) + else: + out_quad_list.append(left_quad) + for idx in range(left_idx + 1, right_idx): + out_quad_list.append(quads[idx]) + out_quad_list.append(right_quad) + + return np.array(out_quad_list), list(range(left_idx, right_idx + 1)) + + def prepare_text_label(self, label_str, Lexicon_Table): + """ + Prepare text lablel by given Lexicon_Table. + """ + if len(Lexicon_Table) == 36: + return label_str.lower() + else: + return label_str + + def vector_angle(self, A, B): + """ + Calculate the angle between vector AB and x-axis positive direction. + """ + AB = np.array([B[1] - A[1], B[0] - A[0]]) + return np.arctan2(*AB) + + def theta_line_cross_point(self, theta, point): + """ + Calculate the line through given point and angle in ax + by + c =0 form. + """ + x, y = point + cos = np.cos(theta) + sin = np.sin(theta) + return [sin, -cos, cos * y - sin * x] + + def line_cross_two_point(self, A, B): + """ + Calculate the line through given point A and B in ax + by + c =0 form. + """ + angle = self.vector_angle(A, B) + return self.theta_line_cross_point(angle, A) + + def average_angle(self, poly): + """ + Calculate the average angle between left and right edge in given poly. + """ + p0, p1, p2, p3 = poly + angle30 = self.vector_angle(p3, p0) + angle21 = self.vector_angle(p2, p1) + return (angle30 + angle21) / 2 + + def line_cross_point(self, line1, line2): + """ + line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2 + """ + a1, b1, c1 = line1 + a2, b2, c2 = line2 + d = a1 * b2 - a2 * b1 + + if d == 0: + print('Cross point does not exist') + return np.array([0, 0], dtype=np.float32) + else: + x = (b1 * c2 - b2 * c1) / d + y = (a2 * c1 - a1 * c2) / d + + return np.array([x, y], dtype=np.float32) + + def quad2tcl(self, poly, ratio): + """ + Generate center line by poly clock-wise point. (4, 2) + """ + ratio_pair = np.array( + [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32) + p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair + p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair + return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]]) + + def poly2tcl(self, poly, ratio): + """ + Generate center line by poly clock-wise point. + """ + ratio_pair = np.array( + [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32) + tcl_poly = np.zeros_like(poly) + point_num = poly.shape[0] + + for idx in range(point_num // 2): + point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx] + ) * ratio_pair + tcl_poly[idx] = point_pair[0] + tcl_poly[point_num - 1 - idx] = point_pair[1] + return tcl_poly + + def gen_quad_tbo(self, quad, tcl_mask, tbo_map): + """ + Generate tbo_map for give quad. + """ + # upper and lower line function: ax + by + c = 0; + up_line = self.line_cross_two_point(quad[0], quad[1]) + lower_line = self.line_cross_two_point(quad[3], quad[2]) + + quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) + + np.linalg.norm(quad[1] - quad[2])) + quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) + + np.linalg.norm(quad[2] - quad[3])) + + # average angle of left and right line. + angle = self.average_angle(quad) + + xy_in_poly = np.argwhere(tcl_mask == 1) + for y, x in xy_in_poly: + point = (x, y) + line = self.theta_line_cross_point(angle, point) + cross_point_upper = self.line_cross_point(up_line, line) + cross_point_lower = self.line_cross_point(lower_line, line) + ##FIX, offset reverse + upper_offset_x, upper_offset_y = cross_point_upper - point + lower_offset_x, lower_offset_y = cross_point_lower - point + tbo_map[y, x, 0] = upper_offset_y + tbo_map[y, x, 1] = upper_offset_x + tbo_map[y, x, 2] = lower_offset_y + tbo_map[y, x, 3] = lower_offset_x + tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2 + return tbo_map + + def poly2quads(self, poly): + """ + Split poly into quads. + """ + quad_list = [] + point_num = poly.shape[0] + + # point pair + point_pair_list = [] + for idx in range(point_num // 2): + point_pair = [poly[idx], poly[point_num - 1 - idx]] + point_pair_list.append(point_pair) + + quad_num = point_num // 2 - 1 + for idx in range(quad_num): + # reshape and adjust to clock-wise + quad_list.append((np.array(point_pair_list)[[idx, idx + 1]] + ).reshape(4, 2)[[0, 2, 3, 1]]) + + return np.array(quad_list) + + def rotate_im_poly(self, im, text_polys): + """ + rotate image with 90 / 180 / 270 degre + """ + im_w, im_h = im.shape[1], im.shape[0] + dst_im = im.copy() + dst_polys = [] + rand_degree_ratio = np.random.rand() + rand_degree_cnt = 1 + if rand_degree_ratio > 0.5: + rand_degree_cnt = 3 + for i in range(rand_degree_cnt): + dst_im = np.rot90(dst_im) + rot_degree = -90 * rand_degree_cnt + rot_angle = rot_degree * math.pi / 180.0 + n_poly = text_polys.shape[0] + cx, cy = 0.5 * im_w, 0.5 * im_h + ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0] + for i in range(n_poly): + wordBB = text_polys[i] + poly = [] + for j in range(4): # 16->4 + sx, sy = wordBB[j][0], wordBB[j][1] + dx = math.cos(rot_angle) * (sx - cx) - math.sin(rot_angle) * ( + sy - cy) + ncx + dy = math.sin(rot_angle) * (sx - cx) + math.cos(rot_angle) * ( + sy - cy) + ncy + poly.append([dx, dy]) + dst_polys.append(poly) + return dst_im, np.array(dst_polys, dtype=np.float32) + + def __call__(self, data): + input_size = 512 + im = data['image'] + text_polys = data['polys'] + text_tags = data['tags'] + text_strs = data['strs'] + h, w, _ = im.shape + text_polys, text_tags, hv_tags = self.check_and_validate_polys( + text_polys, text_tags, (h, w)) + if text_polys.shape[0] <= 0: + return None + # set aspect ratio and keep area fix + asp_scales = np.arange(1.0, 1.55, 0.1) + asp_scale = np.random.choice(asp_scales) + if np.random.rand() < 0.5: + asp_scale = 1.0 / asp_scale + asp_scale = math.sqrt(asp_scale) + + asp_wx = asp_scale + asp_hy = 1.0 / asp_scale + im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy) + text_polys[:, :, 0] *= asp_wx + text_polys[:, :, 1] *= asp_hy + + h, w, _ = im.shape + if max(h, w) > 2048: + rd_scale = 2048.0 / max(h, w) + im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) + text_polys *= rd_scale + h, w, _ = im.shape + if min(h, w) < 16: + return None + + # no background + im, text_polys, text_tags, hv_tags, text_strs = self.crop_area( + im, + text_polys, + text_tags, + hv_tags, + text_strs, + crop_background=False) + + if text_polys.shape[0] == 0: + return None + # # continue for all ignore case + if np.sum((text_tags * 1.0)) >= text_tags.size: + return None + new_h, new_w, _ = im.shape + if (new_h is None) or (new_w is None): + return None + # resize image + std_ratio = float(input_size) / max(new_w, new_h) + rand_scales = np.array( + [0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0]) + rz_scale = std_ratio * np.random.choice(rand_scales) + im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale) + text_polys[:, :, 0] *= rz_scale + text_polys[:, :, 1] *= rz_scale + + # add gaussian blur + if np.random.rand() < 0.1 * 0.5: + ks = np.random.permutation(5)[0] + 1 + ks = int(ks / 2) * 2 + 1 + im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0) + # add brighter + if np.random.rand() < 0.1 * 0.5: + im = im * (1.0 + np.random.rand() * 0.5) + im = np.clip(im, 0.0, 255.0) + # add darker + if np.random.rand() < 0.1 * 0.5: + im = im * (1.0 - np.random.rand() * 0.5) + im = np.clip(im, 0.0, 255.0) + + # Padding the im to [input_size, input_size] + new_h, new_w, _ = im.shape + if min(new_w, new_h) < input_size * 0.5: + return None + im_padded = np.ones((input_size, input_size, 3), dtype=np.float32) + im_padded[:, :, 2] = 0.485 * 255 + im_padded[:, :, 1] = 0.456 * 255 + im_padded[:, :, 0] = 0.406 * 255 + + # Random the start position + del_h = input_size - new_h + del_w = input_size - new_w + sh, sw = 0, 0 + if del_h > 1: + sh = int(np.random.rand() * del_h) + if del_w > 1: + sw = int(np.random.rand() * del_w) + + # Padding + im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy() + text_polys[:, :, 0] += sw + text_polys[:, :, 1] += sh + + score_map, score_label_map, border_map, direction_map, training_mask, \ + pos_list, pos_mask, label_list, score_label_map_text_label = self.generate_tcl_ctc_label(input_size, + input_size, + text_polys, + text_tags, + text_strs, 0.25) + if len(label_list) <= 0: # eliminate negative samples + return None + pos_list_temp = np.zeros([64, 3]) + pos_mask_temp = np.zeros([64, 1]) + label_list_temp = np.zeros([self.max_text_length, 1]) + self.pad_num + + for i, label in enumerate(label_list): + n = len(label) + if n > self.max_text_length: + label_list[i] = label[:self.max_text_length] + continue + while n < self.max_text_length: + label.append([self.pad_num]) + n += 1 + + for i in range(len(label_list)): + label_list[i] = np.array(label_list[i]) + + if len(pos_list) <= 0 or len(pos_list) > self.max_text_nums: + return None + for __ in range(self.max_text_nums - len(pos_list), 0, -1): + pos_list.append(pos_list_temp) + pos_mask.append(pos_mask_temp) + label_list.append(label_list_temp) + + if self.img_id == self.batch_size - 1: + self.img_id = 0 + else: + self.img_id += 1 + + im_padded[:, :, 2] -= 0.485 * 255 + im_padded[:, :, 1] -= 0.456 * 255 + im_padded[:, :, 0] -= 0.406 * 255 + im_padded[:, :, 2] /= (255.0 * 0.229) + im_padded[:, :, 1] /= (255.0 * 0.224) + im_padded[:, :, 0] /= (255.0 * 0.225) + im_padded = im_padded.transpose((2, 0, 1)) + images = im_padded[::-1, :, :] + tcl_maps = score_map[np.newaxis, :, :] + tcl_label_maps = score_label_map[np.newaxis, :, :] + border_maps = border_map.transpose((2, 0, 1)) + direction_maps = direction_map.transpose((2, 0, 1)) + training_masks = training_mask[np.newaxis, :, :] + pos_list = np.array(pos_list) + pos_mask = np.array(pos_mask) + label_list = np.array(label_list) + data['images'] = images + data['tcl_maps'] = tcl_maps + data['tcl_label_maps'] = tcl_label_maps + data['border_maps'] = border_maps + data['direction_maps'] = direction_maps + data['training_masks'] = training_masks + data['label_list'] = label_list + data['pos_list'] = pos_list + data['pos_mask'] = pos_mask + return data diff --git a/ppocr/data/pgnet_dataset.py b/ppocr/data/pgnet_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..543dbe79ef6ff548c91b4e17bf6424797cdeeea7 --- /dev/null +++ b/ppocr/data/pgnet_dataset.py @@ -0,0 +1,170 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import os +from paddle.io import Dataset +from .imaug import transform, create_operators +import random + + +class PGDataSet(Dataset): + def __init__(self, config, mode, logger, seed=None): + super(PGDataSet, self).__init__() + + self.logger = logger + self.seed = seed + self.mode = mode + global_config = config['Global'] + dataset_config = config[mode]['dataset'] + loader_config = config[mode]['loader'] + + label_file_list = dataset_config.pop('label_file_list') + data_source_num = len(label_file_list) + ratio_list = dataset_config.get("ratio_list", [1.0]) + if isinstance(ratio_list, (float, int)): + ratio_list = [float(ratio_list)] * int(data_source_num) + self.data_format = dataset_config.get('data_format', 'icdar') + assert len( + ratio_list + ) == data_source_num, "The length of ratio_list should be the same as the file_list." + self.do_shuffle = loader_config['shuffle'] + + logger.info("Initialize indexs of datasets:%s" % label_file_list) + self.data_lines = self.get_image_info_list(label_file_list, ratio_list, + self.data_format) + self.data_idx_order_list = list(range(len(self.data_lines))) + if mode.lower() == "train": + self.shuffle_data_random() + + self.ops = create_operators(dataset_config['transforms'], global_config) + + def shuffle_data_random(self): + if self.do_shuffle: + random.seed(self.seed) + random.shuffle(self.data_lines) + return + + def extract_polys(self, poly_txt_path): + """ + Read text_polys, txt_tags, txts from give txt file. + """ + text_polys, txt_tags, txts = [], [], [] + with open(poly_txt_path) as f: + for line in f.readlines(): + poly_str, txt = line.strip().split('\t') + poly = list(map(float, poly_str.split(','))) + text_polys.append( + np.array( + poly, dtype=np.float32).reshape(-1, 2)) + txts.append(txt) + txt_tags.append(txt == '###') + + return np.array(list(map(np.array, text_polys))), \ + np.array(txt_tags, dtype=np.bool), txts + + def extract_info_textnet(self, im_fn, img_dir=''): + """ + Extract information from line in textnet format. + """ + info_list = im_fn.split('\t') + img_path = '' + for ext in [ + 'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'JPG' + ]: + if os.path.exists(os.path.join(img_dir, info_list[0] + "." + ext)): + img_path = os.path.join(img_dir, info_list[0] + "." + ext) + break + + if img_path == '': + print('Image {0} NOT found in {1}, and it will be ignored.'.format( + info_list[0], img_dir)) + + nBox = (len(info_list) - 1) // 9 + wordBBs, txts, txt_tags = [], [], [] + for n in range(0, nBox): + wordBB = list(map(float, info_list[n * 9 + 1:(n + 1) * 9])) + txt = info_list[(n + 1) * 9] + wordBBs.append([[wordBB[0], wordBB[1]], [wordBB[2], wordBB[3]], + [wordBB[4], wordBB[5]], [wordBB[6], wordBB[7]]]) + txts.append(txt) + if txt == '###': + txt_tags.append(True) + else: + txt_tags.append(False) + return img_path, np.array(wordBBs, dtype=np.float32), txt_tags, txts + + def get_image_info_list(self, file_list, ratio_list, data_format='textnet'): + if isinstance(file_list, str): + file_list = [file_list] + data_lines = [] + for idx, data_source in enumerate(file_list): + image_files = [] + if data_format == 'icdar': + image_files = [(data_source, x) for x in + os.listdir(os.path.join(data_source, 'rgb')) + if x.split('.')[-1] in [ + 'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', + 'tiff', 'gif', 'JPG' + ]] + elif data_format == 'textnet': + with open(data_source) as f: + image_files = [(data_source, x.strip()) + for x in f.readlines()] + else: + print("Unrecognized data format...") + exit(-1) + random.seed(self.seed) + image_files = random.sample( + image_files, round(len(image_files) * ratio_list[idx])) + data_lines.extend(image_files) + return data_lines + + def __getitem__(self, idx): + file_idx = self.data_idx_order_list[idx] + data_path, data_line = self.data_lines[file_idx] + try: + if self.data_format == 'icdar': + im_path = os.path.join(data_path, 'rgb', data_line) + poly_path = os.path.join(data_path, 'poly', + data_line.split('.')[0] + '.txt') + text_polys, text_tags, text_strs = self.extract_polys(poly_path) + else: + image_dir = os.path.join(os.path.dirname(data_path), 'image') + im_path, text_polys, text_tags, text_strs = self.extract_info_textnet( + data_line, image_dir) + img_id = int(data_line.split(".")[0][3:]) + + data = { + 'img_path': im_path, + 'polys': text_polys, + 'tags': text_tags, + 'strs': text_strs, + 'img_id': img_id + } + with open(data['img_path'], 'rb') as f: + img = f.read() + data['image'] = img + outs = transform(data, self.ops) + + except Exception as e: + self.logger.error( + "When parsing line {}, error happened with msg: {}".format( + self.data_idx_order_list[idx], e)) + outs = None + if outs is None: + return self.__getitem__(np.random.randint(self.__len__())) + return outs + + def __len__(self): + return len(self.data_idx_order_list) diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index d2a86b0f855c3d25754d78d5fa00c8fb4bcf4db9..ea57b785460fe31414ab37b0b649ec174dcf3122 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -23,6 +23,7 @@ class SimpleDataSet(Dataset): def __init__(self, config, mode, logger, seed=None): super(SimpleDataSet, self).__init__() self.logger = logger + self.mode = mode.lower() global_config = config['Global'] dataset_config = config[mode]['dataset'] @@ -45,7 +46,7 @@ class SimpleDataSet(Dataset): logger.info("Initialize indexs of datasets:%s" % label_file_list) self.data_lines = self.get_image_info_list(label_file_list, ratio_list) self.data_idx_order_list = list(range(len(self.data_lines))) - if mode.lower() == "train": + if self.mode == "train" and self.do_shuffle: self.shuffle_data_random() self.ops = create_operators(dataset_config['transforms'], global_config) @@ -56,16 +57,16 @@ class SimpleDataSet(Dataset): for idx, file in enumerate(file_list): with open(file, "rb") as f: lines = f.readlines() - random.seed(self.seed) - lines = random.sample(lines, - round(len(lines) * ratio_list[idx])) + if self.mode == "train" or ratio_list[idx] < 1.0: + random.seed(self.seed) + lines = random.sample(lines, + round(len(lines) * ratio_list[idx])) data_lines.extend(lines) return data_lines def shuffle_data_random(self): - if self.do_shuffle: - random.seed(self.seed) - random.shuffle(self.data_lines) + random.seed(self.seed) + random.shuffle(self.data_lines) return def __getitem__(self, idx): @@ -90,7 +91,10 @@ class SimpleDataSet(Dataset): data_line, e)) outs = None if outs is None: - return self.__getitem__(np.random.randint(self.__len__())) + # during evaluation, we should fix the idx to get same results for many times of evaluation. + rnd_idx = np.random.randint(self.__len__( + )) if self.mode == "train" else (idx + 1) % self.__len__() + return self.__getitem__(rnd_idx) return outs def __len__(self): diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 3881abf7741b8be78306bd070afb11df15606327..223ae6b1da996478ac607e29dd37173ca51d9903 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -29,10 +29,11 @@ def build_loss(config): # cls loss from .cls_loss import ClsLoss + # e2e loss + from .e2e_pg_loss import PGLoss support_dict = [ 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', - 'SRNLoss' - ] + 'SRNLoss', 'PGLoss'] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/e2e_pg_loss.py b/ppocr/losses/e2e_pg_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..10a8ed0aa907123b155976ba498426604f23c2b0 --- /dev/null +++ b/ppocr/losses/e2e_pg_loss.py @@ -0,0 +1,140 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import nn +import paddle + +from .det_basic_loss import DiceLoss +from ppocr.utils.e2e_utils.extract_batchsize import pre_process + + +class PGLoss(nn.Layer): + def __init__(self, + tcl_bs, + max_text_length, + max_text_nums, + pad_num, + eps=1e-6, + **kwargs): + super(PGLoss, self).__init__() + self.tcl_bs = tcl_bs + self.max_text_nums = max_text_nums + self.max_text_length = max_text_length + self.pad_num = pad_num + self.dice_loss = DiceLoss(eps=eps) + + def border_loss(self, f_border, l_border, l_score, l_mask): + l_border_split, l_border_norm = paddle.tensor.split( + l_border, num_or_sections=[4, 1], axis=1) + f_border_split = f_border + b, c, h, w = l_border_norm.shape + l_border_norm_split = paddle.expand( + x=l_border_norm, shape=[b, 4 * c, h, w]) + b, c, h, w = l_score.shape + l_border_score = paddle.expand(x=l_score, shape=[b, 4 * c, h, w]) + b, c, h, w = l_mask.shape + l_border_mask = paddle.expand(x=l_mask, shape=[b, 4 * c, h, w]) + border_diff = l_border_split - f_border_split + abs_border_diff = paddle.abs(border_diff) + border_sign = abs_border_diff < 1.0 + border_sign = paddle.cast(border_sign, dtype='float32') + border_sign.stop_gradient = True + border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \ + (abs_border_diff - 0.5) * (1.0 - border_sign) + border_out_loss = l_border_norm_split * border_in_loss + border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \ + (paddle.sum(l_border_score * l_border_mask) + 1e-5) + return border_loss + + def direction_loss(self, f_direction, l_direction, l_score, l_mask): + l_direction_split, l_direction_norm = paddle.tensor.split( + l_direction, num_or_sections=[2, 1], axis=1) + f_direction_split = f_direction + b, c, h, w = l_direction_norm.shape + l_direction_norm_split = paddle.expand( + x=l_direction_norm, shape=[b, 2 * c, h, w]) + b, c, h, w = l_score.shape + l_direction_score = paddle.expand(x=l_score, shape=[b, 2 * c, h, w]) + b, c, h, w = l_mask.shape + l_direction_mask = paddle.expand(x=l_mask, shape=[b, 2 * c, h, w]) + direction_diff = l_direction_split - f_direction_split + abs_direction_diff = paddle.abs(direction_diff) + direction_sign = abs_direction_diff < 1.0 + direction_sign = paddle.cast(direction_sign, dtype='float32') + direction_sign.stop_gradient = True + direction_in_loss = 0.5 * abs_direction_diff * abs_direction_diff * direction_sign + \ + (abs_direction_diff - 0.5) * (1.0 - direction_sign) + direction_out_loss = l_direction_norm_split * direction_in_loss + direction_loss = paddle.sum(direction_out_loss * l_direction_score * l_direction_mask) / \ + (paddle.sum(l_direction_score * l_direction_mask) + 1e-5) + return direction_loss + + def ctcloss(self, f_char, tcl_pos, tcl_mask, tcl_label, label_t): + f_char = paddle.transpose(f_char, [0, 2, 3, 1]) + tcl_pos = paddle.reshape(tcl_pos, [-1, 3]) + tcl_pos = paddle.cast(tcl_pos, dtype=int) + f_tcl_char = paddle.gather_nd(f_char, tcl_pos) + f_tcl_char = paddle.reshape(f_tcl_char, + [-1, 64, 37]) # len(Lexicon_Table)+1 + f_tcl_char_fg, f_tcl_char_bg = paddle.split(f_tcl_char, [36, 1], axis=2) + f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0 + b, c, l = tcl_mask.shape + tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, 36 * l]) + tcl_mask_fg.stop_gradient = True + f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * ( + -20.0) + f_tcl_char_mask = paddle.concat([f_tcl_char_fg, f_tcl_char_bg], axis=2) + f_tcl_char_ld = paddle.transpose(f_tcl_char_mask, (1, 0, 2)) + N, B, _ = f_tcl_char_ld.shape + input_lengths = paddle.to_tensor([N] * B, dtype='int64') + cost = paddle.nn.functional.ctc_loss( + log_probs=f_tcl_char_ld, + labels=tcl_label, + input_lengths=input_lengths, + label_lengths=label_t, + blank=self.pad_num, + reduction='none') + cost = cost.mean() + return cost + + def forward(self, predicts, labels): + images, tcl_maps, tcl_label_maps, border_maps \ + , direction_maps, training_masks, label_list, pos_list, pos_mask = labels + # for all the batch_size + pos_list, pos_mask, label_list, label_t = pre_process( + label_list, pos_list, pos_mask, self.max_text_length, + self.max_text_nums, self.pad_num, self.tcl_bs) + + f_score, f_border, f_direction, f_char = predicts['f_score'], predicts['f_border'], predicts['f_direction'], \ + predicts['f_char'] + score_loss = self.dice_loss(f_score, tcl_maps, training_masks) + border_loss = self.border_loss(f_border, border_maps, tcl_maps, + training_masks) + direction_loss = self.direction_loss(f_direction, direction_maps, + tcl_maps, training_masks) + ctc_loss = self.ctcloss(f_char, pos_list, pos_mask, label_list, label_t) + loss_all = score_loss + border_loss + direction_loss + 5 * ctc_loss + + losses = { + 'loss': loss_all, + "score_loss": score_loss, + "border_loss": border_loss, + "direction_loss": direction_loss, + "ctc_loss": ctc_loss + } + return losses diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py index a0e7d91207277d5c1696d99473f6bc5f685591fc..f913010dbd994633d3df1cf996abb994d246a11a 100644 --- a/ppocr/metrics/__init__.py +++ b/ppocr/metrics/__init__.py @@ -26,8 +26,9 @@ def build_metric(config): from .det_metric import DetMetric from .rec_metric import RecMetric from .cls_metric import ClsMetric + from .e2e_metric import E2EMetric - support_dict = ['DetMetric', 'RecMetric', 'ClsMetric'] + support_dict = ['DetMetric', 'RecMetric', 'ClsMetric', 'E2EMetric'] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/metrics/e2e_metric.py b/ppocr/metrics/e2e_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..8a604192fa455071202eec157e3832e2804bfdfd --- /dev/null +++ b/ppocr/metrics/e2e_metric.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__all__ = ['E2EMetric'] + +from ppocr.utils.e2e_metric.Deteval import get_socre, combine_results +from ppocr.utils.e2e_utils.extract_textpoint_slow import get_dict + + +class E2EMetric(object): + def __init__(self, + gt_mat_dir, + character_dict_path, + main_indicator='f_score_e2e', + **kwargs): + self.gt_mat_dir = gt_mat_dir + self.label_list = get_dict(character_dict_path) + self.max_index = len(self.label_list) + self.main_indicator = main_indicator + self.reset() + + def __call__(self, preds, batch, **kwargs): + img_id = batch[5][0] + e2e_info_list = [{ + 'points': det_polyon, + 'text': pred_str + } for det_polyon, pred_str in zip(preds['points'], preds['strs'])] + result = get_socre(self.gt_mat_dir, img_id, e2e_info_list) + self.results.append(result) + + def get_metric(self): + metircs = combine_results(self.results) + self.reset() + return metircs + + def reset(self): + self.results = [] # clear results diff --git a/ppocr/metrics/eval_det_iou.py b/ppocr/metrics/eval_det_iou.py index a2a3f41833a9ef7615b73b70808fcb3ba2f22aa4..0e32b2d19281de9a18a1fe0343bd7e8237825b7b 100644 --- a/ppocr/metrics/eval_det_iou.py +++ b/ppocr/metrics/eval_det_iou.py @@ -150,7 +150,7 @@ class DetectionIoUEvaluator(object): pairs.append({'gt': gtNum, 'det': detNum}) detMatchedNums.append(detNum) evaluationLog += "Match GT #" + \ - str(gtNum) + " with Det #" + str(detNum) + "\n" + str(gtNum) + " with Det #" + str(detNum) + "\n" numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) numDetCare = (len(detPols) - len(detDontCarePolsNum)) @@ -162,7 +162,7 @@ class DetectionIoUEvaluator(object): precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare hmean = 0 if (precision + recall) == 0 else 2.0 * \ - precision * recall / (precision + recall) + precision * recall / (precision + recall) matchedSum += detMatched numGlobalCareGt += numGtCare @@ -200,7 +200,8 @@ class DetectionIoUEvaluator(object): methodPrecision = 0 if numGlobalCareDet == 0 else float( matchedSum) / numGlobalCareDet methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \ - methodRecall * methodPrecision / (methodRecall + methodPrecision) + methodRecall * methodPrecision / ( + methodRecall + methodPrecision) # print(methodRecall, methodPrecision, methodHmean) # sys.exit(-1) methodMetrics = { diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 03c15508a58313b234a72bb3ef47ac27dc3ebb7e..fe2c9bc30a4f2abd1ba7d3d6989b9ef9b20c1f4f 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -26,6 +26,9 @@ def build_backbone(config, model_type): from .rec_resnet_vd import ResNet from .rec_resnet_fpn import ResNetFPN support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN'] + elif model_type == 'e2e': + from .e2e_resnet_vd_pg import ResNet + support_dict = ['ResNet'] else: raise NotImplementedError diff --git a/ppocr/modeling/backbones/e2e_resnet_vd_pg.py b/ppocr/modeling/backbones/e2e_resnet_vd_pg.py new file mode 100644 index 0000000000000000000000000000000000000000..97afd3460d03dc078b53064fb45b6fb6d3542df9 --- /dev/null +++ b/ppocr/modeling/backbones/e2e_resnet_vd_pg.py @@ -0,0 +1,265 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F + +__all__ = ["ResNet"] + + +class ConvBNLayer(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + groups=1, + is_vd_mode=False, + act=None, + name=None, ): + super(ConvBNLayer, self).__init__() + + self.is_vd_mode = is_vd_mode + self._pool2d_avg = nn.AvgPool2D( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self._conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(name=name + "_weights"), + bias_attr=False) + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + self._batch_norm = nn.BatchNorm( + out_channels, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + return y + + +class BottleneckBlock(nn.Layer): + def __init__(self, + in_channels, + out_channels, + stride, + shortcut=True, + if_first=False, + name=None): + super(BottleneckBlock, self).__init__() + + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + act='relu', + name=name + "_branch2a") + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + act='relu', + name=name + "_branch2b") + self.conv2 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels * 4, + kernel_size=1, + act=None, + name=name + "_branch2c") + + if not shortcut: + self.short = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels * 4, + kernel_size=1, + stride=stride, + is_vd_mode=False if if_first else True, + name=name + "_branch1") + + self.shortcut = shortcut + + def forward(self, inputs): + y = self.conv0(inputs) + conv1 = self.conv1(y) + conv2 = self.conv2(conv1) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + y = paddle.add(x=short, y=conv2) + y = F.relu(y) + return y + + +class BasicBlock(nn.Layer): + def __init__(self, + in_channels, + out_channels, + stride, + shortcut=True, + if_first=False, + name=None): + super(BasicBlock, self).__init__() + self.stride = stride + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + act='relu', + name=name + "_branch2a") + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + act=None, + name=name + "_branch2b") + + if not shortcut: + self.short = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + is_vd_mode=False if if_first else True, + name=name + "_branch1") + + self.shortcut = shortcut + + def forward(self, inputs): + y = self.conv0(inputs) + conv1 = self.conv1(y) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + y = paddle.add(x=short, y=conv1) + y = F.relu(y) + return y + + +class ResNet(nn.Layer): + def __init__(self, in_channels=3, layers=50, **kwargs): + super(ResNet, self).__init__() + + self.layers = layers + supported_layers = [18, 34, 50, 101, 152, 200] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format( + supported_layers, layers) + + if layers == 18: + depth = [2, 2, 2, 2] + elif layers == 34 or layers == 50: + # depth = [3, 4, 6, 3] + depth = [3, 4, 6, 3, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + elif layers == 200: + depth = [3, 12, 48, 3] + num_channels = [64, 256, 512, 1024, + 2048] if layers >= 50 else [64, 64, 128, 256] + num_filters = [64, 128, 256, 512, 512] + + self.conv1_1 = ConvBNLayer( + in_channels=in_channels, + out_channels=64, + kernel_size=7, + stride=2, + act='relu', + name="conv1_1") + self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) + + self.stages = [] + self.out_channels = [3, 64] + # num_filters = [64, 128, 256, 512, 512] + if layers >= 50: + for block in range(len(depth)): + block_list = [] + shortcut = False + for i in range(depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) + bottleneck_block = self.add_sublayer( + 'bb_%d_%d' % (block, i), + BottleneckBlock( + in_channels=num_channels[block] + if i == 0 else num_filters[block] * 4, + out_channels=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + shortcut=shortcut, + if_first=block == i == 0, + name=conv_name)) + shortcut = True + block_list.append(bottleneck_block) + self.out_channels.append(num_filters[block] * 4) + self.stages.append(nn.Sequential(*block_list)) + else: + for block in range(len(depth)): + block_list = [] + shortcut = False + for i in range(depth[block]): + conv_name = "res" + str(block + 2) + chr(97 + i) + basic_block = self.add_sublayer( + 'bb_%d_%d' % (block, i), + BasicBlock( + in_channels=num_channels[block] + if i == 0 else num_filters[block], + out_channels=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + shortcut=shortcut, + if_first=block == i == 0, + name=conv_name)) + shortcut = True + block_list.append(basic_block) + self.out_channels.append(num_filters[block]) + self.stages.append(nn.Sequential(*block_list)) + + def forward(self, inputs): + out = [inputs] + y = self.conv1_1(inputs) + out.append(y) + y = self.pool2d_max(y) + for block in self.stages: + y = block(y) + out.append(y) + return out diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index efe05718506e94a5ae8ad5ff47bcff26d44c1473..4852c7f2d14d72b9e4d59f40532469f7226c966d 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -20,6 +20,7 @@ def build_head(config): from .det_db_head import DBHead from .det_east_head import EASTHead from .det_sast_head import SASTHead + from .e2e_pg_head import PGHead # rec head from .rec_ctc_head import CTCHead @@ -30,8 +31,8 @@ def build_head(config): from .cls_head import ClsHead support_dict = [ 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', - 'SRNHead' - ] + 'SRNHead', 'PGHead'] + module_name = config.pop('name') assert module_name in support_dict, Exception('head only support {}'.format( diff --git a/ppocr/modeling/heads/e2e_pg_head.py b/ppocr/modeling/heads/e2e_pg_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0da9de7580a0ceb473f971b2246c966497026a5d --- /dev/null +++ b/ppocr/modeling/heads/e2e_pg_head.py @@ -0,0 +1,253 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + groups=1, + if_act=True, + act=None, + name=None): + super(ConvBNLayer, self).__init__() + self.if_act = if_act + self.act = act + self.conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + weight_attr=ParamAttr(name=name + '_weights'), + bias_attr=False) + + self.bn = nn.BatchNorm( + num_channels=out_channels, + act=act, + param_attr=ParamAttr(name="bn_" + name + "_scale"), + bias_attr=ParamAttr(name="bn_" + name + "_offset"), + moving_mean_name="bn_" + name + "_mean", + moving_variance_name="bn_" + name + "_variance", + use_global_stats=False) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class PGHead(nn.Layer): + """ + """ + + def __init__(self, in_channels, **kwargs): + super(PGHead, self).__init__() + self.conv_f_score1 = ConvBNLayer( + in_channels=in_channels, + out_channels=64, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_score{}".format(1)) + self.conv_f_score2 = ConvBNLayer( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + act='relu', + name="conv_f_score{}".format(2)) + self.conv_f_score3 = ConvBNLayer( + in_channels=64, + out_channels=128, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_score{}".format(3)) + + self.conv1 = nn.Conv2D( + in_channels=128, + out_channels=1, + kernel_size=3, + stride=1, + padding=1, + groups=1, + weight_attr=ParamAttr(name="conv_f_score{}".format(4)), + bias_attr=False) + + self.conv_f_boder1 = ConvBNLayer( + in_channels=in_channels, + out_channels=64, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_boder{}".format(1)) + self.conv_f_boder2 = ConvBNLayer( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + act='relu', + name="conv_f_boder{}".format(2)) + self.conv_f_boder3 = ConvBNLayer( + in_channels=64, + out_channels=128, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_boder{}".format(3)) + self.conv2 = nn.Conv2D( + in_channels=128, + out_channels=4, + kernel_size=3, + stride=1, + padding=1, + groups=1, + weight_attr=ParamAttr(name="conv_f_boder{}".format(4)), + bias_attr=False) + self.conv_f_char1 = ConvBNLayer( + in_channels=in_channels, + out_channels=128, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_char{}".format(1)) + self.conv_f_char2 = ConvBNLayer( + in_channels=128, + out_channels=128, + kernel_size=3, + stride=1, + padding=1, + act='relu', + name="conv_f_char{}".format(2)) + self.conv_f_char3 = ConvBNLayer( + in_channels=128, + out_channels=256, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_char{}".format(3)) + self.conv_f_char4 = ConvBNLayer( + in_channels=256, + out_channels=256, + kernel_size=3, + stride=1, + padding=1, + act='relu', + name="conv_f_char{}".format(4)) + self.conv_f_char5 = ConvBNLayer( + in_channels=256, + out_channels=256, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_char{}".format(5)) + self.conv3 = nn.Conv2D( + in_channels=256, + out_channels=37, + kernel_size=3, + stride=1, + padding=1, + groups=1, + weight_attr=ParamAttr(name="conv_f_char{}".format(6)), + bias_attr=False) + + self.conv_f_direc1 = ConvBNLayer( + in_channels=in_channels, + out_channels=64, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_direc{}".format(1)) + self.conv_f_direc2 = ConvBNLayer( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + act='relu', + name="conv_f_direc{}".format(2)) + self.conv_f_direc3 = ConvBNLayer( + in_channels=64, + out_channels=128, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_direc{}".format(3)) + self.conv4 = nn.Conv2D( + in_channels=128, + out_channels=2, + kernel_size=3, + stride=1, + padding=1, + groups=1, + weight_attr=ParamAttr(name="conv_f_direc{}".format(4)), + bias_attr=False) + + def forward(self, x): + f_score = self.conv_f_score1(x) + f_score = self.conv_f_score2(f_score) + f_score = self.conv_f_score3(f_score) + f_score = self.conv1(f_score) + f_score = F.sigmoid(f_score) + + # f_border + f_border = self.conv_f_boder1(x) + f_border = self.conv_f_boder2(f_border) + f_border = self.conv_f_boder3(f_border) + f_border = self.conv2(f_border) + + f_char = self.conv_f_char1(x) + f_char = self.conv_f_char2(f_char) + f_char = self.conv_f_char3(f_char) + f_char = self.conv_f_char4(f_char) + f_char = self.conv_f_char5(f_char) + f_char = self.conv3(f_char) + + f_direction = self.conv_f_direc1(x) + f_direction = self.conv_f_direc2(f_direction) + f_direction = self.conv_f_direc3(f_direction) + f_direction = self.conv4(f_direction) + + predicts = {} + predicts['f_score'] = f_score + predicts['f_border'] = f_border + predicts['f_char'] = f_char + predicts['f_direction'] = f_direction + return predicts diff --git a/ppocr/modeling/heads/rec_att_head.py b/ppocr/modeling/heads/rec_att_head.py index 0d222714ff7edebfc717daa81d48ce7424dfbd03..4286d7691d1abcf80c283d1c1ab76f8cd1f4a634 100644 --- a/ppocr/modeling/heads/rec_att_head.py +++ b/ppocr/modeling/heads/rec_att_head.py @@ -38,7 +38,7 @@ class AttentionHead(nn.Layer): return input_ont_hot def forward(self, inputs, targets=None, batch_max_length=25): - batch_size = inputs.shape[0] + batch_size = paddle.shape(inputs)[0] num_steps = batch_max_length hidden = paddle.zeros((batch_size, self.hidden_size)) diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py index 405e062b352da759b743d5997fd6e4c8b89c038b..37a5cf7863cb386884d82ed88c756c9fc06a541d 100644 --- a/ppocr/modeling/necks/__init__.py +++ b/ppocr/modeling/necks/__init__.py @@ -14,12 +14,14 @@ __all__ = ['build_neck'] + def build_neck(config): from .db_fpn import DBFPN from .east_fpn import EASTFPN from .sast_fpn import SASTFPN from .rnn import SequenceEncoder - support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder'] + from .pg_fpn import PGFPN + support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN'] module_name = config.pop('name') assert module_name in support_dict, Exception('neck only support {}'.format( diff --git a/ppocr/modeling/necks/pg_fpn.py b/ppocr/modeling/necks/pg_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..3f64539f790b55bb1f95adc8d3c78b84ca2fccc5 --- /dev/null +++ b/ppocr/modeling/necks/pg_fpn.py @@ -0,0 +1,314 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + groups=1, + is_vd_mode=False, + act=None, + name=None): + super(ConvBNLayer, self).__init__() + + self.is_vd_mode = is_vd_mode + self._pool2d_avg = nn.AvgPool2D( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self._conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(name=name + "_weights"), + bias_attr=False) + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + self._batch_norm = nn.BatchNorm( + out_channels, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', + use_global_stats=False) + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + return y + + +class DeConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size=4, + stride=2, + padding=1, + groups=1, + if_act=True, + act=None, + name=None): + super(DeConvBNLayer, self).__init__() + + self.if_act = if_act + self.act = act + self.deconv = nn.Conv2DTranspose( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + weight_attr=ParamAttr(name=name + '_weights'), + bias_attr=False) + self.bn = nn.BatchNorm( + num_channels=out_channels, + act=act, + param_attr=ParamAttr(name="bn_" + name + "_scale"), + bias_attr=ParamAttr(name="bn_" + name + "_offset"), + moving_mean_name="bn_" + name + "_mean", + moving_variance_name="bn_" + name + "_variance", + use_global_stats=False) + + def forward(self, x): + x = self.deconv(x) + x = self.bn(x) + return x + + +class PGFPN(nn.Layer): + def __init__(self, in_channels, **kwargs): + super(PGFPN, self).__init__() + num_inputs = [2048, 2048, 1024, 512, 256] + num_outputs = [256, 256, 192, 192, 128] + self.out_channels = 128 + self.conv_bn_layer_1 = ConvBNLayer( + in_channels=3, + out_channels=32, + kernel_size=3, + stride=1, + act=None, + name='FPN_d1') + self.conv_bn_layer_2 = ConvBNLayer( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + act=None, + name='FPN_d2') + self.conv_bn_layer_3 = ConvBNLayer( + in_channels=256, + out_channels=128, + kernel_size=3, + stride=1, + act=None, + name='FPN_d3') + self.conv_bn_layer_4 = ConvBNLayer( + in_channels=32, + out_channels=64, + kernel_size=3, + stride=2, + act=None, + name='FPN_d4') + self.conv_bn_layer_5 = ConvBNLayer( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + act='relu', + name='FPN_d5') + self.conv_bn_layer_6 = ConvBNLayer( + in_channels=64, + out_channels=128, + kernel_size=3, + stride=2, + act=None, + name='FPN_d6') + self.conv_bn_layer_7 = ConvBNLayer( + in_channels=128, + out_channels=128, + kernel_size=3, + stride=1, + act='relu', + name='FPN_d7') + self.conv_bn_layer_8 = ConvBNLayer( + in_channels=128, + out_channels=128, + kernel_size=1, + stride=1, + act=None, + name='FPN_d8') + + self.conv_h0 = ConvBNLayer( + in_channels=num_inputs[0], + out_channels=num_outputs[0], + kernel_size=1, + stride=1, + act=None, + name="conv_h{}".format(0)) + self.conv_h1 = ConvBNLayer( + in_channels=num_inputs[1], + out_channels=num_outputs[1], + kernel_size=1, + stride=1, + act=None, + name="conv_h{}".format(1)) + self.conv_h2 = ConvBNLayer( + in_channels=num_inputs[2], + out_channels=num_outputs[2], + kernel_size=1, + stride=1, + act=None, + name="conv_h{}".format(2)) + self.conv_h3 = ConvBNLayer( + in_channels=num_inputs[3], + out_channels=num_outputs[3], + kernel_size=1, + stride=1, + act=None, + name="conv_h{}".format(3)) + self.conv_h4 = ConvBNLayer( + in_channels=num_inputs[4], + out_channels=num_outputs[4], + kernel_size=1, + stride=1, + act=None, + name="conv_h{}".format(4)) + + self.dconv0 = DeConvBNLayer( + in_channels=num_outputs[0], + out_channels=num_outputs[0 + 1], + name="dconv_{}".format(0)) + self.dconv1 = DeConvBNLayer( + in_channels=num_outputs[1], + out_channels=num_outputs[1 + 1], + act=None, + name="dconv_{}".format(1)) + self.dconv2 = DeConvBNLayer( + in_channels=num_outputs[2], + out_channels=num_outputs[2 + 1], + act=None, + name="dconv_{}".format(2)) + self.dconv3 = DeConvBNLayer( + in_channels=num_outputs[3], + out_channels=num_outputs[3 + 1], + act=None, + name="dconv_{}".format(3)) + self.conv_g1 = ConvBNLayer( + in_channels=num_outputs[1], + out_channels=num_outputs[1], + kernel_size=3, + stride=1, + act='relu', + name="conv_g{}".format(1)) + self.conv_g2 = ConvBNLayer( + in_channels=num_outputs[2], + out_channels=num_outputs[2], + kernel_size=3, + stride=1, + act='relu', + name="conv_g{}".format(2)) + self.conv_g3 = ConvBNLayer( + in_channels=num_outputs[3], + out_channels=num_outputs[3], + kernel_size=3, + stride=1, + act='relu', + name="conv_g{}".format(3)) + self.conv_g4 = ConvBNLayer( + in_channels=num_outputs[4], + out_channels=num_outputs[4], + kernel_size=3, + stride=1, + act='relu', + name="conv_g{}".format(4)) + self.convf = ConvBNLayer( + in_channels=num_outputs[4], + out_channels=num_outputs[4], + kernel_size=1, + stride=1, + act=None, + name="conv_f{}".format(4)) + + def forward(self, x): + c0, c1, c2, c3, c4, c5, c6 = x + # FPN_Down_Fusion + f = [c0, c1, c2] + g = [None, None, None] + h = [None, None, None] + h[0] = self.conv_bn_layer_1(f[0]) + h[1] = self.conv_bn_layer_2(f[1]) + h[2] = self.conv_bn_layer_3(f[2]) + + g[0] = self.conv_bn_layer_4(h[0]) + g[1] = paddle.add(g[0], h[1]) + g[1] = F.relu(g[1]) + g[1] = self.conv_bn_layer_5(g[1]) + g[1] = self.conv_bn_layer_6(g[1]) + + g[2] = paddle.add(g[1], h[2]) + g[2] = F.relu(g[2]) + g[2] = self.conv_bn_layer_7(g[2]) + f_down = self.conv_bn_layer_8(g[2]) + + # FPN UP Fusion + f1 = [c6, c5, c4, c3, c2] + g = [None, None, None, None, None] + h = [None, None, None, None, None] + h[0] = self.conv_h0(f1[0]) + h[1] = self.conv_h1(f1[1]) + h[2] = self.conv_h2(f1[2]) + h[3] = self.conv_h3(f1[3]) + h[4] = self.conv_h4(f1[4]) + + g[0] = self.dconv0(h[0]) + g[1] = paddle.add(g[0], h[1]) + g[1] = F.relu(g[1]) + g[1] = self.conv_g1(g[1]) + g[1] = self.dconv1(g[1]) + + g[2] = paddle.add(g[1], h[2]) + g[2] = F.relu(g[2]) + g[2] = self.conv_g2(g[2]) + g[2] = self.dconv2(g[2]) + + g[3] = paddle.add(g[2], h[3]) + g[3] = F.relu(g[3]) + g[3] = self.conv_g3(g[3]) + g[3] = self.dconv3(g[3]) + + g[4] = paddle.add(x=g[3], y=h[4]) + g[4] = F.relu(g[4]) + g[4] = self.conv_g4(g[4]) + f_up = self.convf(g[4]) + f_common = paddle.add(f_down, f_up) + f_common = F.relu(f_common) + return f_common diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 0156e438e9e24820943c9e48b04565710ea2fd4b..042654a19d2d2d2f1363fedbb9ac3530696e6903 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -28,10 +28,11 @@ def build_post_process(config, global_config=None): from .sast_postprocess import SASTPostProcess from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode from .cls_postprocess import ClsPostProcess + from .pg_postprocess import PGPostProcess support_dict = [ 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', - 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode' + 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/pg_postprocess.py b/ppocr/postprocess/pg_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..0b1455181fddb0adb5347406bb2eb3093ee6fb30 --- /dev/null +++ b/ppocr/postprocess/pg_postprocess.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +__dir__ = os.path.dirname(__file__) +sys.path.append(__dir__) +sys.path.append(os.path.join(__dir__, '..')) +from ppocr.utils.e2e_utils.pgnet_pp_utils import PGNet_PostProcess + + +class PGPostProcess(object): + """ + The post process for PGNet. + """ + + def __init__(self, character_dict_path, valid_set, score_thresh, mode, + **kwargs): + self.character_dict_path = character_dict_path + self.valid_set = valid_set + self.score_thresh = score_thresh + self.mode = mode + + # c++ la-nms is faster, but only support python 3.5 + self.is_python35 = False + if sys.version_info.major == 3 and sys.version_info.minor == 5: + self.is_python35 = True + + def __call__(self, outs_dict, shape_list): + post = PGNet_PostProcess(self.character_dict_path, self.valid_set, + self.score_thresh, outs_dict, shape_list) + if self.mode == 'fast': + data = post.pg_postprocess_fast() + else: + data = post.pg_postprocess_slow() + return data diff --git a/ppocr/postprocess/sast_postprocess.py b/ppocr/postprocess/sast_postprocess.py index f011e7e571cf4c2297a81a7f7772aa0c09f0aaf1..bee75c05b1a3ea59193d566f91378c96797f533b 100755 --- a/ppocr/postprocess/sast_postprocess.py +++ b/ppocr/postprocess/sast_postprocess.py @@ -18,6 +18,7 @@ from __future__ import print_function import os import sys + __dir__ = os.path.dirname(__file__) sys.path.append(__dir__) sys.path.append(os.path.join(__dir__, '..')) @@ -49,12 +50,12 @@ class SASTPostProcess(object): self.shrink_ratio_of_width = shrink_ratio_of_width self.expand_scale = expand_scale self.tcl_map_thresh = tcl_map_thresh - + # c++ la-nms is faster, but only support python 3.5 self.is_python35 = False if sys.version_info.major == 3 and sys.version_info.minor == 5: self.is_python35 = True - + def point_pair2poly(self, point_pair_list): """ Transfer vertical point_pairs into poly point in clockwise. @@ -66,31 +67,42 @@ class SASTPostProcess(object): point_list[idx] = point_pair[0] point_list[point_num - 1 - idx] = point_pair[1] return np.array(point_list).reshape(-1, 2) - - def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.): + + def shrink_quad_along_width(self, + quad, + begin_width_ratio=0., + end_width_ratio=1.): """ Generate shrink_quad_along_width. """ - ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32) + ratio_pair = np.array( + [[begin_width_ratio], [end_width_ratio]], dtype=np.float32) p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) - + def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3): """ expand poly along width. """ point_num = poly.shape[0] - left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) + left_quad = np.array( + [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \ - (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) - left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0) - right_quad = np.array([poly[point_num // 2 - 2], poly[point_num // 2 - 1], - poly[point_num // 2], poly[point_num // 2 + 1]], dtype=np.float32) + (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) + left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, + 1.0) + right_quad = np.array( + [ + poly[point_num // 2 - 2], poly[point_num // 2 - 1], + poly[point_num // 2], poly[point_num // 2 + 1] + ], + dtype=np.float32) right_ratio = 1.0 + \ - shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \ - (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) - right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio) + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \ + (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) + right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, + right_ratio) poly[0] = left_quad_expand[0] poly[-1] = left_quad_expand[-1] poly[point_num // 2 - 1] = right_quad_expand[1] @@ -100,7 +112,7 @@ class SASTPostProcess(object): def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map): """Restore quad.""" xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh) - xy_text = xy_text[:, ::-1] # (n, 2) + xy_text = xy_text[:, ::-1] # (n, 2) # Sort the text boxes via the y axis xy_text = xy_text[np.argsort(xy_text[:, 1])] @@ -112,7 +124,7 @@ class SASTPostProcess(object): point_num = int(tvo_map.shape[-1] / 2) assert point_num == 4 tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :] - xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2) + xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2) quads = xy_text_tile - tvo_map return scores, quads, xy_text @@ -121,14 +133,12 @@ class SASTPostProcess(object): """ compute area of a quad. """ - edge = [ - (quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]), - (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]), - (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]), - (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1]) - ] + edge = [(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]), + (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]), + (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]), + (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])] return np.sum(edge) / 2. - + def nms(self, dets): if self.is_python35: import lanms @@ -141,7 +151,7 @@ class SASTPostProcess(object): """ Cluster pixels in tcl_map based on quads. """ - instance_count = quads.shape[0] + 1 # contain background + instance_count = quads.shape[0] + 1 # contain background instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32) if instance_count == 1: return instance_count, instance_label_map @@ -149,18 +159,19 @@ class SASTPostProcess(object): # predict text center xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh) n = xy_text.shape[0] - xy_text = xy_text[:, ::-1] # (n, 2) - tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2) + xy_text = xy_text[:, ::-1] # (n, 2) + tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2) pred_tc = xy_text - tco - + # get gt text center m = quads.shape[0] - gt_tc = np.mean(quads, axis=1) # (m, 2) + gt_tc = np.mean(quads, axis=1) # (m, 2) - pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2) - gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2) - dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m) - xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,) + pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], + (1, m, 1)) # (n, m, 2) + gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2) + dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m) + xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,) instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign return instance_count, instance_label_map @@ -169,26 +180,47 @@ class SASTPostProcess(object): """ Estimate sample points number. """ - eh = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) / 2.0 - ew = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0 + eh = (np.linalg.norm(quad[0] - quad[3]) + + np.linalg.norm(quad[1] - quad[2])) / 2.0 + ew = (np.linalg.norm(quad[0] - quad[1]) + + np.linalg.norm(quad[2] - quad[3])) / 2.0 dense_sample_pts_num = max(2, int(ew)) - dense_xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, dense_sample_pts_num, - endpoint=True, dtype=np.float32).astype(np.int32)] - - dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1] - estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1)) + dense_xy_center_line = xy_text[np.linspace( + 0, + xy_text.shape[0] - 1, + dense_sample_pts_num, + endpoint=True, + dtype=np.float32).astype(np.int32)] + + dense_xy_center_line_diff = dense_xy_center_line[ + 1:] - dense_xy_center_line[:-1] + estimate_arc_len = np.sum( + np.linalg.norm( + dense_xy_center_line_diff, axis=1)) sample_pts_num = max(2, int(estimate_arc_len / eh)) return sample_pts_num - def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_w, src_h, - shrink_ratio_of_width=0.3, tcl_map_thresh=0.5, offset_expand=1.0, out_strid=4.0): + def detect_sast(self, + tcl_map, + tvo_map, + tbo_map, + tco_map, + ratio_w, + ratio_h, + src_w, + src_h, + shrink_ratio_of_width=0.3, + tcl_map_thresh=0.5, + offset_expand=1.0, + out_strid=4.0): """ first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys """ # restore quad - scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map) + scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, + tvo_map) dets = np.hstack((quads, scores)).astype(np.float32, copy=False) dets = self.nms(dets) if dets.shape[0] == 0: @@ -202,7 +234,8 @@ class SASTPostProcess(object): # instance segmentation # instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8) - instance_count, instance_label_map = self.cluster_by_quads_tco(tcl_map, tcl_map_thresh, quads, tco_map) + instance_count, instance_label_map = self.cluster_by_quads_tco( + tcl_map, tcl_map_thresh, quads, tco_map) # restore single poly with tcl instance. poly_list = [] @@ -212,10 +245,10 @@ class SASTPostProcess(object): q_area = quad_areas[instance_idx - 1] if q_area < 5: continue - + # - len1 = float(np.linalg.norm(quad[0] -quad[1])) - len2 = float(np.linalg.norm(quad[1] -quad[2])) + len1 = float(np.linalg.norm(quad[0] - quad[1])) + len2 = float(np.linalg.norm(quad[1] - quad[2])) min_len = min(len1, len2) if min_len < 3: continue @@ -225,16 +258,18 @@ class SASTPostProcess(object): continue # filter low confidence instance - xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0] + xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0] if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1: - # if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05: + # if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05: continue # sort xy_text - left_center_pt = np.array([[(quad[0, 0] + quad[-1, 0]) / 2.0, - (quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2) - right_center_pt = np.array([[(quad[1, 0] + quad[2, 0]) / 2.0, - (quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2) + left_center_pt = np.array( + [[(quad[0, 0] + quad[-1, 0]) / 2.0, + (quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2) + right_center_pt = np.array( + [[(quad[1, 0] + quad[2, 0]) / 2.0, + (quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2) proj_unit_vec = (right_center_pt - left_center_pt) / \ (np.linalg.norm(right_center_pt - left_center_pt) + 1e-6) proj_value = np.sum(xy_text * proj_unit_vec, axis=1) @@ -245,33 +280,45 @@ class SASTPostProcess(object): sample_pts_num = self.estimate_sample_pts_num(quad, xy_text) else: sample_pts_num = self.sample_pts_num - xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, sample_pts_num, - endpoint=True, dtype=np.float32).astype(np.int32)] + xy_center_line = xy_text[np.linspace( + 0, + xy_text.shape[0] - 1, + sample_pts_num, + endpoint=True, + dtype=np.float32).astype(np.int32)] point_pair_list = [] for x, y in xy_center_line: # get corresponding offset offset = tbo_map[y, x, :].reshape(2, 2) if offset_expand != 1.0: - offset_length = np.linalg.norm(offset, axis=1, keepdims=True) - expand_length = np.clip(offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0) + offset_length = np.linalg.norm( + offset, axis=1, keepdims=True) + expand_length = np.clip( + offset_length * (offset_expand - 1), + a_min=0.5, + a_max=3.0) offset_detal = offset / offset_length * expand_length - offset = offset + offset_detal - # original point + offset = offset + offset_detal + # original point ori_yx = np.array([y, x], dtype=np.float32) - point_pair = (ori_yx + offset)[:, ::-1]* out_strid / np.array([ratio_w, ratio_h]).reshape(-1, 2) + point_pair = (ori_yx + offset)[:, ::-1] * out_strid / np.array( + [ratio_w, ratio_h]).reshape(-1, 2) point_pair_list.append(point_pair) # ndarry: (x, 2), expand poly along width detected_poly = self.point_pair2poly(point_pair_list) - detected_poly = self.expand_poly_along_width(detected_poly, shrink_ratio_of_width) - detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w) - detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h) + detected_poly = self.expand_poly_along_width(detected_poly, + shrink_ratio_of_width) + detected_poly[:, 0] = np.clip( + detected_poly[:, 0], a_min=0, a_max=src_w) + detected_poly[:, 1] = np.clip( + detected_poly[:, 1], a_min=0, a_max=src_h) poly_list.append(detected_poly) return poly_list - def __call__(self, outs_dict, shape_list): + def __call__(self, outs_dict, shape_list): score_list = outs_dict['f_score'] border_list = outs_dict['f_border'] tvo_list = outs_dict['f_tvo'] @@ -281,20 +328,28 @@ class SASTPostProcess(object): border_list = border_list.numpy() tvo_list = tvo_list.numpy() tco_list = tco_list.numpy() - + img_num = len(shape_list) poly_lists = [] for ino in range(img_num): - p_score = score_list[ino].transpose((1,2,0)) - p_border = border_list[ino].transpose((1,2,0)) - p_tvo = tvo_list[ino].transpose((1,2,0)) - p_tco = tco_list[ino].transpose((1,2,0)) + p_score = score_list[ino].transpose((1, 2, 0)) + p_border = border_list[ino].transpose((1, 2, 0)) + p_tvo = tvo_list[ino].transpose((1, 2, 0)) + p_tco = tco_list[ino].transpose((1, 2, 0)) src_h, src_w, ratio_h, ratio_w = shape_list[ino] - poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h, - shrink_ratio_of_width=self.shrink_ratio_of_width, - tcl_map_thresh=self.tcl_map_thresh, offset_expand=self.expand_scale) + poly_list = self.detect_sast( + p_score, + p_tvo, + p_border, + p_tco, + ratio_w, + ratio_h, + src_w, + src_h, + shrink_ratio_of_width=self.shrink_ratio_of_width, + tcl_map_thresh=self.tcl_map_thresh, + offset_expand=self.expand_scale) poly_lists.append({'points': np.array(poly_list)}) return poly_lists - diff --git a/ppocr/utils/dict/arabic_dict.txt b/ppocr/utils/dict/arabic_dict.txt new file mode 100644 index 0000000000000000000000000000000000000000..e97abf39274df77fbad066ee4635aebc6743140c --- /dev/null +++ b/ppocr/utils/dict/arabic_dict.txt @@ -0,0 +1,162 @@ + +! +# +$ +% +& +' +( ++ +, +- +. +/ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +: +? +@ +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z +_ +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +É +é +ء +آ +أ +ؤ +إ +ئ +ا +ب +ة +ت +ث +ج +ح +خ +د +ذ +ر +ز +س +ش +ص +ض +ط +ظ +ع +غ +ف +ق +ك +ل +م +ن +ه +و +ى +ي +ً +ٌ +ٍ +َ +ُ +ِ +ّ +ْ +ٓ +ٔ +ٰ +ٱ +ٹ +پ +چ +ڈ +ڑ +ژ +ک +ڭ +گ +ں +ھ +ۀ +ہ +ۂ +ۃ +ۆ +ۇ +ۈ +ۋ +ی +ې +ے +ۓ +ە +١ +٢ +٣ +٤ +٥ +٦ +٧ +٨ +٩ diff --git a/ppocr/utils/dict/cyrillic_dict.txt b/ppocr/utils/dict/cyrillic_dict.txt new file mode 100644 index 0000000000000000000000000000000000000000..2b6f66494d5417e18bbd225719aa72690e09e126 --- /dev/null +++ b/ppocr/utils/dict/cyrillic_dict.txt @@ -0,0 +1,163 @@ + +! +# +$ +% +& +' +( ++ +, +- +. +/ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +: +? +@ +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z +_ +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +É +é +Ё +Є +І +Ј +Љ +Ў +А +Б +В +Г +Д +Е +Ж +З +И +Й +К +Л +М +Н +О +П +Р +С +Т +У +Ф +Х +Ц +Ч +Ш +Щ +Ъ +Ы +Ь +Э +Ю +Я +а +б +в +г +д +е +ж +з +и +й +к +л +м +н +о +п +р +с +т +у +ф +х +ц +ч +ш +щ +ъ +ы +ь +э +ю +я +ё +ђ +є +і +ј +љ +њ +ћ +ў +џ +Ґ +ґ diff --git a/ppocr/utils/dict/devanagari_dict.txt b/ppocr/utils/dict/devanagari_dict.txt new file mode 100644 index 0000000000000000000000000000000000000000..f55923061bfd480b875bb3679d7a75a9157387a9 --- /dev/null +++ b/ppocr/utils/dict/devanagari_dict.txt @@ -0,0 +1,167 @@ + +! +# +$ +% +& +' +( ++ +, +- +. +/ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +: +? +@ +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z +_ +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +É +é +ँ +ं +ः +अ +आ +इ +ई +उ +ऊ +ऋ +ए +ऐ +ऑ +ओ +औ +क +ख +ग +घ +ङ +च +छ +ज +झ +ञ +ट +ठ +ड +ढ +ण +त +थ +द +ध +न +ऩ +प +फ +ब +भ +म +य +र +ऱ +ल +ळ +व +श +ष +स +ह +़ +ा +ि +ी +ु +ू +ृ +ॅ +े +ै +ॉ +ो +ौ +् +॒ +क़ +ख़ +ग़ +ज़ +ड़ +ढ़ +फ़ +ॠ +। +० +१ +२ +३ +४ +५ +६ +७ +८ +९ +॰ diff --git a/ppocr/utils/dict/latin_dict.txt b/ppocr/utils/dict/latin_dict.txt new file mode 100644 index 0000000000000000000000000000000000000000..e166bf33ecfbdc90ddb3d9743fded23306acabd5 --- /dev/null +++ b/ppocr/utils/dict/latin_dict.txt @@ -0,0 +1,185 @@ + +! +" +# +$ +% +& +' +( +) +* ++ +, +- +. +/ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +: +; +< += +> +? +@ +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z +[ +] +_ +` +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +{ +} +¡ +£ +§ +ª +« +­ +° +² +³ +´ +µ +· +º +» +¿ +À +Á + +Ä +Å +Ç +È +É +Ê +Ë +Ì +Í +Î +Ï +Ò +Ó +Ô +Õ +Ö +Ú +Ü +Ý +ß +à +á +â +ã +ä +å +æ +ç +è +é +ê +ë +ì +í +î +ï +ñ +ò +ó +ô +õ +ö +ø +ù +ú +û +ü +ý +ą +Ć +ć +Č +č +Đ +đ +ę +ı +Ł +ł +ō +Œ +œ +Š +š +Ÿ +Ž +ž +ʒ +β +δ +ε +з +Ṡ +‘ +€ +™ diff --git a/ppocr/utils/e2e_metric/Deteval.py b/ppocr/utils/e2e_metric/Deteval.py new file mode 100755 index 0000000000000000000000000000000000000000..e30a498eaf2e24f7a337ee48536466e7c4f0d91c --- /dev/null +++ b/ppocr/utils/e2e_metric/Deteval.py @@ -0,0 +1,437 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import scipy.io as io +from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area + + +def get_socre(gt_dir, img_id, pred_dict): + allInputs = 1 + + def input_reading_mod(pred_dict): + """This helper reads input from txt files""" + det = [] + n = len(pred_dict) + for i in range(n): + points = pred_dict[i]['points'] + text = pred_dict[i]['text'] + point = ",".join(map(str, points.reshape(-1, ))) + det.append([point, text]) + return det + + def gt_reading_mod(gt_dir, gt_id): + gt = io.loadmat('%s/poly_gt_img%s.mat' % (gt_dir, gt_id)) + gt = gt['polygt'] + return gt + + def detection_filtering(detections, groundtruths, threshold=0.5): + for gt_id, gt in enumerate(groundtruths): + if (gt[5] == '#') and (gt[1].shape[1] > 1): + gt_x = list(map(int, np.squeeze(gt[1]))) + gt_y = list(map(int, np.squeeze(gt[3]))) + for det_id, detection in enumerate(detections): + detection_orig = detection + detection = [float(x) for x in detection[0].split(',')] + detection = list(map(int, detection)) + det_x = detection[0::2] + det_y = detection[1::2] + det_gt_iou = iod(det_x, det_y, gt_x, gt_y) + if det_gt_iou > threshold: + detections[det_id] = [] + + detections[:] = [item for item in detections if item != []] + return detections + + def sigma_calculation(det_x, det_y, gt_x, gt_y): + """ + sigma = inter_area / gt_area + """ + return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / + area(gt_x, gt_y)), 2) + + def tau_calculation(det_x, det_y, gt_x, gt_y): + if area(det_x, det_y) == 0.0: + return 0 + return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / + area(det_x, det_y)), 2) + + ##############################Initialization################################### + # global_sigma = [] + # global_tau = [] + # global_pred_str = [] + # global_gt_str = [] + ############################################################################### + + for input_id in range(allInputs): + if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and ( + input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and ( + input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \ + and (input_id != 'Deteval_result_non_curved.txt'): + detections = input_reading_mod(pred_dict) + groundtruths = gt_reading_mod(gt_dir, img_id).tolist() + detections = detection_filtering( + detections, + groundtruths) # filters detections overlapping with DC area + dc_id = [] + for i in range(len(groundtruths)): + if groundtruths[i][5] == '#': + dc_id.append(i) + cnt = 0 + for a in dc_id: + num = a - cnt + del groundtruths[num] + cnt += 1 + + local_sigma_table = np.zeros((len(groundtruths), len(detections))) + local_tau_table = np.zeros((len(groundtruths), len(detections))) + local_pred_str = {} + local_gt_str = {} + + for gt_id, gt in enumerate(groundtruths): + if len(detections) > 0: + for det_id, detection in enumerate(detections): + detection_orig = detection + detection = [float(x) for x in detection[0].split(',')] + detection = list(map(int, detection)) + pred_seq_str = detection_orig[1].strip() + det_x = detection[0::2] + det_y = detection[1::2] + gt_x = list(map(int, np.squeeze(gt[1]))) + gt_y = list(map(int, np.squeeze(gt[3]))) + gt_seq_str = str(gt[4].tolist()[0]) + + local_sigma_table[gt_id, det_id] = sigma_calculation( + det_x, det_y, gt_x, gt_y) + local_tau_table[gt_id, det_id] = tau_calculation( + det_x, det_y, gt_x, gt_y) + local_pred_str[det_id] = pred_seq_str + local_gt_str[gt_id] = gt_seq_str + + global_sigma = local_sigma_table + global_tau = local_tau_table + global_pred_str = local_pred_str + global_gt_str = local_gt_str + + single_data = {} + single_data['sigma'] = global_sigma + single_data['global_tau'] = global_tau + single_data['global_pred_str'] = global_pred_str + single_data['global_gt_str'] = global_gt_str + return single_data + + +def combine_results(all_data): + tr = 0.7 + tp = 0.6 + fsc_k = 0.8 + k = 2 + global_sigma = [] + global_tau = [] + global_pred_str = [] + global_gt_str = [] + for data in all_data: + global_sigma.append(data['sigma']) + global_tau.append(data['global_tau']) + global_pred_str.append(data['global_pred_str']) + global_gt_str.append(data['global_gt_str']) + + global_accumulative_recall = 0 + global_accumulative_precision = 0 + total_num_gt = 0 + total_num_det = 0 + hit_str_count = 0 + hit_count = 0 + + def one_to_one(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idy): + hit_str_num = 0 + for gt_id in range(num_gt): + gt_matching_qualified_sigma_candidates = np.where( + local_sigma_table[gt_id, :] > tr) + gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[ + 0].shape[0] + gt_matching_qualified_tau_candidates = np.where( + local_tau_table[gt_id, :] > tp) + gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[ + 0].shape[0] + + det_matching_qualified_sigma_candidates = np.where( + local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]] + > tr) + det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[ + 0].shape[0] + det_matching_qualified_tau_candidates = np.where( + local_tau_table[:, gt_matching_qualified_tau_candidates[0]] > + tp) + det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[ + 0].shape[0] + + if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \ + (det_matching_num_qualified_sigma_candidates == 1) and ( + det_matching_num_qualified_tau_candidates == 1): + global_accumulative_recall = global_accumulative_recall + 1.0 + global_accumulative_precision = global_accumulative_precision + 1.0 + local_accumulative_recall = local_accumulative_recall + 1.0 + local_accumulative_precision = local_accumulative_precision + 1.0 + + gt_flag[0, gt_id] = 1 + matched_det_id = np.where(local_sigma_table[gt_id, :] > tr) + # recg start + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[ + 0]] + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + # recg end + det_flag[0, matched_det_id] = 1 + return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num + + def one_to_many(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idy): + hit_str_num = 0 + for gt_id in range(num_gt): + # skip the following if the groundtruth was matched + if gt_flag[0, gt_id] > 0: + continue + + non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0) + num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0] + + if num_non_zero_in_sigma >= k: + ####search for all detections that overlaps with this groundtruth + qualified_tau_candidates = np.where((local_tau_table[ + gt_id, :] >= tp) & (det_flag[0, :] == 0)) + num_qualified_tau_candidates = qualified_tau_candidates[ + 0].shape[0] + + if num_qualified_tau_candidates == 1: + if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp) + and + (local_sigma_table[gt_id, qualified_tau_candidates] >= + tr)): + # became an one-to-one case + global_accumulative_recall = global_accumulative_recall + 1.0 + global_accumulative_precision = global_accumulative_precision + 1.0 + local_accumulative_recall = local_accumulative_recall + 1.0 + local_accumulative_precision = local_accumulative_precision + 1.0 + + gt_flag[0, gt_id] = 1 + det_flag[0, qualified_tau_candidates] = 1 + # recg start + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][ + qualified_tau_candidates[0].tolist()[0]] + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + # recg end + elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates]) + >= tr): + gt_flag[0, gt_id] = 1 + det_flag[0, qualified_tau_candidates] = 1 + # recg start + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][ + qualified_tau_candidates[0].tolist()[0]] + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + # recg end + + global_accumulative_recall = global_accumulative_recall + fsc_k + global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k + + local_accumulative_recall = local_accumulative_recall + fsc_k + local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k + + return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num + + def many_to_one(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idy): + hit_str_num = 0 + for det_id in range(num_det): + # skip the following if the detection was matched + if det_flag[0, det_id] > 0: + continue + + non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0) + num_non_zero_in_tau = non_zero_in_tau[0].shape[0] + + if num_non_zero_in_tau >= k: + ####search for all detections that overlaps with this groundtruth + qualified_sigma_candidates = np.where(( + local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0)) + num_qualified_sigma_candidates = qualified_sigma_candidates[ + 0].shape[0] + + if num_qualified_sigma_candidates == 1: + if ((local_tau_table[qualified_sigma_candidates, det_id] >= + tp) and + (local_sigma_table[qualified_sigma_candidates, det_id] + >= tr)): + # became an one-to-one case + global_accumulative_recall = global_accumulative_recall + 1.0 + global_accumulative_precision = global_accumulative_precision + 1.0 + local_accumulative_recall = local_accumulative_recall + 1.0 + local_accumulative_precision = local_accumulative_precision + 1.0 + + gt_flag[0, qualified_sigma_candidates] = 1 + det_flag[0, det_id] = 1 + # recg start + pred_str_cur = global_pred_str[idy][det_id] + gt_len = len(qualified_sigma_candidates[0]) + for idx in range(gt_len): + ele_gt_id = qualified_sigma_candidates[0].tolist()[ + idx] + if ele_gt_id not in global_gt_str[idy]: + continue + gt_str_cur = global_gt_str[idy][ele_gt_id] + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + break + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + break + # recg end + elif (np.sum(local_tau_table[qualified_sigma_candidates, + det_id]) >= tp): + det_flag[0, det_id] = 1 + gt_flag[0, qualified_sigma_candidates] = 1 + # recg start + pred_str_cur = global_pred_str[idy][det_id] + gt_len = len(qualified_sigma_candidates[0]) + for idx in range(gt_len): + ele_gt_id = qualified_sigma_candidates[0].tolist()[idx] + if ele_gt_id not in global_gt_str[idy]: + continue + gt_str_cur = global_gt_str[idy][ele_gt_id] + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + break + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + break + # recg end + + global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k + global_accumulative_precision = global_accumulative_precision + fsc_k + + local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k + local_accumulative_precision = local_accumulative_precision + fsc_k + return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num + + for idx in range(len(global_sigma)): + local_sigma_table = np.array(global_sigma[idx]) + local_tau_table = global_tau[idx] + + num_gt = local_sigma_table.shape[0] + num_det = local_sigma_table.shape[1] + + total_num_gt = total_num_gt + num_gt + total_num_det = total_num_det + num_det + + local_accumulative_recall = 0 + local_accumulative_precision = 0 + gt_flag = np.zeros((1, num_gt)) + det_flag = np.zeros((1, num_det)) + + #######first check for one-to-one case########## + local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ + gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idx) + + hit_str_count += hit_str_num + #######then check for one-to-many case########## + local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ + gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idx) + hit_str_count += hit_str_num + #######then check for many-to-one case########## + local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ + gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idx) + hit_str_count += hit_str_num + + try: + recall = global_accumulative_recall / total_num_gt + except ZeroDivisionError: + recall = 0 + + try: + precision = global_accumulative_precision / total_num_det + except ZeroDivisionError: + precision = 0 + + try: + f_score = 2 * precision * recall / (precision + recall) + except ZeroDivisionError: + f_score = 0 + + try: + seqerr = 1 - float(hit_str_count) / global_accumulative_recall + except ZeroDivisionError: + seqerr = 1 + + try: + recall_e2e = float(hit_str_count) / total_num_gt + except ZeroDivisionError: + recall_e2e = 0 + + try: + precision_e2e = float(hit_str_count) / total_num_det + except ZeroDivisionError: + precision_e2e = 0 + + try: + f_score_e2e = 2 * precision_e2e * recall_e2e / ( + precision_e2e + recall_e2e) + except ZeroDivisionError: + f_score_e2e = 0 + + final = { + 'total_num_gt': total_num_gt, + 'total_num_det': total_num_det, + 'global_accumulative_recall': global_accumulative_recall, + 'hit_str_count': hit_str_count, + 'recall': recall, + 'precision': precision, + 'f_score': f_score, + 'seqerr': seqerr, + 'recall_e2e': recall_e2e, + 'precision_e2e': precision_e2e, + 'f_score_e2e': f_score_e2e + } + return final diff --git a/ppocr/utils/e2e_metric/polygon_fast.py b/ppocr/utils/e2e_metric/polygon_fast.py new file mode 100755 index 0000000000000000000000000000000000000000..81c9ad70675bb37a95968283b6dc6f42f709df27 --- /dev/null +++ b/ppocr/utils/e2e_metric/polygon_fast.py @@ -0,0 +1,83 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from shapely.geometry import Polygon +""" +:param det_x: [1, N] Xs of detection's vertices +:param det_y: [1, N] Ys of detection's vertices +:param gt_x: [1, N] Xs of groundtruth's vertices +:param gt_y: [1, N] Ys of groundtruth's vertices + +############## +All the calculation of 'AREA' in this script is handled by: +1) First generating a binary mask with the polygon area filled up with 1's +2) Summing up all the 1's +""" + + +def area(x, y): + polygon = Polygon(np.stack([x, y], axis=1)) + return float(polygon.area) + + +def approx_area_of_intersection(det_x, det_y, gt_x, gt_y): + """ + This helper determine if both polygons are intersecting with each others with an approximation method. + Area of intersection represented by the minimum bounding rectangular [xmin, ymin, xmax, ymax] + """ + det_ymax = np.max(det_y) + det_xmax = np.max(det_x) + det_ymin = np.min(det_y) + det_xmin = np.min(det_x) + + gt_ymax = np.max(gt_y) + gt_xmax = np.max(gt_x) + gt_ymin = np.min(gt_y) + gt_xmin = np.min(gt_x) + + all_min_ymax = np.minimum(det_ymax, gt_ymax) + all_max_ymin = np.maximum(det_ymin, gt_ymin) + + intersect_heights = np.maximum(0.0, (all_min_ymax - all_max_ymin)) + + all_min_xmax = np.minimum(det_xmax, gt_xmax) + all_max_xmin = np.maximum(det_xmin, gt_xmin) + intersect_widths = np.maximum(0.0, (all_min_xmax - all_max_xmin)) + + return intersect_heights * intersect_widths + + +def area_of_intersection(det_x, det_y, gt_x, gt_y): + p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0) + p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0) + return float(p1.intersection(p2).area) + + +def area_of_union(det_x, det_y, gt_x, gt_y): + p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0) + p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0) + return float(p1.union(p2).area) + + +def iou(det_x, det_y, gt_x, gt_y): + return area_of_intersection(det_x, det_y, gt_x, gt_y) / ( + area_of_union(det_x, det_y, gt_x, gt_y) + 1.0) + + +def iod(det_x, det_y, gt_x, gt_y): + """ + This helper determine the fraction of intersection area over detection area + """ + return area_of_intersection(det_x, det_y, gt_x, gt_y) / ( + area(det_x, det_y) + 1.0) diff --git a/ppocr/utils/e2e_utils/extract_batchsize.py b/ppocr/utils/e2e_utils/extract_batchsize.py new file mode 100644 index 0000000000000000000000000000000000000000..e99a833ea76a81e02d39b16fe1a01e22f15bf3a4 --- /dev/null +++ b/ppocr/utils/e2e_utils/extract_batchsize.py @@ -0,0 +1,87 @@ +import paddle +import numpy as np +import copy + + +def org_tcl_rois(batch_size, pos_lists, pos_masks, label_lists, tcl_bs): + """ + """ + pos_lists_, pos_masks_, label_lists_ = [], [], [] + img_bs = batch_size + ngpu = int(batch_size / img_bs) + img_ids = np.array(pos_lists, dtype=np.int32)[:, 0, 0].copy() + pos_lists_split, pos_masks_split, label_lists_split = [], [], [] + for i in range(ngpu): + pos_lists_split.append([]) + pos_masks_split.append([]) + label_lists_split.append([]) + + for i in range(img_ids.shape[0]): + img_id = img_ids[i] + gpu_id = int(img_id / img_bs) + img_id = img_id % img_bs + pos_list = pos_lists[i].copy() + pos_list[:, 0] = img_id + pos_lists_split[gpu_id].append(pos_list) + pos_masks_split[gpu_id].append(pos_masks[i].copy()) + label_lists_split[gpu_id].append(copy.deepcopy(label_lists[i])) + # repeat or delete + for i in range(ngpu): + vp_len = len(pos_lists_split[i]) + if vp_len <= tcl_bs: + for j in range(0, tcl_bs - vp_len): + pos_list = pos_lists_split[i][j].copy() + pos_lists_split[i].append(pos_list) + pos_mask = pos_masks_split[i][j].copy() + pos_masks_split[i].append(pos_mask) + label_list = copy.deepcopy(label_lists_split[i][j]) + label_lists_split[i].append(label_list) + else: + for j in range(0, vp_len - tcl_bs): + c_len = len(pos_lists_split[i]) + pop_id = np.random.permutation(c_len)[0] + pos_lists_split[i].pop(pop_id) + pos_masks_split[i].pop(pop_id) + label_lists_split[i].pop(pop_id) + # merge + for i in range(ngpu): + pos_lists_.extend(pos_lists_split[i]) + pos_masks_.extend(pos_masks_split[i]) + label_lists_.extend(label_lists_split[i]) + return pos_lists_, pos_masks_, label_lists_ + + +def pre_process(label_list, pos_list, pos_mask, max_text_length, max_text_nums, + pad_num, tcl_bs): + label_list = label_list.numpy() + batch, _, _, _ = label_list.shape + pos_list = pos_list.numpy() + pos_mask = pos_mask.numpy() + pos_list_t = [] + pos_mask_t = [] + label_list_t = [] + for i in range(batch): + for j in range(max_text_nums): + if pos_mask[i, j].any(): + pos_list_t.append(pos_list[i][j]) + pos_mask_t.append(pos_mask[i][j]) + label_list_t.append(label_list[i][j]) + pos_list, pos_mask, label_list = org_tcl_rois(batch, pos_list_t, pos_mask_t, + label_list_t, tcl_bs) + label = [] + tt = [l.tolist() for l in label_list] + for i in range(tcl_bs): + k = 0 + for j in range(max_text_length): + if tt[i][j][0] != pad_num: + k += 1 + else: + break + label.append(k) + label = paddle.to_tensor(label) + label = paddle.cast(label, dtype='int64') + pos_list = paddle.to_tensor(pos_list) + pos_mask = paddle.to_tensor(pos_mask) + label_list = paddle.squeeze(paddle.to_tensor(label_list), axis=2) + label_list = paddle.cast(label_list, dtype='int32') + return pos_list, pos_mask, label_list, label diff --git a/ppocr/utils/e2e_utils/extract_textpoint_fast.py b/ppocr/utils/e2e_utils/extract_textpoint_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..787cd3017fafa6fc554bead0cc05b5bfe682df42 --- /dev/null +++ b/ppocr/utils/e2e_utils/extract_textpoint_fast.py @@ -0,0 +1,457 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains various CTC decoders.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import cv2 +import math + +import numpy as np +from itertools import groupby +from skimage.morphology._skeletonize import thin + + +def get_dict(character_dict_path): + character_str = "" + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip("\n").strip("\r\n") + character_str += line + dict_character = list(character_str) + return dict_character + + +def softmax(logits): + """ + logits: N x d + """ + max_value = np.max(logits, axis=1, keepdims=True) + exp = np.exp(logits - max_value) + exp_sum = np.sum(exp, axis=1, keepdims=True) + dist = exp / exp_sum + return dist + + +def get_keep_pos_idxs(labels, remove_blank=None): + """ + Remove duplicate and get pos idxs of keep items. + The value of keep_blank should be [None, 95]. + """ + duplicate_len_list = [] + keep_pos_idx_list = [] + keep_char_idx_list = [] + for k, v_ in groupby(labels): + current_len = len(list(v_)) + if k != remove_blank: + current_idx = int(sum(duplicate_len_list) + current_len // 2) + keep_pos_idx_list.append(current_idx) + keep_char_idx_list.append(k) + duplicate_len_list.append(current_len) + return keep_char_idx_list, keep_pos_idx_list + + +def remove_blank(labels, blank=0): + new_labels = [x for x in labels if x != blank] + return new_labels + + +def insert_blank(labels, blank=0): + new_labels = [blank] + for l in labels: + new_labels += [l, blank] + return new_labels + + +def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True): + """ + CTC greedy (best path) decoder. + """ + raw_str = np.argmax(np.array(probs_seq), axis=1) + remove_blank_in_pos = None if keep_blank_in_idxs else blank + dedup_str, keep_idx_list = get_keep_pos_idxs( + raw_str, remove_blank=remove_blank_in_pos) + dst_str = remove_blank(dedup_str, blank=blank) + return dst_str, keep_idx_list + + +def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4): + _, _, C = logits_map.shape + ys, xs = zip(*gather_info) + logits_seq = logits_map[list(ys), list(xs)] + probs_seq = logits_seq + labels = np.argmax(probs_seq, axis=1) + dst_str = [k for k, v_ in groupby(labels) if k != C - 1] + detal = len(gather_info) // (pts_num - 1) + keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1] + keep_gather_list = [gather_info[idx] for idx in keep_idx_list] + return dst_str, keep_gather_list + + +def ctc_decoder_for_image(gather_info_list, + logits_map, + Lexicon_Table, + pts_num=6): + """ + CTC decoder using multiple processes. + """ + decoder_str = [] + decoder_xys = [] + for gather_info in gather_info_list: + if len(gather_info) < pts_num: + continue + dst_str, xys_list = instance_ctc_greedy_decoder( + gather_info, logits_map, pts_num=pts_num) + dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str]) + if len(dst_str_readable) < 2: + continue + decoder_str.append(dst_str_readable) + decoder_xys.append(xys_list) + return decoder_str, decoder_xys + + +def sort_with_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + + def sort_part_with_direction(pos_list, point_direction): + pos_list = np.array(pos_list).reshape(-1, 2) + point_direction = np.array(point_direction).reshape(-1, 2) + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist() + sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist() + return sorted_list, sorted_direction + + pos_list = np.array(pos_list).reshape(-1, 2) + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + sorted_point, sorted_direction = sort_part_with_direction(pos_list, + point_direction) + + point_num = len(sorted_point) + if point_num >= 16: + middle_num = point_num // 2 + first_part_point = sorted_point[:middle_num] + first_point_direction = sorted_direction[:middle_num] + sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction( + first_part_point, first_point_direction) + + last_part_point = sorted_point[middle_num:] + last_point_direction = sorted_direction[middle_num:] + sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction( + last_part_point, last_point_direction) + sorted_point = sorted_fist_part_point + sorted_last_part_point + sorted_direction = sorted_fist_part_direction + sorted_last_part_direction + + return sorted_point, np.array(sorted_direction) + + +def add_id(pos_list, image_id=0): + """ + Add id for gather feature, for inference. + """ + new_list = [] + for item in pos_list: + new_list.append((image_id, item[0], item[1])) + return new_list + + +def sort_and_expand_with_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + h, w, _ = f_direction.shape + sorted_list, point_direction = sort_with_direction(pos_list, f_direction) + + point_num = len(sorted_list) + sub_direction_len = max(point_num // 3, 2) + left_direction = point_direction[:sub_direction_len, :] + right_dirction = point_direction[point_num - sub_direction_len:, :] + + left_average_direction = -np.mean(left_direction, axis=0, keepdims=True) + left_average_len = np.linalg.norm(left_average_direction) + left_start = np.array(sorted_list[0]) + left_step = left_average_direction / (left_average_len + 1e-6) + + right_average_direction = np.mean(right_dirction, axis=0, keepdims=True) + right_average_len = np.linalg.norm(right_average_direction) + right_step = right_average_direction / (right_average_len + 1e-6) + right_start = np.array(sorted_list[-1]) + + append_num = max( + int((left_average_len + right_average_len) / 2.0 * 0.15), 1) + left_list = [] + right_list = [] + for i in range(append_num): + ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ly < h and lx < w and (ly, lx) not in left_list: + left_list.append((ly, lx)) + ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ry < h and rx < w and (ry, rx) not in right_list: + right_list.append((ry, rx)) + + all_list = left_list[::-1] + sorted_list + right_list + return all_list + + +def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + binary_tcl_map: h x w + """ + h, w, _ = f_direction.shape + sorted_list, point_direction = sort_with_direction(pos_list, f_direction) + + point_num = len(sorted_list) + sub_direction_len = max(point_num // 3, 2) + left_direction = point_direction[:sub_direction_len, :] + right_dirction = point_direction[point_num - sub_direction_len:, :] + + left_average_direction = -np.mean(left_direction, axis=0, keepdims=True) + left_average_len = np.linalg.norm(left_average_direction) + left_start = np.array(sorted_list[0]) + left_step = left_average_direction / (left_average_len + 1e-6) + + right_average_direction = np.mean(right_dirction, axis=0, keepdims=True) + right_average_len = np.linalg.norm(right_average_direction) + right_step = right_average_direction / (right_average_len + 1e-6) + right_start = np.array(sorted_list[-1]) + + append_num = max( + int((left_average_len + right_average_len) / 2.0 * 0.15), 1) + max_append_num = 2 * append_num + + left_list = [] + right_list = [] + for i in range(max_append_num): + ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ly < h and lx < w and (ly, lx) not in left_list: + if binary_tcl_map[ly, lx] > 0.5: + left_list.append((ly, lx)) + else: + break + + for i in range(max_append_num): + ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ry < h and rx < w and (ry, rx) not in right_list: + if binary_tcl_map[ry, rx] > 0.5: + right_list.append((ry, rx)) + else: + break + + all_list = left_list[::-1] + sorted_list + right_list + return all_list + + +def point_pair2poly(point_pair_list): + """ + Transfer vertical point_pairs into poly point in clockwise. + """ + point_num = len(point_pair_list) * 2 + point_list = [0] * point_num + for idx, point_pair in enumerate(point_pair_list): + point_list[idx] = point_pair[0] + point_list[point_num - 1 - idx] = point_pair[1] + return np.array(point_list).reshape(-1, 2) + + +def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.): + ratio_pair = np.array( + [[begin_width_ratio], [end_width_ratio]], dtype=np.float32) + p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair + p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair + return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) + + +def expand_poly_along_width(poly, shrink_ratio_of_width=0.3): + """ + expand poly along width. + """ + point_num = poly.shape[0] + left_quad = np.array( + [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) + left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \ + (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) + left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0) + right_quad = np.array( + [ + poly[point_num // 2 - 2], poly[point_num // 2 - 1], + poly[point_num // 2], poly[point_num // 2 + 1] + ], + dtype=np.float32) + right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \ + (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) + right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio) + poly[0] = left_quad_expand[0] + poly[-1] = left_quad_expand[-1] + poly[point_num // 2 - 1] = right_quad_expand[1] + poly[point_num // 2] = right_quad_expand[2] + return poly + + +def restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w, + src_h, valid_set): + poly_list = [] + keep_str_list = [] + for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs): + if len(keep_str) < 2: + print('--> too short, {}'.format(keep_str)) + continue + + offset_expand = 1.0 + if valid_set == 'totaltext': + offset_expand = 1.2 + + point_pair_list = [] + for y, x in yx_center_line: + offset = p_border[:, y, x].reshape(2, 2) * offset_expand + ori_yx = np.array([y, x], dtype=np.float32) + point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array( + [ratio_w, ratio_h]).reshape(-1, 2) + point_pair_list.append(point_pair) + + detected_poly = point_pair2poly(point_pair_list) + detected_poly = expand_poly_along_width( + detected_poly, shrink_ratio_of_width=0.2) + detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w) + detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h) + + keep_str_list.append(keep_str) + if valid_set == 'partvgg': + middle_point = len(detected_poly) // 2 + detected_poly = detected_poly[ + [0, middle_point - 1, middle_point, -1], :] + poly_list.append(detected_poly) + elif valid_set == 'totaltext': + poly_list.append(detected_poly) + else: + print('--> Not supported format.') + exit(-1) + return poly_list, keep_str_list + + +def generate_pivot_list_fast(p_score, + p_char_maps, + f_direction, + Lexicon_Table, + score_thresh=0.5): + """ + return center point and end point of TCL instance; filter with the char maps; + """ + p_score = p_score[0] + f_direction = f_direction.transpose(1, 2, 0) + p_tcl_map = (p_score > score_thresh) * 1.0 + skeleton_map = thin(p_tcl_map.astype(np.uint8)) + instance_count, instance_label_map = cv2.connectedComponents( + skeleton_map.astype(np.uint8), connectivity=8) + + # get TCL Instance + all_pos_yxs = [] + if instance_count > 0: + for instance_id in range(1, instance_count): + pos_list = [] + ys, xs = np.where(instance_label_map == instance_id) + pos_list = list(zip(ys, xs)) + + if len(pos_list) < 3: + continue + + pos_list_sorted = sort_and_expand_with_direction_v2( + pos_list, f_direction, p_tcl_map) + all_pos_yxs.append(pos_list_sorted) + + p_char_maps = p_char_maps.transpose([1, 2, 0]) + decoded_str, keep_yxs_list = ctc_decoder_for_image( + all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table) + return keep_yxs_list, decoded_str + + +def extract_main_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + pos_list = np.array(pos_list) + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] + point_direction = point_direction[:, ::-1] # x, y -> y, x + average_direction = np.mean(point_direction, axis=0, keepdims=True) + average_direction = average_direction / ( + np.linalg.norm(average_direction) + 1e-6) + return average_direction + + +def sort_by_direction_with_image_id_deprecated(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[id, y, x], [id, y, x], [id, y, x] ...] + """ + pos_list_full = np.array(pos_list).reshape(-1, 3) + pos_list = pos_list_full[:, 1:] + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist() + return sorted_list + + +def sort_by_direction_with_image_id(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + + def sort_part_with_direction(pos_list_full, point_direction): + pos_list_full = np.array(pos_list_full).reshape(-1, 3) + pos_list = pos_list_full[:, 1:] + point_direction = np.array(point_direction).reshape(-1, 2) + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist() + sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist() + return sorted_list, sorted_direction + + pos_list = np.array(pos_list).reshape(-1, 3) + point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + sorted_point, sorted_direction = sort_part_with_direction(pos_list, + point_direction) + + point_num = len(sorted_point) + if point_num >= 16: + middle_num = point_num // 2 + first_part_point = sorted_point[:middle_num] + first_point_direction = sorted_direction[:middle_num] + sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction( + first_part_point, first_point_direction) + + last_part_point = sorted_point[middle_num:] + last_point_direction = sorted_direction[middle_num:] + sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction( + last_part_point, last_point_direction) + sorted_point = sorted_fist_part_point + sorted_last_part_point + sorted_direction = sorted_fist_part_direction + sorted_last_part_direction + + return sorted_point diff --git a/ppocr/utils/e2e_utils/extract_textpoint_slow.py b/ppocr/utils/e2e_utils/extract_textpoint_slow.py new file mode 100644 index 0000000000000000000000000000000000000000..db0c30e67bea472da6c7ed5176b1c70f0ab1cbc6 --- /dev/null +++ b/ppocr/utils/e2e_utils/extract_textpoint_slow.py @@ -0,0 +1,590 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains various CTC decoders.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import cv2 +import math + +import numpy as np +from itertools import groupby +from skimage.morphology._skeletonize import thin + + +def get_dict(character_dict_path): + character_str = "" + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip("\n").strip("\r\n") + character_str += line + dict_character = list(character_str) + return dict_character + + +def point_pair2poly(point_pair_list): + """ + Transfer vertical point_pairs into poly point in clockwise. + """ + pair_length_list = [] + for point_pair in point_pair_list: + pair_length = np.linalg.norm(point_pair[0] - point_pair[1]) + pair_length_list.append(pair_length) + pair_length_list = np.array(pair_length_list) + pair_info = (pair_length_list.max(), pair_length_list.min(), + pair_length_list.mean()) + + point_num = len(point_pair_list) * 2 + point_list = [0] * point_num + for idx, point_pair in enumerate(point_pair_list): + point_list[idx] = point_pair[0] + point_list[point_num - 1 - idx] = point_pair[1] + return np.array(point_list).reshape(-1, 2), pair_info + + +def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.): + """ + Generate shrink_quad_along_width. + """ + ratio_pair = np.array( + [[begin_width_ratio], [end_width_ratio]], dtype=np.float32) + p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair + p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair + return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) + + +def expand_poly_along_width(poly, shrink_ratio_of_width=0.3): + """ + expand poly along width. + """ + point_num = poly.shape[0] + left_quad = np.array( + [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) + left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \ + (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) + left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0) + right_quad = np.array( + [ + poly[point_num // 2 - 2], poly[point_num // 2 - 1], + poly[point_num // 2], poly[point_num // 2 + 1] + ], + dtype=np.float32) + right_ratio = 1.0 + \ + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \ + (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) + right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio) + poly[0] = left_quad_expand[0] + poly[-1] = left_quad_expand[-1] + poly[point_num // 2 - 1] = right_quad_expand[1] + poly[point_num // 2] = right_quad_expand[2] + return poly + + +def softmax(logits): + """ + logits: N x d + """ + max_value = np.max(logits, axis=1, keepdims=True) + exp = np.exp(logits - max_value) + exp_sum = np.sum(exp, axis=1, keepdims=True) + dist = exp / exp_sum + return dist + + +def get_keep_pos_idxs(labels, remove_blank=None): + """ + Remove duplicate and get pos idxs of keep items. + The value of keep_blank should be [None, 95]. + """ + duplicate_len_list = [] + keep_pos_idx_list = [] + keep_char_idx_list = [] + for k, v_ in groupby(labels): + current_len = len(list(v_)) + if k != remove_blank: + current_idx = int(sum(duplicate_len_list) + current_len // 2) + keep_pos_idx_list.append(current_idx) + keep_char_idx_list.append(k) + duplicate_len_list.append(current_len) + return keep_char_idx_list, keep_pos_idx_list + + +def remove_blank(labels, blank=0): + new_labels = [x for x in labels if x != blank] + return new_labels + + +def insert_blank(labels, blank=0): + new_labels = [blank] + for l in labels: + new_labels += [l, blank] + return new_labels + + +def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True): + """ + CTC greedy (best path) decoder. + """ + raw_str = np.argmax(np.array(probs_seq), axis=1) + remove_blank_in_pos = None if keep_blank_in_idxs else blank + dedup_str, keep_idx_list = get_keep_pos_idxs( + raw_str, remove_blank=remove_blank_in_pos) + dst_str = remove_blank(dedup_str, blank=blank) + return dst_str, keep_idx_list + + +def instance_ctc_greedy_decoder(gather_info, + logits_map, + keep_blank_in_idxs=True): + """ + gather_info: [[x, y], [x, y] ...] + logits_map: H x W X (n_chars + 1) + """ + _, _, C = logits_map.shape + ys, xs = zip(*gather_info) + logits_seq = logits_map[list(ys), list(xs)] # n x 96 + probs_seq = softmax(logits_seq) + dst_str, keep_idx_list = ctc_greedy_decoder( + probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs) + keep_gather_list = [gather_info[idx] for idx in keep_idx_list] + return dst_str, keep_gather_list + + +def ctc_decoder_for_image(gather_info_list, logits_map, + keep_blank_in_idxs=True): + """ + CTC decoder using multiple processes. + """ + decoder_results = [] + for gather_info in gather_info_list: + res = instance_ctc_greedy_decoder( + gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs) + decoder_results.append(res) + return decoder_results + + +def sort_with_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + + def sort_part_with_direction(pos_list, point_direction): + pos_list = np.array(pos_list).reshape(-1, 2) + point_direction = np.array(point_direction).reshape(-1, 2) + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist() + sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist() + return sorted_list, sorted_direction + + pos_list = np.array(pos_list).reshape(-1, 2) + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + sorted_point, sorted_direction = sort_part_with_direction(pos_list, + point_direction) + + point_num = len(sorted_point) + if point_num >= 16: + middle_num = point_num // 2 + first_part_point = sorted_point[:middle_num] + first_point_direction = sorted_direction[:middle_num] + sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction( + first_part_point, first_point_direction) + + last_part_point = sorted_point[middle_num:] + last_point_direction = sorted_direction[middle_num:] + sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction( + last_part_point, last_point_direction) + sorted_point = sorted_fist_part_point + sorted_last_part_point + sorted_direction = sorted_fist_part_direction + sorted_last_part_direction + + return sorted_point, np.array(sorted_direction) + + +def add_id(pos_list, image_id=0): + """ + Add id for gather feature, for inference. + """ + new_list = [] + for item in pos_list: + new_list.append((image_id, item[0], item[1])) + return new_list + + +def sort_and_expand_with_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + h, w, _ = f_direction.shape + sorted_list, point_direction = sort_with_direction(pos_list, f_direction) + + # expand along + point_num = len(sorted_list) + sub_direction_len = max(point_num // 3, 2) + left_direction = point_direction[:sub_direction_len, :] + right_dirction = point_direction[point_num - sub_direction_len:, :] + + left_average_direction = -np.mean(left_direction, axis=0, keepdims=True) + left_average_len = np.linalg.norm(left_average_direction) + left_start = np.array(sorted_list[0]) + left_step = left_average_direction / (left_average_len + 1e-6) + + right_average_direction = np.mean(right_dirction, axis=0, keepdims=True) + right_average_len = np.linalg.norm(right_average_direction) + right_step = right_average_direction / (right_average_len + 1e-6) + right_start = np.array(sorted_list[-1]) + + append_num = max( + int((left_average_len + right_average_len) / 2.0 * 0.15), 1) + left_list = [] + right_list = [] + for i in range(append_num): + ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ly < h and lx < w and (ly, lx) not in left_list: + left_list.append((ly, lx)) + ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ry < h and rx < w and (ry, rx) not in right_list: + right_list.append((ry, rx)) + + all_list = left_list[::-1] + sorted_list + right_list + return all_list + + +def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + binary_tcl_map: h x w + """ + h, w, _ = f_direction.shape + sorted_list, point_direction = sort_with_direction(pos_list, f_direction) + + # expand along + point_num = len(sorted_list) + sub_direction_len = max(point_num // 3, 2) + left_direction = point_direction[:sub_direction_len, :] + right_dirction = point_direction[point_num - sub_direction_len:, :] + + left_average_direction = -np.mean(left_direction, axis=0, keepdims=True) + left_average_len = np.linalg.norm(left_average_direction) + left_start = np.array(sorted_list[0]) + left_step = left_average_direction / (left_average_len + 1e-6) + + right_average_direction = np.mean(right_dirction, axis=0, keepdims=True) + right_average_len = np.linalg.norm(right_average_direction) + right_step = right_average_direction / (right_average_len + 1e-6) + right_start = np.array(sorted_list[-1]) + + append_num = max( + int((left_average_len + right_average_len) / 2.0 * 0.15), 1) + max_append_num = 2 * append_num + + left_list = [] + right_list = [] + for i in range(max_append_num): + ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ly < h and lx < w and (ly, lx) not in left_list: + if binary_tcl_map[ly, lx] > 0.5: + left_list.append((ly, lx)) + else: + break + + for i in range(max_append_num): + ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ry < h and rx < w and (ry, rx) not in right_list: + if binary_tcl_map[ry, rx] > 0.5: + right_list.append((ry, rx)) + else: + break + + all_list = left_list[::-1] + sorted_list + right_list + return all_list + + +def generate_pivot_list_curved(p_score, + p_char_maps, + f_direction, + score_thresh=0.5, + is_expand=True, + is_backbone=False, + image_id=0): + """ + return center point and end point of TCL instance; filter with the char maps; + """ + p_score = p_score[0] + f_direction = f_direction.transpose(1, 2, 0) + p_tcl_map = (p_score > score_thresh) * 1.0 + skeleton_map = thin(p_tcl_map) + instance_count, instance_label_map = cv2.connectedComponents( + skeleton_map.astype(np.uint8), connectivity=8) + + # get TCL Instance + all_pos_yxs = [] + center_pos_yxs = [] + end_points_yxs = [] + instance_center_pos_yxs = [] + if instance_count > 0: + for instance_id in range(1, instance_count): + pos_list = [] + ys, xs = np.where(instance_label_map == instance_id) + pos_list = list(zip(ys, xs)) + + ### FIX-ME, eliminate outlier + if len(pos_list) < 3: + continue + + if is_expand: + pos_list_sorted = sort_and_expand_with_direction_v2( + pos_list, f_direction, p_tcl_map) + else: + pos_list_sorted, _ = sort_with_direction(pos_list, f_direction) + all_pos_yxs.append(pos_list_sorted) + + # use decoder to filter backgroud points. + p_char_maps = p_char_maps.transpose([1, 2, 0]) + decode_res = ctc_decoder_for_image( + all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True) + for decoded_str, keep_yxs_list in decode_res: + if is_backbone: + keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id) + instance_center_pos_yxs.append(keep_yxs_list_with_id) + else: + end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1])) + center_pos_yxs.extend(keep_yxs_list) + + if is_backbone: + return instance_center_pos_yxs + else: + return center_pos_yxs, end_points_yxs + + +def generate_pivot_list_horizontal(p_score, + p_char_maps, + f_direction, + score_thresh=0.5, + is_backbone=False, + image_id=0): + """ + return center point and end point of TCL instance; filter with the char maps; + """ + p_score = p_score[0] + f_direction = f_direction.transpose(1, 2, 0) + p_tcl_map_bi = (p_score > score_thresh) * 1.0 + instance_count, instance_label_map = cv2.connectedComponents( + p_tcl_map_bi.astype(np.uint8), connectivity=8) + + # get TCL Instance + all_pos_yxs = [] + center_pos_yxs = [] + end_points_yxs = [] + instance_center_pos_yxs = [] + + if instance_count > 0: + for instance_id in range(1, instance_count): + pos_list = [] + ys, xs = np.where(instance_label_map == instance_id) + pos_list = list(zip(ys, xs)) + + ### FIX-ME, eliminate outlier + if len(pos_list) < 5: + continue + + # add rule here + main_direction = extract_main_direction(pos_list, + f_direction) # y x + reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x + is_h_angle = abs(np.sum( + main_direction * reference_directin)) < math.cos(math.pi / 180 * + 70) + + point_yxs = np.array(pos_list) + max_y, max_x = np.max(point_yxs, axis=0) + min_y, min_x = np.min(point_yxs, axis=0) + is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x) + + pos_list_final = [] + if is_h_len: + xs = np.unique(xs) + for x in xs: + ys = instance_label_map[:, x].copy().reshape((-1, )) + y = int(np.where(ys == instance_id)[0].mean()) + pos_list_final.append((y, x)) + else: + ys = np.unique(ys) + for y in ys: + xs = instance_label_map[y, :].copy().reshape((-1, )) + x = int(np.where(xs == instance_id)[0].mean()) + pos_list_final.append((y, x)) + + pos_list_sorted, _ = sort_with_direction(pos_list_final, + f_direction) + all_pos_yxs.append(pos_list_sorted) + + # use decoder to filter backgroud points. + p_char_maps = p_char_maps.transpose([1, 2, 0]) + decode_res = ctc_decoder_for_image( + all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True) + for decoded_str, keep_yxs_list in decode_res: + if is_backbone: + keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id) + instance_center_pos_yxs.append(keep_yxs_list_with_id) + else: + end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1])) + center_pos_yxs.extend(keep_yxs_list) + + if is_backbone: + return instance_center_pos_yxs + else: + return center_pos_yxs, end_points_yxs + + +def generate_pivot_list_slow(p_score, + p_char_maps, + f_direction, + score_thresh=0.5, + is_backbone=False, + is_curved=True, + image_id=0): + """ + Warp all the function together. + """ + if is_curved: + return generate_pivot_list_curved( + p_score, + p_char_maps, + f_direction, + score_thresh=score_thresh, + is_expand=True, + is_backbone=is_backbone, + image_id=image_id) + else: + return generate_pivot_list_horizontal( + p_score, + p_char_maps, + f_direction, + score_thresh=score_thresh, + is_backbone=is_backbone, + image_id=image_id) + + +# for refine module +def extract_main_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + pos_list = np.array(pos_list) + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] + point_direction = point_direction[:, ::-1] # x, y -> y, x + average_direction = np.mean(point_direction, axis=0, keepdims=True) + average_direction = average_direction / ( + np.linalg.norm(average_direction) + 1e-6) + return average_direction + + +def sort_by_direction_with_image_id_deprecated(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[id, y, x], [id, y, x], [id, y, x] ...] + """ + pos_list_full = np.array(pos_list).reshape(-1, 3) + pos_list = pos_list_full[:, 1:] + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist() + return sorted_list + + +def sort_by_direction_with_image_id(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + + def sort_part_with_direction(pos_list_full, point_direction): + pos_list_full = np.array(pos_list_full).reshape(-1, 3) + pos_list = pos_list_full[:, 1:] + point_direction = np.array(point_direction).reshape(-1, 2) + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist() + sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist() + return sorted_list, sorted_direction + + pos_list = np.array(pos_list).reshape(-1, 3) + point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + sorted_point, sorted_direction = sort_part_with_direction(pos_list, + point_direction) + + point_num = len(sorted_point) + if point_num >= 16: + middle_num = point_num // 2 + first_part_point = sorted_point[:middle_num] + first_point_direction = sorted_direction[:middle_num] + sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction( + first_part_point, first_point_direction) + + last_part_point = sorted_point[middle_num:] + last_point_direction = sorted_direction[middle_num:] + sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction( + last_part_point, last_point_direction) + sorted_point = sorted_fist_part_point + sorted_last_part_point + sorted_direction = sorted_fist_part_direction + sorted_last_part_direction + + return sorted_point + + +def generate_pivot_list_tt_inference(p_score, + p_char_maps, + f_direction, + score_thresh=0.5, + is_backbone=False, + is_curved=True, + image_id=0): + """ + return center point and end point of TCL instance; filter with the char maps; + """ + p_score = p_score[0] + f_direction = f_direction.transpose(1, 2, 0) + p_tcl_map = (p_score > score_thresh) * 1.0 + skeleton_map = thin(p_tcl_map) + instance_count, instance_label_map = cv2.connectedComponents( + skeleton_map.astype(np.uint8), connectivity=8) + + # get TCL Instance + all_pos_yxs = [] + if instance_count > 0: + for instance_id in range(1, instance_count): + pos_list = [] + ys, xs = np.where(instance_label_map == instance_id) + pos_list = list(zip(ys, xs)) + ### FIX-ME, eliminate outlier + if len(pos_list) < 3: + continue + pos_list_sorted = sort_and_expand_with_direction_v2( + pos_list, f_direction, p_tcl_map) + pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id) + all_pos_yxs.append(pos_list_sorted_with_id) + return all_pos_yxs diff --git a/ppocr/utils/e2e_utils/pgnet_pp_utils.py b/ppocr/utils/e2e_utils/pgnet_pp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..64bfd372cc75ab533ce3ef46216a33345dae40c4 --- /dev/null +++ b/ppocr/utils/e2e_utils/pgnet_pp_utils.py @@ -0,0 +1,181 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle +import os +import sys + +__dir__ = os.path.dirname(__file__) +sys.path.append(__dir__) +sys.path.append(os.path.join(__dir__, '..')) +from extract_textpoint_slow import * +from extract_textpoint_fast import generate_pivot_list_fast, restore_poly + + +class PGNet_PostProcess(object): + # two different post-process + def __init__(self, character_dict_path, valid_set, score_thresh, outs_dict, + shape_list): + self.Lexicon_Table = get_dict(character_dict_path) + self.valid_set = valid_set + self.score_thresh = score_thresh + self.outs_dict = outs_dict + self.shape_list = shape_list + + def pg_postprocess_fast(self): + p_score = self.outs_dict['f_score'] + p_border = self.outs_dict['f_border'] + p_char = self.outs_dict['f_char'] + p_direction = self.outs_dict['f_direction'] + if isinstance(p_score, paddle.Tensor): + p_score = p_score[0].numpy() + p_border = p_border[0].numpy() + p_direction = p_direction[0].numpy() + p_char = p_char[0].numpy() + else: + p_score = p_score[0] + p_border = p_border[0] + p_direction = p_direction[0] + p_char = p_char[0] + + src_h, src_w, ratio_h, ratio_w = self.shape_list[0] + instance_yxs_list, seq_strs = generate_pivot_list_fast( + p_score, + p_char, + p_direction, + self.Lexicon_Table, + score_thresh=self.score_thresh) + poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs, + p_border, ratio_w, ratio_h, + src_w, src_h, self.valid_set) + data = { + 'points': poly_list, + 'strs': keep_str_list, + } + return data + + def pg_postprocess_slow(self): + p_score = self.outs_dict['f_score'] + p_border = self.outs_dict['f_border'] + p_char = self.outs_dict['f_char'] + p_direction = self.outs_dict['f_direction'] + if isinstance(p_score, paddle.Tensor): + p_score = p_score[0].numpy() + p_border = p_border[0].numpy() + p_direction = p_direction[0].numpy() + p_char = p_char[0].numpy() + else: + p_score = p_score[0] + p_border = p_border[0] + p_direction = p_direction[0] + p_char = p_char[0] + src_h, src_w, ratio_h, ratio_w = self.shape_list[0] + is_curved = self.valid_set == "totaltext" + instance_yxs_list = generate_pivot_list_slow( + p_score, + p_char, + p_direction, + score_thresh=self.score_thresh, + is_backbone=True, + is_curved=is_curved) + p_char = paddle.to_tensor(np.expand_dims(p_char, axis=0)) + char_seq_idx_set = [] + for i in range(len(instance_yxs_list)): + gather_info_lod = paddle.to_tensor(instance_yxs_list[i]) + f_char_map = paddle.transpose(p_char, [0, 2, 3, 1]) + feature_seq = paddle.gather_nd(f_char_map, gather_info_lod) + feature_seq = np.expand_dims(feature_seq.numpy(), axis=0) + feature_len = [len(feature_seq[0])] + featyre_seq = paddle.to_tensor(feature_seq) + feature_len = np.array([feature_len]).astype(np.int64) + length = paddle.to_tensor(feature_len) + seq_pred = paddle.fluid.layers.ctc_greedy_decoder( + input=featyre_seq, blank=36, input_length=length) + seq_pred_str = seq_pred[0].numpy().tolist()[0] + seq_len = seq_pred[1].numpy()[0][0] + temp_t = [] + for c in seq_pred_str[:seq_len]: + temp_t.append(c) + char_seq_idx_set.append(temp_t) + seq_strs = [] + for char_idx_set in char_seq_idx_set: + pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set]) + seq_strs.append(pr_str) + poly_list = [] + keep_str_list = [] + all_point_list = [] + all_point_pair_list = [] + for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs): + if len(yx_center_line) == 1: + yx_center_line.append(yx_center_line[-1]) + + offset_expand = 1.0 + if self.valid_set == 'totaltext': + offset_expand = 1.2 + + point_pair_list = [] + for batch_id, y, x in yx_center_line: + offset = p_border[:, y, x].reshape(2, 2) + if offset_expand != 1.0: + offset_length = np.linalg.norm( + offset, axis=1, keepdims=True) + expand_length = np.clip( + offset_length * (offset_expand - 1), + a_min=0.5, + a_max=3.0) + offset_detal = offset / offset_length * expand_length + offset = offset + offset_detal + ori_yx = np.array([y, x], dtype=np.float32) + point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array( + [ratio_w, ratio_h]).reshape(-1, 2) + point_pair_list.append(point_pair) + + all_point_list.append([ + int(round(x * 4.0 / ratio_w)), + int(round(y * 4.0 / ratio_h)) + ]) + all_point_pair_list.append(point_pair.round().astype(np.int32) + .tolist()) + + detected_poly, pair_length_info = point_pair2poly(point_pair_list) + detected_poly = expand_poly_along_width( + detected_poly, shrink_ratio_of_width=0.2) + detected_poly[:, 0] = np.clip( + detected_poly[:, 0], a_min=0, a_max=src_w) + detected_poly[:, 1] = np.clip( + detected_poly[:, 1], a_min=0, a_max=src_h) + + if len(keep_str) < 2: + continue + + keep_str_list.append(keep_str) + detected_poly = np.round(detected_poly).astype('int32') + if self.valid_set == 'partvgg': + middle_point = len(detected_poly) // 2 + detected_poly = detected_poly[ + [0, middle_point - 1, middle_point, -1], :] + poly_list.append(detected_poly) + elif self.valid_set == 'totaltext': + poly_list.append(detected_poly) + else: + print('--> Not supported format.') + exit(-1) + data = { + 'points': poly_list, + 'strs': keep_str_list, + } + return data diff --git a/ppocr/utils/e2e_utils/visual.py b/ppocr/utils/e2e_utils/visual.py new file mode 100644 index 0000000000000000000000000000000000000000..e6e4fd0667dbf4a42dbc0fd9bf26e6fd91be0d82 --- /dev/null +++ b/ppocr/utils/e2e_utils/visual.py @@ -0,0 +1,162 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import cv2 +import time + + +def resize_image(im, max_side_len=512): + """ + resize image to a size multiple of max_stride which is required by the network + :param im: the resized image + :param max_side_len: limit of max image size to avoid out of memory in gpu + :return: the resized image and the resize ratio + """ + h, w, _ = im.shape + + resize_w = w + resize_h = h + + if resize_h > resize_w: + ratio = float(max_side_len) / resize_h + else: + ratio = float(max_side_len) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + + return im, (ratio_h, ratio_w) + + +def resize_image_min(im, max_side_len=512): + """ + """ + h, w, _ = im.shape + + resize_w = w + resize_h = h + + if resize_h < resize_w: + ratio = float(max_side_len) / resize_h + else: + ratio = float(max_side_len) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + return im, (ratio_h, ratio_w) + + +def resize_image_for_totaltext(im, max_side_len=512): + """ + """ + h, w, _ = im.shape + + resize_w = w + resize_h = h + ratio = 1.25 + if h * ratio > max_side_len: + ratio = float(max_side_len) / resize_h + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + return im, (ratio_h, ratio_w) + + +def point_pair2poly(point_pair_list): + """ + Transfer vertical point_pairs into poly point in clockwise. + """ + pair_length_list = [] + for point_pair in point_pair_list: + pair_length = np.linalg.norm(point_pair[0] - point_pair[1]) + pair_length_list.append(pair_length) + pair_length_list = np.array(pair_length_list) + pair_info = (pair_length_list.max(), pair_length_list.min(), + pair_length_list.mean()) + + point_num = len(point_pair_list) * 2 + point_list = [0] * point_num + for idx, point_pair in enumerate(point_pair_list): + point_list[idx] = point_pair[0] + point_list[point_num - 1 - idx] = point_pair[1] + return np.array(point_list).reshape(-1, 2), pair_info + + +def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.): + """ + Generate shrink_quad_along_width. + """ + ratio_pair = np.array( + [[begin_width_ratio], [end_width_ratio]], dtype=np.float32) + p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair + p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair + return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) + + +def expand_poly_along_width(poly, shrink_ratio_of_width=0.3): + """ + expand poly along width. + """ + point_num = poly.shape[0] + left_quad = np.array( + [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) + left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \ + (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) + left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0) + right_quad = np.array( + [ + poly[point_num // 2 - 2], poly[point_num // 2 - 1], + poly[point_num // 2], poly[point_num // 2 + 1] + ], + dtype=np.float32) + right_ratio = 1.0 + \ + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \ + (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) + right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio) + poly[0] = left_quad_expand[0] + poly[-1] = left_quad_expand[-1] + poly[point_num // 2 - 1] = right_quad_expand[1] + poly[point_num // 2] = right_quad_expand[2] + return poly + + +def norm2(x, axis=None): + if axis: + return np.sqrt(np.sum(x**2, axis=axis)) + return np.sqrt(np.sum(x**2)) + + +def cos(p1, p2): + return (p1 * p2).sum() / (norm2(p1) * norm2(p2)) diff --git a/ppocr/utils/en_dict.txt b/ppocr/utils/en_dict.txt new file mode 100644 index 0000000000000000000000000000000000000000..7677d31b9d3f08eef2823c2cf051beeab1f0470b --- /dev/null +++ b/ppocr/utils/en_dict.txt @@ -0,0 +1,95 @@ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +: +; +< += +> +? +@ +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z +[ +\ +] +^ +_ +` +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +{ +| +} +~ +! +" +# +$ +% +& +' +( +) +* ++ +, +- +. +/ + diff --git a/ppocr/utils/utility.py b/ppocr/utils/utility.py index 29576d971486326aec3c93601656d7b982ef3336..7bb4c906d298af54ed56e2805f487a2c22d1894b 100755 --- a/ppocr/utils/utility.py +++ b/ppocr/utils/utility.py @@ -61,6 +61,7 @@ def get_image_file_list(img_file): imgs_lists.append(file_path) if len(imgs_lists) == 0: raise Exception("not found any img file in {}".format(img_file)) + imgs_lists = sorted(imgs_lists) return imgs_lists diff --git a/requirements.txt b/requirements.txt index 2401d52b48c10bad5ea5b244a0fd4c4365b94f09..1b01e690f77d2bf5e570c86b268c7128c8bf79fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ opencv-python==4.2.0.32 tqdm numpy visualdl -python-Levenshtein \ No newline at end of file +python-Levenshtein +opencv-contrib-python \ No newline at end of file diff --git a/setup.py b/setup.py index 58f6de48548d494a7fde8528130b8e881bc7620d..d491adb17e6251355c0190d0ddecb9a82b09bc2e 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ setup( package_dir={'paddleocr': ''}, include_package_data=True, entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]}, - version='2.0.2', + version='2.0.4', install_requires=requirements, license='Apache License 2.0', description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices', diff --git a/tools/eval.py b/tools/eval.py index 4afed469c875ef8d2200cdbfd89e5a8af4c6b7c3..9817fa75093dd5127e3d11501ebc0473c9b53365 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -59,10 +59,10 @@ def main(): eval_class = build_metric(config['Metric']) # start eval - metirc = program.eval(model, valid_dataloader, post_process_class, + metric = program.eval(model, valid_dataloader, post_process_class, eval_class, use_srn) logger.info('metric eval ***************') - for k, v in metirc.items(): + for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) diff --git a/tools/export_model.py b/tools/export_model.py index 1e9526e03d6b9001249d5891c37bee071c1f36a3..f587b2bb363e01ab4c0b2429fc95f243085649d1 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -31,14 +31,6 @@ from ppocr.utils.logging import get_logger from tools.program import load_config, merge_config, ArgsParser -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("-c", "--config", help="configuration file to use") - parser.add_argument( - "-o", "--output_path", type=str, default='./output/infer/') - return parser.parse_args() - - def main(): FLAGS = ArgsParser().parse_args() config = load_config(FLAGS.config) diff --git a/tools/infer/predict_e2e.py b/tools/infer/predict_e2e.py new file mode 100755 index 0000000000000000000000000000000000000000..8b94f24a9ffa16ca5683795afa6392fa23c24a94 --- /dev/null +++ b/tools/infer/predict_e2e.py @@ -0,0 +1,159 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import numpy as np +import time +import sys + +import tools.infer.utility as utility +from ppocr.utils.logging import get_logger +from ppocr.utils.utility import get_image_file_list, check_and_read_gif +from ppocr.data import create_operators, transform +from ppocr.postprocess import build_post_process + +logger = get_logger() + + +class TextE2E(object): + def __init__(self, args): + self.args = args + self.e2e_algorithm = args.e2e_algorithm + pre_process_list = [{ + 'E2EResizeForTest': {} + }, { + 'NormalizeImage': { + 'std': [0.229, 0.224, 0.225], + 'mean': [0.485, 0.456, 0.406], + 'scale': '1./255.', + 'order': 'hwc' + } + }, { + 'ToCHWImage': None + }, { + 'KeepKeys': { + 'keep_keys': ['image', 'shape'] + } + }] + postprocess_params = {} + if self.e2e_algorithm == "PGNet": + pre_process_list[0] = { + 'E2EResizeForTest': { + 'max_side_len': args.e2e_limit_side_len, + 'valid_set': 'totaltext' + } + } + postprocess_params['name'] = 'PGPostProcess' + postprocess_params["score_thresh"] = args.e2e_pgnet_score_thresh + postprocess_params["character_dict_path"] = args.e2e_char_dict_path + postprocess_params["valid_set"] = args.e2e_pgnet_valid_set + postprocess_params["mode"] = args.e2e_pgnet_mode + self.e2e_pgnet_polygon = args.e2e_pgnet_polygon + else: + logger.info("unknown e2e_algorithm:{}".format(self.e2e_algorithm)) + sys.exit(0) + + self.preprocess_op = create_operators(pre_process_list) + self.postprocess_op = build_post_process(postprocess_params) + self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor( + args, 'e2e', logger) # paddle.jit.load(args.det_model_dir) + # self.predictor.eval() + + def clip_det_res(self, points, img_height, img_width): + for pno in range(points.shape[0]): + points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) + points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) + return points + + def filter_tag_det_res_only_clip(self, dt_boxes, image_shape): + img_height, img_width = image_shape[0:2] + dt_boxes_new = [] + for box in dt_boxes: + box = self.clip_det_res(box, img_height, img_width) + dt_boxes_new.append(box) + dt_boxes = np.array(dt_boxes_new) + return dt_boxes + + def __call__(self, img): + + ori_im = img.copy() + data = {'image': img} + data = transform(data, self.preprocess_op) + img, shape_list = data + if img is None: + return None, 0 + img = np.expand_dims(img, axis=0) + shape_list = np.expand_dims(shape_list, axis=0) + img = img.copy() + starttime = time.time() + + self.input_tensor.copy_from_cpu(img) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + + preds = {} + if self.e2e_algorithm == 'PGNet': + preds['f_border'] = outputs[0] + preds['f_char'] = outputs[1] + preds['f_direction'] = outputs[2] + preds['f_score'] = outputs[3] + else: + raise NotImplementedError + post_result = self.postprocess_op(preds, shape_list) + points, strs = post_result['points'], post_result['strs'] + dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape) + elapse = time.time() - starttime + return dt_boxes, strs, elapse + + +if __name__ == "__main__": + args = utility.parse_args() + image_file_list = get_image_file_list(args.image_dir) + text_detector = TextE2E(args) + count = 0 + total_time = 0 + draw_img_save = "./inference_results" + if not os.path.exists(draw_img_save): + os.makedirs(draw_img_save) + for image_file in image_file_list: + img, flag = check_and_read_gif(image_file) + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + points, strs, elapse = text_detector(img) + if count > 0: + total_time += elapse + count += 1 + logger.info("Predict time of {}: {}".format(image_file, elapse)) + src_im = utility.draw_e2e_res(points, strs, image_file) + img_name_pure = os.path.split(image_file)[-1] + img_path = os.path.join(draw_img_save, + "e2e_res_{}".format(img_name_pure)) + cv2.imwrite(img_path, src_im) + logger.info("The visualized image saved in {}".format(img_path)) + if count > 1: + logger.info("Avg Time: {}".format(total_time / (count - 1))) diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 1cb6e01b087ff98efb0a57be3cc58a79425fea57..24388026b8f395427c93e285ed550446e3aa9b9c 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -41,6 +41,7 @@ class TextRecognizer(object): self.character_type = args.rec_char_type self.rec_batch_num = args.rec_batch_num self.rec_algorithm = args.rec_algorithm + self.max_text_length = args.max_text_length postprocess_params = { 'name': 'CTCLabelDecode', "character_type": args.rec_char_type, @@ -186,8 +187,9 @@ class TextRecognizer(object): norm_img = norm_img[np.newaxis, :] norm_img_batch.append(norm_img) else: - norm_img = self.process_image_srn( - img_list[indices[ino]], self.rec_image_shape, 8, 25) + norm_img = self.process_image_srn(img_list[indices[ino]], + self.rec_image_shape, 8, + self.max_text_length) encoder_word_pos_list = [] gsrm_word_pos_list = [] gsrm_slf_attn_bias1_list = [] diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index de7ee9d342063161f2e329c99d2428051c0ecf8c..ba81aff0a940fbee234e59e98f73c62fc7f69f09 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -13,6 +13,7 @@ # limitations under the License. import os import sys +import subprocess __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) @@ -141,6 +142,7 @@ def sorted_boxes(dt_boxes): def main(args): image_file_list = get_image_file_list(args.image_dir) + image_file_list = image_file_list[args.process_id::args.total_process_num] text_sys = TextSystem(args) is_visualize = True font_path = args.vis_font_path @@ -184,4 +186,18 @@ def main(args): if __name__ == "__main__": - main(utility.parse_args()) + args = utility.parse_args() + if args.use_mp: + p_list = [] + total_process_num = args.total_process_num + for process_id in range(total_process_num): + cmd = [sys.executable, "-u"] + sys.argv + [ + "--process_id={}".format(process_id), + "--use_mp={}".format(False) + ] + p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout) + p_list.append(p) + for p in p_list: + p.wait() + else: + main(args) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 911ca7fcb727d9f04002f367cda8928a2aefde90..9fa51d80b41ae80abe206ee779accd44d49cebba 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -74,6 +74,20 @@ def parse_args(): "--vis_font_path", type=str, default="./doc/fonts/simfang.ttf") parser.add_argument("--drop_score", type=float, default=0.5) + # params for e2e + parser.add_argument("--e2e_algorithm", type=str, default='PGNet') + parser.add_argument("--e2e_model_dir", type=str) + parser.add_argument("--e2e_limit_side_len", type=float, default=768) + parser.add_argument("--e2e_limit_type", type=str, default='max') + + # PGNet parmas + parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5) + parser.add_argument( + "--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt") + parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext') + parser.add_argument("--e2e_pgnet_polygon", type=bool, default=True) + parser.add_argument("--e2e_pgnet_mode", type=str, default='fast') + # params for text classifier parser.add_argument("--use_angle_cls", type=str2bool, default=False) parser.add_argument("--cls_model_dir", type=str) @@ -85,6 +99,10 @@ def parse_args(): parser.add_argument("--enable_mkldnn", type=str2bool, default=False) parser.add_argument("--use_pdserving", type=str2bool, default=False) + parser.add_argument("--use_mp", type=str2bool, default=False) + parser.add_argument("--total_process_num", type=int, default=1) + parser.add_argument("--process_id", type=int, default=0) + return parser.parse_args() @@ -93,8 +111,10 @@ def create_predictor(args, mode, logger): model_dir = args.det_model_dir elif mode == 'cls': model_dir = args.cls_model_dir - else: + elif mode == 'rec': model_dir = args.rec_model_dir + else: + model_dir = args.e2e_model_dir if model_dir is None: logger.info("not find {} model file path {}".format(mode, model_dir)) @@ -148,6 +168,22 @@ def create_predictor(args, mode, logger): return predictor, input_tensor, output_tensors +def draw_e2e_res(dt_boxes, strs, img_path): + src_im = cv2.imread(img_path) + for box, str in zip(dt_boxes, strs): + box = box.astype(np.int32).reshape((-1, 1, 2)) + cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) + cv2.putText( + src_im, + str, + org=(int(box[0, 0, 0]), int(box[0, 0, 1])), + fontFace=cv2.FONT_HERSHEY_COMPLEX, + fontScale=0.7, + color=(0, 255, 0), + thickness=1) + return src_im + + def draw_text_det_res(dt_boxes, img_path): src_im = cv2.imread(img_path) for box in dt_boxes: diff --git a/tools/infer_e2e.py b/tools/infer_e2e.py new file mode 100755 index 0000000000000000000000000000000000000000..b7503adb94eb797d4fb12cf47b377fa72d02158b --- /dev/null +++ b/tools/infer_e2e.py @@ -0,0 +1,122 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import json +import paddle + +from ppocr.data import create_operators, transform +from ppocr.modeling.architectures import build_model +from ppocr.postprocess import build_post_process +from ppocr.utils.save_load import init_model +from ppocr.utils.utility import get_image_file_list +import tools.program as program + + +def draw_e2e_res(dt_boxes, strs, config, img, img_name): + if len(dt_boxes) > 0: + src_im = img + for box, str in zip(dt_boxes, strs): + box = box.astype(np.int32).reshape((-1, 1, 2)) + cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) + cv2.putText( + src_im, + str, + org=(int(box[0, 0, 0]), int(box[0, 0, 1])), + fontFace=cv2.FONT_HERSHEY_COMPLEX, + fontScale=0.7, + color=(0, 255, 0), + thickness=1) + save_det_path = os.path.dirname(config['Global'][ + 'save_res_path']) + "/e2e_results/" + if not os.path.exists(save_det_path): + os.makedirs(save_det_path) + save_path = os.path.join(save_det_path, os.path.basename(img_name)) + cv2.imwrite(save_path, src_im) + logger.info("The e2e Image saved in {}".format(save_path)) + + +def main(): + global_config = config['Global'] + + # build model + model = build_model(config['Architecture']) + + init_model(config, model, logger) + + # build post process + post_process_class = build_post_process(config['PostProcess'], + global_config) + + # create data ops + transforms = [] + for op in config['Eval']['dataset']['transforms']: + op_name = list(op)[0] + if 'Label' in op_name: + continue + elif op_name == 'KeepKeys': + op[op_name]['keep_keys'] = ['image', 'shape'] + transforms.append(op) + + ops = create_operators(transforms, global_config) + + save_res_path = config['Global']['save_res_path'] + if not os.path.exists(os.path.dirname(save_res_path)): + os.makedirs(os.path.dirname(save_res_path)) + + model.eval() + with open(save_res_path, "wb") as fout: + for file in get_image_file_list(config['Global']['infer_img']): + logger.info("infer_img: {}".format(file)) + with open(file, 'rb') as f: + img = f.read() + data = {'image': img} + batch = transform(data, ops) + images = np.expand_dims(batch[0], axis=0) + shape_list = np.expand_dims(batch[1], axis=0) + images = paddle.to_tensor(images) + preds = model(images) + post_result = post_process_class(preds, shape_list) + points, strs = post_result['points'], post_result['strs'] + # write resule + dt_boxes_json = [] + for poly, str in zip(points, strs): + tmp_json = {"transcription": str} + tmp_json['points'] = poly.tolist() + dt_boxes_json.append(tmp_json) + otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n" + fout.write(otstr.encode()) + src_img = cv2.imread(file) + draw_e2e_res(points, strs, config, src_img, file) + logger.info("success!") + + +if __name__ == '__main__': + config, device, logger, vdl_writer = program.preprocess() + main() diff --git a/tools/program.py b/tools/program.py index ae6491768ceff0379c917c45fd29b30514eba9c1..c22bf18b991a8aed6d47a1ea242aa3b7bb02aacc 100755 --- a/tools/program.py +++ b/tools/program.py @@ -237,8 +237,9 @@ def train(config, vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step) vdl_writer.add_scalar('TRAIN/lr', lr, global_step) - if dist.get_rank( - ) == 0 and global_step > 0 and global_step % print_batch_step == 0: + if dist.get_rank() == 0 and ( + (global_step > 0 and global_step % print_batch_step == 0) or + (idx >= len(train_dataloader) - 1)): logs = train_stats.log() strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format( epoch, epoch_num, global_step, logs, train_reader_cost / @@ -374,7 +375,8 @@ def preprocess(is_train=False): alg = config['Architecture']['algorithm'] assert alg in [ - 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS' + 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', + 'CLS', 'PGNet' ] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' diff --git a/tools/train.py b/tools/train.py index c12cf005638ad8be8f8b8fce42e682fed2f29be7..47358ca43da46b7eb6a04cd1f7fe284efd7e96f7 100755 --- a/tools/train.py +++ b/tools/train.py @@ -52,7 +52,10 @@ def main(config, device, logger, vdl_writer): train_dataloader = build_dataloader(config, 'Train', device, logger) if len(train_dataloader) == 0: logger.error( - 'No Images in train dataset, please check annotation file and path in the configuration file' + "No Images in train dataset, please ensure\n" + + "\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n" + + + "\t2. The annotation file and path in the configuration file are provided normally." ) return