diff --git a/tools/infer/utility.py b/tools/infer/utility.py index b16aecd496ec291fcbe9c66dccf3ec04bb662034..22ffe4d6c223ca64cd7a5d0a6bcad4cd307e89b1 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -271,9 +271,10 @@ def create_predictor(args, mode, logger): elif mode == "rec": if args.rec_algorithm != "CRNN": use_dynamic_shape = False - min_input_shape = {"x": [1, 3, 32, 10]} - max_input_shape = {"x": [args.rec_batch_num, 3, 32, 1536]} - opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]} + imgH = int(args.rec_image_shape.split(',')[1]) + min_input_shape = {"x": [1, 3, imgH, 10]} + max_input_shape = {"x": [args.rec_batch_num, 3, imgH, 1536]} + opt_input_shape = {"x": [args.rec_batch_num, 3, imgH, 320]} elif mode == "cls": min_input_shape = {"x": [1, 3, 48, 10]} max_input_shape = {"x": [args.rec_batch_num, 3, 48, 1024]}