From c86765ed476ef7f51fa80aa7c1538c41761672f4 Mon Sep 17 00:00:00 2001 From: JingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com> Date: Fri, 11 Feb 2022 15:06:56 +0800 Subject: [PATCH] fix prelu trt convert (#39389) --- paddle/fluid/inference/tensorrt/op_teller.cc | 5 ++-- .../ir/inference/test_trt_convert_prelu.py | 25 ++++++++++--------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 5e320a0270..4c8d9d5096 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1203,8 +1203,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, auto x_var_name = desc.Input("X")[0]; auto* x_var_desc = block->FindVar(x_var_name); const auto x_shape = x_var_desc->GetShape(); - if (x_shape.size() == 1) { - VLOG(3) << "prelu op does not support input's dim is 1 in tensorrt."; + if (!with_dynamic_shape && x_shape.size() == 1) { + VLOG(3) << "prelu op does not support input's dim is 1 in tensorrt " + "with static shape."; return false; } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_prelu.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_prelu.py index 5153476ae1..10109cdc73 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_prelu.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_prelu.py @@ -172,29 +172,30 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest): for i in range(len(program_config.ops)) ] + def generate_trt_nodes_num(attrs, dynamic_shape): + if not dynamic_shape and self.dim1 == 0 and self.dim2 == 0 and self.dim3 == 0: + return 0, 3 + return 1, 2 + # for static_shape clear_dynamic_shape() self.trt_param.precision = paddle_infer.PrecisionType.Float32 - yield self.create_inference_config(), (1, 2), 1e-5 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), (1, 2), 1e-5 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 # for dynamic_shape generate_dynamic_shape(attrs) self.trt_param.precision = paddle_infer.PrecisionType.Float32 - yield self.create_inference_config(), (1, 2), 1e-5 + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), (1, 2), 1e-5 + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 def add_skip_trt_case(self): - def teller1(program_config, predictor_config): - if self.dim1 == 0 and self.dim2 == 0 and self.dim3 == 0: - return True - return False - - self.add_skip_case(teller1, SkipReasons.TRT_NOT_SUPPORT, - "Trt does not support 1-dimensional input.") - ver = paddle_infer.get_trt_compile_version() if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 7000: -- GitLab