未验证 提交 f695dc97 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle-Inference] fix_ele_convert: IElementWiseLayer can broadcast (#37908)

* fix_ele_convert: IElementWiseLayer can broadcast

* fix_ele_convert
上级 1911b6f0
...@@ -228,7 +228,7 @@ class ElementwiseTensorOpConverter : public OpConverter { ...@@ -228,7 +228,7 @@ class ElementwiseTensorOpConverter : public OpConverter {
} }
}; };
if (CheckDims(dims_x, dims_y)) { if (dims_x.nbDims == dims_y.nbDims) {
// The two input tensor should have the same dims // The two input tensor should have the same dims
VLOG(3) << "Convert a fluid elementwise op to TensorRT IElementWiseLayer"; VLOG(3) << "Convert a fluid elementwise op to TensorRT IElementWiseLayer";
nvinfer1::IElementWiseLayer* elet_layer = nvinfer1::IElementWiseLayer* elet_layer =
......
...@@ -317,21 +317,28 @@ class TrtConvertElementwiseTest_two_input_with_broadcast(TrtLayerAutoScanTest): ...@@ -317,21 +317,28 @@ class TrtConvertElementwiseTest_two_input_with_broadcast(TrtLayerAutoScanTest):
input1_shape_list = [[4, 32], [2, 4, 32], [4, 2, 4, 32]] input1_shape_list = [[4, 32], [2, 4, 32], [4, 2, 4, 32]]
input2_shape1_list = [[32], [4, 32], [2, 4, 32]] input2_shape1_list = [[32], [4, 32], [2, 4, 32]]
input2_shape2_list = [[4, 1], [2, 4, 1], [4, 2, 4, 1]] input2_shape2_list = [[4, 1], [2, 4, 1], [4, 2, 4, 1]]
input2_shape3_list = [[32], [2, 1, 1], [4, 2, 1, 1]] input2_shape3_list = [[32], [2, 1, 1], [4, 2, 1, 32]]
input2_shape4_list = [[32], [4, 32], [4, 1, 1, 1]] input2_shape4_list = [[32], [4, 32], [4, 1, 4, 32]]
input2_shape5_list = [[32], [2, 1, 32], [4, 1, 1, 32]]
input2_shape6_list = [[1, 32], [1, 32], [1, 1, 1, 32]]
input2_shape_list = [ input2_shape_list = [
input2_shape1_list, input2_shape2_list, input2_shape3_list, input2_shape1_list, input2_shape2_list, input2_shape3_list,
input2_shape4_list input2_shape4_list, input2_shape5_list, input2_shape6_list
] ]
axis1_list = [[-1], [1, -1], [1, -1]] axis1_list = [[-1], [1, -1], [1, -1]]
axis2_list = [[-1], [0], [0]] axis2_list = [[-1], [0], [0]]
axis3_list = [[-1], [0], [0]] axis3_list = [[-1], [0], [0]]
axis4_list = [[-1], [-1], [0]] axis4_list = [[-1], [-1], [0]]
axis_list = [axis1_list, axis2_list, axis3_list, axis4_list] axis5_list = [[-1, 1], [-1, 0], [-1, 0]]
axis6_list = [[-1, 0], [-1, 1], [-1, 0]]
axis_list = [
axis1_list, axis2_list, axis3_list, axis4_list, axis5_list,
axis6_list
]
for i in range(3): for i in range(3):
input1_shape = input1_shape_list[i] input1_shape = input1_shape_list[i]
for j in range(4): for j in range(6):
input2_shape = input2_shape_list[j][i] input2_shape = input2_shape_list[j][i]
for op_type in ["elementwise_add", "elementwise_mul"]: for op_type in ["elementwise_add", "elementwise_mul"]:
for axis in axis_list[j][i]: for axis in axis_list[j][i]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册