diff --git a/doc/doc_ch/whl.md b/doc/doc_ch/whl.md index 2e93c487c2f2071c7c89c753cf86eef61ce20805..c341b49a7b12aa10f0f3187bc861306fcae05c29 100644 --- a/doc/doc_ch/whl.md +++ b/doc/doc_ch/whl.md @@ -59,7 +59,7 @@ im_show.save('result.jpg') from paddleocr import PaddleOCR, draw_ocr ocr = PaddleOCR() # need to run only once to download and load model into memory img_path = 'PaddleOCR/doc/imgs/11.jpg' -result = ocr.ocr(img_path) +result = ocr.ocr(img_path,cls=False) for line in result: print(line) diff --git a/doc/doc_en/whl_en.md b/doc/doc_en/whl_en.md index 69abf085556f466853798077bb116b3986582bcc..eeaf1347dc77a24f158ba8ba2c6f013b1fd89b81 100644 --- a/doc/doc_en/whl_en.md +++ b/doc/doc_en/whl_en.md @@ -59,7 +59,7 @@ Visualization of results from paddleocr import PaddleOCR,draw_ocr ocr = PaddleOCR(lang='en') # need to run only once to download and load model into memory img_path = 'PaddleOCR/doc/imgs_en/img_12.jpg' -result = ocr.ocr(img_path) +result = ocr.ocr(img_path, cls=False) for line in result: print(line) diff --git a/paddleocr.py b/paddleocr.py index c5da7248d2cc7d778758a87309cfeaedcbd8ceb5..1e4d94ff4e72da951e1ffb92edb50715482581ae 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -30,7 +30,7 @@ from ppocr.utils.logging import get_logger logger = get_logger() from ppocr.utils.utility import check_and_read_gif, get_image_file_list -from tools.infer.utility import draw_ocr +from tools.infer.utility import draw_ocr, init_args, str2bool __all__ = ['PaddleOCR'] @@ -167,106 +167,24 @@ def maybe_download(model_storage_directory, url): os.remove(tmp_path) -def parse_args(mMain=True, add_help=True): +def parse_args(mMain=True): import argparse - - def str2bool(v): - return v.lower() in ("true", "t", "1") - + parser = init_args() + parser.add_help = mMain + parser.add_argument("--lang", type=str, default='ch') + parser.add_argument("--det", type=str2bool, default=True) + parser.add_argument("--rec", type=str2bool, default=True) + + for action in parser._actions: + if action.dest == 'rec_char_dict_path': + action.default = None if mMain: - parser = argparse.ArgumentParser(add_help=add_help) - # params for prediction engine - parser.add_argument("--use_gpu", type=str2bool, default=True) - parser.add_argument("--ir_optim", type=str2bool, default=True) - parser.add_argument("--use_tensorrt", type=str2bool, default=False) - parser.add_argument("--gpu_mem", type=int, default=8000) - - # params for text detector - parser.add_argument("--image_dir", type=str) - parser.add_argument("--det_algorithm", type=str, default='DB') - parser.add_argument("--det_model_dir", type=str, default=None) - parser.add_argument("--det_limit_side_len", type=float, default=960) - parser.add_argument("--det_limit_type", type=str, default='max') - - # DB parmas - parser.add_argument("--det_db_thresh", type=float, default=0.3) - parser.add_argument("--det_db_box_thresh", type=float, default=0.5) - parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) - parser.add_argument("--use_dilation", type=bool, default=False) - parser.add_argument("--det_db_score_mode", type=str, default="fast") - - # EAST parmas - parser.add_argument("--det_east_score_thresh", type=float, default=0.8) - parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) - parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) - - # params for text recognizer - parser.add_argument("--rec_algorithm", type=str, default='CRNN') - parser.add_argument("--rec_model_dir", type=str, default=None) - parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") - parser.add_argument("--rec_char_type", type=str, default='ch') - parser.add_argument("--rec_batch_num", type=int, default=6) - parser.add_argument("--max_text_length", type=int, default=25) - parser.add_argument("--rec_char_dict_path", type=str, default=None) - parser.add_argument("--use_space_char", type=bool, default=True) - parser.add_argument("--drop_score", type=float, default=0.5) - - # params for text classifier - parser.add_argument("--cls_model_dir", type=str, default=None) - parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") - parser.add_argument("--label_list", type=list, default=['0', '180']) - parser.add_argument("--cls_batch_num", type=int, default=6) - parser.add_argument("--cls_thresh", type=float, default=0.9) - - parser.add_argument("--enable_mkldnn", type=bool, default=False) - parser.add_argument("--use_zero_copy_run", type=bool, default=False) - parser.add_argument("--use_pdserving", type=str2bool, default=False) - - parser.add_argument("--lang", type=str, default='ch') - parser.add_argument("--det", type=str2bool, default=True) - parser.add_argument("--rec", type=str2bool, default=True) - parser.add_argument("--use_angle_cls", type=str2bool, default=False) return parser.parse_args() else: - return argparse.Namespace( - use_gpu=True, - ir_optim=True, - use_tensorrt=False, - gpu_mem=8000, - image_dir='', - det_algorithm='DB', - det_model_dir=None, - det_limit_side_len=960, - det_limit_type='max', - det_db_thresh=0.3, - det_db_box_thresh=0.5, - det_db_unclip_ratio=1.6, - use_dilation=False, - det_db_score_mode="fast", - det_east_score_thresh=0.8, - det_east_cover_thresh=0.1, - det_east_nms_thresh=0.2, - rec_algorithm='CRNN', - rec_model_dir=None, - rec_image_shape="3, 32, 320", - rec_char_type='ch', - rec_batch_num=6, - max_text_length=25, - rec_char_dict_path=None, - use_space_char=True, - drop_score=0.5, - cls_model_dir=None, - cls_image_shape="3, 48, 192", - label_list=['0', '180'], - cls_batch_num=6, - cls_thresh=0.9, - enable_mkldnn=False, - use_zero_copy_run=False, - use_pdserving=False, - lang='ch', - det=True, - rec=True, - use_angle_cls=False) + inference_args_dict = {} + for action in parser._actions: + inference_args_dict[action.dest] = action.default + return argparse.Namespace(**inference_args_dict) class PaddleOCR(predict_system.TextSystem): @@ -276,7 +194,7 @@ class PaddleOCR(predict_system.TextSystem): args: **kwargs: other params show in paddleocr --help """ - postprocess_params = parse_args(mMain=False, add_help=False) + postprocess_params = parse_args(mMain=False) postprocess_params.__dict__.update(**kwargs) self.use_angle_cls = postprocess_params.use_angle_cls lang = postprocess_params.lang @@ -346,7 +264,7 @@ class PaddleOCR(predict_system.TextSystem): # init det_model and rec_model super().__init__(postprocess_params) - def ocr(self, img, det=True, rec=True, cls=False): + def ocr(self, img, det=True, rec=True, cls=True): """ ocr with paddleocr args: @@ -358,9 +276,7 @@ class PaddleOCR(predict_system.TextSystem): if isinstance(img, list) and det == True: logger.error('When input a list of images, det must be false') exit(0) - if cls == False: - self.use_angle_cls = False - elif cls == True and self.use_angle_cls == False: + if cls == True and self.use_angle_cls == False: logger.warning( 'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process' ) @@ -382,7 +298,7 @@ class PaddleOCR(predict_system.TextSystem): if isinstance(img, np.ndarray) and len(img.shape) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) if det and rec: - dt_boxes, rec_res = self.__call__(img) + dt_boxes, rec_res = self.__call__(img, cls) return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] elif det and not rec: dt_boxes, elapse = self.text_detector(img) @@ -392,7 +308,7 @@ class PaddleOCR(predict_system.TextSystem): else: if not isinstance(img, list): img = [img] - if self.use_angle_cls: + if self.use_angle_cls and cls: img, cls_res, elapse = self.text_classifier(img) if not rec: return cls_res diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index ba81aff0a940fbee234e59e98f73c62fc7f69f09..78f5a4729918e33e174705e1b9b0f6e4c27699c6 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -85,7 +85,7 @@ class TextSystem(object): cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno]) logger.info(bno, rec_res[bno]) - def __call__(self, img): + def __call__(self, img, cls=True): ori_im = img.copy() dt_boxes, elapse = self.text_detector(img) logger.info("dt_boxes num : {}, elapse : {}".format( @@ -100,7 +100,7 @@ class TextSystem(object): tmp_box = copy.deepcopy(dt_boxes[bno]) img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop_list.append(img_crop) - if self.use_angle_cls: + if self.use_angle_cls and cls: img_crop_list, angle_list, elapse = self.text_classifier( img_crop_list) logger.info("cls num : {}, elapse : {}".format( diff --git a/tools/infer/utility.py b/tools/infer/utility.py index ff4a0276e2dc690faaa5c0f22a4c88a9f31a75f9..3f0ff2ff64bff2c2e70be37a95b5449deaa90046 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -23,13 +23,15 @@ import math from paddle import inference import time from ppocr.utils.logging import get_logger + logger = get_logger() -def parse_args(): - def str2bool(v): - return v.lower() in ("true", "t", "1") +def str2bool(v): + return v.lower() in ("true", "t", "1") + +def init_args(): parser = argparse.ArgumentParser() # params for prediction engine parser.add_argument("--use_gpu", type=str2bool, default=True) @@ -108,6 +110,11 @@ def parse_args(): parser.add_argument("--total_process_num", type=int, default=1) parser.add_argument("--process_id", type=int, default=0) + return parser + + +def parse_args(): + parser = init_args() return parser.parse_args() @@ -141,7 +148,7 @@ def create_predictor(args, mode, logger): config.enable_tensorrt_engine( precision_mode=inference.PrecisionType.Float32, max_batch_size=args.max_batch_size, - min_subgraph_size=3) # skip the minmum trt subgraph + min_subgraph_size=3) # skip the minmum trt subgraph if mode == "det" and "mobile" in model_file_path: min_input_shape = { "x": [1, 3, 50, 50],