未验证 提交 78ec942b 编写于 作者: R RedContritio 提交者: GitHub

Fix 空指针 (Null pointer) of case15: paddle.broadcast_tensors (#49980)

* fix incorrect output shape of broadcast

* add unittest
上级 1048b166
...@@ -791,7 +791,7 @@ void BroadcastTensorsInferMeta(const std::vector<const MetaTensor*>& x, ...@@ -791,7 +791,7 @@ void BroadcastTensorsInferMeta(const std::vector<const MetaTensor*>& x,
// We performed bcast semantics check at python level // We performed bcast semantics check at python level
// So input tensors should all have legal shape // So input tensors should all have legal shape
target_dim_size = std::max(target_dim_size, dim_size); target_dim_size = dim_size == 1 ? target_dim_size : dim_size;
} }
target_dims[target_rank - index - 1] = target_dim_size; target_dims[target_rank - index - 1] = target_dim_size;
} }
......
...@@ -33,14 +33,12 @@ def find_output_shape(input_list): ...@@ -33,14 +33,12 @@ def find_output_shape(input_list):
rank = len(x.shape) rank = len(x.shape)
output_rank = max(output_rank, rank) output_rank = max(output_rank, rank)
output_shape = [0 for i in range(output_rank)] output_shape = [1 for i in range(output_rank)]
for i in range(output_rank): for i in range(output_rank):
for x in input_list: for x in input_list:
shape = list(reversed(x.shape)) shape = list(reversed(x.shape))
size = 1 if i < len(shape) and shape[i] != 1:
if i < len(shape): output_shape[i] = shape[i]
size = shape[i]
output_shape[i] = max(output_shape[i], size)
return list(reversed(output_shape)) return list(reversed(output_shape))
...@@ -80,6 +78,11 @@ def gen_mixed_tensors_test(dtype): ...@@ -80,6 +78,11 @@ def gen_mixed_tensors_test(dtype):
return make_inputs_outputs(input_shapes, dtype) return make_inputs_outputs(input_shapes, dtype)
def gen_empty_tensors_test(dtype):
input_shapes = [(0), (0), (0)]
return make_inputs_outputs(input_shapes, dtype)
class TestCPUBroadcastTensorsOp(OpTest): class TestCPUBroadcastTensorsOp(OpTest):
def set_place(self): def set_place(self):
self.place = core.CPUPlace() self.place = core.CPUPlace()
...@@ -95,6 +98,7 @@ class TestCPUBroadcastTensorsOp(OpTest): ...@@ -95,6 +98,7 @@ class TestCPUBroadcastTensorsOp(OpTest):
gen_rank_diff_test, gen_rank_diff_test,
gen_no_broadcast_test, gen_no_broadcast_test,
gen_mixed_tensors_test, gen_mixed_tensors_test,
gen_empty_tensors_test,
] ]
self.set_place() self.set_place()
self.set_dtypes() self.set_dtypes()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册