未验证 提交 fc882c7b 编写于 作者: G gem5 提交者: GitHub

Support rsqrt op. (#48223)

上级 c0d31dac
...@@ -2246,6 +2246,7 @@ USE_TRT_CONVERTER(flatten_contiguous_range); ...@@ -2246,6 +2246,7 @@ USE_TRT_CONVERTER(flatten_contiguous_range);
USE_TRT_CONVERTER(matmul); USE_TRT_CONVERTER(matmul);
USE_TRT_CONVERTER(matmul_v2); USE_TRT_CONVERTER(matmul_v2);
USE_TRT_CONVERTER(bmm); USE_TRT_CONVERTER(bmm);
USE_TRT_CONVERTER(rsqrt);
USE_TRT_CONVERTER(conv2d); USE_TRT_CONVERTER(conv2d);
USE_TRT_CONVERTER(relu); USE_TRT_CONVERTER(relu);
USE_TRT_CONVERTER(exp); USE_TRT_CONVERTER(exp);
......
文件模式从 100644 更改为 100755
...@@ -52,38 +52,47 @@ class UnaryOpConverter : public OpConverter { ...@@ -52,38 +52,47 @@ class UnaryOpConverter : public OpConverter {
nvinfer1::ITensor* input_tensor = nvinfer1::ITensor* input_tensor =
engine_->GetITensor(op_desc.Input("X")[0]); engine_->GetITensor(op_desc.Input("X")[0]);
auto op_pair = ops.find(op_type_); auto op_pair = ops.find(op_type_);
nvinfer1::IUnaryLayer* layer =
TRT_ENGINE_ADD_LAYER(engine_, Unary, *input_tensor, op_pair->second); nvinfer1::IUnaryLayer* layer = nullptr;
for (auto trt_op : op_pair->second) {
layer = TRT_ENGINE_ADD_LAYER(engine_, Unary, *input_tensor, trt_op);
input_tensor = layer->getOutput(0);
}
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, op_type_, {output_name}, test_mode); RreplenishLayerAndOutput(layer, op_type_, {output_name}, test_mode);
} }
protected: protected:
std::string op_type_; std::string op_type_;
static const std::unordered_map<std::string, nvinfer1::UnaryOperation> ops; static const std::unordered_map<std::string,
std::vector<nvinfer1::UnaryOperation>>
ops;
}; };
const std::unordered_map<std::string, nvinfer1::UnaryOperation> const std::unordered_map<std::string, std::vector<nvinfer1::UnaryOperation>>
UnaryOpConverter::ops = { UnaryOpConverter::ops = {
{"exp", nvinfer1::UnaryOperation::kEXP}, {"exp", {nvinfer1::UnaryOperation::kEXP}},
{"log", nvinfer1::UnaryOperation::kLOG}, {"log", {nvinfer1::UnaryOperation::kLOG}},
{"sqrt", nvinfer1::UnaryOperation::kSQRT}, {"sqrt", {nvinfer1::UnaryOperation::kSQRT}},
{"abs", nvinfer1::UnaryOperation::kABS}, {"abs", {nvinfer1::UnaryOperation::kABS}},
{"sin", nvinfer1::UnaryOperation::kSIN}, {"sin", {nvinfer1::UnaryOperation::kSIN}},
{"cos", nvinfer1::UnaryOperation::kCOS}, {"cos", {nvinfer1::UnaryOperation::kCOS}},
{"tan", nvinfer1::UnaryOperation::kTAN}, {"tan", {nvinfer1::UnaryOperation::kTAN}},
{"sinh", nvinfer1::UnaryOperation::kSINH}, {"sinh", {nvinfer1::UnaryOperation::kSINH}},
{"cosh", nvinfer1::UnaryOperation::kCOSH}, {"cosh", {nvinfer1::UnaryOperation::kCOSH}},
{"asin", nvinfer1::UnaryOperation::kASIN}, {"asin", {nvinfer1::UnaryOperation::kASIN}},
{"acos", nvinfer1::UnaryOperation::kACOS}, {"acos", {nvinfer1::UnaryOperation::kACOS}},
{"atan", nvinfer1::UnaryOperation::kATAN}, {"atan", {nvinfer1::UnaryOperation::kATAN}},
{"asinh", nvinfer1::UnaryOperation::kASINH}, {"asinh", {nvinfer1::UnaryOperation::kASINH}},
{"atanh", nvinfer1::UnaryOperation::kATANH}, {"atanh", {nvinfer1::UnaryOperation::kATANH}},
{"ceil", nvinfer1::UnaryOperation::kCEIL}, {"ceil", {nvinfer1::UnaryOperation::kCEIL}},
{"floor", nvinfer1::UnaryOperation::kFLOOR}, {"floor", {nvinfer1::UnaryOperation::kFLOOR}},
{"reciprocal", nvinfer1::UnaryOperation::kRECIP}, {"rsqrt",
{nvinfer1::UnaryOperation::kSQRT, nvinfer1::UnaryOperation::kRECIP}},
{"reciprocal", {nvinfer1::UnaryOperation::kRECIP}},
#if IS_TRT_VERSION_GE(7000) #if IS_TRT_VERSION_GE(7000)
{"erf", nvinfer1::UnaryOperation::kERF}, {"erf", {nvinfer1::UnaryOperation::kERF}},
#endif #endif
}; };
...@@ -153,6 +162,11 @@ class FloorOpConverter : public UnaryOpConverter { ...@@ -153,6 +162,11 @@ class FloorOpConverter : public UnaryOpConverter {
public: public:
FloorOpConverter() { op_type_ = "floor"; } FloorOpConverter() { op_type_ = "floor"; }
}; };
class RsqrtOpConverter : public UnaryOpConverter {
public:
RsqrtOpConverter() { op_type_ = "rsqrt"; }
};
class ReciprocalOpConverter : public UnaryOpConverter { class ReciprocalOpConverter : public UnaryOpConverter {
public: public:
ReciprocalOpConverter() { op_type_ = "reciprocal"; } ReciprocalOpConverter() { op_type_ = "reciprocal"; }
...@@ -184,6 +198,7 @@ REGISTER_TRT_OP_CONVERTER(asinh, AsinhOpConverter); ...@@ -184,6 +198,7 @@ REGISTER_TRT_OP_CONVERTER(asinh, AsinhOpConverter);
REGISTER_TRT_OP_CONVERTER(atanh, AtanhOpConverter); REGISTER_TRT_OP_CONVERTER(atanh, AtanhOpConverter);
REGISTER_TRT_OP_CONVERTER(ceil, CeilOpConverter); REGISTER_TRT_OP_CONVERTER(ceil, CeilOpConverter);
REGISTER_TRT_OP_CONVERTER(floor, FloorOpConverter); REGISTER_TRT_OP_CONVERTER(floor, FloorOpConverter);
REGISTER_TRT_OP_CONVERTER(rsqrt, RsqrtOpConverter);
REGISTER_TRT_OP_CONVERTER(reciprocal, ReciprocalOpConverter); REGISTER_TRT_OP_CONVERTER(reciprocal, ReciprocalOpConverter);
#if IS_TRT_VERSION_GE(7000) #if IS_TRT_VERSION_GE(7000)
REGISTER_TRT_OP_CONVERTER(erf, ErfOpConverter); REGISTER_TRT_OP_CONVERTER(erf, ErfOpConverter);
......
...@@ -2310,6 +2310,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2310,6 +2310,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"atanh", "atanh",
"ceil", "ceil",
"floor", "floor",
"rsqrt",
"reciprocal", "reciprocal",
"erf", "erf",
"softmax", "softmax",
...@@ -2438,6 +2439,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2438,6 +2439,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"atanh", "atanh",
"ceil", "ceil",
"floor", "floor",
"rsqrt",
"reciprocal", "reciprocal",
"erf", "erf",
"softmax", "softmax",
......
...@@ -31,16 +31,14 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -31,16 +31,14 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
self.trt_param.workspace_size = 1073741824 self.trt_param.workspace_size = 1073741824
def generate_input1(dims, batch, attrs: List[Dict[str, Any]]): def generate_input1(dims, batch, attrs: List[Dict[str, Any]]):
if dims == 1: if dims == 2:
return np.random.random([32]).astype(np.float32)
elif dims == 2:
return np.random.random([3, 32]).astype(np.float32) return np.random.random([3, 32]).astype(np.float32)
elif dims == 3: elif dims == 3:
return np.random.random([3, 32, 32]).astype(np.float32) return np.random.random([3, 32, 32]).astype(np.float32)
else: else:
return np.random.random([batch, 3, 32, 32]).astype(np.float32) return np.random.random([batch, 3, 32, 32]).astype(np.float32)
for dims in [1, 2, 3, 4]: for dims in [2, 3, 4]:
for batch in [1, 4]: for batch in [1, 4]:
for op_type in [ for op_type in [
"exp", "exp",
...@@ -59,6 +57,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -59,6 +57,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
"atanh", "atanh",
"ceil", "ceil",
"floor", "floor",
"rsqrt",
"reciprocal", "reciprocal",
]: ]:
self.dims = dims self.dims = dims
...@@ -135,7 +134,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -135,7 +134,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
self.trt_param.precision = paddle_infer.PrecisionType.Float32 self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num( yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False attrs, False
), 1e-5 ), 1e-4
self.trt_param.precision = paddle_infer.PrecisionType.Half self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num( yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False attrs, False
...@@ -146,7 +145,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -146,7 +145,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
self.trt_param.precision = paddle_infer.PrecisionType.Float32 self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num( yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True attrs, True
), 1e-5 ), 1e-4
self.trt_param.precision = paddle_infer.PrecisionType.Half self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num( yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True attrs, True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册