diff --git a/tools/infer/utility.py b/tools/infer/utility.py index af50c5e6a8cb39faf416dbe7adb516c4db05aef5..d7e058c2b7c0eaf6bd40dd197a3cb1417bc7bb7d 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -271,8 +271,13 @@ def create_predictor(args, mode, logger): min_input_shape = {"x": [1, 3, 10, 10]} max_input_shape = {"x": [1, 3, 512, 512]} opt_input_shape = {"x": [1, 3, 256, 256]} - config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, - opt_input_shape) + if mode == "rec": + if args.rec_algorithm == "CRNN": + config.set_trt_dynamic_shape_info( + min_input_shape, max_input_shape, opt_input_shape) + else: + config.set_trt_dynamic_shape_info( + min_input_shape, max_input_shape, opt_input_shape) else: config.disable_gpu()