未验证 提交 78ec942b 编写于 作者: R RedContritio 提交者: GitHub

Fix 空指针 (Null pointer) of case15: paddle.broadcast_tensors (#49980)

* fix incorrect output shape of broadcast

* add unittest
上级 1048b166
......@@ -791,7 +791,7 @@ void BroadcastTensorsInferMeta(const std::vector<const MetaTensor*>& x,
// We performed bcast semantics check at python level
// So input tensors should all have legal shape
target_dim_size = std::max(target_dim_size, dim_size);
target_dim_size = dim_size == 1 ? target_dim_size : dim_size;
}
target_dims[target_rank - index - 1] = target_dim_size;
}
......
......@@ -33,14 +33,12 @@ def find_output_shape(input_list):
rank = len(x.shape)
output_rank = max(output_rank, rank)
output_shape = [0 for i in range(output_rank)]
output_shape = [1 for i in range(output_rank)]
for i in range(output_rank):
for x in input_list:
shape = list(reversed(x.shape))
size = 1
if i < len(shape):
size = shape[i]
output_shape[i] = max(output_shape[i], size)
if i < len(shape) and shape[i] != 1:
output_shape[i] = shape[i]
return list(reversed(output_shape))
......@@ -80,6 +78,11 @@ def gen_mixed_tensors_test(dtype):
return make_inputs_outputs(input_shapes, dtype)
def gen_empty_tensors_test(dtype):
input_shapes = [(0), (0), (0)]
return make_inputs_outputs(input_shapes, dtype)
class TestCPUBroadcastTensorsOp(OpTest):
def set_place(self):
self.place = core.CPUPlace()
......@@ -95,6 +98,7 @@ class TestCPUBroadcastTensorsOp(OpTest):
gen_rank_diff_test,
gen_no_broadcast_test,
gen_mixed_tensors_test,
gen_empty_tensors_test,
]
self.set_place()
self.set_dtypes()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册