diff --git a/deploy/hubserving/readme.md b/deploy/hubserving/readme.md index 9135615520ddfcff30af74c07348d35a4d0de8d7..f8f7cf03cfe42fa57ce013194309a9029e6b259c 100755 --- a/deploy/hubserving/readme.md +++ b/deploy/hubserving/readme.md @@ -27,6 +27,7 @@ deploy/hubserving/ └─ ocr_rec 识别模块服务包 └─ ocr_system 检测+识别串联服务包 └─ structure_table 表格识别服务包 + └─ structure_system PP-Structure服务包 ``` 每个服务包下包含3个文件。以2阶段串联服务包为例,目录如下: @@ -77,6 +78,9 @@ hub install deploy/hubserving/ocr_system/ # 或,安装表格识别服务模块: hub install deploy/hubserving/structure_table/ + +# 或,安装PP-Structure服务模块: +hub install deploy/hubserving/structure_system/ ``` * 在Windows环境下(文件夹的分隔符为`\`),安装示例如下: @@ -95,6 +99,9 @@ hub install deploy\hubserving\ocr_system\ # 或,安装表格识别服务模块: hub install deploy\hubserving\structure_table\ + +# 或,安装PP-Structure服务模块: +hub install deploy\hubserving\structure_system\ ``` ### 4. 启动服务 @@ -165,14 +172,16 @@ hub serving start -c deploy/hubserving/ocr_system/config.json 需要给脚本传递2个参数: - **server_url**:服务地址,格式为 `http://[ip_address]:[port]/predict/[module_name]` -例如,如果使用配置文件启动分类,检测、识别,检测+分类+识别3阶段,表格识别服务,那么发送请求的url将分别是: +例如,如果使用配置文件启动分类,检测、识别,检测+分类+识别3阶段,表格识别和PP-Structure服务,那么发送请求的url将分别是: `http://127.0.0.1:8865/predict/ocr_det` `http://127.0.0.1:8866/predict/ocr_cls` `http://127.0.0.1:8867/predict/ocr_rec` `http://127.0.0.1:8868/predict/ocr_system` `http://127.0.0.1:8869/predict/structure_table` +`http://127.0.0.1:8870/predict/structure_system` - **image_dir**:测试图像路径,可以是单张图片路径,也可以是图像集合目录路径 - **visualize**:是否可视化结果,默认为False +- **output**:可视化结果保存路径,默认为`./hubserving_result` 访问示例: ```python tools/test_hubserving.py --server_url=http://127.0.0.1:8868/predict/ocr_system --image_dir./doc/imgs/ --visualize=false``` @@ -187,16 +196,18 @@ hub serving start -c deploy/hubserving/ocr_system/config.json |confidence|float| 文本识别置信度或文本角度分类置信度| |text_region|list|文本位置坐标| |html|str|表格的html字符串| +|regions|list|版面分析+表格识别+OCR的结果,每一项为一个list,包含表示区域坐标的`bbox`,区域类型的`type`和区域结果的`res`三个字段| 不同模块返回的字段不同,如,文本识别服务模块返回结果不含`text_region`字段,具体信息如下: -| 字段名/模块名 | ocr_det | ocr_cls | ocr_rec | ocr_system | structure_table | -| --- | --- | --- | --- | --- | --- | -|angle| | ✔ | | ✔ | | -|text| | |✔|✔| | -|confidence| |✔ |✔| | | -|text_region| ✔| | |✔ | | -|html| | | | |✔ | +| 字段名/模块名 | ocr_det | ocr_cls | ocr_rec | ocr_system | structure_table | structure_system | +| --- | --- | --- | --- | --- | --- |--- | +|angle| | ✔ | | ✔ | || +|text| | |✔|✔| | ✔ | +|confidence| |✔ |✔| | | ✔| +|text_region| ✔| | |✔ | | ✔| +|html| | | | |✔ |✔| +|regions| | | | |✔ |✔ | **说明:** 如果需要增加、删除、修改返回字段,可在相应模块的`module.py`文件中进行修改,完整流程参考下一节自定义修改服务模块。 diff --git a/deploy/hubserving/readme_en.md b/deploy/hubserving/readme_en.md index 03c1630de6f51aeab1926d432a56b50140114083..79a0768752a3be1cbea7172f22fd2c0e298d33b3 100755 --- a/deploy/hubserving/readme_en.md +++ b/deploy/hubserving/readme_en.md @@ -27,6 +27,7 @@ deploy/hubserving/ └─ ocr_rec text recognition module service package └─ ocr_system two-stage series connection service package └─ structure_table table recognition service package + └─ structure_system PP-Structure service package ``` Each service pack contains 3 files. Take the 2-stage series connection service package as an example, the directory is as follows: @@ -78,6 +79,9 @@ hub install deploy/hubserving/ocr_system/ # Or install table recognition service module hub install deploy/hubserving/structure_table/ + +# Or install PP-Structure service module +hub install deploy/hubserving/structure_system/ ``` * On Windows platform, the examples are as follows. @@ -96,6 +100,9 @@ hub install deploy\hubserving\ocr_system\ # Or install table recognition service module hub install deploy/hubserving/structure_table/ + +# Or install PP-Structure service module +hub install deploy\hubserving\structure_system\ ``` ### 4. Start service @@ -170,15 +177,17 @@ python tools/test_hubserving.py server_url image_path Two parameters need to be passed to the script: - **server_url**:service address,format of which is `http://[ip_address]:[port]/predict/[module_name]` -For example, if using the configuration file to start the text angle classification, text detection, text recognition, detection+classification+recognition 3 stages, table recognition service, then the `server_url` to send the request will be: +For example, if using the configuration file to start the text angle classification, text detection, text recognition, detection+classification+recognition 3 stages, table recognition and PP-Structure service, then the `server_url` to send the request will be: `http://127.0.0.1:8865/predict/ocr_det` `http://127.0.0.1:8866/predict/ocr_cls` `http://127.0.0.1:8867/predict/ocr_rec` `http://127.0.0.1:8868/predict/ocr_system` `http://127.0.0.1:8869/predict/structure_table` +`http://127.0.0.1:8870/predict/structure_system` - **image_dir**:Test image path, can be a single image path or an image directory path - **visualize**:Whether to visualize the results, the default value is False +- **output**:The floder to save Visualization result, default value is `./hubserving_result` **Eg.** ```shell @@ -195,16 +204,18 @@ The returned result is a list. Each item in the list is a dict. The dict may con |confidence|float|text recognition confidence| |text_region|list|text location coordinates| |html|str|table html str| +|regions|list|The result of layout analysis + table recognition + OCR, each item is a list, including `bbox` indicating area coordinates, `type` of area type and `res` of area results| The fields returned by different modules are different. For example, the results returned by the text recognition service module do not contain `text_region`. The details are as follows: -| field name/module name| ocr_det | ocr_cls | ocr_rec | ocr_system | structure_table | -| ---- | ---- | ---- | ---- | ---- | ---- | -|angle| | ✔ | | ✔ | | -|text| | |✔|✔| | -|confidence| |✔ |✔| | | -|text_region| ✔| | |✔ | | -|html| | | | |✔ | +| field name/module name | ocr_det | ocr_cls | ocr_rec | ocr_system | structure_table | structure_system | +| --- | --- | --- | --- | --- | --- |--- | +|angle| | ✔ | | ✔ | || +|text| | |✔|✔| | ✔ | +|confidence| |✔ |✔| | | ✔| +|text_region| ✔| | |✔ | | ✔| +|html| | | | |✔ |✔| +|regions| | | | |✔ |✔ | **Note:** If you need to add, delete or modify the returned fields, you can modify the file `module.py` of the corresponding module. For the complete process, refer to the user-defined modification service module in the next section. diff --git a/deploy/hubserving/structure_system/__init__.py b/deploy/hubserving/structure_system/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/deploy/hubserving/structure_system/config.json b/deploy/hubserving/structure_system/config.json new file mode 100644 index 0000000000000000000000000000000000000000..642aa94a2a25759469f74280f6aab9a2495f493f --- /dev/null +++ b/deploy/hubserving/structure_system/config.json @@ -0,0 +1,16 @@ +{ + "modules_info": { + "structure_system": { + "init_args": { + "version": "1.0.0", + "use_gpu": true + }, + "predict_args": { + } + } + }, + "port": 8870, + "use_multiprocess": false, + "workers": 2 +} + diff --git a/deploy/hubserving/structure_system/module.py b/deploy/hubserving/structure_system/module.py new file mode 100644 index 0000000000000000000000000000000000000000..cb2b422d5fa3ae77d7158ef1fb1833ccc9a8bc4c --- /dev/null +++ b/deploy/hubserving/structure_system/module.py @@ -0,0 +1,136 @@ +# -*- coding:utf-8 -*- +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +sys.path.insert(0, ".") +import copy + +import time +import paddlehub +from paddlehub.common.logger import logger +from paddlehub.module.module import moduleinfo, runnable, serving +import cv2 +import numpy as np +import paddlehub as hub + +from tools.infer.utility import base64_to_cv2 +from ppstructure.predict_system import StructureSystem as PPStructureSystem +from ppstructure.predict_system import save_structure_res +from ppstructure.utility import parse_args +from deploy.hubserving.structure_system.params import read_params + + +@moduleinfo( + name="structure_system", + version="1.0.0", + summary="PP-Structure system service", + author="paddle-dev", + author_email="paddle-dev@baidu.com", + type="cv/structure_system") +class StructureSystem(hub.Module): + def _initialize(self, use_gpu=False, enable_mkldnn=False): + """ + initialize with the necessary elements + """ + cfg = self.merge_configs() + + cfg.use_gpu = use_gpu + if use_gpu: + try: + _places = os.environ["CUDA_VISIBLE_DEVICES"] + int(_places[0]) + print("use gpu: ", use_gpu) + print("CUDA_VISIBLE_DEVICES: ", _places) + cfg.gpu_mem = 8000 + except: + raise RuntimeError( + "Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id." + ) + cfg.ir_optim = True + cfg.enable_mkldnn = enable_mkldnn + + self.table_sys = PPStructureSystem(cfg) + + def merge_configs(self): + # deafult cfg + backup_argv = copy.deepcopy(sys.argv) + sys.argv = sys.argv[:1] + cfg = parse_args() + + update_cfg_map = vars(read_params()) + + for key in update_cfg_map: + cfg.__setattr__(key, update_cfg_map[key]) + + sys.argv = copy.deepcopy(backup_argv) + return cfg + + def read_images(self, paths=[]): + images = [] + for img_path in paths: + assert os.path.isfile( + img_path), "The {} isn't a valid file.".format(img_path) + img = cv2.imread(img_path) + if img is None: + logger.info("error in loading image:{}".format(img_path)) + continue + images.append(img) + return images + + def predict(self, images=[], paths=[]): + """ + Get the chinese texts in the predicted images. + Args: + images (list(numpy.ndarray)): images data, shape of each is [H, W, C]. If images not paths + paths (list[str]): The paths of images. If paths not images + Returns: + res (list): The result of chinese texts and save path of images. + """ + + if images != [] and isinstance(images, list) and paths == []: + predicted_data = images + elif images == [] and isinstance(paths, list) and paths != []: + predicted_data = self.read_images(paths) + else: + raise TypeError("The input data is inconsistent with expectations.") + + assert predicted_data != [], "There is not any image to be predicted. Please check the input data." + + all_results = [] + for img in predicted_data: + if img is None: + logger.info("error in loading image") + all_results.append([]) + continue + starttime = time.time() + res = self.table_sys(img) + elapse = time.time() - starttime + logger.info("Predict time: {}".format(elapse)) + + # parse result + res_final = [] + for region in res: + region.pop('img') + res_final.append(region) + all_results.append({'regions': res_final}) + return all_results + + @serving + def serving_method(self, images, **kwargs): + """ + Run as a service. + """ + images_decode = [base64_to_cv2(image) for image in images] + results = self.predict(images_decode, **kwargs) + return results + + +if __name__ == '__main__': + structure_system = StructureSystem() + structure_system._initialize() + image_path = ['./doc/table/1.png'] + res = structure_system.predict(paths=image_path) + print(res) diff --git a/deploy/hubserving/structure_system/params.py b/deploy/hubserving/structure_system/params.py new file mode 100755 index 0000000000000000000000000000000000000000..608e016d728c0edc8c44e96d2cc6347660ad591c --- /dev/null +++ b/deploy/hubserving/structure_system/params.py @@ -0,0 +1,18 @@ +# -*- coding:utf-8 -*- +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from deploy.hubserving.structure_table.params import read_params as table_read_params + + +def read_params(): + cfg = table_read_params() + + # params for layout parser model + cfg.layout_path_model = 'lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config' + cfg.layout_label_map = None + + cfg.mode = 'structure' + cfg.output = './output' + return cfg diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py index f6703e1a419a38b51fb2cb034c6e2b1aa6d8aa90..96227aabbbf38904417f3e3a6fd6c49031c4bc58 100644 --- a/ppstructure/predict_system.py +++ b/ppstructure/predict_system.py @@ -22,6 +22,7 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) os.environ["FLAGS_allocator_strategy"] = 'auto_growth' import cv2 +import json import numpy as np import time import logging @@ -66,8 +67,7 @@ class StructureSystem(object): self.use_angle_cls = args.use_angle_cls self.drop_score = args.drop_score elif self.mode == 'vqa': - from ppstructure.vqa.infer_ser_e2e import SerPredictor, draw_ser_results - self.vqa_engine = SerPredictor(args) + raise NotImplementedError def __call__(self, img): if self.mode == 'structure': @@ -82,24 +82,24 @@ class StructureSystem(object): res = self.table_system(roi_img) else: filter_boxes, filter_rec_res = self.text_system(roi_img) - filter_boxes = [x + [x1, y1] for x in filter_boxes] - filter_boxes = [ - x.reshape(-1).tolist() for x in filter_boxes - ] # remove style char style_token = [ '', '', '', '', '', '', '', '', '', '', '', '', '', '' ] - filter_rec_res_tmp = [] - for rec_res in filter_rec_res: + res = [] + for box, rec_res in zip(filter_boxes, filter_rec_res): rec_str, rec_conf = rec_res for token in style_token: if token in rec_str: rec_str = rec_str.replace(token, '') - filter_rec_res_tmp.append((rec_str, rec_conf)) - res = (filter_boxes, filter_rec_res_tmp) + box += [x1, y1] + res.append({ + 'text': rec_str, + 'confidence': float(rec_conf), + 'text_region': box.tolist() + }) res_list.append({ 'type': region.type, 'bbox': [x1, y1, x2, y2], @@ -107,7 +107,7 @@ class StructureSystem(object): 'res': res }) elif self.mode == 'vqa': - res_list, _ = self.vqa_engine(img) + raise NotImplementedError return res_list @@ -123,15 +123,14 @@ def save_structure_res(res, save_folder, img_name): excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox'])) to_excel(region['res'], excel_path) - if region['type'] == 'Figure': + elif region['type'] == 'Figure': roi_img = region['img'] img_path = os.path.join(excel_save_folder, '{}.jpg'.format(region['bbox'])) cv2.imwrite(img_path, roi_img) else: - for box, rec_res in zip(region['res'][0], region['res'][1]): - f.write('{}\t{}\n'.format( - np.array(box).reshape(-1).tolist(), rec_res)) + for text_result in region['res']: + f.write('{}\n'.format(json.dumps(text_result))) def main(args): @@ -162,8 +161,9 @@ def main(args): draw_img = draw_structure_result(img, res, args.vis_font_path) img_save_path = os.path.join(save_folder, img_name, 'show.jpg') elif structure_sys.mode == 'vqa': - draw_img = draw_ser_results(img, res, args.vis_font_path) - img_save_path = os.path.join(save_folder, img_name + '.jpg') + raise NotImplementedError + # draw_img = draw_ser_results(img, res, args.vis_font_path) + # img_save_path = os.path.join(save_folder, img_name + '.jpg') cv2.imwrite(img_save_path, draw_img) logger.info('result save to {}'.format(img_save_path)) elapse = time.time() - starttime diff --git a/ppstructure/utility.py b/ppstructure/utility.py index 43cb0b0873812baf3ce2dc689fb62f1d0ca2c551..10d9f71a7cdfed00b555c46689b2dd3c5aad807c 100644 --- a/ppstructure/utility.py +++ b/ppstructure/utility.py @@ -40,12 +40,6 @@ def init_args(): type=ast.literal_eval, default=None, help='label map according to ppstructure/layout/README_ch.md') - # params for ser - parser.add_argument("--model_name_or_path", type=str) - parser.add_argument("--max_seq_length", type=int, default=512) - parser.add_argument( - "--label_map_path", type=str, default='./vqa/labels/labels_ser.txt') - parser.add_argument( "--mode", type=str, @@ -67,10 +61,10 @@ def draw_structure_result(image, result, font_path): if region['type'] == 'Table': pass else: - for box, rec_res in zip(region['res'][0], region['res'][1]): - boxes.append(np.array(box).reshape(-1, 2)) - txts.append(rec_res[0]) - scores.append(rec_res[1]) + for text_result in region['res']: + boxes.append(np.array(text_result['text_region'])) + txts.append(text_result['text']) + scores.append(text_result['confidence']) im_show = draw_ocr_box_txt( image, boxes, txts, scores, font_path=font_path, drop_score=0) return im_show diff --git a/tools/test_hubserving.py b/tools/test_hubserving.py index 7051e98c40b362d92af6f04add6d4c5b4dfb12a2..5d75aacb4ebb90b369172b8eddd83a0e4d6a33cb 100755 --- a/tools/test_hubserving.py +++ b/tools/test_hubserving.py @@ -27,7 +27,7 @@ from PIL import Image from ppocr.utils.utility import get_image_file_list from tools.infer.utility import draw_ocr, draw_boxes, str2bool from ppstructure.utility import draw_structure_result -from ppstructure.predict_system import save_structure_res, to_excel +from ppstructure.predict_system import to_excel import requests import json @@ -71,6 +71,31 @@ def draw_server_result(image_file, res): return draw_img +def save_structure_res(res, save_folder, image_file): + img = cv2.imread(image_file) + excel_save_folder = os.path.join(save_folder, os.path.basename(image_file)) + os.makedirs(excel_save_folder, exist_ok=True) + # save res + with open( + os.path.join(excel_save_folder, 'res.txt'), 'w', + encoding='utf8') as f: + for region in res: + if region['type'] == 'Table': + excel_path = os.path.join(excel_save_folder, + '{}.xlsx'.format(region['bbox'])) + to_excel(region['res'], excel_path) + elif region['type'] == 'Figure': + x1, y1, x2, y2 = region['bbox'] + print(region['bbox']) + roi_img = img[y1:y2, x1:x2, :] + img_path = os.path.join(excel_save_folder, + '{}.jpg'.format(region['bbox'])) + cv2.imwrite(img_path, roi_img) + else: + for text_result in region['res']: + f.write('{}\n'.format(json.dumps(text_result))) + + def main(args): image_file_list = get_image_file_list(args.image_dir) is_visualize = False @@ -97,20 +122,19 @@ def main(args): if args.visualize: draw_img = None if 'structure_table' in args.server_url: - to_excel(res, './{}.xlsx'.format(img_name)) + to_excel(res['html'], './{}.xlsx'.format(img_name)) elif 'structure_system' in args.server_url: - pass + save_structure_res(res['regions'], args.output, image_file) else: draw_img = draw_server_result(image_file, res) if draw_img is not None: - draw_img_save = "./server_results/" - if not os.path.exists(draw_img_save): - os.makedirs(draw_img_save) + if not os.path.exists(args.output): + os.makedirs(args.output) cv2.imwrite( - os.path.join(draw_img_save, os.path.basename(image_file)), + os.path.join(args.output, os.path.basename(image_file)), draw_img[:, :, ::-1]) logger.info("The visualized image saved in {}".format( - os.path.join(draw_img_save, os.path.basename(image_file)))) + os.path.join(args.output, os.path.basename(image_file)))) cnt += 1 if cnt % 100 == 0: logger.info("{} processed".format(cnt)) @@ -123,6 +147,7 @@ def parse_args(): parser.add_argument("--server_url", type=str, required=True) parser.add_argument("--image_dir", type=str, required=True) parser.add_argument("--visualize", type=str2bool, default=False) + parser.add_argument("--output", type=str, default='./hubserving_result') args = parser.parse_args() return args