From b16c3d42345b604d3dab4bea12e9ea7af55e0134 Mon Sep 17 00:00:00 2001 From: andyjpaddle Date: Tue, 10 May 2022 02:41:54 +0000 Subject: [PATCH] update v3 rec name --- paddleocr.py | 2 +- tools/infer/utility.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/paddleocr.py b/paddleocr.py index f7871db6..5d362557 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -48,7 +48,7 @@ __all__ = [ SUPPORT_DET_MODEL = ['DB'] VERSION = '2.5.0.1' -SUPPORT_REC_MODEL = ['CRNN'] +SUPPORT_REC_MODEL = ['CRNN', 'SVTR_LCNet'] BASE_DIR = os.path.expanduser("~/.paddleocr/") DEFAULT_OCR_MODEL_VERSION = 'PP-OCRv3' diff --git a/tools/infer/utility.py b/tools/infer/utility.py index ce4e2d92..81bee85c 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -79,7 +79,7 @@ def init_args(): parser.add_argument("--det_fce_box_type", type=str, default='poly') # params for text recognizer - parser.add_argument("--rec_algorithm", type=str, default='CRNN') + parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet') parser.add_argument("--rec_model_dir", type=str) parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") parser.add_argument("--rec_batch_num", type=int, default=6) @@ -269,11 +269,11 @@ def create_predictor(args, mode, logger): max_input_shape.update(max_pact_shape) opt_input_shape.update(opt_pact_shape) elif mode == "rec": - if args.rec_algorithm != "CRNN": + if args.rec_algorithm not in ["CRNN", "SVTR_LCNet"]: use_dynamic_shape = False imgH = int(args.rec_image_shape.split(',')[-2]) min_input_shape = {"x": [1, 3, imgH, 10]} - max_input_shape = {"x": [args.rec_batch_num, 3, imgH, 1536]} + max_input_shape = {"x": [args.rec_batch_num, 3, imgH, 2304]} opt_input_shape = {"x": [args.rec_batch_num, 3, imgH, 320]} elif mode == "cls": min_input_shape = {"x": [1, 3, 48, 10]} @@ -320,7 +320,7 @@ def create_predictor(args, mode, logger): def get_output_tensors(args, mode, predictor): output_names = predictor.get_output_names() output_tensors = [] - if mode == "rec" and args.rec_algorithm == "CRNN": + if mode == "rec" and args.rec_algorithm in ["CRNN", "SVTR_LCNet"]: output_name = 'softmax_0.tmp_0' if output_name in output_names: return [predictor.get_output_handle(output_name)] -- GitLab