未验证 提交 d80330fe 编写于 作者: G gem5 提交者: GitHub

add reciprocal trt converter (#48230)

上级 ab6a3dad
...@@ -81,6 +81,7 @@ const std::unordered_map<std::string, nvinfer1::UnaryOperation> ...@@ -81,6 +81,7 @@ const std::unordered_map<std::string, nvinfer1::UnaryOperation>
{"atanh", nvinfer1::UnaryOperation::kATANH}, {"atanh", nvinfer1::UnaryOperation::kATANH},
{"ceil", nvinfer1::UnaryOperation::kCEIL}, {"ceil", nvinfer1::UnaryOperation::kCEIL},
{"floor", nvinfer1::UnaryOperation::kFLOOR}, {"floor", nvinfer1::UnaryOperation::kFLOOR},
{"reciprocal", nvinfer1::UnaryOperation::kRECIP},
#if IS_TRT_VERSION_GE(7000) #if IS_TRT_VERSION_GE(7000)
{"erf", nvinfer1::UnaryOperation::kERF}, {"erf", nvinfer1::UnaryOperation::kERF},
#endif #endif
...@@ -152,6 +153,10 @@ class FloorOpConverter : public UnaryOpConverter { ...@@ -152,6 +153,10 @@ class FloorOpConverter : public UnaryOpConverter {
public: public:
FloorOpConverter() { op_type_ = "floor"; } FloorOpConverter() { op_type_ = "floor"; }
}; };
class ReciprocalOpConverter : public UnaryOpConverter {
public:
ReciprocalOpConverter() { op_type_ = "reciprocal"; }
};
#if IS_TRT_VERSION_GE(7000) #if IS_TRT_VERSION_GE(7000)
class ErfOpConverter : public UnaryOpConverter { class ErfOpConverter : public UnaryOpConverter {
public: public:
...@@ -179,6 +184,7 @@ REGISTER_TRT_OP_CONVERTER(asinh, AsinhOpConverter); ...@@ -179,6 +184,7 @@ REGISTER_TRT_OP_CONVERTER(asinh, AsinhOpConverter);
REGISTER_TRT_OP_CONVERTER(atanh, AtanhOpConverter); REGISTER_TRT_OP_CONVERTER(atanh, AtanhOpConverter);
REGISTER_TRT_OP_CONVERTER(ceil, CeilOpConverter); REGISTER_TRT_OP_CONVERTER(ceil, CeilOpConverter);
REGISTER_TRT_OP_CONVERTER(floor, FloorOpConverter); REGISTER_TRT_OP_CONVERTER(floor, FloorOpConverter);
REGISTER_TRT_OP_CONVERTER(reciprocal, ReciprocalOpConverter);
#if IS_TRT_VERSION_GE(7000) #if IS_TRT_VERSION_GE(7000)
REGISTER_TRT_OP_CONVERTER(erf, ErfOpConverter); REGISTER_TRT_OP_CONVERTER(erf, ErfOpConverter);
#endif #endif
...@@ -79,17 +79,17 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -79,17 +79,17 @@ struct SimpleOpTypeSetTeller : public Teller {
desc.HasAttr("skip_quant")) desc.HasAttr("skip_quant"))
return false; return false;
std::unordered_set<std::string> act_op_list = { std::unordered_set<std::string> act_op_list = {
"relu", "relu6", "sigmoid", "relu", "relu6", "sigmoid",
"elu", "selu", "softsign", "elu", "selu", "softsign",
"softplus", "stanh", "thresholded_relu", "softplus", "stanh", "thresholded_relu",
"exp", "log", "sqrt", "exp", "log", "sqrt",
"abs", "sin", "cos", "abs", "sin", "cos",
"tan", "tanh", "sinh", "tan", "tanh", "sinh",
"cosh", "asin", "acos", "cosh", "asin", "acos",
"atan", "asinh", "atanh", "atan", "asinh", "atanh",
"ceil", "floor", "erf", "ceil", "floor", "erf",
"silu", "celu", "tanh_shrink", "reciprocal", "silu", "celu",
"logsigmoid"}; "tanh_shrink", "logsigmoid"};
if (act_op_list.find(op_type) != act_op_list.end()) { if (act_op_list.find(op_type) != act_op_list.end()) {
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) { if (block == nullptr) {
...@@ -2301,6 +2301,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2301,6 +2301,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"atanh", "atanh",
"ceil", "ceil",
"floor", "floor",
"reciprocal",
"erf", "erf",
"softmax", "softmax",
"sigmoid", "sigmoid",
...@@ -2428,6 +2429,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2428,6 +2429,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"atanh", "atanh",
"ceil", "ceil",
"floor", "floor",
"reciprocal",
"erf", "erf",
"softmax", "softmax",
"sigmoid", "sigmoid",
......
...@@ -59,6 +59,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -59,6 +59,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
"atanh", "atanh",
"ceil", "ceil",
"floor", "floor",
"reciprocal",
]: ]:
self.dims = dims self.dims = dims
dics = [{}] dics = [{}]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册