diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 1c82280099f17f6d3bf848669e47439505f10576..93b54505c36e5f18913bad8b70d49b4f28334477 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -231,13 +231,13 @@ def create_predictor(args, mode, logger): max_input_shape.update(max_pact_shape) opt_input_shape.update(opt_pact_shape) elif mode == "rec": - min_input_shape = {"x": [args.rec_batch_num, 3, 32, 10]} + min_input_shape = {"x": [1, 3, 32, 10]} max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]} - opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]} + opt_input_shape = {"x": [512, 3, 32, 320]} elif mode == "cls": - min_input_shape = {"x": [args.rec_batch_num, 3, 48, 10]} + min_input_shape = {"x": [1, 3, 48, 10]} max_input_shape = {"x": [args.rec_batch_num, 3, 48, 2000]} - opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]} + opt_input_shape = {"x": [512, 3, 48, 320]} else: min_input_shape = {"x": [1, 3, 10, 10]} max_input_shape = {"x": [1, 3, 1000, 1000]}