diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index ef94266b4ebe19eb0e668f4eaf4a0fd2f3395624..545b3c6f52354db9bd81440392112fcad9d24655 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -791,7 +791,7 @@ void BroadcastTensorsInferMeta(const std::vector& x, // We performed bcast semantics check at python level // So input tensors should all have legal shape - target_dim_size = std::max(target_dim_size, dim_size); + target_dim_size = dim_size == 1 ? target_dim_size : dim_size; } target_dims[target_rank - index - 1] = target_dim_size; } diff --git a/python/paddle/fluid/tests/unittests/test_broadcast_tensors_op.py b/python/paddle/fluid/tests/unittests/test_broadcast_tensors_op.py index 6eec711c49e0abb854cd5da84216e7a966221ec6..9879aac254fb702d5c016f025f2445ee49776349 100644 --- a/python/paddle/fluid/tests/unittests/test_broadcast_tensors_op.py +++ b/python/paddle/fluid/tests/unittests/test_broadcast_tensors_op.py @@ -33,14 +33,12 @@ def find_output_shape(input_list): rank = len(x.shape) output_rank = max(output_rank, rank) - output_shape = [0 for i in range(output_rank)] + output_shape = [1 for i in range(output_rank)] for i in range(output_rank): for x in input_list: shape = list(reversed(x.shape)) - size = 1 - if i < len(shape): - size = shape[i] - output_shape[i] = max(output_shape[i], size) + if i < len(shape) and shape[i] != 1: + output_shape[i] = shape[i] return list(reversed(output_shape)) @@ -80,6 +78,11 @@ def gen_mixed_tensors_test(dtype): return make_inputs_outputs(input_shapes, dtype) +def gen_empty_tensors_test(dtype): + input_shapes = [(0), (0), (0)] + return make_inputs_outputs(input_shapes, dtype) + + class TestCPUBroadcastTensorsOp(OpTest): def set_place(self): self.place = core.CPUPlace() @@ -95,6 +98,7 @@ class TestCPUBroadcastTensorsOp(OpTest): gen_rank_diff_test, gen_no_broadcast_test, gen_mixed_tensors_test, + gen_empty_tensors_test, ] self.set_place() self.set_dtypes()