diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 5e320a027022f5d51d05645b32e8e35531486624..4c8d9d50965c069d4306eab29d71136eeb744639 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1203,8 +1203,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, auto x_var_name = desc.Input("X")[0]; auto* x_var_desc = block->FindVar(x_var_name); const auto x_shape = x_var_desc->GetShape(); - if (x_shape.size() == 1) { - VLOG(3) << "prelu op does not support input's dim is 1 in tensorrt."; + if (!with_dynamic_shape && x_shape.size() == 1) { + VLOG(3) << "prelu op does not support input's dim is 1 in tensorrt " + "with static shape."; return false; } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_prelu.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_prelu.py index 5153476ae19f13d3b5edbb45d0df4f29ad07d0ba..10109cdc73a2b8b29455153a52c1de89766c4e41 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_prelu.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_prelu.py @@ -172,29 +172,30 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest): for i in range(len(program_config.ops)) ] + def generate_trt_nodes_num(attrs, dynamic_shape): + if not dynamic_shape and self.dim1 == 0 and self.dim2 == 0 and self.dim3 == 0: + return 0, 3 + return 1, 2 + # for static_shape clear_dynamic_shape() self.trt_param.precision = paddle_infer.PrecisionType.Float32 - yield self.create_inference_config(), (1, 2), 1e-5 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), (1, 2), 1e-5 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 # for dynamic_shape generate_dynamic_shape(attrs) self.trt_param.precision = paddle_infer.PrecisionType.Float32 - yield self.create_inference_config(), (1, 2), 1e-5 + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), (1, 2), 1e-5 + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 def add_skip_trt_case(self): - def teller1(program_config, predictor_config): - if self.dim1 == 0 and self.dim2 == 0 and self.dim3 == 0: - return True - return False - - self.add_skip_case(teller1, SkipReasons.TRT_NOT_SUPPORT, - "Trt does not support 1-dimensional input.") - ver = paddle_infer.get_trt_compile_version() if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 7000: