未验证 提交 10f9249b 编写于 作者: Z Zhang Jun 提交者: GitHub

[inference][trt]Unary operation support 0d (#53506)

* fix trt Unary operation do not support 0d when TRT < 8.6

* update unary ut

* add rsqrt to unary_list

* move rsqrt to act_list
上级 fe919400
...@@ -104,7 +104,16 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -104,7 +104,16 @@ struct SimpleOpTypeSetTeller : public Teller {
"atanh", "ceil", "celu", "atanh", "ceil", "celu",
"erf", "floor", "round", "erf", "floor", "round",
"sign", "silu", "logical_not", "sign", "silu", "logical_not",
"reciprocal", "tanh_shrink", "logsigmoid"}; "reciprocal", "tanh_shrink", "logsigmoid",
"rsqrt"};
std::unordered_set<std::string> unary_list = {
"exp", "log", "sqrt", "abs", "sin",
"cos", "tan", "tanh", "sinh", "cosh",
"asin", "acos", "atan", "asinh", "acosh",
"atanh", "ceil", "celu", "floor", "round",
"sign", "silu", "logical_not", "reciprocal", "tanh_shrink",
"logsigmoid", "erf", "bitwise_not", "equal", "not_equal",
"rsqrt"};
if (act_op_list.find(op_type) != act_op_list.end()) { if (act_op_list.find(op_type) != act_op_list.end()) {
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) { if (block == nullptr) {
...@@ -127,9 +136,15 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -127,9 +136,15 @@ struct SimpleOpTypeSetTeller : public Teller {
VLOG(3) << op_type << " op does not support tensorrt."; VLOG(3) << op_type << " op does not support tensorrt.";
return false; return false;
} }
#endif
#if !IS_TRT_VERSION_GE(8600)
if (x_shape.size() == 0 && unary_list.find(op_type) != unary_list.end()) {
VLOG(3) << op_type
<< " op does not support 0 dim input when TensorRT < 8.6.";
return false;
}
#endif #endif
} }
// In static shape in Paddle-TRT, we can't allow that one op has a // In static shape in Paddle-TRT, we can't allow that one op has a
// 1D intermediate tensor as input. // 1D intermediate tensor as input.
if (!with_dynamic_shape) { if (!with_dynamic_shape) {
......
...@@ -49,7 +49,6 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -49,7 +49,6 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
for op_type in [ for op_type in [
"relu", "relu",
"sigmoid", "sigmoid",
"tanh",
"relu6", "relu6",
"elu", "elu",
"selu", "selu",
...@@ -146,6 +145,15 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -146,6 +145,15 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
def generate_trt_nodes_num(attrs, dynamic_shape): def generate_trt_nodes_num(attrs, dynamic_shape):
if not dynamic_shape and (self.dims == 1 or self.dims == 0): if not dynamic_shape and (self.dims == 1 or self.dims == 0):
return 0, 3 return 0, 3
runtime_version = paddle_infer.get_trt_runtime_version()
if (
runtime_version[0] * 1000
+ runtime_version[1] * 100
+ runtime_version[2] * 10
< 8600
and self.dims == 0
) and program_config.ops[0].type in ["celu", "logsigmoid"]:
return 0, 3
return 1, 2 return 1, 2
attrs = [ attrs = [
......
...@@ -35,6 +35,8 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -35,6 +35,8 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
self.trt_param.workspace_size = 1073741824 self.trt_param.workspace_size = 1073741824
def generate_input1(dims, batch, attrs: List[Dict[str, Any]]): def generate_input1(dims, batch, attrs: List[Dict[str, Any]]):
if dims == 0:
return np.random.random([]).astype(np.float32)
if dims == 2: if dims == 2:
return np.random.random([3, 32]).astype(np.float32) return np.random.random([3, 32]).astype(np.float32)
elif dims == 3: elif dims == 3:
...@@ -43,6 +45,8 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -43,6 +45,8 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
return np.random.random([batch, 3, 32, 32]).astype(np.float32) return np.random.random([batch, 3, 32, 32]).astype(np.float32)
def generate_int_input(dims, batch, attrs: List[Dict[str, Any]]): def generate_int_input(dims, batch, attrs: List[Dict[str, Any]]):
if dims == 0:
return np.random.random([]).astype(np.int32)
if dims == 2: if dims == 2:
return np.random.random([3, 32]).astype(np.int32) return np.random.random([3, 32]).astype(np.int32)
elif dims == 3: elif dims == 3:
...@@ -50,7 +54,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -50,7 +54,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
else: else:
return np.random.random([batch, 3, 32, 32]).astype(np.int32) return np.random.random([batch, 3, 32, 32]).astype(np.int32)
for dims in [2, 3, 4]: for dims in [0, 2, 3, 4]:
for batch in [1, 4]: for batch in [1, 4]:
for op_type in [ for op_type in [
"exp", "exp",
...@@ -60,6 +64,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -60,6 +64,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
"sin", "sin",
"cos", "cos",
"tan", "tan",
"tanh",
"sinh", "sinh",
"cosh", "cosh",
"asin", "asin",
...@@ -179,6 +184,15 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -179,6 +184,15 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
) )
): ):
return 0, 3 return 0, 3
runtime_version = paddle_infer.get_trt_runtime_version()
if (
runtime_version[0] * 1000
+ runtime_version[1] * 100
+ runtime_version[2] * 10
< 8600
and self.dims == 0
):
return 0, 3
return 1, 2 return 1, 2
attrs = [ attrs = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册