From f695dc97bf5d305b3e0d7e3625a86449536b8ce6 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Thu, 9 Dec 2021 10:48:24 +0800 Subject: [PATCH] [Paddle-Inference] fix_ele_convert: IElementWiseLayer can broadcast (#37908) * fix_ele_convert: IElementWiseLayer can broadcast * fix_ele_convert --- .../tensorrt/convert/elementwise_op.cc | 2 +- .../inference/test_trt_convert_elementwise.py | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc index 8569dd63478..7c5af43816c 100644 --- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc @@ -228,7 +228,7 @@ class ElementwiseTensorOpConverter : public OpConverter { } }; - if (CheckDims(dims_x, dims_y)) { + if (dims_x.nbDims == dims_y.nbDims) { // The two input tensor should have the same dims VLOG(3) << "Convert a fluid elementwise op to TensorRT IElementWiseLayer"; nvinfer1::IElementWiseLayer* elet_layer = diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py index 992e0353837..b54b923d3b0 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py @@ -317,21 +317,28 @@ class TrtConvertElementwiseTest_two_input_with_broadcast(TrtLayerAutoScanTest): input1_shape_list = [[4, 32], [2, 4, 32], [4, 2, 4, 32]] input2_shape1_list = [[32], [4, 32], [2, 4, 32]] input2_shape2_list = [[4, 1], [2, 4, 1], [4, 2, 4, 1]] - input2_shape3_list = [[32], [2, 1, 1], [4, 2, 1, 1]] - input2_shape4_list = [[32], [4, 32], [4, 1, 1, 1]] + input2_shape3_list = [[32], [2, 1, 1], [4, 2, 1, 32]] + input2_shape4_list = [[32], [4, 32], [4, 1, 4, 32]] + input2_shape5_list = [[32], [2, 1, 32], [4, 1, 1, 32]] + input2_shape6_list = [[1, 32], [1, 32], [1, 1, 1, 32]] input2_shape_list = [ input2_shape1_list, input2_shape2_list, input2_shape3_list, - input2_shape4_list + input2_shape4_list, input2_shape5_list, input2_shape6_list ] axis1_list = [[-1], [1, -1], [1, -1]] axis2_list = [[-1], [0], [0]] axis3_list = [[-1], [0], [0]] axis4_list = [[-1], [-1], [0]] - axis_list = [axis1_list, axis2_list, axis3_list, axis4_list] + axis5_list = [[-1, 1], [-1, 0], [-1, 0]] + axis6_list = [[-1, 0], [-1, 1], [-1, 0]] + axis_list = [ + axis1_list, axis2_list, axis3_list, axis4_list, axis5_list, + axis6_list + ] for i in range(3): input1_shape = input1_shape_list[i] - for j in range(4): + for j in range(6): input2_shape = input2_shape_list[j][i] for op_type in ["elementwise_add", "elementwise_mul"]: for axis in axis_list[j][i]: -- GitLab