From f3022dfa4aedbeccb30c05c8b23b22461c9d7b73 Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Tue, 29 Mar 2022 14:57:54 +0800 Subject: [PATCH] add elementwise sub and elementwise div in tensorrt op teller (#40806) * add elementwise sub and elementwise div in tensorrt op teller * add unittest of elementwise mul, sub and div --- paddle/fluid/inference/tensorrt/op_teller.cc | 7 ++++++- .../inference/test_trt_convert_elementwise.py | 10 ++++++++-- .../ir/inference/test_trt_elementwise_op.py | 18 ++++++++++++++++++ 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 7ddd4b55822..4a632bef774 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -74,7 +74,9 @@ struct SimpleOpTypeSetTeller : public Teller { "tanh", "pad", "elementwise_add", + "elementwise_sub", "elementwise_mul", + "elementwise_div", "dropout", "prelu", "conv2d_transpose", @@ -133,7 +135,9 @@ struct SimpleOpTypeSetTeller : public Teller { "tanh", "pad", "elementwise_add", + "elementwise_sub", "elementwise_mul", + "elementwise_div", "dropout", "prelu", "conv2d_transpose", @@ -958,7 +962,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } } - if (op_type == "elementwise_add" || op_type == "elementwise_mul") { + if (op_type == "elementwise_add" || op_type == "elementwise_mul" || + op_type == "elementwise_sub" || op_type == "elementwise_div") { if (desc.Input("X").size() != 1) { VLOG(3) << "The input op's Input(\"X\").size() " "should equal to 1, but received Input(\"X\").size() = " diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py index 505060e31a0..047a6094ec1 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py @@ -150,7 +150,10 @@ class TrtConvertElementwiseTest_two_input_without_broadcast( return np.random.random(shape).astype(np.float32) for shape in [[4], [4, 32], [2, 64, 32], [1, 8, 16, 32]]: - for op_type in ["elementwise_add", "elementwise_mul"]: + for op_type in [ + "elementwise_add", "elementwise_mul", "elementwise_sub", + "elementwise_div" + ]: for axis in [0, -1]: self.dims = len(shape) dics = [{"axis": axis}] @@ -306,7 +309,10 @@ class TrtConvertElementwiseTest_two_input_with_broadcast(TrtLayerAutoScanTest): input1_shape = input1_shape_list[i] for j in range(6): input2_shape = input2_shape_list[j][i] - for op_type in ["elementwise_add", "elementwise_mul"]: + for op_type in [ + "elementwise_add", "elementwise_mul", "elementwise_sub", + "elementwise_div" + ]: for axis in axis_list[j][i]: self.shape1 = input1_shape self.shape2 = input2_shape diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_elementwise_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_elementwise_op.py index f84202df5fb..b40daba4868 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_elementwise_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_elementwise_op.py @@ -56,5 +56,23 @@ class TensorRTSubgraphPassElementwiseBroadcastTest(InferencePassTest): PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) +class TensorRTSubgraphPassElementwiseBroadcastTest1( + TensorRTSubgraphPassElementwiseBroadcastTest): + def append_eltwise(self, data1, data2): + return fluid.layers.elementwise_sub(x=data1, y=data2, axis=0) + + +class TensorRTSubgraphPassElementwiseBroadcastTest2( + TensorRTSubgraphPassElementwiseBroadcastTest): + def append_eltwise(self, data1, data2): + return fluid.layers.elementwise_mul(x=data1, y=data2, axis=0) + + +class TensorRTSubgraphPassElementwiseBroadcastTest3( + TensorRTSubgraphPassElementwiseBroadcastTest): + def append_eltwise(self, data1, data2): + return fluid.layers.elementwise_div(x=data1, y=data2, axis=0) + + if __name__ == "__main__": unittest.main() -- GitLab