diff --git a/tools/infer/utility.py b/tools/infer/utility.py index b16aecd496ec291fcbe9c66dccf3ec04bb662034..762db868f58d11aaa626a2e55591d47bfa9536a9 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -271,9 +271,10 @@ def create_predictor(args, mode, logger): elif mode == "rec": if args.rec_algorithm != "CRNN": use_dynamic_shape = False - min_input_shape = {"x": [1, 3, 32, 10]} - max_input_shape = {"x": [args.rec_batch_num, 3, 32, 1536]} - opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]} + imgH = int(args.rec_image_shape.split(',')[-2]) + min_input_shape = {"x": [1, 3, imgH, 10]} + max_input_shape = {"x": [args.rec_batch_num, 3, imgH, 1536]} + opt_input_shape = {"x": [args.rec_batch_num, 3, imgH, 320]} elif mode == "cls": min_input_shape = {"x": [1, 3, 48, 10]} max_input_shape = {"x": [args.rec_batch_num, 3, 48, 1024]}