未验证 提交 c86765ed 编写于 作者: J JingZhuangzhuang 提交者: GitHub

fix prelu trt convert (#39389)

上级 72ad280b
...@@ -1203,8 +1203,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1203,8 +1203,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
auto x_var_name = desc.Input("X")[0]; auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name); auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) { if (!with_dynamic_shape && x_shape.size() == 1) {
VLOG(3) << "prelu op does not support input's dim is 1 in tensorrt."; VLOG(3) << "prelu op does not support input's dim is 1 in tensorrt "
"with static shape.";
return false; return false;
} }
......
...@@ -172,29 +172,30 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest): ...@@ -172,29 +172,30 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
for i in range(len(program_config.ops)) for i in range(len(program_config.ops))
] ]
def generate_trt_nodes_num(attrs, dynamic_shape):
if not dynamic_shape and self.dim1 == 0 and self.dim2 == 0 and self.dim3 == 0:
return 0, 3
return 1, 2
# for static_shape # for static_shape
clear_dynamic_shape() clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32 self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (1, 2), 1e-5 yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 2), 1e-5 yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
# for dynamic_shape # for dynamic_shape
generate_dynamic_shape(attrs) generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32 self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (1, 2), 1e-5 yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 2), 1e-5 yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5
def add_skip_trt_case(self): def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if self.dim1 == 0 and self.dim2 == 0 and self.dim3 == 0:
return True
return False
self.add_skip_case(teller1, SkipReasons.TRT_NOT_SUPPORT,
"Trt does not support 1-dimensional input.")
ver = paddle_infer.get_trt_compile_version() ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 7000: if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 7000:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册