From 890c73158f663b327be7664ed6c4d08fb2c236a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B4=A5?= <33440396+xingjing1@users.noreply.github.com> Date: Thu, 16 Jun 2022 15:23:27 +0800 Subject: [PATCH] [inference]add unary trt convert (#43509) * add unary --- .../inference/tensorrt/convert/unary_op.cc | 97 +++++++++++++++++++ paddle/fluid/inference/tensorrt/op_teller.cc | 63 +++++++++++- .../ir/inference/test_trt_convert_unary.py | 15 ++- 3 files changed, 168 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/unary_op.cc b/paddle/fluid/inference/tensorrt/convert/unary_op.cc index 72d5cb2aeb..2d56dc1196 100644 --- a/paddle/fluid/inference/tensorrt/convert/unary_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/unary_op.cc @@ -66,6 +66,23 @@ const std::unordered_map UnaryOpConverter::ops = { {"exp", nvinfer1::UnaryOperation::kEXP}, {"log", nvinfer1::UnaryOperation::kLOG}, + {"sqrt", nvinfer1::UnaryOperation::kSQRT}, + {"abs", nvinfer1::UnaryOperation::kABS}, + {"sin", nvinfer1::UnaryOperation::kSIN}, + {"cos", nvinfer1::UnaryOperation::kCOS}, + {"tan", nvinfer1::UnaryOperation::kTAN}, + {"sinh", nvinfer1::UnaryOperation::kSINH}, + {"cosh", nvinfer1::UnaryOperation::kCOSH}, + {"asin", nvinfer1::UnaryOperation::kASIN}, + {"acos", nvinfer1::UnaryOperation::kACOS}, + {"atan", nvinfer1::UnaryOperation::kATAN}, + {"asinh", nvinfer1::UnaryOperation::kASINH}, + {"atanh", nvinfer1::UnaryOperation::kATANH}, + {"ceil", nvinfer1::UnaryOperation::kCEIL}, + {"floor", nvinfer1::UnaryOperation::kFLOOR}, +#if IS_TRT_VERSION_GE(7000) + {"erf", nvinfer1::UnaryOperation::kERF}, +#endif }; class ExpOpConverter : public UnaryOpConverter { @@ -78,9 +95,89 @@ class LogOpConverter : public UnaryOpConverter { LogOpConverter() { op_type_ = "log"; } }; +class SqrtOpConverter : public UnaryOpConverter { + public: + SqrtOpConverter() { op_type_ = "sqrt"; } +}; +class AbsOpConverter : public UnaryOpConverter { + public: + AbsOpConverter() { op_type_ = "abs"; } +}; +class SinOpConverter : public UnaryOpConverter { + public: + SinOpConverter() { op_type_ = "sin"; } +}; +class CosOpConverter : public UnaryOpConverter { + public: + CosOpConverter() { op_type_ = "cos"; } +}; +class TanOpConverter : public UnaryOpConverter { + public: + TanOpConverter() { op_type_ = "tan"; } +}; +class SinhOpConverter : public UnaryOpConverter { + public: + SinhOpConverter() { op_type_ = "sinh"; } +}; +class CoshOpConverter : public UnaryOpConverter { + public: + CoshOpConverter() { op_type_ = "cosh"; } +}; +class AsinOpConverter : public UnaryOpConverter { + public: + AsinOpConverter() { op_type_ = "asin"; } +}; +class AcosOpConverter : public UnaryOpConverter { + public: + AcosOpConverter() { op_type_ = "acos"; } +}; +class AtanOpConverter : public UnaryOpConverter { + public: + AtanOpConverter() { op_type_ = "atan"; } +}; +class AsinhOpConverter : public UnaryOpConverter { + public: + AsinhOpConverter() { op_type_ = "asinh"; } +}; +class AtanhOpConverter : public UnaryOpConverter { + public: + AtanhOpConverter() { op_type_ = "atanh"; } +}; +class CeilOpConverter : public UnaryOpConverter { + public: + CeilOpConverter() { op_type_ = "ceil"; } +}; +class FloorOpConverter : public UnaryOpConverter { + public: + FloorOpConverter() { op_type_ = "floor"; } +}; +#if IS_TRT_VERSION_GE(7000) +class ErfOpConverter : public UnaryOpConverter { + public: + ErfOpConverter() { op_type_ = "erf"; } +}; +#endif + } // namespace tensorrt } // namespace inference } // namespace paddle REGISTER_TRT_OP_CONVERTER(exp, ExpOpConverter); REGISTER_TRT_OP_CONVERTER(log, LogOpConverter); +REGISTER_TRT_OP_CONVERTER(sqrt, SqrtOpConverter); +REGISTER_TRT_OP_CONVERTER(abs, AbsOpConverter); +REGISTER_TRT_OP_CONVERTER(sin, SinOpConverter); +REGISTER_TRT_OP_CONVERTER(cos, CosOpConverter); +REGISTER_TRT_OP_CONVERTER(tan, TanOpConverter); +REGISTER_TRT_OP_CONVERTER(sinh, SinhOpConverter); +REGISTER_TRT_OP_CONVERTER(cosh, CoshOpConverter); +REGISTER_TRT_OP_CONVERTER(asin, AsinOpConverter); +REGISTER_TRT_OP_CONVERTER(acos, AcosOpConverter); +REGISTER_TRT_OP_CONVERTER(atan, AtanOpConverter); +REGISTER_TRT_OP_CONVERTER(asinh, AsinhOpConverter); +REGISTER_TRT_OP_CONVERTER(atanh, AtanhOpConverter); +REGISTER_TRT_OP_CONVERTER(ceil, CeilOpConverter); +REGISTER_TRT_OP_CONVERTER(floor, FloorOpConverter); +#if IS_TRT_VERSION_GE(7000) +REGISTER_TRT_OP_CONVERTER(erf, ErfOpConverter); +#endif diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index d9b1e9b85f..05e8d196a8 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -75,6 +75,21 @@ struct SimpleOpTypeSetTeller : public Teller { "relu", "exp", "log", + "sqrt", + "abs", + "sin", + "cos", + "tan", + "sinh", + "cosh", + "asin", + "acos", + "atan", + "asinh", + "atanh", + "ceil", + "floor", + "erf", "softmax", "sigmoid", "hard_swish", @@ -148,6 +163,21 @@ struct SimpleOpTypeSetTeller : public Teller { "relu", "exp", "log", + "sqrt", + "abs", + "sin", + "cos", + "tan", + "sinh", + "cosh", + "asin", + "acos", + "atan", + "asinh", + "atanh", + "ceil", + "floor", + "erf", "softmax", "sigmoid", "hard_swish", @@ -227,8 +257,31 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, return false; for (auto& teller : tellers_) { - if (op_type == "relu" || op_type == "relu6" || op_type == "tanh" || - op_type == "sigmoid" || op_type == "exp" || op_type == "log") { + std::unordered_set act_op_list = {"relu", + "elu", + "selu", + "softsign", + "softplus", + "stanh", + "thresholded_relu", + "exp", + "log", + "sqrt", + "abs", + "sin", + "cos", + "tan", + "sinh", + "cosh", + "asin", + "acos", + "atan", + "asinh", + "atanh", + "ceil", + "floor", + "erf"}; + if (act_op_list.find(op_type) != act_op_list.end()) { auto* block = desc.Block(); if (block == nullptr) { VLOG(3) << "The block desc is nullptr, we can't continue to analyze. " @@ -244,6 +297,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, << " op does not support input's dim is 1 in tensorrt."; return false; } +#if !IS_TRT_VERSION_GE(7000) + if (op_type == "erf") { + VLOG(3) << op_type << " op does not support tensorrt."; + return false; + } +#endif } if (op_type == "pool2d") { diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_unary.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_unary.py index fd4753528e..ca4231a356 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_unary.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_unary.py @@ -27,20 +27,25 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): return True def sample_program_configs(self): + self.trt_param.workspace_size = 1073741824 def generate_input1(dims, batch, attrs: List[Dict[str, Any]]): if dims == 1: - return np.ones([32]).astype(np.float32) + return np.random.random([32]).astype(np.float32) elif dims == 2: - return np.ones([3, 32]).astype(np.float32) + return np.random.random([3, 32]).astype(np.float32) elif dims == 3: - return np.ones([3, 32, 32]).astype(np.float32) + return np.random.random([3, 32, 32]).astype(np.float32) else: - return np.ones([batch, 3, 32, 32]).astype(np.float32) + return np.random.random([batch, 3, 32, 32]).astype(np.float32) for dims in [1, 2, 3, 4]: for batch in [1, 4]: - for op_type in ["exp", "log"]: + for op_type in [ + "exp", "log", "sqrt", "abs", "sin", "cos", "tan", + "sinh", "cosh", "asin", "acos", "atan", "asinh", + "atanh", "ceil", "floor" + ]: self.dims = dims dics = [{}] -- GitLab