未验证 提交 b150b168 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Inference Zero-Dim] Support trt 0dim of gelu, hard_swish, hard_sigmoid and leaky_relu (#53714)

* support_act
* delete_silu
上级 dc003fa3
...@@ -105,15 +105,15 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -105,15 +105,15 @@ struct SimpleOpTypeSetTeller : public Teller {
"erf", "floor", "round", "erf", "floor", "round",
"sign", "silu", "logical_not", "sign", "silu", "logical_not",
"reciprocal", "tanh_shrink", "logsigmoid", "reciprocal", "tanh_shrink", "logsigmoid",
"rsqrt", "swish"}; "rsqrt", "swish", "hard_sigmoid",
"hard_swish", "leaky_relu"};
std::unordered_set<std::string> unary_list = { std::unordered_set<std::string> unary_list = {
"exp", "log", "sqrt", "abs", "sin", "exp", "log", "sqrt", "abs", "sin",
"cos", "tan", "tanh", "sinh", "cosh", "cos", "tan", "tanh", "sinh", "cosh",
"asin", "acos", "atan", "asinh", "acosh", "asin", "acos", "atan", "asinh", "acosh",
"atanh", "ceil", "celu", "floor", "round", "atanh", "ceil", "celu", "floor", "round",
"sign", "silu", "logical_not", "reciprocal", "tanh_shrink", "sign", "logical_not", "reciprocal", "tanh_shrink", "logsigmoid",
"logsigmoid", "erf", "bitwise_not", "equal", "not_equal", "erf", "bitwise_not", "equal", "not_equal", "rsqrt"};
"rsqrt"};
// Static shape does not support 0 or 1 dim's input. // Static shape does not support 0 or 1 dim's input.
if (!with_dynamic_shape) { if (!with_dynamic_shape) {
...@@ -962,20 +962,6 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -962,20 +962,6 @@ struct SimpleOpTypeSetTeller : public Teller {
} }
} }
if (op_type == "hard_swish") {
if (desc.Input("X").size() != 1) {
VLOG(3) << "HardSwish op has only 1 input, but got "
<< desc.Input("X").size();
return false;
}
if (desc.Output("Out").size() != 1) {
VLOG(3) << "HardSwish op has only 1 output, but got "
<< desc.Output("Out").size();
return false;
}
}
if (op_type == "squeeze2") { if (op_type == "squeeze2") {
// If Attribute is Variable(s), HasAttr() will return False // If Attribute is Variable(s), HasAttr() will return False
if (!desc.HasAttr("axes", /*with_attr_var=*/false)) { if (!desc.HasAttr("axes", /*with_attr_var=*/false)) {
...@@ -1642,8 +1628,10 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -1642,8 +1628,10 @@ struct SimpleOpTypeSetTeller : public Teller {
auto x_var_name = desc.Input("X")[0]; auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name); auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) { if (!with_dynamic_shape && (x_shape.size() == 1 || x_shape.size() == 0)) {
VLOG(3) << "gelu op does not support input's dim is 1 in tensorrt."; VLOG(3) << op_type
<< "gelu op does not support input's dim is 1 or 0 in tensorrt "
"static shape mode.";
return false; return false;
} }
} }
...@@ -1733,20 +1721,6 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -1733,20 +1721,6 @@ struct SimpleOpTypeSetTeller : public Teller {
} }
} }
if (op_type == "leaky_relu") {
if (desc.Input("X").size() != 1) {
VLOG(3) << "Invalid number of TRT leaky_relu op converter "
"inputs. Expected 1, but received "
<< desc.Input("X").size();
return false;
}
if (desc.Output("Out").size() != 1) {
VLOG(3) << "output of leaky_relu op converter should be 1, got "
<< desc.Output("Out").size();
return false;
}
}
if (op_type == "pad") { if (op_type == "pad") {
if (!desc.HasAttr("pad_value") || !desc.HasAttr("paddings")) return false; if (!desc.HasAttr("pad_value") || !desc.HasAttr("paddings")) return false;
const float pad_value = const float pad_value =
...@@ -2388,26 +2362,6 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2388,26 +2362,6 @@ struct SimpleOpTypeSetTeller : public Teller {
} }
} }
if (op_type == "hard_sigmoid") {
if (!with_dynamic_shape) {
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) {
VLOG(3) << "Hard sigmoid does not support 1-dimensional input in "
"tensorrt";
return false;
}
}
}
if (op_type == "cast") { if (op_type == "cast") {
// trt 6015 result in Windows ppyolo_mbv3 TRT fp32 diff // trt 6015 result in Windows ppyolo_mbv3 TRT fp32 diff
#if !IS_TRT_VERSION_GE(7000) #if !IS_TRT_VERSION_GE(7000)
......
...@@ -60,6 +60,9 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -60,6 +60,9 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
"logsigmoid", "logsigmoid",
"tanh_shrink", "tanh_shrink",
"softplus", "softplus",
"hard_swish",
"hard_sigmoid",
"leaky_relu",
]: ]:
# few samples to reduce time # few samples to reduce time
# for beta in [-0.2, 0.5, 0.67, 3]: # for beta in [-0.2, 0.5, 0.67, 3]:
...@@ -80,6 +83,18 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -80,6 +83,18 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
dics = [{"threshold": alpha}] dics = [{"threshold": alpha}]
if op_type == "softplus": if op_type == "softplus":
dics = [{"beta": beta}] dics = [{"beta": beta}]
if op_type == "hard_swish":
dics = [
{
"threshold": 6.0,
"scale": 6.0,
"offset": 3.0,
}
]
if op_type == "hard_sigmoid":
dics = [{"slope": beta, "offset": alpha}]
if op_type == "leaky_relu":
dics = [{"alpha": alpha}]
ops_config = [ ops_config = [
{ {
......
...@@ -29,7 +29,9 @@ class TrtConvertGeluTest(TrtLayerAutoScanTest): ...@@ -29,7 +29,9 @@ class TrtConvertGeluTest(TrtLayerAutoScanTest):
def sample_program_configs(self): def sample_program_configs(self):
def generate_input1(dims, attrs: List[Dict[str, Any]]): def generate_input1(dims, attrs: List[Dict[str, Any]]):
if dims == 1: if dims == 0:
return np.ones([]).astype(np.float32)
elif dims == 1:
return np.ones([32]).astype(np.float32) return np.ones([32]).astype(np.float32)
elif dims == 2: elif dims == 2:
return np.ones([3, 32]).astype(np.float32) return np.ones([3, 32]).astype(np.float32)
...@@ -38,7 +40,7 @@ class TrtConvertGeluTest(TrtLayerAutoScanTest): ...@@ -38,7 +40,7 @@ class TrtConvertGeluTest(TrtLayerAutoScanTest):
else: else:
return np.ones([1, 3, 32, 32]).astype(np.float32) return np.ones([1, 3, 32, 32]).astype(np.float32)
for dims in [1, 2, 3, 4]: for dims in [0, 1, 2, 3, 4]:
for approximate in [True, False]: for approximate in [True, False]:
self.dims = dims self.dims = dims
dics = [{"approximate": approximate}] dics = [{"approximate": approximate}]
...@@ -70,7 +72,11 @@ class TrtConvertGeluTest(TrtLayerAutoScanTest): ...@@ -70,7 +72,11 @@ class TrtConvertGeluTest(TrtLayerAutoScanTest):
self, program_config self, program_config
) -> (paddle_infer.Config, List[int], float): ) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs): def generate_dynamic_shape(attrs):
if self.dims == 1: if self.dims == 0:
self.dynamic_shape.min_input_shape = {"input_data": []}
self.dynamic_shape.max_input_shape = {"input_data": []}
self.dynamic_shape.opt_input_shape = {"input_data": []}
elif self.dims == 1:
self.dynamic_shape.min_input_shape = {"input_data": [1]} self.dynamic_shape.min_input_shape = {"input_data": [1]}
self.dynamic_shape.max_input_shape = {"input_data": [64]} self.dynamic_shape.max_input_shape = {"input_data": [64]}
self.dynamic_shape.opt_input_shape = {"input_data": [32]} self.dynamic_shape.opt_input_shape = {"input_data": [32]}
...@@ -104,7 +110,7 @@ class TrtConvertGeluTest(TrtLayerAutoScanTest): ...@@ -104,7 +110,7 @@ class TrtConvertGeluTest(TrtLayerAutoScanTest):
runtime_version = paddle_infer.get_trt_runtime_version() runtime_version = paddle_infer.get_trt_runtime_version()
self.assertTrue(compile_version == runtime_version) self.assertTrue(compile_version == runtime_version)
# Dimension one only runs on Paddle OP # Dimension one only runs on Paddle OP
if self.dims == 1: if not dynamic_shape and (self.dims == 1 or self.dims == 0):
return 0, 3 return 0, 3
if compile_version >= valid_version: if compile_version >= valid_version:
return 1, 2 return 1, 2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册