diff --git a/paddleocr.py b/paddleocr.py index 3958f7ad5727e52f405cd189cb5d5a1da911cbdf..8e39084b9a2d4ea3464ad1c753a1f2ba2164a305 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -117,61 +117,98 @@ def maybe_download(model_storage_directory, url): os.remove(tmp_path) -def parse_args(): +def parse_args(mMain=True): import argparse def str2bool(v): return v.lower() in ("true", "t", "1") - parser = argparse.ArgumentParser() - # params for prediction engine - parser.add_argument("--use_gpu", type=str2bool, default=True) - parser.add_argument("--ir_optim", type=str2bool, default=True) - parser.add_argument("--use_tensorrt", type=str2bool, default=False) - parser.add_argument("--gpu_mem", type=int, default=8000) - - # params for text detector - parser.add_argument("--image_dir", type=str) - parser.add_argument("--det_algorithm", type=str, default='DB') - parser.add_argument("--det_model_dir", type=str, default=None) - parser.add_argument("--det_max_side_len", type=float, default=960) - - # DB parmas - parser.add_argument("--det_db_thresh", type=float, default=0.3) - parser.add_argument("--det_db_box_thresh", type=float, default=0.5) - parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0) - - # EAST parmas - parser.add_argument("--det_east_score_thresh", type=float, default=0.8) - parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) - parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) - - # params for text recognizer - parser.add_argument("--rec_algorithm", type=str, default='CRNN') - parser.add_argument("--rec_model_dir", type=str, default=None) - parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") - parser.add_argument("--rec_char_type", type=str, default='ch') - parser.add_argument("--rec_batch_num", type=int, default=30) - parser.add_argument("--max_text_length", type=int, default=25) - parser.add_argument("--rec_char_dict_path", type=str, default=None) - parser.add_argument("--use_space_char", type=bool, default=True) - - # params for text classifier - parser.add_argument("--cls_model_dir", type=str, default=None) - parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") - parser.add_argument("--label_list", type=list, default=['0', '180']) - parser.add_argument("--cls_batch_num", type=int, default=30) - parser.add_argument("--cls_thresh", type=float, default=0.9) - - parser.add_argument("--enable_mkldnn", type=bool, default=False) - parser.add_argument("--use_zero_copy_run", type=bool, default=False) - parser.add_argument("--use_pdserving", type=str2bool, default=False) - - parser.add_argument("--lang", type=str, default='ch') - parser.add_argument("--det", type=str2bool, default=True) - parser.add_argument("--rec", type=str2bool, default=True) - parser.add_argument("--use_angle_cls", type=str2bool, default=True) - return parser.parse_args() + if mMain: + parser = argparse.ArgumentParser() + # params for prediction engine + parser.add_argument("--use_gpu", type=str2bool, default=True) + parser.add_argument("--ir_optim", type=str2bool, default=True) + parser.add_argument("--use_tensorrt", type=str2bool, default=False) + parser.add_argument("--gpu_mem", type=int, default=8000) + + # params for text detector + parser.add_argument("--image_dir", type=str) + parser.add_argument("--det_algorithm", type=str, default='DB') + parser.add_argument("--det_model_dir", type=str, default=None) + parser.add_argument("--det_max_side_len", type=float, default=960) + + # DB parmas + parser.add_argument("--det_db_thresh", type=float, default=0.3) + parser.add_argument("--det_db_box_thresh", type=float, default=0.5) + parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0) + + # EAST parmas + parser.add_argument("--det_east_score_thresh", type=float, default=0.8) + parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) + parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) + + # params for text recognizer + parser.add_argument("--rec_algorithm", type=str, default='CRNN') + parser.add_argument("--rec_model_dir", type=str, default=None) + parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") + parser.add_argument("--rec_char_type", type=str, default='ch') + parser.add_argument("--rec_batch_num", type=int, default=30) + parser.add_argument("--max_text_length", type=int, default=25) + parser.add_argument("--rec_char_dict_path", type=str, default=None) + parser.add_argument("--use_space_char", type=bool, default=True) + + # params for text classifier + parser.add_argument("--cls_model_dir", type=str, default=None) + parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") + parser.add_argument("--label_list", type=list, default=['0', '180']) + parser.add_argument("--cls_batch_num", type=int, default=30) + parser.add_argument("--cls_thresh", type=float, default=0.9) + + parser.add_argument("--enable_mkldnn", type=bool, default=False) + parser.add_argument("--use_zero_copy_run", type=bool, default=False) + parser.add_argument("--use_pdserving", type=str2bool, default=False) + + parser.add_argument("--lang", type=str, default='ch') + parser.add_argument("--det", type=str2bool, default=True) + parser.add_argument("--rec", type=str2bool, default=True) + parser.add_argument("--use_angle_cls", type=str2bool, default=True) + return parser.parse_args() + else: + return argparse.Namespace( use_gpu=True, + ir_optim=True, + use_tensorrt=False, + gpu_mem=8000, + image_dir='', + det_algorithm='DB', + det_model_dir=None, + det_max_side_len=960, + det_db_thresh=0.3, + det_db_box_thresh=0.5, + det_db_unclip_ratio=2.0, + det_east_score_thresh=0.8, + det_east_cover_thresh=0.1, + det_east_nms_thresh=0.2, + rec_algorithm='CRNN', + rec_model_dir=None, + rec_image_shape="3, 32, 320", + rec_char_type='ch', + rec_batch_num=30, + max_text_length=25, + rec_char_dict_path=None, + use_space_char=True, + cls_model_dir=None, + cls_image_shape="3, 48, 192", + label_list=['0', '180'], + cls_batch_num=30, + cls_thresh=0.9, + enable_mkldnn=False, + use_zero_copy_run=False, + use_pdserving=False, + lang='ch', + det=True, + rec=True, + use_angle_cls=True + ) class PaddleOCR(predict_system.TextSystem): @@ -181,7 +218,7 @@ class PaddleOCR(predict_system.TextSystem): args: **kwargs: other params show in paddleocr --help """ - postprocess_params = parse_args() + postprocess_params = parse_args(mMain=False) postprocess_params.__dict__.update(**kwargs) self.use_angle_cls = postprocess_params.use_angle_cls lang = postprocess_params.lang @@ -259,12 +296,13 @@ class PaddleOCR(predict_system.TextSystem): def main(): # for com - args = parse_args() + args = parse_args(mMain=True) image_file_list = get_image_file_list(args.image_dir) if len(image_file_list) == 0: logger.error('no images find in {}'.format(args.image_dir)) return - ocr_engine = PaddleOCR() + + ocr_engine = PaddleOCR(**(args.__dict__)) for img_path in image_file_list: print(img_path) result = ocr_engine.ocr(img_path,