diff --git a/paddleocr.py b/paddleocr.py index 7c45c4381fdc4e0eddf46969692dfe8fb729d593..1e4d94ff4e72da951e1ffb92edb50715482581ae 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -30,7 +30,7 @@ from ppocr.utils.logging import get_logger logger = get_logger() from ppocr.utils.utility import check_and_read_gif, get_image_file_list -from tools.infer.utility import draw_ocr, inference_args_list, str2bool, parse_args +from tools.infer.utility import draw_ocr, init_args, str2bool __all__ = ['PaddleOCR'] @@ -167,23 +167,23 @@ def maybe_download(model_storage_directory, url): os.remove(tmp_path) -def parse_args_whl(mMain=True): +def parse_args(mMain=True): import argparse - extend_args_list = [ - ['lang', str, 'ch'], - ['det', str2bool, True], - ['rec', str2bool, True], - ] - for item in inference_args_list: - if item[0] == 'rec_char_dict_path': - item[2] = None - inference_args_list.extend(extend_args_list) + parser = init_args() + parser.add_help = mMain + parser.add_argument("--lang", type=str, default='ch') + parser.add_argument("--det", type=str2bool, default=True) + parser.add_argument("--rec", type=str2bool, default=True) + + for action in parser._actions: + if action.dest == 'rec_char_dict_path': + action.default = None if mMain: - return parse_args() + return parser.parse_args() else: inference_args_dict = {} - for item in inference_args_list: - inference_args_dict[item[0]] = item[2] + for action in parser._actions: + inference_args_dict[action.dest] = action.default return argparse.Namespace(**inference_args_dict) @@ -194,7 +194,7 @@ class PaddleOCR(predict_system.TextSystem): args: **kwargs: other params show in paddleocr --help """ - postprocess_params = parse_args_whl(mMain=False) + postprocess_params = parse_args(mMain=False) postprocess_params.__dict__.update(**kwargs) self.use_angle_cls = postprocess_params.use_angle_cls lang = postprocess_params.lang @@ -318,7 +318,7 @@ class PaddleOCR(predict_system.TextSystem): def main(): # for cmd - args = parse_args_whl(mMain=True) + args = parse_args(mMain=True) image_dir = args.image_dir if image_dir.startswith('http'): download_with_progressbar(image_dir, 'tmp.jpg') diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 7b543353d2f9accb94b0c855a1e4d199dec8a6aa..3f0ff2ff64bff2c2e70be37a95b5449deaa90046 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -23,6 +23,7 @@ import math from paddle import inference import time from ppocr.utils.logging import get_logger + logger = get_logger() @@ -30,77 +31,90 @@ def str2bool(v): return v.lower() in ("true", "t", "1") -inference_args_list = [ - # name type defalue +def init_args(): + parser = argparse.ArgumentParser() # params for prediction engine - ['use_gpu', str2bool, True], - ['use_tensorrt', str2bool, False], - ['use_fp16', str2bool, False], - ['use_pdserving', str2bool, False], - ['use_mp', str2bool, False], - ['enable_mkldnn', str2bool, False], - ['ir_optim', str2bool, True], - ['total_process_num', int, 1], - ['process_id', int, 0], - ['gpu_mem', int, 500], - ['cpu_threads', int, 10], + parser.add_argument("--use_gpu", type=str2bool, default=True) + parser.add_argument("--ir_optim", type=str2bool, default=True) + parser.add_argument("--use_tensorrt", type=str2bool, default=False) + parser.add_argument("--use_fp16", type=str2bool, default=False) + parser.add_argument("--gpu_mem", type=int, default=500) + # params for text detector - ['image_dir', str, None], - ['det_algorithm', str, 'DB'], - ['det_model_dir', str, None], - ['det_limit_side_len', float, 960], - ['det_limit_type', str, 'max'], + parser.add_argument("--image_dir", type=str) + parser.add_argument("--det_algorithm", type=str, default='DB') + parser.add_argument("--det_model_dir", type=str) + parser.add_argument("--det_limit_side_len", type=float, default=960) + parser.add_argument("--det_limit_type", type=str, default='max') + # DB parmas - ['det_db_thresh', float, 0.3], - ['det_db_box_thresh', float, 0.5], - ['det_db_unclip_ratio', float, 1.6], - ['max_batch_size', int, 10], - ['use_dilation', str2bool, False], - ['det_db_score_mode', str, 'fast'], + parser.add_argument("--det_db_thresh", type=float, default=0.3) + parser.add_argument("--det_db_box_thresh", type=float, default=0.5) + parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) + parser.add_argument("--max_batch_size", type=int, default=10) + parser.add_argument("--use_dilation", type=bool, default=False) + parser.add_argument("--det_db_score_mode", type=str, default="fast") # EAST parmas - ['det_east_score_thresh', float, 0.8], - ['det_east_cover_thresh', float, 0.1], - ['det_east_nms_thresh', float, 0.2], + parser.add_argument("--det_east_score_thresh", type=float, default=0.8) + parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) + parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) + # SAST parmas - ['det_sast_score_thresh', float, 0.5], - ['det_sast_nms_thresh', float, 0.2], - ['det_sast_polygon', str2bool, False], + parser.add_argument("--det_sast_score_thresh", type=float, default=0.5) + parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) + parser.add_argument("--det_sast_polygon", type=bool, default=False) + # params for text recognizer - ['rec_algorithm', str, 'CRNN'], - ['rec_model_dir', str, None], - ['rec_image_shape', str, '3, 32, 320'], - ['rec_char_type', str, "ch"], - ['rec_batch_num', int, 6], - ['max_text_length', int, 25], - ['rec_char_dict_path', str, './ppocr/utils/ppocr_keys_v1.txt'], - ['use_space_char', str2bool, True], - ['vis_font_path', str, './doc/fonts/simfang.ttf'], - ['drop_score', float, 0.5], + parser.add_argument("--rec_algorithm", type=str, default='CRNN') + parser.add_argument("--rec_model_dir", type=str) + parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") + parser.add_argument("--rec_char_type", type=str, default='ch') + parser.add_argument("--rec_batch_num", type=int, default=6) + parser.add_argument("--max_text_length", type=int, default=25) + parser.add_argument( + "--rec_char_dict_path", + type=str, + default="./ppocr/utils/ppocr_keys_v1.txt") + parser.add_argument("--use_space_char", type=str2bool, default=True) + parser.add_argument( + "--vis_font_path", type=str, default="./doc/fonts/simfang.ttf") + parser.add_argument("--drop_score", type=float, default=0.5) + # params for e2e - ['e2e_algorithm', str, 'PGNet'], - ['e2e_model_dir', str, None], - ['e2e_limit_side_len', float, 768], - ['e2e_limit_type', str, 'max'], + parser.add_argument("--e2e_algorithm", type=str, default='PGNet') + parser.add_argument("--e2e_model_dir", type=str) + parser.add_argument("--e2e_limit_side_len", type=float, default=768) + parser.add_argument("--e2e_limit_type", type=str, default='max') + # PGNet parmas - ['e2e_pgnet_score_thresh', float, 0.5], - ['e2e_char_dict_path', str, './ppocr/utils/ic15_dict.txt'], - ['e2e_pgnet_valid_set', str, 'totaltext'], - ['e2e_pgnet_polygon', str2bool, True], - ['e2e_pgnet_mode', str, 'fast'], + parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5) + parser.add_argument( + "--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt") + parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext') + parser.add_argument("--e2e_pgnet_polygon", type=bool, default=True) + parser.add_argument("--e2e_pgnet_mode", type=str, default='fast') + # params for text classifier - ['use_angle_cls', str2bool, False], - ['cls_model_dir', str, None], - ['cls_image_shape', str, '3, 48, 192'], - ['label_list', list, ['0', '180']], - ['cls_batch_num', int, 6], - ['cls_thresh', float, 0.9], -] + parser.add_argument("--use_angle_cls", type=str2bool, default=False) + parser.add_argument("--cls_model_dir", type=str) + parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") + parser.add_argument("--label_list", type=list, default=['0', '180']) + parser.add_argument("--cls_batch_num", type=int, default=6) + parser.add_argument("--cls_thresh", type=float, default=0.9) + + parser.add_argument("--enable_mkldnn", type=str2bool, default=False) + parser.add_argument("--cpu_threads", type=int, default=10) + parser.add_argument("--use_pdserving", type=str2bool, default=False) + + parser.add_argument("--use_mp", type=str2bool, default=False) + parser.add_argument("--total_process_num", type=int, default=1) + parser.add_argument("--process_id", type=int, default=0) + + return parser def parse_args(): - parser = argparse.ArgumentParser() - for item in inference_args_list: - parser.add_argument('--' + item[0], type=item[1], default=item[2]) + parser = init_args() return parser.parse_args() @@ -217,8 +231,8 @@ def create_predictor(args, mode, logger): if hasattr(args, "cpu_threads"): config.set_cpu_math_library_num_threads(args.cpu_threads) else: - config.set_cpu_math_library_num_threads( - 10) # default cpu threads as 10 + # default cpu threads as 10 + config.set_cpu_math_library_num_threads(10) if args.enable_mkldnn: # cache 10 different shapes for mkldnn to avoid memory leak config.set_mkldnn_cache_capacity(10)