未验证 提交 379222c3 编写于 作者: P Pei Yang 提交者: GitHub

add output scale and trt op teller support for hard_swish and hard_sigmoid (#26499)

上级 74836ec7
......@@ -24,6 +24,8 @@ struct SimpleOpTypeSetTeller : public Teller {
#if IS_TRT_VERSION_GE(5130)
teller_set.insert("relu6");
teller_set.insert("hard_sigmoid");
int8_teller_set.insert("relu6");
int8_teller_set.insert("hard_sigmoid");
#endif
#if IS_TRT_VERSION_GE(6000)
teller_set.insert("fused_embedding_eltwise_layernorm");
......@@ -53,11 +55,11 @@ struct SimpleOpTypeSetTeller : public Teller {
"elementwise_add",
"leaky_relu",
"fc",
"relu6",
"concat",
"scale",
"elementwise_mul",
"conv2d_transpose"};
"conv2d_transpose",
"hard_swish"};
std::unordered_set<std::string> teller_set{
"mul",
"conv2d",
......
......@@ -66,6 +66,8 @@ _out_scale_op_list = [
"concat",
"elementwise_mul",
"scale",
"hard_swish",
"hard_sigmoid",
]
# list op real input and output names, to avoid processing input such as AxisTensor.
......@@ -109,6 +111,8 @@ _op_real_in_out_name = {
"sigmoid": [["X"], ["Out"]],
"elementwise_mul": [["X", "Y"], ["Out"]],
"scale": [["X"], ["Out"]],
"hard_swish": [["X"], ["Out"]],
"hard_sigmoid": [["X"], ["Out"]],
}
_conv_ops = ['conv2d', 'depthwise_conv2d', 'conv2d_transpose']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册