提交 b16c3d42 编写于 作者: A andyjpaddle

update v3 rec name

上级 3e1db518
...@@ -48,7 +48,7 @@ __all__ = [ ...@@ -48,7 +48,7 @@ __all__ = [
SUPPORT_DET_MODEL = ['DB'] SUPPORT_DET_MODEL = ['DB']
VERSION = '2.5.0.1' VERSION = '2.5.0.1'
SUPPORT_REC_MODEL = ['CRNN'] SUPPORT_REC_MODEL = ['CRNN', 'SVTR_LCNet']
BASE_DIR = os.path.expanduser("~/.paddleocr/") BASE_DIR = os.path.expanduser("~/.paddleocr/")
DEFAULT_OCR_MODEL_VERSION = 'PP-OCRv3' DEFAULT_OCR_MODEL_VERSION = 'PP-OCRv3'
......
...@@ -79,7 +79,7 @@ def init_args(): ...@@ -79,7 +79,7 @@ def init_args():
parser.add_argument("--det_fce_box_type", type=str, default='poly') parser.add_argument("--det_fce_box_type", type=str, default='poly')
# params for text recognizer # params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN') parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet')
parser.add_argument("--rec_model_dir", type=str) parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_batch_num", type=int, default=6) parser.add_argument("--rec_batch_num", type=int, default=6)
...@@ -269,11 +269,11 @@ def create_predictor(args, mode, logger): ...@@ -269,11 +269,11 @@ def create_predictor(args, mode, logger):
max_input_shape.update(max_pact_shape) max_input_shape.update(max_pact_shape)
opt_input_shape.update(opt_pact_shape) opt_input_shape.update(opt_pact_shape)
elif mode == "rec": elif mode == "rec":
if args.rec_algorithm != "CRNN": if args.rec_algorithm not in ["CRNN", "SVTR_LCNet"]:
use_dynamic_shape = False use_dynamic_shape = False
imgH = int(args.rec_image_shape.split(',')[-2]) imgH = int(args.rec_image_shape.split(',')[-2])
min_input_shape = {"x": [1, 3, imgH, 10]} min_input_shape = {"x": [1, 3, imgH, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, imgH, 1536]} max_input_shape = {"x": [args.rec_batch_num, 3, imgH, 2304]}
opt_input_shape = {"x": [args.rec_batch_num, 3, imgH, 320]} opt_input_shape = {"x": [args.rec_batch_num, 3, imgH, 320]}
elif mode == "cls": elif mode == "cls":
min_input_shape = {"x": [1, 3, 48, 10]} min_input_shape = {"x": [1, 3, 48, 10]}
...@@ -320,7 +320,7 @@ def create_predictor(args, mode, logger): ...@@ -320,7 +320,7 @@ def create_predictor(args, mode, logger):
def get_output_tensors(args, mode, predictor): def get_output_tensors(args, mode, predictor):
output_names = predictor.get_output_names() output_names = predictor.get_output_names()
output_tensors = [] output_tensors = []
if mode == "rec" and args.rec_algorithm == "CRNN": if mode == "rec" and args.rec_algorithm in ["CRNN", "SVTR_LCNet"]:
output_name = 'softmax_0.tmp_0' output_name = 'softmax_0.tmp_0'
if output_name in output_names: if output_name in output_names:
return [predictor.get_output_handle(output_name)] return [predictor.get_output_handle(output_name)]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册