未验证 提交 1b1d6d3f 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle Inference] Add sign and not trt converter (#48557)

上级 529e74e4
...@@ -90,7 +90,11 @@ const std::unordered_map<std::string, std::vector<nvinfer1::UnaryOperation>> ...@@ -90,7 +90,11 @@ const std::unordered_map<std::string, std::vector<nvinfer1::UnaryOperation>>
{"floor", {nvinfer1::UnaryOperation::kFLOOR}}, {"floor", {nvinfer1::UnaryOperation::kFLOOR}},
{"rsqrt", {"rsqrt",
{nvinfer1::UnaryOperation::kSQRT, nvinfer1::UnaryOperation::kRECIP}}, {nvinfer1::UnaryOperation::kSQRT, nvinfer1::UnaryOperation::kRECIP}},
{"logical_not", {nvinfer1::UnaryOperation::kNOT}},
{"reciprocal", {nvinfer1::UnaryOperation::kRECIP}}, {"reciprocal", {nvinfer1::UnaryOperation::kRECIP}},
#if IS_TRT_VERSION_GE(8200)
{"sign", {nvinfer1::UnaryOperation::kSIGN}},
#endif
#if IS_TRT_VERSION_GE(7000) #if IS_TRT_VERSION_GE(7000)
{"erf", {nvinfer1::UnaryOperation::kERF}}, {"erf", {nvinfer1::UnaryOperation::kERF}},
#endif #endif
...@@ -167,10 +171,24 @@ class RsqrtOpConverter : public UnaryOpConverter { ...@@ -167,10 +171,24 @@ class RsqrtOpConverter : public UnaryOpConverter {
public: public:
RsqrtOpConverter() { op_type_ = "rsqrt"; } RsqrtOpConverter() { op_type_ = "rsqrt"; }
}; };
class LogicalNotOpConverter : public UnaryOpConverter {
public:
LogicalNotOpConverter() { op_type_ = "logical_not"; }
};
class ReciprocalOpConverter : public UnaryOpConverter { class ReciprocalOpConverter : public UnaryOpConverter {
public: public:
ReciprocalOpConverter() { op_type_ = "reciprocal"; } ReciprocalOpConverter() { op_type_ = "reciprocal"; }
}; };
#if IS_TRT_VERSION_GE(8200)
class SignOpConverter : public UnaryOpConverter {
public:
SignOpConverter() { op_type_ = "sign"; }
};
#endif
#if IS_TRT_VERSION_GE(7000) #if IS_TRT_VERSION_GE(7000)
class ErfOpConverter : public UnaryOpConverter { class ErfOpConverter : public UnaryOpConverter {
public: public:
...@@ -199,7 +217,11 @@ REGISTER_TRT_OP_CONVERTER(atanh, AtanhOpConverter); ...@@ -199,7 +217,11 @@ REGISTER_TRT_OP_CONVERTER(atanh, AtanhOpConverter);
REGISTER_TRT_OP_CONVERTER(ceil, CeilOpConverter); REGISTER_TRT_OP_CONVERTER(ceil, CeilOpConverter);
REGISTER_TRT_OP_CONVERTER(floor, FloorOpConverter); REGISTER_TRT_OP_CONVERTER(floor, FloorOpConverter);
REGISTER_TRT_OP_CONVERTER(rsqrt, RsqrtOpConverter); REGISTER_TRT_OP_CONVERTER(rsqrt, RsqrtOpConverter);
REGISTER_TRT_OP_CONVERTER(logical_not, LogicalNotOpConverter);
REGISTER_TRT_OP_CONVERTER(reciprocal, ReciprocalOpConverter); REGISTER_TRT_OP_CONVERTER(reciprocal, ReciprocalOpConverter);
#if IS_TRT_VERSION_GE(8200)
REGISTER_TRT_OP_CONVERTER(sign, SignOpConverter);
#endif
#if IS_TRT_VERSION_GE(7000) #if IS_TRT_VERSION_GE(7000)
REGISTER_TRT_OP_CONVERTER(erf, ErfOpConverter); REGISTER_TRT_OP_CONVERTER(erf, ErfOpConverter);
#endif #endif
...@@ -89,7 +89,8 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -89,7 +89,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"atan", "asinh", "atanh", "atan", "asinh", "atanh",
"ceil", "floor", "erf", "ceil", "floor", "erf",
"reciprocal", "silu", "celu", "reciprocal", "silu", "celu",
"tanh_shrink", "logsigmoid"}; "tanh_shrink", "logsigmoid", "sign",
"logical_not"};
if (act_op_list.find(op_type) != act_op_list.end()) { if (act_op_list.find(op_type) != act_op_list.end()) {
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) { if (block == nullptr) {
...@@ -336,6 +337,29 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -336,6 +337,29 @@ struct SimpleOpTypeSetTeller : public Teller {
} }
} }
if (op_type == "sign") {
#if IS_TRT_VERSION_GE(8200)
if (!with_dynamic_shape) {
return false;
}
#else
VLOG(3) << "sign op is only supported by trt8.2 above ";
return false;
#endif
}
if (op_type == "logical_not") {
#if IS_TRT_VERSION_GE(8400)
if (!with_dynamic_shape) {
return false;
}
#else
VLOG(3) << "logical_not op is only supported by trt8.4 above because of "
"cast op";
return false;
#endif
}
if (op_type == "matmul_v2") { if (op_type == "matmul_v2") {
if (!with_dynamic_shape) { if (!with_dynamic_shape) {
return false; return false;
...@@ -2341,7 +2365,9 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2341,7 +2365,9 @@ struct SimpleOpTypeSetTeller : public Teller {
"ceil", "ceil",
"floor", "floor",
"rsqrt", "rsqrt",
"sign",
"reciprocal", "reciprocal",
"logical_not",
"erf", "erf",
"softmax", "softmax",
"sigmoid", "sigmoid",
...@@ -2471,7 +2497,9 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2471,7 +2497,9 @@ struct SimpleOpTypeSetTeller : public Teller {
"ceil", "ceil",
"floor", "floor",
"rsqrt", "rsqrt",
"sign",
"reciprocal", "reciprocal",
"logical_not",
"erf", "erf",
"softmax", "softmax",
"sigmoid", "sigmoid",
......
...@@ -59,8 +59,10 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -59,8 +59,10 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
"floor", "floor",
"rsqrt", "rsqrt",
"reciprocal", "reciprocal",
"sign",
]: ]:
self.dims = dims self.dims = dims
self.op_type = op_type
dics = [{}] dics = [{}]
ops_config = [ ops_config = [
...@@ -121,7 +123,14 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -121,7 +123,14 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
self.dynamic_shape.opt_input_shape = {} self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape): def generate_trt_nodes_num(attrs, dynamic_shape):
if self.dims == 1: ver = paddle_infer.get_trt_compile_version()
if self.dims == 1 or (
self.op_type == "sign"
and (
not dynamic_shape
or ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8200
)
):
return 0, 3 return 0, 3
return 1, 2 return 1, 2
...@@ -155,5 +164,143 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -155,5 +164,143 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
self.run_test() self.run_test()
class TrtConvertLogicalNotTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self):
def generate_input(shape):
return np.random.random(shape).astype(np.float32)
for shape in [[2, 16], [2, 16, 32], [1, 32, 16, 32]]:
for op_type in ["logical_not"]:
for axis in [-1]:
self.dims = len(shape)
dics = [
{"axis": axis},
{"in_dtype": 5, "out_dtype": 0},
{"in_dtype": 0, "out_dtype": 5},
]
ops_config = [
{
"op_type": "cast",
"op_inputs": {"X": ["input_data"]},
"op_outputs": {"Out": ["cast_output_data1"]},
"op_attrs": dics[1],
"outputs_dtype": {"cast_output_data1": np.bool},
},
{
"op_type": op_type,
"op_inputs": {
"X": ["cast_output_data1"],
},
"op_outputs": {"Out": ["cast_output_data0"]},
"op_attrs": dics[0],
"outputs_dtype": {"cast_output_data0": np.bool},
},
{
"op_type": "cast",
"op_inputs": {"X": ["cast_output_data0"]},
"op_outputs": {"Out": ["output_data"]},
"op_attrs": dics[2],
},
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data": TensorConfig(
data_gen=partial(generate_input, shape)
),
},
outputs=["output_data"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
if self.dims == 2:
self.dynamic_shape.min_input_shape = {
"input_data": [2, 16],
}
self.dynamic_shape.max_input_shape = {
"input_data": [2, 16],
}
self.dynamic_shape.opt_input_shape = {
"input_data": [2, 16],
}
if self.dims == 3:
self.dynamic_shape.min_input_shape = {
"input_data": [2, 16, 32],
}
self.dynamic_shape.max_input_shape = {
"input_data": [2, 16, 32],
}
self.dynamic_shape.opt_input_shape = {
"input_data": [2, 16, 32],
}
if self.dims == 4:
self.dynamic_shape.min_input_shape = {
"input_data": [1, 32, 16, 32],
}
self.dynamic_shape.max_input_shape = {
"input_data": [1, 32, 16, 32],
}
self.dynamic_shape.opt_input_shape = {
"input_data": [1, 32, 16, 32],
}
def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
if dynamic_shape:
ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8400:
return 0, 5
return 1, 2
return 0, 5
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), (1e-3, 1e-3)
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), (1e-3, 1e-3)
def add_skip_trt_case(self):
pass
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册