未验证 提交 890c7315 编写于 作者: 提交者: GitHub

[inference]add unary trt convert (#43509)

* add unary
上级 1ec626b1
......@@ -66,6 +66,23 @@ const std::unordered_map<std::string, nvinfer1::UnaryOperation>
UnaryOpConverter::ops = {
{"exp", nvinfer1::UnaryOperation::kEXP},
{"log", nvinfer1::UnaryOperation::kLOG},
{"sqrt", nvinfer1::UnaryOperation::kSQRT},
{"abs", nvinfer1::UnaryOperation::kABS},
{"sin", nvinfer1::UnaryOperation::kSIN},
{"cos", nvinfer1::UnaryOperation::kCOS},
{"tan", nvinfer1::UnaryOperation::kTAN},
{"sinh", nvinfer1::UnaryOperation::kSINH},
{"cosh", nvinfer1::UnaryOperation::kCOSH},
{"asin", nvinfer1::UnaryOperation::kASIN},
{"acos", nvinfer1::UnaryOperation::kACOS},
{"atan", nvinfer1::UnaryOperation::kATAN},
{"asinh", nvinfer1::UnaryOperation::kASINH},
{"atanh", nvinfer1::UnaryOperation::kATANH},
{"ceil", nvinfer1::UnaryOperation::kCEIL},
{"floor", nvinfer1::UnaryOperation::kFLOOR},
#if IS_TRT_VERSION_GE(7000)
{"erf", nvinfer1::UnaryOperation::kERF},
#endif
};
class ExpOpConverter : public UnaryOpConverter {
......@@ -78,9 +95,89 @@ class LogOpConverter : public UnaryOpConverter {
LogOpConverter() { op_type_ = "log"; }
};
class SqrtOpConverter : public UnaryOpConverter {
public:
SqrtOpConverter() { op_type_ = "sqrt"; }
};
class AbsOpConverter : public UnaryOpConverter {
public:
AbsOpConverter() { op_type_ = "abs"; }
};
class SinOpConverter : public UnaryOpConverter {
public:
SinOpConverter() { op_type_ = "sin"; }
};
class CosOpConverter : public UnaryOpConverter {
public:
CosOpConverter() { op_type_ = "cos"; }
};
class TanOpConverter : public UnaryOpConverter {
public:
TanOpConverter() { op_type_ = "tan"; }
};
class SinhOpConverter : public UnaryOpConverter {
public:
SinhOpConverter() { op_type_ = "sinh"; }
};
class CoshOpConverter : public UnaryOpConverter {
public:
CoshOpConverter() { op_type_ = "cosh"; }
};
class AsinOpConverter : public UnaryOpConverter {
public:
AsinOpConverter() { op_type_ = "asin"; }
};
class AcosOpConverter : public UnaryOpConverter {
public:
AcosOpConverter() { op_type_ = "acos"; }
};
class AtanOpConverter : public UnaryOpConverter {
public:
AtanOpConverter() { op_type_ = "atan"; }
};
class AsinhOpConverter : public UnaryOpConverter {
public:
AsinhOpConverter() { op_type_ = "asinh"; }
};
class AtanhOpConverter : public UnaryOpConverter {
public:
AtanhOpConverter() { op_type_ = "atanh"; }
};
class CeilOpConverter : public UnaryOpConverter {
public:
CeilOpConverter() { op_type_ = "ceil"; }
};
class FloorOpConverter : public UnaryOpConverter {
public:
FloorOpConverter() { op_type_ = "floor"; }
};
#if IS_TRT_VERSION_GE(7000)
class ErfOpConverter : public UnaryOpConverter {
public:
ErfOpConverter() { op_type_ = "erf"; }
};
#endif
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(exp, ExpOpConverter);
REGISTER_TRT_OP_CONVERTER(log, LogOpConverter);
REGISTER_TRT_OP_CONVERTER(sqrt, SqrtOpConverter);
REGISTER_TRT_OP_CONVERTER(abs, AbsOpConverter);
REGISTER_TRT_OP_CONVERTER(sin, SinOpConverter);
REGISTER_TRT_OP_CONVERTER(cos, CosOpConverter);
REGISTER_TRT_OP_CONVERTER(tan, TanOpConverter);
REGISTER_TRT_OP_CONVERTER(sinh, SinhOpConverter);
REGISTER_TRT_OP_CONVERTER(cosh, CoshOpConverter);
REGISTER_TRT_OP_CONVERTER(asin, AsinOpConverter);
REGISTER_TRT_OP_CONVERTER(acos, AcosOpConverter);
REGISTER_TRT_OP_CONVERTER(atan, AtanOpConverter);
REGISTER_TRT_OP_CONVERTER(asinh, AsinhOpConverter);
REGISTER_TRT_OP_CONVERTER(atanh, AtanhOpConverter);
REGISTER_TRT_OP_CONVERTER(ceil, CeilOpConverter);
REGISTER_TRT_OP_CONVERTER(floor, FloorOpConverter);
#if IS_TRT_VERSION_GE(7000)
REGISTER_TRT_OP_CONVERTER(erf, ErfOpConverter);
#endif
......@@ -75,6 +75,21 @@ struct SimpleOpTypeSetTeller : public Teller {
"relu",
"exp",
"log",
"sqrt",
"abs",
"sin",
"cos",
"tan",
"sinh",
"cosh",
"asin",
"acos",
"atan",
"asinh",
"atanh",
"ceil",
"floor",
"erf",
"softmax",
"sigmoid",
"hard_swish",
......@@ -148,6 +163,21 @@ struct SimpleOpTypeSetTeller : public Teller {
"relu",
"exp",
"log",
"sqrt",
"abs",
"sin",
"cos",
"tan",
"sinh",
"cosh",
"asin",
"acos",
"atan",
"asinh",
"atanh",
"ceil",
"floor",
"erf",
"softmax",
"sigmoid",
"hard_swish",
......@@ -227,8 +257,31 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false;
for (auto& teller : tellers_) {
if (op_type == "relu" || op_type == "relu6" || op_type == "tanh" ||
op_type == "sigmoid" || op_type == "exp" || op_type == "log") {
std::unordered_set<std::string> act_op_list = {"relu",
"elu",
"selu",
"softsign",
"softplus",
"stanh",
"thresholded_relu",
"exp",
"log",
"sqrt",
"abs",
"sin",
"cos",
"tan",
"sinh",
"cosh",
"asin",
"acos",
"atan",
"asinh",
"atanh",
"ceil",
"floor",
"erf"};
if (act_op_list.find(op_type) != act_op_list.end()) {
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
......@@ -244,6 +297,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
<< " op does not support input's dim is 1 in tensorrt.";
return false;
}
#if !IS_TRT_VERSION_GE(7000)
if (op_type == "erf") {
VLOG(3) << op_type << " op does not support tensorrt.";
return false;
}
#endif
}
if (op_type == "pool2d") {
......
......@@ -27,20 +27,25 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
return True
def sample_program_configs(self):
self.trt_param.workspace_size = 1073741824
def generate_input1(dims, batch, attrs: List[Dict[str, Any]]):
if dims == 1:
return np.ones([32]).astype(np.float32)
return np.random.random([32]).astype(np.float32)
elif dims == 2:
return np.ones([3, 32]).astype(np.float32)
return np.random.random([3, 32]).astype(np.float32)
elif dims == 3:
return np.ones([3, 32, 32]).astype(np.float32)
return np.random.random([3, 32, 32]).astype(np.float32)
else:
return np.ones([batch, 3, 32, 32]).astype(np.float32)
return np.random.random([batch, 3, 32, 32]).astype(np.float32)
for dims in [1, 2, 3, 4]:
for batch in [1, 4]:
for op_type in ["exp", "log"]:
for op_type in [
"exp", "log", "sqrt", "abs", "sin", "cos", "tan",
"sinh", "cosh", "asin", "acos", "atan", "asinh",
"atanh", "ceil", "floor"
]:
self.dims = dims
dics = [{}]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册