From eaf38b9b12cb529daa9eb920b843c7754dac38a2 Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Wed, 26 May 2021 17:34:47 +0800 Subject: [PATCH] combine args in paddleocr and ppocr/infer/utility --- doc/doc_ch/whl.md | 2 +- doc/doc_en/whl_en.md | 2 +- paddleocr.py | 140 ++++---------- tools/infer/predict_system.py | 4 +- tools/infer/utility.py | 341 +++++++++++++++++++++++++++------- 5 files changed, 309 insertions(+), 180 deletions(-) diff --git a/doc/doc_ch/whl.md b/doc/doc_ch/whl.md index 2e93c487..c341b49a 100644 --- a/doc/doc_ch/whl.md +++ b/doc/doc_ch/whl.md @@ -59,7 +59,7 @@ im_show.save('result.jpg') from paddleocr import PaddleOCR, draw_ocr ocr = PaddleOCR() # need to run only once to download and load model into memory img_path = 'PaddleOCR/doc/imgs/11.jpg' -result = ocr.ocr(img_path) +result = ocr.ocr(img_path,cls=False) for line in result: print(line) diff --git a/doc/doc_en/whl_en.md b/doc/doc_en/whl_en.md index 69abf085..eeaf1347 100644 --- a/doc/doc_en/whl_en.md +++ b/doc/doc_en/whl_en.md @@ -59,7 +59,7 @@ Visualization of results from paddleocr import PaddleOCR,draw_ocr ocr = PaddleOCR(lang='en') # need to run only once to download and load model into memory img_path = 'PaddleOCR/doc/imgs_en/img_12.jpg' -result = ocr.ocr(img_path) +result = ocr.ocr(img_path, cls=False) for line in result: print(line) diff --git a/paddleocr.py b/paddleocr.py index c5da7248..5e14deb1 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -30,7 +30,7 @@ from ppocr.utils.logging import get_logger logger = get_logger() from ppocr.utils.utility import check_and_read_gif, get_image_file_list -from tools.infer.utility import draw_ocr +from tools.infer.utility import draw_ocr, inference_args_list, str2bool, parse_args __all__ = ['PaddleOCR'] @@ -167,106 +167,36 @@ def maybe_download(model_storage_directory, url): os.remove(tmp_path) -def parse_args(mMain=True, add_help=True): +def parse_args_whl(mMain=True): import argparse - - def str2bool(v): - return v.lower() in ("true", "t", "1") - + extend_args_list = [ + { + 'name': 'lang', + 'type': str, + 'default': 'ch' + }, + { + 'name': 'det', + 'type': str2bool, + 'default': True + }, + { + 'name': 'rec', + 'type': str2bool, + 'default': True + }, + ] + for item in inference_args_list: + if item['name'] == 'rec_char_dict_path': + item['default'] = None + inference_args_list.extend(extend_args_list) if mMain: - parser = argparse.ArgumentParser(add_help=add_help) - # params for prediction engine - parser.add_argument("--use_gpu", type=str2bool, default=True) - parser.add_argument("--ir_optim", type=str2bool, default=True) - parser.add_argument("--use_tensorrt", type=str2bool, default=False) - parser.add_argument("--gpu_mem", type=int, default=8000) - - # params for text detector - parser.add_argument("--image_dir", type=str) - parser.add_argument("--det_algorithm", type=str, default='DB') - parser.add_argument("--det_model_dir", type=str, default=None) - parser.add_argument("--det_limit_side_len", type=float, default=960) - parser.add_argument("--det_limit_type", type=str, default='max') - - # DB parmas - parser.add_argument("--det_db_thresh", type=float, default=0.3) - parser.add_argument("--det_db_box_thresh", type=float, default=0.5) - parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) - parser.add_argument("--use_dilation", type=bool, default=False) - parser.add_argument("--det_db_score_mode", type=str, default="fast") - - # EAST parmas - parser.add_argument("--det_east_score_thresh", type=float, default=0.8) - parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) - parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) - - # params for text recognizer - parser.add_argument("--rec_algorithm", type=str, default='CRNN') - parser.add_argument("--rec_model_dir", type=str, default=None) - parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") - parser.add_argument("--rec_char_type", type=str, default='ch') - parser.add_argument("--rec_batch_num", type=int, default=6) - parser.add_argument("--max_text_length", type=int, default=25) - parser.add_argument("--rec_char_dict_path", type=str, default=None) - parser.add_argument("--use_space_char", type=bool, default=True) - parser.add_argument("--drop_score", type=float, default=0.5) - - # params for text classifier - parser.add_argument("--cls_model_dir", type=str, default=None) - parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") - parser.add_argument("--label_list", type=list, default=['0', '180']) - parser.add_argument("--cls_batch_num", type=int, default=6) - parser.add_argument("--cls_thresh", type=float, default=0.9) - - parser.add_argument("--enable_mkldnn", type=bool, default=False) - parser.add_argument("--use_zero_copy_run", type=bool, default=False) - parser.add_argument("--use_pdserving", type=str2bool, default=False) - - parser.add_argument("--lang", type=str, default='ch') - parser.add_argument("--det", type=str2bool, default=True) - parser.add_argument("--rec", type=str2bool, default=True) - parser.add_argument("--use_angle_cls", type=str2bool, default=False) - return parser.parse_args() + return parse_args() else: - return argparse.Namespace( - use_gpu=True, - ir_optim=True, - use_tensorrt=False, - gpu_mem=8000, - image_dir='', - det_algorithm='DB', - det_model_dir=None, - det_limit_side_len=960, - det_limit_type='max', - det_db_thresh=0.3, - det_db_box_thresh=0.5, - det_db_unclip_ratio=1.6, - use_dilation=False, - det_db_score_mode="fast", - det_east_score_thresh=0.8, - det_east_cover_thresh=0.1, - det_east_nms_thresh=0.2, - rec_algorithm='CRNN', - rec_model_dir=None, - rec_image_shape="3, 32, 320", - rec_char_type='ch', - rec_batch_num=6, - max_text_length=25, - rec_char_dict_path=None, - use_space_char=True, - drop_score=0.5, - cls_model_dir=None, - cls_image_shape="3, 48, 192", - label_list=['0', '180'], - cls_batch_num=6, - cls_thresh=0.9, - enable_mkldnn=False, - use_zero_copy_run=False, - use_pdserving=False, - lang='ch', - det=True, - rec=True, - use_angle_cls=False) + inference_args_dict = {} + for item in inference_args_list: + inference_args_dict[item['name']] = item['default'] + return argparse.Namespace(**inference_args_dict) class PaddleOCR(predict_system.TextSystem): @@ -276,7 +206,7 @@ class PaddleOCR(predict_system.TextSystem): args: **kwargs: other params show in paddleocr --help """ - postprocess_params = parse_args(mMain=False, add_help=False) + postprocess_params = parse_args_whl(mMain=False) postprocess_params.__dict__.update(**kwargs) self.use_angle_cls = postprocess_params.use_angle_cls lang = postprocess_params.lang @@ -346,7 +276,7 @@ class PaddleOCR(predict_system.TextSystem): # init det_model and rec_model super().__init__(postprocess_params) - def ocr(self, img, det=True, rec=True, cls=False): + def ocr(self, img, det=True, rec=True, cls=True): """ ocr with paddleocr args: @@ -358,9 +288,7 @@ class PaddleOCR(predict_system.TextSystem): if isinstance(img, list) and det == True: logger.error('When input a list of images, det must be false') exit(0) - if cls == False: - self.use_angle_cls = False - elif cls == True and self.use_angle_cls == False: + if cls == True and self.use_angle_cls == False: logger.warning( 'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process' ) @@ -382,7 +310,7 @@ class PaddleOCR(predict_system.TextSystem): if isinstance(img, np.ndarray) and len(img.shape) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) if det and rec: - dt_boxes, rec_res = self.__call__(img) + dt_boxes, rec_res = self.__call__(img, cls) return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] elif det and not rec: dt_boxes, elapse = self.text_detector(img) @@ -392,7 +320,7 @@ class PaddleOCR(predict_system.TextSystem): else: if not isinstance(img, list): img = [img] - if self.use_angle_cls: + if self.use_angle_cls and cls: img, cls_res, elapse = self.text_classifier(img) if not rec: return cls_res @@ -402,7 +330,7 @@ class PaddleOCR(predict_system.TextSystem): def main(): # for cmd - args = parse_args(mMain=True) + args = parse_args_whl(mMain=True) image_dir = args.image_dir if image_dir.startswith('http'): download_with_progressbar(image_dir, 'tmp.jpg') diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index ba81aff0..78f5a472 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -85,7 +85,7 @@ class TextSystem(object): cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno]) logger.info(bno, rec_res[bno]) - def __call__(self, img): + def __call__(self, img, cls=True): ori_im = img.copy() dt_boxes, elapse = self.text_detector(img) logger.info("dt_boxes num : {}, elapse : {}".format( @@ -100,7 +100,7 @@ class TextSystem(object): tmp_box = copy.deepcopy(dt_boxes[bno]) img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop_list.append(img_crop) - if self.use_angle_cls: + if self.use_angle_cls and cls: img_crop_list, angle_list, elapse = self.text_classifier( img_crop_list) logger.info("cls num : {}, elapse : {}".format( diff --git a/tools/infer/utility.py b/tools/infer/utility.py index b5fe3ba9..3aa2a01e 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -23,87 +23,288 @@ import math from paddle import inference -def parse_args(): - def str2bool(v): - return v.lower() in ("true", "t", "1") +def str2bool(v): + return v.lower() in ("true", "t", "1") - parser = argparse.ArgumentParser() - # params for prediction engine - parser.add_argument("--use_gpu", type=str2bool, default=True) - parser.add_argument("--ir_optim", type=str2bool, default=True) - parser.add_argument("--use_tensorrt", type=str2bool, default=False) - parser.add_argument("--use_fp16", type=str2bool, default=False) - parser.add_argument("--gpu_mem", type=int, default=500) +inference_args_list = [ + # params for prediction engine + { + 'name': 'use_gpu', + 'type': str2bool, + 'default': True + }, + { + 'name': 'ir_optim', + 'type': str2bool, + 'default': True + }, + { + 'name': 'use_tensorrt', + 'type': str2bool, + 'default': False + }, + { + 'name': 'use_fp16', + 'type': str2bool, + 'default': False + }, + { + 'name': 'enable_mkldnn', + 'type': str2bool, + 'default': False + }, + { + 'name': 'use_pdserving', + 'type': str2bool, + 'default': False + }, + { + 'name': 'use_mp', + 'type': str2bool, + 'default': False + }, + { + 'name': 'total_process_num', + 'type': int, + 'default': 1 + }, + { + 'name': 'process_id', + 'type': int, + 'default': 0 + }, + { + 'name': 'gpu_mem', + 'type': int, + 'default': 500 + }, # params for text detector - parser.add_argument("--image_dir", type=str) - parser.add_argument("--det_algorithm", type=str, default='DB') - parser.add_argument("--det_model_dir", type=str) - parser.add_argument("--det_limit_side_len", type=float, default=960) - parser.add_argument("--det_limit_type", type=str, default='max') - + { + 'name': 'image_dir', + 'type': str, + 'default': None + }, + { + 'name': 'det_algorithm', + 'type': str, + 'default': 'DB' + }, + { + 'name': 'det_model_dir', + 'type': str, + 'default': None + }, + { + 'name': 'det_limit_side_len', + 'type': float, + 'default': 960 + }, + { + 'name': 'det_limit_type', + 'type': str, + 'default': 'max' + }, # DB parmas - parser.add_argument("--det_db_thresh", type=float, default=0.3) - parser.add_argument("--det_db_box_thresh", type=float, default=0.5) - parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) - parser.add_argument("--max_batch_size", type=int, default=10) - parser.add_argument("--use_dilation", type=bool, default=False) - parser.add_argument("--det_db_score_mode", type=str, default="fast") + { + 'name': 'det_db_thresh', + 'type': float, + 'default': 0.3 + }, + { + 'name': 'det_db_box_thresh', + 'type': float, + 'default': 0.5 + }, + { + 'name': 'det_db_unclip_ratio', + 'type': float, + 'default': 1.6 + }, + { + 'name': 'max_batch_size', + 'type': int, + 'default': 10 + }, + { + 'name': 'use_dilation', + 'type': str2bool, + 'default': False + }, + { + 'name': 'det_db_score_mode', + 'type': str, + 'default': 'fast' + }, # EAST parmas - parser.add_argument("--det_east_score_thresh", type=float, default=0.8) - parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) - parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) - + { + 'name': 'det_east_score_thresh', + 'type': float, + 'default': 0.8 + }, + { + 'name': 'det_east_cover_thresh', + 'type': float, + 'default': 0.1 + }, + { + 'name': 'det_east_nms_thresh', + 'type': float, + 'default': 0.2 + }, # SAST parmas - parser.add_argument("--det_sast_score_thresh", type=float, default=0.5) - parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) - parser.add_argument("--det_sast_polygon", type=bool, default=False) - + { + 'name': 'det_sast_score_thresh', + 'type': float, + 'default': 0.5 + }, + { + 'name': 'det_sast_nms_thresh', + 'type': float, + 'default': 0.2 + }, + { + 'name': 'det_sast_polygon', + 'type': str2bool, + 'default': False + }, # params for text recognizer - parser.add_argument("--rec_algorithm", type=str, default='CRNN') - parser.add_argument("--rec_model_dir", type=str) - parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") - parser.add_argument("--rec_char_type", type=str, default='ch') - parser.add_argument("--rec_batch_num", type=int, default=6) - parser.add_argument("--max_text_length", type=int, default=25) - parser.add_argument( - "--rec_char_dict_path", - type=str, - default="./ppocr/utils/ppocr_keys_v1.txt") - parser.add_argument("--use_space_char", type=str2bool, default=True) - parser.add_argument( - "--vis_font_path", type=str, default="./doc/fonts/simfang.ttf") - parser.add_argument("--drop_score", type=float, default=0.5) - + { + 'name': 'rec_algorithm', + 'type': str, + 'default': 'CRNN' + }, + { + 'name': 'rec_model_dir', + 'type': str, + 'default': None + }, + { + 'name': 'rec_image_shape', + 'type': str, + 'default': '3, 32, 320' + }, + { + 'name': 'rec_char_type', + 'type': str, + 'default': "ch" + }, + { + 'name': 'rec_batch_num', + 'type': int, + 'default': 6 + }, + { + 'name': 'max_text_length', + 'type': int, + 'default': 25 + }, + { + 'name': 'rec_char_dict_path', + 'type': str, + 'default': './ppocr/utils/ppocr_keys_v1.txt' + }, + { + 'name': 'use_space_char', + 'type': str2bool, + 'default': True + }, + { + 'name': 'vis_font_path', + 'type': str, + 'default': './doc/fonts/simfang.ttf' + }, + { + 'name': 'drop_score', + 'type': float, + 'default': 0.5 + }, # params for e2e - parser.add_argument("--e2e_algorithm", type=str, default='PGNet') - parser.add_argument("--e2e_model_dir", type=str) - parser.add_argument("--e2e_limit_side_len", type=float, default=768) - parser.add_argument("--e2e_limit_type", type=str, default='max') - + { + 'name': 'e2e_algorithm', + 'type': str, + 'default': 'PGNet' + }, + { + 'name': 'e2e_model_dir', + 'type': str, + 'default': None + }, + { + 'name': 'e2e_limit_side_len', + 'type': float, + 'default': 768 + }, + { + 'name': 'e2e_limit_type', + 'type': str, + 'default': 'max' + }, # PGNet parmas - parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5) - parser.add_argument( - "--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt") - parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext') - parser.add_argument("--e2e_pgnet_polygon", type=bool, default=True) - parser.add_argument("--e2e_pgnet_mode", type=str, default='fast') - + { + 'name': 'e2e_pgnet_score_thresh', + 'type': float, + 'default': 0.5 + }, + { + 'name': 'e2e_char_dict_path', + 'type': str, + 'default': './ppocr/utils/ic15_dict.txt' + }, + { + 'name': 'e2e_pgnet_valid_set', + 'type': str, + 'default': 'totaltext' + }, + { + 'name': 'e2e_pgnet_polygon', + 'type': str2bool, + 'default': True + }, + { + 'name': 'e2e_pgnet_mode', + 'type': str, + 'default': 'fast' + }, # params for text classifier - parser.add_argument("--use_angle_cls", type=str2bool, default=False) - parser.add_argument("--cls_model_dir", type=str) - parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") - parser.add_argument("--label_list", type=list, default=['0', '180']) - parser.add_argument("--cls_batch_num", type=int, default=6) - parser.add_argument("--cls_thresh", type=float, default=0.9) + { + 'name': 'use_angle_cls', + 'type': str2bool, + 'default': False + }, + { + 'name': 'cls_model_dir', + 'type': str, + 'default': None + }, + { + 'name': 'cls_image_shape', + 'type': str, + 'default': '3, 48, 192' + }, + { + 'name': 'label_list', + 'type': list, + 'default': ['0', '180'] + }, + { + 'name': 'cls_batch_num', + 'type': int, + 'default': 6 + }, + { + 'name': 'cls_thresh', + 'type': float, + 'default': 0.9 + }, +] - parser.add_argument("--enable_mkldnn", type=str2bool, default=False) - parser.add_argument("--use_pdserving", type=str2bool, default=False) - - parser.add_argument("--use_mp", type=str2bool, default=False) - parser.add_argument("--total_process_num", type=int, default=1) - parser.add_argument("--process_id", type=int, default=0) +def parse_args(): + parser = argparse.ArgumentParser() + for item in inference_args_list: + parser.add_argument( + '--' + item['name'], type=item['type'], default=item['default']) return parser.parse_args() @@ -146,7 +347,7 @@ def create_predictor(args, mode, logger): config.set_mkldnn_cache_capacity(10) config.enable_mkldnn() # TODO LDOUBLEV: fix mkldnn bug when bach_size > 1 - #config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'}) + # config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'}) args.rec_batch_num = 1 # enable memory optim -- GitLab