未验证 提交 b150b168 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Inference Zero-Dim] Support trt 0dim of gelu, hard_swish, hard_sigmoid and leaky_relu (#53714)

* support_act
* delete_silu
上级 dc003fa3
......@@ -105,15 +105,15 @@ struct SimpleOpTypeSetTeller : public Teller {
"erf", "floor", "round",
"sign", "silu", "logical_not",
"reciprocal", "tanh_shrink", "logsigmoid",
"rsqrt", "swish"};
"rsqrt", "swish", "hard_sigmoid",
"hard_swish", "leaky_relu"};
std::unordered_set<std::string> unary_list = {
"exp", "log", "sqrt", "abs", "sin",
"cos", "tan", "tanh", "sinh", "cosh",
"asin", "acos", "atan", "asinh", "acosh",
"atanh", "ceil", "celu", "floor", "round",
"sign", "silu", "logical_not", "reciprocal", "tanh_shrink",
"logsigmoid", "erf", "bitwise_not", "equal", "not_equal",
"rsqrt"};
"exp", "log", "sqrt", "abs", "sin",
"cos", "tan", "tanh", "sinh", "cosh",
"asin", "acos", "atan", "asinh", "acosh",
"atanh", "ceil", "celu", "floor", "round",
"sign", "logical_not", "reciprocal", "tanh_shrink", "logsigmoid",
"erf", "bitwise_not", "equal", "not_equal", "rsqrt"};
// Static shape does not support 0 or 1 dim's input.
if (!with_dynamic_shape) {
......@@ -962,20 +962,6 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}
if (op_type == "hard_swish") {
if (desc.Input("X").size() != 1) {
VLOG(3) << "HardSwish op has only 1 input, but got "
<< desc.Input("X").size();
return false;
}
if (desc.Output("Out").size() != 1) {
VLOG(3) << "HardSwish op has only 1 output, but got "
<< desc.Output("Out").size();
return false;
}
}
if (op_type == "squeeze2") {
// If Attribute is Variable(s), HasAttr() will return False
if (!desc.HasAttr("axes", /*with_attr_var=*/false)) {
......@@ -1642,8 +1628,10 @@ struct SimpleOpTypeSetTeller : public Teller {
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) {
VLOG(3) << "gelu op does not support input's dim is 1 in tensorrt.";
if (!with_dynamic_shape && (x_shape.size() == 1 || x_shape.size() == 0)) {
VLOG(3) << op_type
<< "gelu op does not support input's dim is 1 or 0 in tensorrt "
"static shape mode.";
return false;
}
}
......@@ -1733,20 +1721,6 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}
if (op_type == "leaky_relu") {
if (desc.Input("X").size() != 1) {
VLOG(3) << "Invalid number of TRT leaky_relu op converter "
"inputs. Expected 1, but received "
<< desc.Input("X").size();
return false;
}
if (desc.Output("Out").size() != 1) {
VLOG(3) << "output of leaky_relu op converter should be 1, got "
<< desc.Output("Out").size();
return false;
}
}
if (op_type == "pad") {
if (!desc.HasAttr("pad_value") || !desc.HasAttr("paddings")) return false;
const float pad_value =
......@@ -2388,26 +2362,6 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}
if (op_type == "hard_sigmoid") {
if (!with_dynamic_shape) {
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) {
VLOG(3) << "Hard sigmoid does not support 1-dimensional input in "
"tensorrt";
return false;
}
}
}
if (op_type == "cast") {
// trt 6015 result in Windows ppyolo_mbv3 TRT fp32 diff
#if !IS_TRT_VERSION_GE(7000)
......
......@@ -60,6 +60,9 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
"logsigmoid",
"tanh_shrink",
"softplus",
"hard_swish",
"hard_sigmoid",
"leaky_relu",
]:
# few samples to reduce time
# for beta in [-0.2, 0.5, 0.67, 3]:
......@@ -80,6 +83,18 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
dics = [{"threshold": alpha}]
if op_type == "softplus":
dics = [{"beta": beta}]
if op_type == "hard_swish":
dics = [
{
"threshold": 6.0,
"scale": 6.0,
"offset": 3.0,
}
]
if op_type == "hard_sigmoid":
dics = [{"slope": beta, "offset": alpha}]
if op_type == "leaky_relu":
dics = [{"alpha": alpha}]
ops_config = [
{
......
......@@ -29,7 +29,9 @@ class TrtConvertGeluTest(TrtLayerAutoScanTest):
def sample_program_configs(self):
def generate_input1(dims, attrs: List[Dict[str, Any]]):
if dims == 1:
if dims == 0:
return np.ones([]).astype(np.float32)
elif dims == 1:
return np.ones([32]).astype(np.float32)
elif dims == 2:
return np.ones([3, 32]).astype(np.float32)
......@@ -38,7 +40,7 @@ class TrtConvertGeluTest(TrtLayerAutoScanTest):
else:
return np.ones([1, 3, 32, 32]).astype(np.float32)
for dims in [1, 2, 3, 4]:
for dims in [0, 1, 2, 3, 4]:
for approximate in [True, False]:
self.dims = dims
dics = [{"approximate": approximate}]
......@@ -70,7 +72,11 @@ class TrtConvertGeluTest(TrtLayerAutoScanTest):
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
if self.dims == 1:
if self.dims == 0:
self.dynamic_shape.min_input_shape = {"input_data": []}
self.dynamic_shape.max_input_shape = {"input_data": []}
self.dynamic_shape.opt_input_shape = {"input_data": []}
elif self.dims == 1:
self.dynamic_shape.min_input_shape = {"input_data": [1]}
self.dynamic_shape.max_input_shape = {"input_data": [64]}
self.dynamic_shape.opt_input_shape = {"input_data": [32]}
......@@ -104,7 +110,7 @@ class TrtConvertGeluTest(TrtLayerAutoScanTest):
runtime_version = paddle_infer.get_trt_runtime_version()
self.assertTrue(compile_version == runtime_version)
# Dimension one only runs on Paddle OP
if self.dims == 1:
if not dynamic_shape and (self.dims == 1 or self.dims == 0):
return 0, 3
if compile_version >= valid_version:
return 1, 2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册