From 10f9249b991805e571846d85b6f5218c8a381a1a Mon Sep 17 00:00:00 2001 From: Zhang Jun Date: Mon, 8 May 2023 14:49:31 +0800 Subject: [PATCH] [inference][trt]Unary operation support 0d (#53506) * fix trt Unary operation do not support 0d when TRT < 8.6 * update unary ut * add rsqrt to unary_list * move rsqrt to act_list --- paddle/fluid/inference/tensorrt/op_teller.cc | 19 +++++++++++++++++-- .../inference/test_trt_convert_activation.py | 10 +++++++++- test/ir/inference/test_trt_convert_unary.py | 16 +++++++++++++++- 3 files changed, 41 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index beb289b420f..56d7fba985c 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -104,7 +104,16 @@ struct SimpleOpTypeSetTeller : public Teller { "atanh", "ceil", "celu", "erf", "floor", "round", "sign", "silu", "logical_not", - "reciprocal", "tanh_shrink", "logsigmoid"}; + "reciprocal", "tanh_shrink", "logsigmoid", + "rsqrt"}; + std::unordered_set unary_list = { + "exp", "log", "sqrt", "abs", "sin", + "cos", "tan", "tanh", "sinh", "cosh", + "asin", "acos", "atan", "asinh", "acosh", + "atanh", "ceil", "celu", "floor", "round", + "sign", "silu", "logical_not", "reciprocal", "tanh_shrink", + "logsigmoid", "erf", "bitwise_not", "equal", "not_equal", + "rsqrt"}; if (act_op_list.find(op_type) != act_op_list.end()) { auto* block = desc.Block(); if (block == nullptr) { @@ -127,9 +136,15 @@ struct SimpleOpTypeSetTeller : public Teller { VLOG(3) << op_type << " op does not support tensorrt."; return false; } +#endif +#if !IS_TRT_VERSION_GE(8600) + if (x_shape.size() == 0 && unary_list.find(op_type) != unary_list.end()) { + VLOG(3) << op_type + << " op does not support 0 dim input when TensorRT < 8.6."; + return false; + } #endif } - // In static shape in Paddle-TRT, we can't allow that one op has a // 1D intermediate tensor as input. if (!with_dynamic_shape) { diff --git a/test/ir/inference/test_trt_convert_activation.py b/test/ir/inference/test_trt_convert_activation.py index bba6beae142..4b7052e682d 100644 --- a/test/ir/inference/test_trt_convert_activation.py +++ b/test/ir/inference/test_trt_convert_activation.py @@ -49,7 +49,6 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): for op_type in [ "relu", "sigmoid", - "tanh", "relu6", "elu", "selu", @@ -146,6 +145,15 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): def generate_trt_nodes_num(attrs, dynamic_shape): if not dynamic_shape and (self.dims == 1 or self.dims == 0): return 0, 3 + runtime_version = paddle_infer.get_trt_runtime_version() + if ( + runtime_version[0] * 1000 + + runtime_version[1] * 100 + + runtime_version[2] * 10 + < 8600 + and self.dims == 0 + ) and program_config.ops[0].type in ["celu", "logsigmoid"]: + return 0, 3 return 1, 2 attrs = [ diff --git a/test/ir/inference/test_trt_convert_unary.py b/test/ir/inference/test_trt_convert_unary.py index 97e83e79714..322130c83f5 100644 --- a/test/ir/inference/test_trt_convert_unary.py +++ b/test/ir/inference/test_trt_convert_unary.py @@ -35,6 +35,8 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): self.trt_param.workspace_size = 1073741824 def generate_input1(dims, batch, attrs: List[Dict[str, Any]]): + if dims == 0: + return np.random.random([]).astype(np.float32) if dims == 2: return np.random.random([3, 32]).astype(np.float32) elif dims == 3: @@ -43,6 +45,8 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): return np.random.random([batch, 3, 32, 32]).astype(np.float32) def generate_int_input(dims, batch, attrs: List[Dict[str, Any]]): + if dims == 0: + return np.random.random([]).astype(np.int32) if dims == 2: return np.random.random([3, 32]).astype(np.int32) elif dims == 3: @@ -50,7 +54,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): else: return np.random.random([batch, 3, 32, 32]).astype(np.int32) - for dims in [2, 3, 4]: + for dims in [0, 2, 3, 4]: for batch in [1, 4]: for op_type in [ "exp", @@ -60,6 +64,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): "sin", "cos", "tan", + "tanh", "sinh", "cosh", "asin", @@ -179,6 +184,15 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ) ): return 0, 3 + runtime_version = paddle_infer.get_trt_runtime_version() + if ( + runtime_version[0] * 1000 + + runtime_version[1] * 100 + + runtime_version[2] * 10 + < 8600 + and self.dims == 0 + ): + return 0, 3 return 1, 2 attrs = [ -- GitLab