diff --git a/deploy/hubserving/ocr_cls/__init__.py b/deploy/hubserving/ocr_cls/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/deploy/hubserving/ocr_cls/config.json b/deploy/hubserving/ocr_cls/config.json new file mode 100644 index 0000000000000000000000000000000000000000..2ced861b57efe3b9a426c129ff8a3f56e9d83a60 --- /dev/null +++ b/deploy/hubserving/ocr_cls/config.json @@ -0,0 +1,15 @@ +{ + "modules_info": { + "ocr_cls": { + "init_args": { + "version": "1.0.0", + "use_gpu": true + }, + "predict_args": { + } + } + }, + "port": 8866, + "use_multiprocess": false, + "workers": 2 +} diff --git a/deploy/hubserving/ocr_cls/module.py b/deploy/hubserving/ocr_cls/module.py new file mode 100644 index 0000000000000000000000000000000000000000..1b91580ca37027e512a73138270d051497af5b89 --- /dev/null +++ b/deploy/hubserving/ocr_cls/module.py @@ -0,0 +1,121 @@ +# -*- coding:utf-8 -*- +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +sys.path.insert(0, ".") + +from paddlehub.common.logger import logger +from paddlehub.module.module import moduleinfo, runnable, serving +import cv2 +import paddlehub as hub + +from tools.infer.utility import base64_to_cv2 +from tools.infer.predict_cls import TextClassifier + + +@moduleinfo( + name="ocr_cls", + version="1.0.0", + summary="ocr recognition service", + author="paddle-dev", + author_email="paddle-dev@baidu.com", + type="cv/text_recognition") +class OCRCls(hub.Module): + def _initialize(self, use_gpu=False, enable_mkldnn=False): + """ + initialize with the necessary elements + """ + from ocr_cls.params import read_params + cfg = read_params() + + cfg.use_gpu = use_gpu + if use_gpu: + try: + _places = os.environ["CUDA_VISIBLE_DEVICES"] + int(_places[0]) + print("use gpu: ", use_gpu) + print("CUDA_VISIBLE_DEVICES: ", _places) + cfg.gpu_mem = 8000 + except: + raise RuntimeError( + "Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id." + ) + cfg.ir_optim = True + cfg.enable_mkldnn = enable_mkldnn + + self.text_classifier = TextClassifier(cfg) + + def read_images(self, paths=[]): + images = [] + for img_path in paths: + assert os.path.isfile( + img_path), "The {} isn't a valid file.".format(img_path) + img = cv2.imread(img_path) + if img is None: + logger.info("error in loading image:{}".format(img_path)) + continue + images.append(img) + return images + + def predict(self, images=[], paths=[]): + """ + Get the text angle in the predicted images. + Args: + images (list(numpy.ndarray)): images data, shape of each is [H, W, C]. If images not paths + paths (list[str]): The paths of images. If paths not images + Returns: + res (list): The result of text detection box and save path of images. + """ + + if images != [] and isinstance(images, list) and paths == []: + predicted_data = images + elif images == [] and isinstance(paths, list) and paths != []: + predicted_data = self.read_images(paths) + else: + raise TypeError("The input data is inconsistent with expectations.") + + assert predicted_data != [], "There is not any image to be predicted. Please check the input data." + + img_list = [] + for img in predicted_data: + if img is None: + continue + img_list.append(img) + + rec_res_final = [] + try: + img_list, cls_res, predict_time = self.text_classifier(img_list) + for dno in range(len(cls_res)): + angle, score = cls_res[dno] + rec_res_final.append({ + 'angle': angle, + 'confidence': float(score), + }) + except Exception as e: + print(e) + return [[]] + + return [rec_res_final] + + @serving + def serving_method(self, images, **kwargs): + """ + Run as a service. + """ + images_decode = [base64_to_cv2(image) for image in images] + results = self.predict(images_decode, **kwargs) + return results + + +if __name__ == '__main__': + ocr = OCRCls() + image_path = [ + './doc/imgs_words/ch/word_1.jpg', + './doc/imgs_words/ch/word_2.jpg', + './doc/imgs_words/ch/word_3.jpg', + ] + res = ocr.predict(paths=image_path) + print(res) diff --git a/deploy/hubserving/ocr_cls/params.py b/deploy/hubserving/ocr_cls/params.py new file mode 100644 index 0000000000000000000000000000000000000000..bcdb2d6e3800c0ba7897b71f0b0999cafdc223af --- /dev/null +++ b/deploy/hubserving/ocr_cls/params.py @@ -0,0 +1,24 @@ +# -*- coding:utf-8 -*- +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class Config(object): + pass + + +def read_params(): + cfg = Config() + + #params for text classifier + cfg.cls_model_dir = "./inference/ch_ppocr_mobile_v1.1_cls_infer/" + cfg.cls_image_shape = "3, 48, 192" + cfg.label_list = ['0', '180'] + cfg.cls_batch_num = 30 + cfg.cls_thresh = 0.9 + + cfg.use_zero_copy_run = False + cfg.use_pdserving = False + + return cfg diff --git a/deploy/hubserving/ocr_det/config.json b/deploy/hubserving/ocr_det/config.json index c8ef055e05470b8011db7a59782d2edc8c123782..6080d1c53ad065f47739a77e95bd15f792c63da0 100644 --- a/deploy/hubserving/ocr_det/config.json +++ b/deploy/hubserving/ocr_det/config.json @@ -9,7 +9,7 @@ } } }, - "port": 8866, + "port": 8865, "use_multiprocess": false, "workers": 2 } diff --git a/deploy/hubserving/ocr_det/module.py b/deploy/hubserving/ocr_det/module.py index be74306dacf4a3648e3227f11227e6399e6ed2eb..5f7bd6c473884b80fbe4d24088808444633f100d 100644 --- a/deploy/hubserving/ocr_det/module.py +++ b/deploy/hubserving/ocr_det/module.py @@ -3,20 +3,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import argparse -import ast -import copy -import math import os -import time +import sys +sys.path.insert(0, ".") -from paddle.fluid.core import AnalysisConfig, create_paddle_predictor, PaddleTensor from paddlehub.common.logger import logger from paddlehub.module.module import moduleinfo, runnable, serving -from PIL import Image import cv2 import numpy as np -import paddle.fluid as fluid import paddlehub as hub from tools.infer.utility import base64_to_cv2 @@ -67,9 +61,7 @@ class OCRDet(hub.Module): images.append(img) return images - def predict(self, - images=[], - paths=[]): + def predict(self, images=[], paths=[]): """ Get the text box in the predicted images. Args: @@ -87,7 +79,7 @@ class OCRDet(hub.Module): raise TypeError("The input data is inconsistent with expectations.") assert predicted_data != [], "There is not any image to be predicted. Please check the input data." - + all_results = [] for img in predicted_data: if img is None: @@ -99,11 +91,9 @@ class OCRDet(hub.Module): rec_res_final = [] for dno in range(len(dt_boxes)): - rec_res_final.append( - { - 'text_region': dt_boxes[dno].astype(np.int).tolist() - } - ) + rec_res_final.append({ + 'text_region': dt_boxes[dno].astype(np.int).tolist() + }) all_results.append(rec_res_final) return all_results @@ -116,7 +106,7 @@ class OCRDet(hub.Module): results = self.predict(images_decode, **kwargs) return results - + if __name__ == '__main__': ocr = OCRDet() image_path = [ @@ -124,4 +114,4 @@ if __name__ == '__main__': './doc/imgs/12.jpg', ] res = ocr.predict(paths=image_path) - print(res) \ No newline at end of file + print(res) diff --git a/deploy/hubserving/ocr_det/params.py b/deploy/hubserving/ocr_det/params.py index e88ab45c7bb548ef971465d4aaefb30d247ab17f..4d4a9fc27b727034d8185c82dad3e542659fd463 100644 --- a/deploy/hubserving/ocr_det/params.py +++ b/deploy/hubserving/ocr_det/params.py @@ -10,16 +10,17 @@ class Config(object): def read_params(): cfg = Config() - + #params for text detector cfg.det_algorithm = "DB" - cfg.det_model_dir = "./inference/ch_det_mv3_db/" - cfg.det_max_side_len = 960 + cfg.det_model_dir = "./inference/ch_ppocr_mobile_v1.1_det_infer/" + cfg.det_limit_side_len = 960 + cfg.det_limit_type = 'max' #DB parmas - cfg.det_db_thresh =0.3 - cfg.det_db_box_thresh =0.5 - cfg.det_db_unclip_ratio =2.0 + cfg.det_db_thresh = 0.3 + cfg.det_db_box_thresh = 0.5 + cfg.det_db_unclip_ratio = 2.0 # #EAST parmas # cfg.det_east_score_thresh = 0.8 @@ -37,5 +38,6 @@ def read_params(): # cfg.use_space_char = True cfg.use_zero_copy_run = False + cfg.use_pdserving = False return cfg diff --git a/deploy/hubserving/ocr_rec/module.py b/deploy/hubserving/ocr_rec/module.py index 846f5437fe3b0a5136bff6c902481f888558d594..41a42104a81b736eeca346737d2dcefe3d728ef8 100644 --- a/deploy/hubserving/ocr_rec/module.py +++ b/deploy/hubserving/ocr_rec/module.py @@ -3,20 +3,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import argparse -import ast -import copy -import math import os -import time +import sys +sys.path.insert(0, ".") -from paddle.fluid.core import AnalysisConfig, create_paddle_predictor, PaddleTensor from paddlehub.common.logger import logger from paddlehub.module.module import moduleinfo, runnable, serving -from PIL import Image import cv2 -import numpy as np -import paddle.fluid as fluid import paddlehub as hub from tools.infer.utility import base64_to_cv2 @@ -67,9 +60,7 @@ class OCRRec(hub.Module): images.append(img) return images - def predict(self, - images=[], - paths=[]): + def predict(self, images=[], paths=[]): """ Get the text box in the predicted images. Args: @@ -87,31 +78,28 @@ class OCRRec(hub.Module): raise TypeError("The input data is inconsistent with expectations.") assert predicted_data != [], "There is not any image to be predicted. Please check the input data." - + img_list = [] for img in predicted_data: if img is None: continue img_list.append(img) - + rec_res_final = [] try: rec_res, predict_time = self.text_recognizer(img_list) for dno in range(len(rec_res)): text, score = rec_res[dno] - rec_res_final.append( - { - 'text': text, - 'confidence': float(score), - } - ) + rec_res_final.append({ + 'text': text, + 'confidence': float(score), + }) except Exception as e: print(e) return [[]] return [rec_res_final] - @serving def serving_method(self, images, **kwargs): """ @@ -121,7 +109,7 @@ class OCRRec(hub.Module): results = self.predict(images_decode, **kwargs) return results - + if __name__ == '__main__': ocr = OCRRec() image_path = [ @@ -130,4 +118,4 @@ if __name__ == '__main__': './doc/imgs_words/ch/word_3.jpg', ] res = ocr.predict(paths=image_path) - print(res) \ No newline at end of file + print(res) diff --git a/deploy/hubserving/ocr_rec/params.py b/deploy/hubserving/ocr_rec/params.py index 59772e2163d1d5f8279dee85432b5bf93502914e..6f428ecb2686afa5ff66b84d963d1c2175b9cee2 100644 --- a/deploy/hubserving/ocr_rec/params.py +++ b/deploy/hubserving/ocr_rec/params.py @@ -10,25 +10,10 @@ class Config(object): def read_params(): cfg = Config() - - # #params for text detector - # cfg.det_algorithm = "DB" - # cfg.det_model_dir = "./inference/ch_det_mv3_db/" - # cfg.det_max_side_len = 960 - - # #DB parmas - # cfg.det_db_thresh =0.3 - # cfg.det_db_box_thresh =0.5 - # cfg.det_db_unclip_ratio =2.0 - - # #EAST parmas - # cfg.det_east_score_thresh = 0.8 - # cfg.det_east_cover_thresh = 0.1 - # cfg.det_east_nms_thresh = 0.2 #params for text recognizer cfg.rec_algorithm = "CRNN" - cfg.rec_model_dir = "./inference/ch_rec_mv3_crnn/" + cfg.rec_model_dir = "./inference/ch_ppocr_mobile_v1.1_rec_infer/" cfg.rec_image_shape = "3, 32, 320" cfg.rec_char_type = 'ch' @@ -39,5 +24,6 @@ def read_params(): cfg.use_space_char = True cfg.use_zero_copy_run = False + cfg.use_pdserving = False return cfg diff --git a/deploy/hubserving/ocr_system/module.py b/deploy/hubserving/ocr_system/module.py index cb526e1185d8eb623af84ba2451ff38523bbf642..7f3617330bc87ac77a150af9e40dc125d3bfd7de 100644 --- a/deploy/hubserving/ocr_system/module.py +++ b/deploy/hubserving/ocr_system/module.py @@ -3,20 +3,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import argparse -import ast -import copy -import math import os +import sys +sys.path.insert(0, ".") + import time -from paddle.fluid.core import AnalysisConfig, create_paddle_predictor, PaddleTensor from paddlehub.common.logger import logger from paddlehub.module.module import moduleinfo, runnable, serving -from PIL import Image import cv2 import numpy as np -import paddle.fluid as fluid import paddlehub as hub from tools.infer.utility import base64_to_cv2 @@ -52,7 +48,7 @@ class OCRSystem(hub.Module): ) cfg.ir_optim = True cfg.enable_mkldnn = enable_mkldnn - + self.text_sys = TextSystem(cfg) def read_images(self, paths=[]): @@ -67,9 +63,7 @@ class OCRSystem(hub.Module): images.append(img) return images - def predict(self, - images=[], - paths=[]): + def predict(self, images=[], paths=[]): """ Get the chinese texts in the predicted images. Args: @@ -104,13 +98,11 @@ class OCRSystem(hub.Module): for dno in range(dt_num): text, score = rec_res[dno] - rec_res_final.append( - { - 'text': text, - 'confidence': float(score), - 'text_region': dt_boxes[dno].astype(np.int).tolist() - } - ) + rec_res_final.append({ + 'text': text, + 'confidence': float(score), + 'text_region': dt_boxes[dno].astype(np.int).tolist() + }) all_results.append(rec_res_final) return all_results @@ -123,7 +115,7 @@ class OCRSystem(hub.Module): results = self.predict(images_decode, **kwargs) return results - + if __name__ == '__main__': ocr = OCRSystem() image_path = [ @@ -131,4 +123,4 @@ if __name__ == '__main__': './doc/imgs/12.jpg', ] res = ocr.predict(paths=image_path) - print(res) \ No newline at end of file + print(res) diff --git a/deploy/hubserving/ocr_system/params.py b/deploy/hubserving/ocr_system/params.py index 0ff56d37d50b30b09bb13b529a48a260dfe8f84a..1f6a07bcc0167e90564edab9c4719b9192233b4c 100644 --- a/deploy/hubserving/ocr_system/params.py +++ b/deploy/hubserving/ocr_system/params.py @@ -10,16 +10,17 @@ class Config(object): def read_params(): cfg = Config() - + #params for text detector cfg.det_algorithm = "DB" - cfg.det_model_dir = "./inference/ch_det_mv3_db/" - cfg.det_max_side_len = 960 + cfg.det_model_dir = "./inference/ch_ppocr_mobile_v1.1_det_infer/" + cfg.det_limit_side_len = 960 + cfg.det_limit_type = 'max' #DB parmas - cfg.det_db_thresh =0.3 - cfg.det_db_box_thresh =0.5 - cfg.det_db_unclip_ratio =2.0 + cfg.det_db_thresh = 0.3 + cfg.det_db_box_thresh = 0.5 + cfg.det_db_unclip_ratio = 2.0 #EAST parmas cfg.det_east_score_thresh = 0.8 @@ -28,7 +29,7 @@ def read_params(): #params for text recognizer cfg.rec_algorithm = "CRNN" - cfg.rec_model_dir = "./inference/ch_rec_mv3_crnn/" + cfg.rec_model_dir = "./inference/ch_ppocr_mobile_v1.1_rec_infer/" cfg.rec_image_shape = "3, 32, 320" cfg.rec_char_type = 'ch' @@ -38,6 +39,15 @@ def read_params(): cfg.rec_char_dict_path = "./ppocr/utils/ppocr_keys_v1.txt" cfg.use_space_char = True + #params for text classifier + cfg.use_angle_cls = True + cfg.cls_model_dir = "./inference/ch_ppocr_mobile_v1.1_cls_infer/" + cfg.cls_image_shape = "3, 48, 192" + cfg.label_list = ['0', '180'] + cfg.cls_batch_num = 30 + cfg.cls_thresh = 0.9 + cfg.use_zero_copy_run = False + cfg.use_pdserving = False return cfg diff --git a/deploy/hubserving/readme.md b/deploy/hubserving/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..f64bd372569f12ea52214e3e89927df0c859a17f --- /dev/null +++ b/deploy/hubserving/readme.md @@ -0,0 +1,193 @@ +[English](readme_en.md) | 简体中文 + +PaddleOCR提供2种服务部署方式: +- 基于PaddleHub Serving的部署:代码路径为"`./deploy/hubserving`",按照本教程使用; +- 基于PaddleServing的部署:代码路径为"`./deploy/pdserving`",使用方法参考[文档](../../deploy/pdserving/readme.md)。 + +# 基于PaddleHub Serving的服务部署 + +hubserving服务部署目录下包括检测、识别、2阶段串联三种服务包,请根据需求选择相应的服务包进行安装和启动。目录结构如下: +``` +deploy/hubserving/ + └─ ocr_cls 分类模块服务包 + └─ ocr_det 检测模块服务包 + └─ ocr_rec 识别模块服务包 + └─ ocr_system 检测+识别串联服务包 +``` + +每个服务包下包含3个文件。以2阶段串联服务包为例,目录如下: +``` +deploy/hubserving/ocr_system/ + └─ __init__.py 空文件,必选 + └─ config.json 配置文件,可选,使用配置启动服务时作为参数传入 + └─ module.py 主模块,必选,包含服务的完整逻辑 + └─ params.py 参数文件,必选,包含模型路径、前后处理参数等参数 +``` + +## 快速启动服务 +以下步骤以检测+识别2阶段串联服务为例,如果只需要检测服务或识别服务,替换相应文件路径即可。 +### 1. 准备环境 +```shell +# 安装paddlehub +pip3 install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple +``` + +### 2. 下载推理模型 +安装服务模块前,需要准备推理模型并放到正确路径。默认使用的是v1.1版的超轻量模型,默认模型路径为: +``` +检测模型:./inference/ch_ppocr_mobile_v1.1_det_infer/ +识别模型:./inference/ch_ppocr_mobile_v1.1_rec_infer/ +方向分类器:./inference/ch_ppocr_mobile_v1.1_cls_infer/ +``` + +**模型路径可在`params.py`中查看和修改。** 更多模型可以从PaddleOCR提供的[模型库](../../doc/doc_ch/models_list.md)下载,也可以替换成自己训练转换好的模型。 + +### 3. 安装服务模块 +PaddleOCR提供3种服务模块,根据需要安装所需模块。 + +* 在Linux环境下,安装示例如下: +```shell +# 安装检测服务模块: +hub install deploy/hubserving/ocr_det/ + +# 或,安装分类服务模块: +hub install deploy/hubserving/ocr_cls/ + +# 或,安装识别服务模块: +hub install deploy/hubserving/ocr_rec/ + +# 或,安装检测+识别串联服务模块: +hub install deploy/hubserving/ocr_system/ +``` + +* 在Windows环境下(文件夹的分隔符为`\`),安装示例如下: +```shell +# 安装检测服务模块: +hub install deploy\hubserving\ocr_det\ + +# 或,安装分类服务模块: +hub install deploy\hubserving\ocr_cls\ + +# 或,安装识别服务模块: +hub install deploy\hubserving\ocr_rec\ + +# 或,安装检测+识别串联服务模块: +hub install deploy\hubserving\ocr_system\ +``` + +### 4. 启动服务 +#### 方式1. 命令行命令启动(仅支持CPU) +**启动命令:** +```shell +$ hub serving start --modules [Module1==Version1, Module2==Version2, ...] \ + --port XXXX \ + --use_multiprocess \ + --workers \ +``` + +**参数:** + +|参数|用途| +|-|-| +|--modules/-m|PaddleHub Serving预安装模型,以多个Module==Version键值对的形式列出
*`当不指定Version时,默认选择最新版本`*| +|--port/-p|服务端口,默认为8866| +|--use_multiprocess|是否启用并发方式,默认为单进程方式,推荐多核CPU机器使用此方式
*`Windows操作系统只支持单进程方式`*| +|--workers|在并发方式下指定的并发任务数,默认为`2*cpu_count-1`,其中`cpu_count`为CPU核数| + +如启动串联服务: ```hub serving start -m ocr_system``` + +这样就完成了一个服务化API的部署,使用默认端口号8866。 + +#### 方式2. 配置文件启动(支持CPU、GPU) +**启动命令:** +```hub serving start -c config.json``` + +其中,`config.json`格式如下: +```python +{ + "modules_info": { + "ocr_system": { + "init_args": { + "version": "1.0.0", + "use_gpu": true + }, + "predict_args": { + } + } + }, + "port": 8868, + "use_multiprocess": false, + "workers": 2 +} +``` + +- `init_args`中的可配参数与`module.py`中的`_initialize`函数接口一致。其中,**当`use_gpu`为`true`时,表示使用GPU启动服务**。 +- `predict_args`中的可配参数与`module.py`中的`predict`函数接口一致。 + +**注意:** +- 使用配置文件启动服务时,其他参数会被忽略。 +- 如果使用GPU预测(即,`use_gpu`置为`true`),则需要在启动服务之前,设置CUDA_VISIBLE_DEVICES环境变量,如:```export CUDA_VISIBLE_DEVICES=0```,否则不用设置。 +- **`use_gpu`不可与`use_multiprocess`同时为`true`**。 + +如,使用GPU 3号卡启动串联服务: +```shell +export CUDA_VISIBLE_DEVICES=3 +hub serving start -c deploy/hubserving/ocr_system/config.json +``` + +## 发送预测请求 +配置好服务端,可使用以下命令发送预测请求,获取预测结果: + +```python tools/test_hubserving.py server_url image_path``` + +需要给脚本传递2个参数: +- **server_url**:服务地址,格式为 +`http://[ip_address]:[port]/predict/[module_name]` +例如,如果使用配置文件启动分类,检测、识别,检测+分类+识别3阶段服务,那么发送请求的url将分别是: +`http://127.0.0.1:8865/predict/ocr_det` +`http://127.0.0.1:8866/predict/ocr_cls` +`http://127.0.0.1:8867/predict/ocr_rec` +`http://127.0.0.1:8868/predict/ocr_system` +- **image_path**:测试图像路径,可以是单张图片路径,也可以是图像集合目录路径 + +访问示例: +```python tools/test_hubserving.py http://127.0.0.1:8868/predict/ocr_system ./doc/imgs/``` + +## 返回结果格式说明 +返回结果为列表(list),列表中的每一项为词典(dict),词典一共可能包含3种字段,信息如下: + +|字段名称|数据类型|意义| +|----|----|----| +|angle|str|文本角度| +|text|str|文本内容| +|confidence|float| 文本识别置信度或文本角度分类置信度| +|text_region|list|文本位置坐标| + +不同模块返回的字段不同,如,文本识别服务模块返回结果不含`text_region`字段,具体信息如下: + +| 字段名/模块名 | ocr_det | ocr_cls | ocr_rec | ocr_system | +| ---- | ---- | ---- | ---- | ---- | +|angle| | ✔ | | ✔ | +|text| | |✔|✔| +|confidence| |✔ |✔|✔| +|text_region| ✔| | |✔ | + +**说明:** 如果需要增加、删除、修改返回字段,可在相应模块的`module.py`文件中进行修改,完整流程参考下一节自定义修改服务模块。 + +## 自定义修改服务模块 +如果需要修改服务逻辑,你一般需要操作以下步骤(以修改`ocr_system`为例): + +- 1、 停止服务 +```hub serving stop --port/-p XXXX``` + +- 2、 到相应的`module.py`和`params.py`等文件中根据实际需求修改代码。 +例如,如果需要替换部署服务所用模型,则需要到`params.py`中修改模型路径参数`det_model_dir`和`rec_model_dir`,如果需要关闭文本方向分类器,则将参数`use_angle_cls`置为`False`,当然,同时可能还需要修改其他相关参数,请根据实际情况修改调试。 **强烈建议修改后先直接运行`module.py`调试,能正确运行预测后再启动服务测试。** + +- 3、 卸载旧服务包 +```hub uninstall ocr_system``` + +- 4、 安装修改后的新服务包 +```hub install deploy/hubserving/ocr_system/``` + +- 5、重新启动服务 +```hub serving start -m ocr_system``` diff --git a/deploy/hubserving/readme_en.md b/deploy/hubserving/readme_en.md new file mode 100644 index 0000000000000000000000000000000000000000..c6cf53413bc3eac45f933fead66356d1491cc60c --- /dev/null +++ b/deploy/hubserving/readme_en.md @@ -0,0 +1,204 @@ +English | [简体中文](readme.md) + +PaddleOCR provides 2 service deployment methods: +- Based on **PaddleHub Serving**: Code path is "`./deploy/hubserving`". Please follow this tutorial. +- Based on **PaddleServing**: Code path is "`./deploy/pdserving`". Please refer to the [tutorial](../../deploy/pdserving/readme.md) for usage. + +# Service deployment based on PaddleHub Serving + +The hubserving service deployment directory includes three service packages: detection, recognition, and two-stage series connection. Please select the corresponding service package to install and start service according to your needs. The directory is as follows: +``` +deploy/hubserving/ + └─ ocr_det detection module service package + └─ ocr_cls angle class module service package + └─ ocr_rec recognition module service package + └─ ocr_system two-stage series connection service package +``` + +Each service pack contains 3 files. Take the 2-stage series connection service package as an example, the directory is as follows: +``` +deploy/hubserving/ocr_system/ + └─ __init__.py Empty file, required + └─ config.json Configuration file, optional, passed in as a parameter when using configuration to start the service + └─ module.py Main module file, required, contains the complete logic of the service + └─ params.py Parameter file, required, including parameters such as model path, pre- and post-processing parameters +``` + +## Quick start service +The following steps take the 2-stage series service as an example. If only the detection service or recognition service is needed, replace the corresponding file path. + +### 1. Prepare the environment +```shell +# Install paddlehub +pip3 install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple +``` + +### 2. Download inference model +Before installing the service module, you need to prepare the inference model and put it in the correct path. By default, the ultra lightweight model of v1.1 is used, and the default model path is: +``` +detection model: ./inference/ch_ppocr_mobile_v1.1_det_infer/ +recognition model: ./inference/ch_ppocr_mobile_v1.1_rec_infer/ +text direction classifier: ./inference/ch_ppocr_mobile_v1.1_cls_infer/ +``` + +**The model path can be found and modified in `params.py`.** More models provided by PaddleOCR can be obtained from the [model library](../../doc/doc_en/models_list_en.md). You can also use models trained by yourself. + +### 3. Install Service Module +PaddleOCR provides 3 kinds of service modules, install the required modules according to your needs. + +* On Linux platform, the examples are as follows. +```shell +# Install the detection service module: +hub install deploy/hubserving/ocr_det/ + +# Or, install the angle class service module: +hub install deploy/hubserving/ocr_cls/ + +# Or, install the recognition service module: +hub install deploy/hubserving/ocr_rec/ + +# Or, install the 2-stage series service module: +hub install deploy/hubserving/ocr_system/ +``` + +* On Windows platform, the examples are as follows. +```shell +# Install the detection service module: +hub install deploy\hubserving\ocr_det\ + +# Or, install the angle class service module: +hub install deploy\hubserving\ocr_cls\ + +# Or, install the recognition service module: +hub install deploy\hubserving\ocr_rec\ + +# Or, install the 2-stage series service module: +hub install deploy\hubserving\ocr_system\ +``` + +### 4. Start service +#### Way 1. Start with command line parameters (CPU only) + +**start command:** +```shell +$ hub serving start --modules [Module1==Version1, Module2==Version2, ...] \ + --port XXXX \ + --use_multiprocess \ + --workers \ +``` +**parameters:** + +|parameters|usage| +|-|-| +|--modules/-m|PaddleHub Serving pre-installed model, listed in the form of multiple Module==Version key-value pairs
*`When Version is not specified, the latest version is selected by default`*| +|--port/-p|Service port, default is 8866| +|--use_multiprocess|Enable concurrent mode, the default is single-process mode, this mode is recommended for multi-core CPU machines
*`Windows operating system only supports single-process mode`*| +|--workers|The number of concurrent tasks specified in concurrent mode, the default is `2*cpu_count-1`, where `cpu_count` is the number of CPU cores| + +For example, start the 2-stage series service: +```shell +hub serving start -m ocr_system +``` + +This completes the deployment of a service API, using the default port number 8866. + +#### Way 2. Start with configuration file(CPU、GPU) +**start command:** +```shell +hub serving start --config/-c config.json +``` +Wherein, the format of `config.json` is as follows: +```python +{ + "modules_info": { + "ocr_system": { + "init_args": { + "version": "1.0.0", + "use_gpu": true + }, + "predict_args": { + } + } + }, + "port": 8868, + "use_multiprocess": false, + "workers": 2 +} +``` +- The configurable parameters in `init_args` are consistent with the `_initialize` function interface in `module.py`. Among them, **when `use_gpu` is `true`, it means that the GPU is used to start the service**. +- The configurable parameters in `predict_args` are consistent with the `predict` function interface in `module.py`. + +**Note:** +- When using the configuration file to start the service, other parameters will be ignored. +- If you use GPU prediction (that is, `use_gpu` is set to `true`), you need to set the environment variable CUDA_VISIBLE_DEVICES before starting the service, such as: ```export CUDA_VISIBLE_DEVICES=0```, otherwise you do not need to set it. +- **`use_gpu` and `use_multiprocess` cannot be `true` at the same time.** + +For example, use GPU card No. 3 to start the 2-stage series service: +```shell +export CUDA_VISIBLE_DEVICES=3 +hub serving start -c deploy/hubserving/ocr_system/config.json +``` + +## Send prediction requests +After the service starts, you can use the following command to send a prediction request to obtain the prediction result: +```shell +python tools/test_hubserving.py server_url image_path +``` + +Two parameters need to be passed to the script: +- **server_url**:service address,format of which is +`http://[ip_address]:[port]/predict/[module_name]` +For example, if the detection, recognition and 2-stage serial services are started with provided configuration files, the respective `server_url` would be: +`http://127.0.0.1:8865/predict/ocr_det` +`http://127.0.0.1:8866/predict/ocr_cls` +`http://127.0.0.1:8867/predict/ocr_rec` +`http://127.0.0.1:8868/predict/ocr_system` +- **image_path**:Test image path, can be a single image path or an image directory path + +**Eg.** +```shell +python tools/test_hubserving.py http://127.0.0.1:8868/predict/ocr_system ./doc/imgs/ +``` + +## Returned result format +The returned result is a list. Each item in the list is a dict. The dict may contain three fields. The information is as follows: + +|field name|data type|description| +|----|----|----| +|angle|str|angle| +|text|str|text content| +|confidence|float|text recognition confidence| +|text_region|list|text location coordinates| + +The fields returned by different modules are different. For example, the results returned by the text recognition service module do not contain `text_region`. The details are as follows: + +| field name/module name | ocr_det | ocr_cls | ocr_rec | ocr_system | +| ---- | ---- | ---- | ---- | ---- | +|angle| | ✔ | | ✔ | +|text| | |✔|✔| +|confidence| |✔ |✔|✔| +|text_region| ✔| | |✔ | + +**Note:** If you need to add, delete or modify the returned fields, you can modify the file `module.py` of the corresponding module. For the complete process, refer to the user-defined modification service module in the next section. + +## User defined service module modification +If you need to modify the service logic, the following steps are generally required (take the modification of `ocr_system` for example): + +- 1. Stop service +```shell +hub serving stop --port/-p XXXX +``` +- 2. Modify the code in the corresponding files, like `module.py` and `params.py`, according to the actual needs. +For example, if you need to replace the model used by the deployed service, you need to modify model path parameters `det_model_dir` and `rec_model_dir` in `params.py`. If you want to turn off the text direction classifier, set the parameter `use_angle_cls` to `False`. Of course, other related parameters may need to be modified at the same time. Please modify and debug according to the actual situation. It is suggested to run `module.py` directly for debugging after modification before starting the service test. +- 3. Uninstall old service module +```shell +hub uninstall ocr_system +``` +- 4. Install modified service module +```shell +hub install deploy/hubserving/ocr_system/ +``` +- 5. Restart service +```shell +hub serving start -m ocr_system +``` diff --git a/tools/infer/utility.py b/tools/infer/utility.py index ee1f954dcc4b6518cfe454a86650b397b9db449e..fabc33dc67265cd304294bab15ecd6d242a3add6 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -230,10 +230,10 @@ def draw_ocr_box_txt(image, box[2][1], box[3][0], box[3][1] ], outline=color) - box_height = math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][ - 1]) ** 2) - box_width = math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][ - 1]) ** 2) + box_height = math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][ + 1])**2) + box_width = math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][ + 1])**2) if box_height > 2 * box_width: font_size = max(int(box_width * 0.9), 10) font = ImageFont.truetype(font_path, font_size, encoding="utf-8") @@ -260,7 +260,6 @@ def str_count(s): Count the number of Chinese characters, a single English character and a single number equal to half the length of Chinese characters. - args: s(string): the input of string return(int): @@ -295,7 +294,6 @@ def text_visual(texts, img_w(int): the width of blank img font_path: the path of font which is used to draw text return(array): - """ if scores is not None: assert len(texts) == len(