未验证 提交 d80330fe 编写于 作者: G gem5 提交者: GitHub

add reciprocal trt converter (#48230)

上级 ab6a3dad
......@@ -81,6 +81,7 @@ const std::unordered_map<std::string, nvinfer1::UnaryOperation>
{"atanh", nvinfer1::UnaryOperation::kATANH},
{"ceil", nvinfer1::UnaryOperation::kCEIL},
{"floor", nvinfer1::UnaryOperation::kFLOOR},
{"reciprocal", nvinfer1::UnaryOperation::kRECIP},
#if IS_TRT_VERSION_GE(7000)
{"erf", nvinfer1::UnaryOperation::kERF},
#endif
......@@ -152,6 +153,10 @@ class FloorOpConverter : public UnaryOpConverter {
public:
FloorOpConverter() { op_type_ = "floor"; }
};
class ReciprocalOpConverter : public UnaryOpConverter {
public:
ReciprocalOpConverter() { op_type_ = "reciprocal"; }
};
#if IS_TRT_VERSION_GE(7000)
class ErfOpConverter : public UnaryOpConverter {
public:
......@@ -179,6 +184,7 @@ REGISTER_TRT_OP_CONVERTER(asinh, AsinhOpConverter);
REGISTER_TRT_OP_CONVERTER(atanh, AtanhOpConverter);
REGISTER_TRT_OP_CONVERTER(ceil, CeilOpConverter);
REGISTER_TRT_OP_CONVERTER(floor, FloorOpConverter);
REGISTER_TRT_OP_CONVERTER(reciprocal, ReciprocalOpConverter);
#if IS_TRT_VERSION_GE(7000)
REGISTER_TRT_OP_CONVERTER(erf, ErfOpConverter);
#endif
......@@ -79,17 +79,17 @@ struct SimpleOpTypeSetTeller : public Teller {
desc.HasAttr("skip_quant"))
return false;
std::unordered_set<std::string> act_op_list = {
"relu", "relu6", "sigmoid",
"elu", "selu", "softsign",
"softplus", "stanh", "thresholded_relu",
"exp", "log", "sqrt",
"abs", "sin", "cos",
"tan", "tanh", "sinh",
"cosh", "asin", "acos",
"atan", "asinh", "atanh",
"ceil", "floor", "erf",
"silu", "celu", "tanh_shrink",
"logsigmoid"};
"relu", "relu6", "sigmoid",
"elu", "selu", "softsign",
"softplus", "stanh", "thresholded_relu",
"exp", "log", "sqrt",
"abs", "sin", "cos",
"tan", "tanh", "sinh",
"cosh", "asin", "acos",
"atan", "asinh", "atanh",
"ceil", "floor", "erf",
"reciprocal", "silu", "celu",
"tanh_shrink", "logsigmoid"};
if (act_op_list.find(op_type) != act_op_list.end()) {
auto* block = desc.Block();
if (block == nullptr) {
......@@ -2301,6 +2301,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"atanh",
"ceil",
"floor",
"reciprocal",
"erf",
"softmax",
"sigmoid",
......@@ -2428,6 +2429,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"atanh",
"ceil",
"floor",
"reciprocal",
"erf",
"softmax",
"sigmoid",
......
......@@ -59,6 +59,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
"atanh",
"ceil",
"floor",
"reciprocal",
]:
self.dims = dims
dics = [{}]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册