diff --git a/README.md b/README.md index 3243f3ce24fada59a0b6f509172b3277e080f7aa..4ebbf2f0067aa6faff3304c97b12afa7274ca554 100644 --- a/README.md +++ b/README.md @@ -189,7 +189,7 @@ PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训 请扫描下面二维码,完成问卷填写,获取加群二维码和OCR方向的炼丹秘籍
- +
diff --git a/README_en.md b/README_en.md index 37250da2cd3f6ccee76b522bf10745ecb8cd649e..c0f17b57710b68e8f33573f116a120607fa8847c 100644 --- a/README_en.md +++ b/README_en.md @@ -56,7 +56,6 @@ Mobile DEMO experience (based on EasyEdge and Paddle-Lite, supports iOS and Andr - Algorithm introduction - [Text Detection Algorithm](#TEXTDETECTIONALGORITHM) - [Text Recognition Algorithm](#TEXTRECOGNITIONALGORITHM) - - [END-TO-END OCR Algorithm](#ENDENDOCRALGORITHM) - Model training/evaluation - [Text Detection](./doc/doc_en/detection_en.md) - [Text Recognition](./doc/doc_en/recognition_en.md) @@ -158,10 +157,6 @@ We use [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/ Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./doc/doc_en/recognition_en.md) - -## END-TO-END OCR Algorithm -- [ ] [End2End-PSL](https://arxiv.org/abs/1909.07808)(Baidu Self-Research, coming soon) - ## Visualization @@ -211,7 +206,7 @@ Please refer to the document for training guide and use of PaddleOCR text recogn Scan the QR code below with your wechat and completing the questionnaire, you can access to offical technical exchange group.
- +
diff --git a/deploy/cpp_infer/src/ocr_det.cpp b/deploy/cpp_infer/src/ocr_det.cpp index 56fbace8cc6fa27f8172bed248573f15d0c98dac..bf94abce236853410c15434d494058be03a62a81 100644 --- a/deploy/cpp_infer/src/ocr_det.cpp +++ b/deploy/cpp_infer/src/ocr_det.cpp @@ -26,6 +26,8 @@ void DBDetector::LoadModel(const std::string &model_dir) { config.DisableGpu(); if (this->use_mkldnn_) { config.EnableMKLDNN(); + // cache 10 different shapes for mkldnn to avoid memory leak + config.SetMkldnnCacheCapacity(10); } config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_); } diff --git a/deploy/cpp_infer/src/ocr_rec.cpp b/deploy/cpp_infer/src/ocr_rec.cpp index a3486db46f6eb6ad0df49619744924e6ef70dd01..b997d8291a64f9b6042bce648bcd358e34d55a95 100644 --- a/deploy/cpp_infer/src/ocr_rec.cpp +++ b/deploy/cpp_infer/src/ocr_rec.cpp @@ -126,6 +126,8 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { config.DisableGpu(); if (this->use_mkldnn_) { config.EnableMKLDNN(); + // cache 10 different shapes for mkldnn to avoid memory leak + config.SetMkldnnCacheCapacity(10); } config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_); } diff --git a/deploy/cpp_infer/tools/config.txt b/deploy/cpp_infer/tools/config.txt index 40beea3a2e6f0260a42202d6411ffb10907bf871..6c53f29eeb310677815d106d3e0ae39fb03bc2e2 100644 --- a/deploy/cpp_infer/tools/config.txt +++ b/deploy/cpp_infer/tools/config.txt @@ -3,7 +3,7 @@ use_gpu 0 gpu_id 0 gpu_mem 4000 cpu_math_library_num_threads 10 -use_mkldnn 0 +use_mkldnn 1 use_zero_copy_run 1 # det config diff --git a/docker/hubserving/README.md b/deploy/docker/hubserving/README.md similarity index 99% rename from docker/hubserving/README.md rename to deploy/docker/hubserving/README.md index 71e2377dcc4f7524384752b95c53f02471353f34..62381073d4c7448f9a238ca4dda4b294ce864f7a 100644 --- a/docker/hubserving/README.md +++ b/deploy/docker/hubserving/README.md @@ -20,7 +20,7 @@ git clone https://github.com/PaddlePaddle/PaddleOCR.git ``` b. Goto Dockerfile directory(ps:Need to distinguish between cpu and gpu version, the following takes cpu as an example, gpu version needs to replace the keyword) ``` -cd docker/cpu +cd deploy/docker/cpu ``` c. Build image ``` diff --git a/docker/hubserving/README_cn.md b/deploy/docker/hubserving/README_cn.md similarity index 99% rename from docker/hubserving/README_cn.md rename to deploy/docker/hubserving/README_cn.md index 9b9e5f50f5b22f3a2125a656112a20542010ac68..f117a0ab4186fea0cb94881c65b2b353bee37ff7 100644 --- a/docker/hubserving/README_cn.md +++ b/deploy/docker/hubserving/README_cn.md @@ -20,7 +20,7 @@ git clone https://github.com/PaddlePaddle/PaddleOCR.git ``` b.切换至Dockerfile目录(注:需要区分cpu或gpu版本,下文以cpu为例,gpu版本需要替换一下关键字即可) ``` -cd docker/cpu +cd deploy/docker/cpu ``` c.生成镜像 ``` diff --git a/docker/hubserving/cpu/Dockerfile b/deploy/docker/hubserving/cpu/Dockerfile similarity index 100% rename from docker/hubserving/cpu/Dockerfile rename to deploy/docker/hubserving/cpu/Dockerfile diff --git a/docker/hubserving/gpu/Dockerfile b/deploy/docker/hubserving/gpu/Dockerfile similarity index 100% rename from docker/hubserving/gpu/Dockerfile rename to deploy/docker/hubserving/gpu/Dockerfile diff --git a/docker/hubserving/sample_request.txt b/deploy/docker/hubserving/sample_request.txt similarity index 100% rename from docker/hubserving/sample_request.txt rename to deploy/docker/hubserving/sample_request.txt diff --git a/deploy/slim/quantization/README.md b/deploy/slim/quantization/README.md new file mode 100755 index 0000000000000000000000000000000000000000..f2e92f54a5b456b25445282a38fe30e01fe4fd49 --- /dev/null +++ b/deploy/slim/quantization/README.md @@ -0,0 +1,34 @@ +> 运行示例前请先安装1.2.0或更高版本PaddleSlim + +# 模型量化压缩教程 + +## 概述 + +该示例使用PaddleSlim提供的[量化压缩API](https://paddlepaddle.github.io/PaddleSlim/api/quantization_api/)对OCR模型进行压缩。 +在阅读该示例前,建议您先了解以下内容: + +- [OCR模型的常规训练方法](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/detection.md) +- [PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/) + +## 安装PaddleSlim +可按照[PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/)中的步骤安装PaddleSlim。 + + + +## 量化训练 + +进入PaddleOCR根目录,通过以下命令对模型进行量化: + +```bash +python deploy/slim/quantization/quant.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=det_mv3_db/best_accuracy Global.save_model_dir=./output/quant_model +``` + + + +## 评估并导出 + +在得到量化训练保存的模型后,我们可以将其导出为inference_model,用于预测部署: + +```bash +python deploy/slim/quantization/export_model.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=output/quant_model/best_accuracy Global.save_model_dir=./output/quant_model +``` diff --git a/deploy/slim/quantization/export_model.py b/deploy/slim/quantization/export_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d0d08b300066044d3088f669045e0536006c3140 --- /dev/null +++ b/deploy/slim/quantization/export_model.py @@ -0,0 +1,129 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +__dir__ = os.path.dirname(__file__) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..', '..', '..'))) +sys.path.append( + os.path.abspath(os.path.join(__dir__, '..', '..', '..', 'tools'))) + + +def set_paddle_flags(**kwargs): + for key, value in kwargs.items(): + if os.environ.get(key, None) is None: + os.environ[key] = str(value) + + +# NOTE(paddle-dev): All of these flags should be +# set before `import paddle`. Otherwise, it would +# not take any effect. +set_paddle_flags( + FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory +) + +import program +from paddle import fluid +from ppocr.utils.utility import initial_logger +logger = initial_logger() +from ppocr.utils.save_load import init_model, load_params +from ppocr.utils.character import CharacterOps +from ppocr.utils.utility import create_module +from ppocr.data.reader_main import reader_main + +from paddleslim.quant import quant_aware, convert +from paddle.fluid.layer_helper import LayerHelper +from eval_utils.eval_det_utils import eval_det_run +from eval_utils.eval_rec_utils import eval_rec_run + + +def main(): + # 1. quantization configs + quant_config = { + # weight quantize type, default is 'channel_wise_abs_max' + 'weight_quantize_type': 'channel_wise_abs_max', + # activation quantize type, default is 'moving_average_abs_max' + 'activation_quantize_type': 'moving_average_abs_max', + # weight quantize bit num, default is 8 + 'weight_bits': 8, + # activation quantize bit num, default is 8 + 'activation_bits': 8, + # ops of name_scope in not_quant_pattern list, will not be quantized + 'not_quant_pattern': ['skip_quant'], + # ops of type in quantize_op_types, will be quantized + 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], + # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' + 'dtype': 'int8', + # window size for 'range_abs_max' quantization. defaulf is 10000 + 'window_size': 10000, + # The decay coefficient of moving average, default is 0.9 + 'moving_rate': 0.9, + } + + startup_prog, eval_program, place, config, alg_type = program.preprocess() + + feeded_var_names, target_vars, fetches_var_name = program.build_export( + config, eval_program, startup_prog) + + eval_program = eval_program.clone(for_test=True) + exe = fluid.Executor(place) + exe.run(startup_prog) + + eval_program = quant_aware( + eval_program, place, quant_config, scope=None, for_test=True) + + init_model(config, eval_program, exe) + + # 2. Convert the program before save inference program + # The dtype of eval_program's weights is float32, but in int8 range. + + eval_program = convert(eval_program, place, quant_config, scope=None) + + eval_fetch_name_list = fetches_var_name + eval_fetch_varname_list = [v.name for v in target_vars] + eval_reader = reader_main(config=config, mode="eval") + quant_info_dict = {'program':eval_program,\ + 'reader':eval_reader,\ + 'fetch_name_list':eval_fetch_name_list,\ + 'fetch_varname_list':eval_fetch_varname_list} + + if alg_type == 'det': + final_metrics = eval_det_run(exe, config, quant_info_dict, "eval") + else: + final_metrics = eval_rec_run(exe, config, quant_info_dict, "eval") + print(final_metrics) + + # 3. Save inference model + model_path = "./quant_model" + if not os.path.isdir(model_path): + os.makedirs(model_path) + + fluid.io.save_inference_model( + dirname=model_path, + feeded_var_names=feeded_var_names, + target_vars=target_vars, + executor=exe, + main_program=eval_program, + model_filename=model_path + '/model', + params_filename=model_path + '/params') + print("model saved as {}".format(model_path)) + + +if __name__ == '__main__': + main() diff --git a/deploy/slim/quantization/quant.py b/deploy/slim/quantization/quant.py new file mode 100755 index 0000000000000000000000000000000000000000..f54a328d4645910dcb24bd4d597e8d5f8867312a --- /dev/null +++ b/deploy/slim/quantization/quant.py @@ -0,0 +1,188 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..', '..', '..'))) +sys.path.append( + os.path.abspath(os.path.join(__dir__, '..', '..', '..', 'tools'))) + + +def set_paddle_flags(**kwargs): + for key, value in kwargs.items(): + if os.environ.get(key, None) is None: + os.environ[key] = str(value) + + +# NOTE(paddle-dev): All of these flags should be +# set before `import paddle`. Otherwise, it would +# not take any effect. +set_paddle_flags( + FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory +) + +import tools.program as program +from paddle import fluid +from ppocr.utils.utility import initial_logger +logger = initial_logger() +from ppocr.data.reader_main import reader_main +from ppocr.utils.save_load import init_model +from paddle.fluid.contrib.model_stat import summary + +# quant dependencies +import paddle +import paddle.fluid as fluid +from paddleslim.quant import quant_aware, convert +from paddle.fluid.layer_helper import LayerHelper + + +def pact(x): + """ + Process a variable using the pact method you define + Args: + x(Tensor): Paddle Tensor, need to be preprocess before quantization + Returns: + The processed Tensor x. + """ + helper = LayerHelper("pact", **locals()) + dtype = 'float32' + init_thres = 20 + u_param_attr = fluid.ParamAttr( + name=x.name + '_pact', + initializer=fluid.initializer.ConstantInitializer(value=init_thres), + regularizer=fluid.regularizer.L2Decay(0.0001), + learning_rate=1) + u_param = helper.create_parameter(attr=u_param_attr, shape=[1], dtype=dtype) + x = fluid.layers.elementwise_sub( + x, fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param))) + x = fluid.layers.elementwise_add( + x, fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x))) + return x + + +def get_optimizer(): + """ + Build a program using a model and an optimizer + """ + return fluid.optimizer.AdamOptimizer(0.001) + + +def main(): + train_build_outputs = program.build( + config, train_program, startup_program, mode='train') + train_loader = train_build_outputs[0] + train_fetch_name_list = train_build_outputs[1] + train_fetch_varname_list = train_build_outputs[2] + train_opt_loss_name = train_build_outputs[3] + model_average = train_build_outputs[-1] + + eval_program = fluid.Program() + eval_build_outputs = program.build( + config, eval_program, startup_program, mode='eval') + eval_fetch_name_list = eval_build_outputs[1] + eval_fetch_varname_list = eval_build_outputs[2] + eval_program = eval_program.clone(for_test=True) + + train_reader = reader_main(config=config, mode="train") + train_loader.set_sample_list_generator(train_reader, places=place) + + eval_reader = reader_main(config=config, mode="eval") + + exe = fluid.Executor(place) + exe.run(startup_program) + + # 1. quantization configs + quant_config = { + # weight quantize type, default is 'channel_wise_abs_max' + 'weight_quantize_type': 'channel_wise_abs_max', + # activation quantize type, default is 'moving_average_abs_max' + 'activation_quantize_type': 'moving_average_abs_max', + # weight quantize bit num, default is 8 + 'weight_bits': 8, + # activation quantize bit num, default is 8 + 'activation_bits': 8, + # ops of name_scope in not_quant_pattern list, will not be quantized + 'not_quant_pattern': ['skip_quant'], + # ops of type in quantize_op_types, will be quantized + 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], + # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' + 'dtype': 'int8', + # window size for 'range_abs_max' quantization. defaulf is 10000 + 'window_size': 10000, + # The decay coefficient of moving average, default is 0.9 + 'moving_rate': 0.9, + } + + # 2. quantization transform programs (training aware) + # Make some quantization transforms in the graph before training and testing. + # According to the weight and activation quantization type, the graph will be added + # some fake quantize operators and fake dequantize operators. + act_preprocess_func = pact + optimizer_func = get_optimizer + executor = exe + + eval_program = quant_aware( + eval_program, + place, + quant_config, + scope=None, + act_preprocess_func=act_preprocess_func, + optimizer_func=optimizer_func, + executor=executor, + for_test=True) + quant_train_program = quant_aware( + train_program, + place, + quant_config, + scope=None, + act_preprocess_func=act_preprocess_func, + optimizer_func=optimizer_func, + executor=executor, + for_test=False, + return_program=True) + + # compile program for multi-devices + train_compile_program = program.create_multi_devices_program( + quant_train_program, train_opt_loss_name, for_quant=True) + + init_model(config, quant_train_program, exe) + + train_info_dict = {'compile_program':train_compile_program,\ + 'train_program':quant_train_program,\ + 'reader':train_loader,\ + 'fetch_name_list':train_fetch_name_list,\ + 'fetch_varname_list':train_fetch_varname_list,\ + 'model_average': model_average} + + eval_info_dict = {'program':eval_program,\ + 'reader':eval_reader,\ + 'fetch_name_list':eval_fetch_name_list,\ + 'fetch_varname_list':eval_fetch_varname_list} + + if train_alg_type == 'det': + program.train_eval_det_run(config, exe, train_info_dict, eval_info_dict) + else: + program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict) + + +if __name__ == '__main__': + startup_program, train_program, place, config, train_alg_type = program.preprocess( + ) + main() diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md index c554b9f11c96744ae928aaf9992606a364680557..1dc52efa8e6f65ef74c8e138f4f388027fe33f28 100644 --- a/doc/doc_ch/recognition.md +++ b/doc/doc_ch/recognition.md @@ -140,7 +140,7 @@ PaddleOCR提供了多种数据增强方式,如果您希望在训练时加入 训练过程中每种扰动方式以50%的概率被选择,具体代码实现请参考:[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py) -*由于OpenCV的兼容性问题,扰动操作暂时只支持GPU* +*由于OpenCV的兼容性问题,扰动操作暂时只支持Linux* - 训练 diff --git a/doc/doc_ch/serving.md b/doc/doc_ch/serving.md index 892745671e639ccd19bec2bc4c789d48d43dfad9..99fe3006fde8762930ef9a168da81cce9069f8e0 100644 --- a/doc/doc_ch/serving.md +++ b/doc/doc_ch/serving.md @@ -61,6 +61,14 @@ hub install deploy\hubserving\ocr_rec\ hub install deploy\hubserving\ocr_system\ ``` +#### 安装模型 +安装服务模块前,需要将训练好的模型放到对应的文件夹内。默认使用的是: +./inference/ch_det_mv3_db/ +和 +./inference/ch_rec_mv3_crnn/ +这两个模型可以在https://github.com/PaddlePaddle/PaddleOCR 下载 +可以在./deploy/hubserving/ocr_system/params.py 里面修改成自己的模型 + ### 3. 启动服务 #### 方式1. 命令行命令启动(仅支持CPU) **启动命令:** diff --git a/doc/joinus.PNG b/doc/joinus.PNG new file mode 100644 index 0000000000000000000000000000000000000000..fa11f286d7d2d56d18d94e9034c3be77c974d42f Binary files /dev/null and b/doc/joinus.PNG differ diff --git a/doc/joinus.jpg b/doc/joinus.jpg deleted file mode 100644 index 6a287f3145c1910a2e25db35a94f5cbb14380b9d..0000000000000000000000000000000000000000 Binary files a/doc/joinus.jpg and /dev/null differ diff --git a/docker/hubserving/readme.md b/docker/hubserving/readme.md deleted file mode 100644 index 71e2377dcc4f7524384752b95c53f02471353f34..0000000000000000000000000000000000000000 --- a/docker/hubserving/readme.md +++ /dev/null @@ -1,58 +0,0 @@ -English | [简体中文](README_cn.md) - -## Introduction -Many user hopes package the PaddleOCR service into an docker image, so that it can be quickly released and used in the docker or k8s environment. - -This page provide some standardized code to achieve this goal. You can quickly publish the PaddleOCR project into a callable Restful API service through the following steps. (At present, the deployment based on the HubServing mode is implemented first, and author plans to increase the deployment of the PaddleServing mode in the futrue) - -## 1. Prerequisites - -You need to install the following basic components first: -a. Docker -b. Graphics driver and CUDA 10.0+(GPU) -c. NVIDIA Container Toolkit(GPU,Docker 19.03+ can skip this) -d. cuDNN 7.6+(GPU) - -## 2. Build Image -a. Download PaddleOCR sourcecode -``` -git clone https://github.com/PaddlePaddle/PaddleOCR.git -``` -b. Goto Dockerfile directory(ps:Need to distinguish between cpu and gpu version, the following takes cpu as an example, gpu version needs to replace the keyword) -``` -cd docker/cpu -``` -c. Build image -``` -docker build -t paddleocr:cpu . -``` - -## 3. Start container -a. CPU version -``` -sudo docker run -dp 8866:8866 --name paddle_ocr paddleocr:cpu -``` -b. GPU version (base on NVIDIA Container Toolkit) -``` -sudo nvidia-docker run -dp 8866:8866 --name paddle_ocr paddleocr:gpu -``` -c. GPU version (Docker 19.03++) -``` -sudo docker run -dp 8866:8866 --gpus all --name paddle_ocr paddleocr:gpu -``` -d. Check service status(If you can see the following statement then it means completed:Successfully installed ocr_system && Running on http://0.0.0.0:8866/) -``` -docker logs -f paddle_ocr -``` - -## 4. Test -a. Calculate the Base64 encoding of the picture to be recognized (if you just test, you can use a free online tool, like:https://freeonlinetools24.com/base64-image/) -b. Post a service request(sample request in sample_request.txt) - -``` -curl -H "Content-Type:application/json" -X POST --data "{\"images\": [\"Input image Base64 encode(need to delete the code 'data:image/jpg;base64,')\"]}" http://localhost:8866/predict/ocr_system -``` -c. Get resposne(If the call is successful, the following result will be returned) -``` -{"msg":"","results":[[{"confidence":0.8403433561325073,"text":"约定","text_region":[[345,377],[641,390],[634,540],[339,528]]},{"confidence":0.8131805658340454,"text":"最终相遇","text_region":[[356,532],[624,530],[624,596],[356,598]]}]],"status":"0"} -``` diff --git a/ppocr/data/det/random_crop_data.py b/ppocr/data/det/random_crop_data.py index d0c081e785cb17282b5486c718446b97a580b6cc..3e8629092e9a74a0764a0c04de829a02d00b6844 100644 --- a/ppocr/data/det/random_crop_data.py +++ b/ppocr/data/det/random_crop_data.py @@ -1,4 +1,4 @@ -# -*- coding:utf-8 -*- +# -*- coding:utf-8 -*- from __future__ import absolute_import from __future__ import division @@ -121,24 +121,22 @@ def RandomCropData(data, size): all_care_polys = [ text_polys[i] for i, tag in enumerate(ignore_tags) if not tag ] - # 计算crop区域 crop_x, crop_y, crop_w, crop_h = crop_area(im, all_care_polys, min_crop_side_ratio, max_tries) - # crop 图片 保持比例填充 - scale_w = size[0] / crop_w - scale_h = size[1] / crop_h + dh, dw = size + scale_w = dw / crop_w + scale_h = dh / crop_h scale = min(scale_w, scale_h) h = int(crop_h * scale) w = int(crop_w * scale) if keep_ratio: - padimg = np.zeros((size[1], size[0], im.shape[2]), im.dtype) + padimg = np.zeros((dh, dw, im.shape[2]), im.dtype) padimg[:h, :w] = cv2.resize( im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) img = padimg else: img = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], - tuple(size)) - # crop 文本框 + (dw, dh)) text_polys_crop = [] ignore_tags_crop = [] texts_crop = [] diff --git a/ppocr/modeling/architectures/det_model.py b/ppocr/modeling/architectures/det_model.py index 54d3a479f40a3f9f6ebb9e6ab739ae7a44796a2e..e4c32b8eba056f8f6483e9ed2170a1650023fb0a 100644 --- a/ppocr/modeling/architectures/det_model.py +++ b/ppocr/modeling/architectures/det_model.py @@ -67,6 +67,7 @@ class DetModel(object): image = fluid.layers.data( name='image', shape=image_shape, dtype='float32') + image.stop_gradient = False if mode == "train": if self.algorithm == "EAST": h, w = int(image_shape[1] // 4), int(image_shape[2] // 4) @@ -108,7 +109,10 @@ class DetModel(object): name='tvo', shape=[9, 128, 128], dtype='float32') input_tco = fluid.layers.data( name='tco', shape=[3, 128, 128], dtype='float32') - feed_list = [image, input_score, input_border, input_mask, input_tvo, input_tco] + feed_list = [ + image, input_score, input_border, input_mask, input_tvo, + input_tco + ] labels = {'input_score': input_score,\ 'input_border': input_border,\ 'input_mask': input_mask,\ diff --git a/ppocr/modeling/architectures/rec_model.py b/ppocr/modeling/architectures/rec_model.py index fe2d4c16dce3882980fe2238ecc16c7c08a89792..261462044a9000561517c3657f5b5a6090fd107a 100755 --- a/ppocr/modeling/architectures/rec_model.py +++ b/ppocr/modeling/architectures/rec_model.py @@ -68,6 +68,7 @@ class RecModel(object): image_shape.insert(0, -1) if mode == "train": image = fluid.data(name='image', shape=image_shape, dtype='float32') + image.stop_gradient = False if self.loss_type == "attention": label_in = fluid.data( name='label_in', @@ -136,7 +137,7 @@ class RecModel(object): else: labels = None loader = None - if self.char_type == "ch" and self.infer_img: + if self.char_type == "ch" and self.infer_img and self.loss_type != "srn": image_shape[-1] = -1 if self.tps != None: logger.info( @@ -146,6 +147,7 @@ class RecModel(object): ) image_shape = deepcopy(self.image_shape) image = fluid.data(name='image', shape=image_shape, dtype='float32') + image.stop_gradient = False if self.loss_type == "srn": encoder_word_pos = fluid.data( name="encoder_word_pos", @@ -172,16 +174,13 @@ class RecModel(object): self.max_text_length ], dtype="float32") - feed_list = [ - image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, - gsrm_slf_attn_bias2 - ] labels = { 'encoder_word_pos': encoder_word_pos, 'gsrm_word_pos': gsrm_word_pos, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1, 'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2 } + return image, labels, loader def __call__(self, mode): @@ -218,8 +217,13 @@ class RecModel(object): if self.loss_type == "ctc": predict = fluid.layers.softmax(predict) if self.loss_type == "srn": - raise Exception( - "Warning! SRN does not support export model currently") + return [ + image, labels, { + 'decoded_out': decoded_out, + 'predicts': predict + } + ] + return [image, {'decoded_out': decoded_out, 'predicts': predict}] else: predict = predicts['predict'] diff --git a/ppocr/modeling/heads/rec_ctc_head.py b/ppocr/modeling/heads/rec_ctc_head.py index 6b8635e4647f186390179b880e132641342df0d6..84948c2b20933d0f2086a42442a420d1b6b1eeee 100755 --- a/ppocr/modeling/heads/rec_ctc_head.py +++ b/ppocr/modeling/heads/rec_ctc_head.py @@ -35,12 +35,13 @@ class CTCPredict(object): self.fc_decay = params.get("fc_decay", 0.0004) def __call__(self, inputs, labels=None, mode=None): - encoder_features = self.encoder(inputs) - if self.encoder_type != "reshape": - encoder_features = fluid.layers.concat(encoder_features, axis=1) - name = "ctc_fc" - para_attr, bias_attr = get_para_bias_attr( - l2_decay=self.fc_decay, k=encoder_features.shape[1], name=name) + with fluid.scope_guard("skip_quant"): + encoder_features = self.encoder(inputs) + if self.encoder_type != "reshape": + encoder_features = fluid.layers.concat(encoder_features, axis=1) + name = "ctc_fc" + para_attr, bias_attr = get_para_bias_attr( + l2_decay=self.fc_decay, k=encoder_features.shape[1], name=name) predict = fluid.layers.fc(input=encoder_features, size=self.char_num + 1, param_attr=para_attr, diff --git a/ppocr/utils/utility.py b/ppocr/utils/utility.py index e27dd1d8738a25c6a6669b99ad2b6eed4a9f25d0..2cf3c8f5c9ebba07ee1c21fe2248fe3f600126d9 100755 --- a/ppocr/utils/utility.py +++ b/ppocr/utils/utility.py @@ -90,15 +90,3 @@ def check_and_read_gif(img_path): return imgvalue, True return None, False - -def create_multi_devices_program(program, loss_var_name): - build_strategy = fluid.BuildStrategy() - build_strategy.memory_optimize = False - build_strategy.enable_inplace = True - exec_strategy = fluid.ExecutionStrategy() - exec_strategy.num_iteration_per_drop_scope = 1 - compile_program = fluid.CompiledProgram(program).with_data_parallel( - loss_name=loss_var_name, - build_strategy=build_strategy, - exec_strategy=exec_strategy) - return compile_program diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 6a379853a4a7d62cbffcbebbf09e2fb3e2207b27..06273e9f9e5b42a9ecc829c435662e9aabcdd224 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -40,6 +40,7 @@ class TextRecognizer(object): self.character_type = args.rec_char_type self.rec_batch_num = args.rec_batch_num self.rec_algorithm = args.rec_algorithm + self.text_len = args.max_text_length self.use_zero_copy_run = args.use_zero_copy_run char_ops_params = { "character_type": args.rec_char_type, @@ -47,12 +48,15 @@ class TextRecognizer(object): "use_space_char": args.use_space_char, "max_text_length": args.max_text_length } - if self.rec_algorithm != "RARE": + if self.rec_algorithm in ["CRNN", "Rosetta", "STAR-Net"]: char_ops_params['loss_type'] = 'ctc' self.loss_type = 'ctc' - else: + elif self.rec_algorithm == "RARE": char_ops_params['loss_type'] = 'attention' self.loss_type = 'attention' + elif self.rec_algorithm == "SRN": + char_ops_params['loss_type'] = 'srn' + self.loss_type = 'srn' self.char_ops = CharacterOps(char_ops_params) def resize_norm_img(self, img, max_wh_ratio): @@ -75,6 +79,83 @@ class TextRecognizer(object): padding_im[:, :, 0:resized_w] = resized_image return padding_im + def resize_norm_img_srn(self, img, image_shape): + imgC, imgH, imgW = image_shape + + img_black = np.zeros((imgH, imgW)) + im_hei = img.shape[0] + im_wid = img.shape[1] + + if im_wid <= im_hei * 1: + img_new = cv2.resize(img, (imgH * 1, imgH)) + elif im_wid <= im_hei * 2: + img_new = cv2.resize(img, (imgH * 2, imgH)) + elif im_wid <= im_hei * 3: + img_new = cv2.resize(img, (imgH * 3, imgH)) + else: + img_new = cv2.resize(img, (imgW, imgH)) + + img_np = np.asarray(img_new) + img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) + img_black[:, 0:img_np.shape[1]] = img_np + img_black = img_black[:, :, np.newaxis] + + row, col, c = img_black.shape + c = 1 + + return np.reshape(img_black, (c, row, col)).astype(np.float32) + + def srn_other_inputs(self, image_shape, num_heads, max_text_length, + char_num): + + imgC, imgH, imgW = image_shape + feature_dim = int((imgH / 8) * (imgW / 8)) + + encoder_word_pos = np.array(range(0, feature_dim)).reshape( + (feature_dim, 1)).astype('int64') + gsrm_word_pos = np.array(range(0, max_text_length)).reshape( + (max_text_length, 1)).astype('int64') + + gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) + gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape( + [-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias1 = np.tile( + gsrm_slf_attn_bias1, + [1, num_heads, 1, 1]).astype('float32') * [-1e9] + + gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape( + [-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias2 = np.tile( + gsrm_slf_attn_bias2, + [1, num_heads, 1, 1]).astype('float32') * [-1e9] + + encoder_word_pos = encoder_word_pos[np.newaxis, :] + gsrm_word_pos = gsrm_word_pos[np.newaxis, :] + + return [ + encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2 + ] + + def process_image_srn(self, + img, + image_shape, + num_heads, + max_text_length, + char_ops=None): + norm_img = self.resize_norm_img_srn(img, image_shape) + norm_img = norm_img[np.newaxis, :] + char_num = char_ops.get_char_num() + + [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \ + self.srn_other_inputs(image_shape, num_heads, max_text_length, char_num) + + gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32) + gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32) + + return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2) + def __call__(self, img_list): img_num = len(img_list) # Calculate the aspect ratio of all text bars @@ -84,7 +165,7 @@ class TextRecognizer(object): # Sorting can speed up the recognition process indices = np.argsort(np.array(width_list)) - # rec_res = [] + #rec_res = [] rec_res = [['', 0.0]] * img_num batch_num = self.rec_batch_num predict_time = 0 @@ -98,20 +179,62 @@ class TextRecognizer(object): wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) for ino in range(beg_img_no, end_img_no): - # norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio) - norm_img = self.resize_norm_img(img_list[indices[ino]], - max_wh_ratio) - norm_img = norm_img[np.newaxis, :] - norm_img_batch.append(norm_img) - norm_img_batch = np.concatenate(norm_img_batch) + if self.loss_type != "srn": + norm_img = self.resize_norm_img(img_list[indices[ino]], + max_wh_ratio) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + else: + norm_img = self.process_image_srn(img_list[indices[ino]], + self.rec_image_shape, 8, + 25, self.char_ops) + encoder_word_pos_list = [] + gsrm_word_pos_list = [] + gsrm_slf_attn_bias1_list = [] + gsrm_slf_attn_bias2_list = [] + encoder_word_pos_list.append(norm_img[1]) + gsrm_word_pos_list.append(norm_img[2]) + gsrm_slf_attn_bias1_list.append(norm_img[3]) + gsrm_slf_attn_bias2_list.append(norm_img[4]) + norm_img_batch.append(norm_img[0]) + + norm_img_batch = np.concatenate(norm_img_batch, axis=0) norm_img_batch = norm_img_batch.copy() - starttime = time.time() - if self.use_zero_copy_run: - self.input_tensor.copy_from_cpu(norm_img_batch) - self.predictor.zero_copy_run() - else: + + if self.loss_type == "srn": + starttime = time.time() + encoder_word_pos_list = np.concatenate(encoder_word_pos_list) + gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) + gsrm_slf_attn_bias1_list = np.concatenate( + gsrm_slf_attn_bias1_list) + gsrm_slf_attn_bias2_list = np.concatenate( + gsrm_slf_attn_bias2_list) + starttime = time.time() + norm_img_batch = fluid.core.PaddleTensor(norm_img_batch) - self.predictor.run([norm_img_batch]) + encoder_word_pos_list = fluid.core.PaddleTensor( + encoder_word_pos_list) + gsrm_word_pos_list = fluid.core.PaddleTensor(gsrm_word_pos_list) + gsrm_slf_attn_bias1_list = fluid.core.PaddleTensor( + gsrm_slf_attn_bias1_list) + gsrm_slf_attn_bias2_list = fluid.core.PaddleTensor( + gsrm_slf_attn_bias2_list) + + inputs = [ + norm_img_batch, encoder_word_pos_list, + gsrm_slf_attn_bias1_list, gsrm_slf_attn_bias2_list, + gsrm_word_pos_list + ] + + self.predictor.run(inputs) + else: + starttime = time.time() + if self.use_zero_copy_run: + self.input_tensor.copy_from_cpu(norm_img_batch) + self.predictor.zero_copy_run() + else: + norm_img_batch = fluid.core.PaddleTensor(norm_img_batch) + self.predictor.run([norm_img_batch]) if self.loss_type == "ctc": rec_idx_batch = self.output_tensors[0].copy_to_cpu() @@ -136,6 +259,26 @@ class TextRecognizer(object): score = np.mean(probs[valid_ind, ind[valid_ind]]) # rec_res.append([preds_text, score]) rec_res[indices[beg_img_no + rno]] = [preds_text, score] + elif self.loss_type == 'srn': + rec_idx_batch = self.output_tensors[0].copy_to_cpu() + probs = self.output_tensors[1].copy_to_cpu() + char_num = self.char_ops.get_char_num() + preds = rec_idx_batch.reshape(-1) + elapse = time.time() - starttime + predict_time += elapse + total_preds = preds.copy() + for ino in range(int(len(rec_idx_batch) / self.text_len)): + preds = total_preds[ino * self.text_len:(ino + 1) * + self.text_len] + ind = np.argmax(probs, axis=1) + valid_ind = np.where(preds != int(char_num - 1))[0] + if len(valid_ind) == 0: + continue + score = np.mean(probs[valid_ind, ind[valid_ind]]) + preds = preds[:valid_ind[-1] + 1] + preds_text = self.char_ops.decode(preds) + + rec_res[indices[beg_img_no + ino]] = [preds_text, score] else: rec_idx_batch = self.output_tensors[0].copy_to_cpu() predict_batch = self.output_tensors[1].copy_to_cpu() @@ -170,6 +313,7 @@ def main(args): continue valid_image_file_list.append(image_file) img_list.append(img) + try: rec_res, predict_time = text_recognizer(img_list) except Exception as e: diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 647a76b20496335cd059242890f86fffe1e3ac1a..ff5d53e94e8ac110d58f2fda9afeb575cd7f0971 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -122,7 +122,6 @@ def main(args): image_file_list = get_image_file_list(args.image_dir) text_sys = TextSystem(args) is_visualize = True - tackle_img_num = 0 for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: @@ -131,9 +130,6 @@ def main(args): logger.info("error in loading image:{}".format(image_file)) continue starttime = time.time() - tackle_img_num += 1 - if not args.use_gpu and args.enable_mkldnn and tackle_img_num % 30 == 0: - text_sys = TextSystem(args) dt_boxes, rec_res = text_sys(img) elapse = time.time() - starttime print("Predict time of %s: %.3fs" % (image_file, elapse)) @@ -153,11 +149,7 @@ def main(args): scores = [rec_res[i][1] for i in range(len(rec_res))] draw_img = draw_ocr( - image, - boxes, - txts, - scores, - drop_score=drop_score) + image, boxes, txts, scores, drop_score=drop_score) draw_img_save = "./inference_results/" if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 9d7ce13d37567ac80e194a6500a0f629ede4b1d4..3e1f07b8a7127e64a994c34d296c945ad1cafd0a 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -101,6 +101,8 @@ def create_predictor(args, mode): config.disable_gpu() config.set_cpu_math_library_num_threads(6) if args.enable_mkldnn: + # cache 10 different shapes for mkldnn to avoid memory leak + config.set_mkldnn_cache_capacity(10) config.enable_mkldnn() #config.enable_memory_optim() @@ -114,7 +116,8 @@ def create_predictor(args, mode): predictor = create_paddle_predictor(config) input_names = predictor.get_input_names() - input_tensor = predictor.get_input_tensor(input_names[0]) + for name in input_names: + input_tensor = predictor.get_input_tensor(name) output_names = predictor.get_output_names() output_tensors = [] for output_name in output_names: diff --git a/tools/infer_rec.py b/tools/infer_rec.py index fd70cd66dccc2cb755efbd10c4d16c9f7a97146d..29fc5b40a890cd6e8fa3ca7d3f0999835555d9bd 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -145,7 +145,7 @@ def main(): preds = preds.reshape(-1) probs = np.array(predict[1]) ind = np.argmax(probs, axis=1) - valid_ind = np.where(preds != int(char_num-1))[0] + valid_ind = np.where(preds != int(char_num - 1))[0] if len(valid_ind) == 0: continue score = np.mean(probs[valid_ind, ind[valid_ind]]) diff --git a/tools/program.py b/tools/program.py index 6d8b9937bff7e70f018b069467525669e7001aae..be133ac2f0605abc39026587baaf884687e48911 100755 --- a/tools/program.py +++ b/tools/program.py @@ -208,18 +208,29 @@ def build_export(config, main_prog, startup_prog): with fluid.unique_name.guard(): func_infor = config['Architecture']['function'] model = create_module(func_infor)(params=config) - image, outputs = model(mode='export') + algorithm = config['Global']['algorithm'] + if algorithm == "SRN": + image, others, outputs = model(mode='export') + else: + image, outputs = model(mode='export') fetches_var_name = sorted([name for name in outputs.keys()]) fetches_var = [outputs[name] for name in fetches_var_name] - feeded_var_names = [image.name] + if algorithm == "SRN": + others_var_names = sorted([name for name in others.keys()]) + feeded_var_names = [image.name] + others_var_names + else: + feeded_var_names = [image.name] + target_vars = fetches_var return feeded_var_names, target_vars, fetches_var_name -def create_multi_devices_program(program, loss_var_name): +def create_multi_devices_program(program, loss_var_name, for_quant=False): build_strategy = fluid.BuildStrategy() build_strategy.memory_optimize = False build_strategy.enable_inplace = True + if for_quant: + build_strategy.fuse_all_reduce_ops = False exec_strategy = fluid.ExecutionStrategy() exec_strategy.num_iteration_per_drop_scope = 1 compile_program = fluid.CompiledProgram(program).with_data_parallel( @@ -409,7 +420,9 @@ def preprocess(): check_gpu(use_gpu) alg = config['Global']['algorithm'] - assert alg in ['EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN'] + assert alg in [ + 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN' + ] if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']: config['Global']['char_ops'] = CharacterOps(config['Global'])