From 2d20869c944fa116f7b8e84c30c91823a3723faf Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Fri, 17 Jan 2020 14:47:36 +0800 Subject: [PATCH] Fix infer_shape in compling for elementwise_op (#22291) --- .../fluid/operators/elementwise/elementwise_op_function.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 644bda34a09..3710e008ca1 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -143,10 +143,11 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims, "the shape of Y = [%s]. Received [%d] in X is not equal to " "[%d] in Y", x_dims, y_dims, x_dims_array[i], y_dims_array[i]); - if (x_dims_array[i] == -1 || y_dims_array[i] == -1) { - out_dims_array[i] = -1; - } else { + if ((x_dims_array[i] > 1 || y_dims_array[i] > 1) || + (x_dims_array[i] == 1 && y_dims_array[i] == 1)) { out_dims_array[i] = std::max(x_dims_array[i], y_dims_array[i]); + } else { + out_dims_array[i] = -1; } } } -- GitLab