未验证 提交 10f9249b 编写于 作者: Z Zhang Jun 提交者: GitHub

[inference][trt]Unary operation support 0d (#53506)

* fix trt Unary operation do not support 0d when TRT < 8.6

* update unary ut

* add rsqrt to unary_list

* move rsqrt to act_list
上级 fe919400
......@@ -104,7 +104,16 @@ struct SimpleOpTypeSetTeller : public Teller {
"atanh", "ceil", "celu",
"erf", "floor", "round",
"sign", "silu", "logical_not",
"reciprocal", "tanh_shrink", "logsigmoid"};
"reciprocal", "tanh_shrink", "logsigmoid",
"rsqrt"};
std::unordered_set<std::string> unary_list = {
"exp", "log", "sqrt", "abs", "sin",
"cos", "tan", "tanh", "sinh", "cosh",
"asin", "acos", "atan", "asinh", "acosh",
"atanh", "ceil", "celu", "floor", "round",
"sign", "silu", "logical_not", "reciprocal", "tanh_shrink",
"logsigmoid", "erf", "bitwise_not", "equal", "not_equal",
"rsqrt"};
if (act_op_list.find(op_type) != act_op_list.end()) {
auto* block = desc.Block();
if (block == nullptr) {
......@@ -127,9 +136,15 @@ struct SimpleOpTypeSetTeller : public Teller {
VLOG(3) << op_type << " op does not support tensorrt.";
return false;
}
#endif
#if !IS_TRT_VERSION_GE(8600)
if (x_shape.size() == 0 && unary_list.find(op_type) != unary_list.end()) {
VLOG(3) << op_type
<< " op does not support 0 dim input when TensorRT < 8.6.";
return false;
}
#endif
}
// In static shape in Paddle-TRT, we can't allow that one op has a
// 1D intermediate tensor as input.
if (!with_dynamic_shape) {
......
......@@ -49,7 +49,6 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
for op_type in [
"relu",
"sigmoid",
"tanh",
"relu6",
"elu",
"selu",
......@@ -146,6 +145,15 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
def generate_trt_nodes_num(attrs, dynamic_shape):
if not dynamic_shape and (self.dims == 1 or self.dims == 0):
return 0, 3
runtime_version = paddle_infer.get_trt_runtime_version()
if (
runtime_version[0] * 1000
+ runtime_version[1] * 100
+ runtime_version[2] * 10
< 8600
and self.dims == 0
) and program_config.ops[0].type in ["celu", "logsigmoid"]:
return 0, 3
return 1, 2
attrs = [
......
......@@ -35,6 +35,8 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
self.trt_param.workspace_size = 1073741824
def generate_input1(dims, batch, attrs: List[Dict[str, Any]]):
if dims == 0:
return np.random.random([]).astype(np.float32)
if dims == 2:
return np.random.random([3, 32]).astype(np.float32)
elif dims == 3:
......@@ -43,6 +45,8 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
return np.random.random([batch, 3, 32, 32]).astype(np.float32)
def generate_int_input(dims, batch, attrs: List[Dict[str, Any]]):
if dims == 0:
return np.random.random([]).astype(np.int32)
if dims == 2:
return np.random.random([3, 32]).astype(np.int32)
elif dims == 3:
......@@ -50,7 +54,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
else:
return np.random.random([batch, 3, 32, 32]).astype(np.int32)
for dims in [2, 3, 4]:
for dims in [0, 2, 3, 4]:
for batch in [1, 4]:
for op_type in [
"exp",
......@@ -60,6 +64,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
"sin",
"cos",
"tan",
"tanh",
"sinh",
"cosh",
"asin",
......@@ -179,6 +184,15 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
)
):
return 0, 3
runtime_version = paddle_infer.get_trt_runtime_version()
if (
runtime_version[0] * 1000
+ runtime_version[1] * 100
+ runtime_version[2] * 10
< 8600
and self.dims == 0
):
return 0, 3
return 1, 2
attrs = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册