diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 644bda34a091adbbab5e77d6ddac84c9d48c514d..3710e008ca1b99c151cd53248dcf740ff5544c82 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -143,10 +143,11 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims, "the shape of Y = [%s]. Received [%d] in X is not equal to " "[%d] in Y", x_dims, y_dims, x_dims_array[i], y_dims_array[i]); - if (x_dims_array[i] == -1 || y_dims_array[i] == -1) { - out_dims_array[i] = -1; - } else { + if ((x_dims_array[i] > 1 || y_dims_array[i] > 1) || + (x_dims_array[i] == 1 && y_dims_array[i] == 1)) { out_dims_array[i] = std::max(x_dims_array[i], y_dims_array[i]); + } else { + out_dims_array[i] = -1; } } }