未验证 提交 fc882c7b 编写于 作者: G gem5 提交者: GitHub

Support rsqrt op. (#48223)

上级 c0d31dac
......@@ -2246,6 +2246,7 @@ USE_TRT_CONVERTER(flatten_contiguous_range);
USE_TRT_CONVERTER(matmul);
USE_TRT_CONVERTER(matmul_v2);
USE_TRT_CONVERTER(bmm);
USE_TRT_CONVERTER(rsqrt);
USE_TRT_CONVERTER(conv2d);
USE_TRT_CONVERTER(relu);
USE_TRT_CONVERTER(exp);
......
文件模式从 100644 更改为 100755
......@@ -52,38 +52,47 @@ class UnaryOpConverter : public OpConverter {
nvinfer1::ITensor* input_tensor =
engine_->GetITensor(op_desc.Input("X")[0]);
auto op_pair = ops.find(op_type_);
nvinfer1::IUnaryLayer* layer =
TRT_ENGINE_ADD_LAYER(engine_, Unary, *input_tensor, op_pair->second);
nvinfer1::IUnaryLayer* layer = nullptr;
for (auto trt_op : op_pair->second) {
layer = TRT_ENGINE_ADD_LAYER(engine_, Unary, *input_tensor, trt_op);
input_tensor = layer->getOutput(0);
}
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, op_type_, {output_name}, test_mode);
}
protected:
std::string op_type_;
static const std::unordered_map<std::string, nvinfer1::UnaryOperation> ops;
static const std::unordered_map<std::string,
std::vector<nvinfer1::UnaryOperation>>
ops;
};
const std::unordered_map<std::string, nvinfer1::UnaryOperation>
const std::unordered_map<std::string, std::vector<nvinfer1::UnaryOperation>>
UnaryOpConverter::ops = {
{"exp", nvinfer1::UnaryOperation::kEXP},
{"log", nvinfer1::UnaryOperation::kLOG},
{"sqrt", nvinfer1::UnaryOperation::kSQRT},
{"abs", nvinfer1::UnaryOperation::kABS},
{"sin", nvinfer1::UnaryOperation::kSIN},
{"cos", nvinfer1::UnaryOperation::kCOS},
{"tan", nvinfer1::UnaryOperation::kTAN},
{"sinh", nvinfer1::UnaryOperation::kSINH},
{"cosh", nvinfer1::UnaryOperation::kCOSH},
{"asin", nvinfer1::UnaryOperation::kASIN},
{"acos", nvinfer1::UnaryOperation::kACOS},
{"atan", nvinfer1::UnaryOperation::kATAN},
{"asinh", nvinfer1::UnaryOperation::kASINH},
{"atanh", nvinfer1::UnaryOperation::kATANH},
{"ceil", nvinfer1::UnaryOperation::kCEIL},
{"floor", nvinfer1::UnaryOperation::kFLOOR},
{"reciprocal", nvinfer1::UnaryOperation::kRECIP},
{"exp", {nvinfer1::UnaryOperation::kEXP}},
{"log", {nvinfer1::UnaryOperation::kLOG}},
{"sqrt", {nvinfer1::UnaryOperation::kSQRT}},
{"abs", {nvinfer1::UnaryOperation::kABS}},
{"sin", {nvinfer1::UnaryOperation::kSIN}},
{"cos", {nvinfer1::UnaryOperation::kCOS}},
{"tan", {nvinfer1::UnaryOperation::kTAN}},
{"sinh", {nvinfer1::UnaryOperation::kSINH}},
{"cosh", {nvinfer1::UnaryOperation::kCOSH}},
{"asin", {nvinfer1::UnaryOperation::kASIN}},
{"acos", {nvinfer1::UnaryOperation::kACOS}},
{"atan", {nvinfer1::UnaryOperation::kATAN}},
{"asinh", {nvinfer1::UnaryOperation::kASINH}},
{"atanh", {nvinfer1::UnaryOperation::kATANH}},
{"ceil", {nvinfer1::UnaryOperation::kCEIL}},
{"floor", {nvinfer1::UnaryOperation::kFLOOR}},
{"rsqrt",
{nvinfer1::UnaryOperation::kSQRT, nvinfer1::UnaryOperation::kRECIP}},
{"reciprocal", {nvinfer1::UnaryOperation::kRECIP}},
#if IS_TRT_VERSION_GE(7000)
{"erf", nvinfer1::UnaryOperation::kERF},
{"erf", {nvinfer1::UnaryOperation::kERF}},
#endif
};
......@@ -153,6 +162,11 @@ class FloorOpConverter : public UnaryOpConverter {
public:
FloorOpConverter() { op_type_ = "floor"; }
};
class RsqrtOpConverter : public UnaryOpConverter {
public:
RsqrtOpConverter() { op_type_ = "rsqrt"; }
};
class ReciprocalOpConverter : public UnaryOpConverter {
public:
ReciprocalOpConverter() { op_type_ = "reciprocal"; }
......@@ -184,6 +198,7 @@ REGISTER_TRT_OP_CONVERTER(asinh, AsinhOpConverter);
REGISTER_TRT_OP_CONVERTER(atanh, AtanhOpConverter);
REGISTER_TRT_OP_CONVERTER(ceil, CeilOpConverter);
REGISTER_TRT_OP_CONVERTER(floor, FloorOpConverter);
REGISTER_TRT_OP_CONVERTER(rsqrt, RsqrtOpConverter);
REGISTER_TRT_OP_CONVERTER(reciprocal, ReciprocalOpConverter);
#if IS_TRT_VERSION_GE(7000)
REGISTER_TRT_OP_CONVERTER(erf, ErfOpConverter);
......
......@@ -2310,6 +2310,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"atanh",
"ceil",
"floor",
"rsqrt",
"reciprocal",
"erf",
"softmax",
......@@ -2438,6 +2439,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"atanh",
"ceil",
"floor",
"rsqrt",
"reciprocal",
"erf",
"softmax",
......
......@@ -31,16 +31,14 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
self.trt_param.workspace_size = 1073741824
def generate_input1(dims, batch, attrs: List[Dict[str, Any]]):
if dims == 1:
return np.random.random([32]).astype(np.float32)
elif dims == 2:
if dims == 2:
return np.random.random([3, 32]).astype(np.float32)
elif dims == 3:
return np.random.random([3, 32, 32]).astype(np.float32)
else:
return np.random.random([batch, 3, 32, 32]).astype(np.float32)
for dims in [1, 2, 3, 4]:
for dims in [2, 3, 4]:
for batch in [1, 4]:
for op_type in [
"exp",
......@@ -59,6 +57,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
"atanh",
"ceil",
"floor",
"rsqrt",
"reciprocal",
]:
self.dims = dims
......@@ -135,7 +134,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), 1e-5
), 1e-4
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
......@@ -146,7 +145,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-5
), 1e-4
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册