未验证 提交 325e5712 编写于 作者: C chentianyu03 提交者: GitHub

[Phi]fix split error when sections has 0 size and add test case (#41708)

* fix split error when sections has 0 size and add test case

* fix test case
上级 8cbf79a3
...@@ -134,7 +134,7 @@ inline void StridedMemcpyWithAxis0( ...@@ -134,7 +134,7 @@ inline void StridedMemcpyWithAxis0(
for (size_t i = 0; i < outputs->size(); ++i) { for (size_t i = 0; i < outputs->size(); ++i) {
auto out_stride = stride_numel(shape_refer[i]->dims()); auto out_stride = stride_numel(shape_refer[i]->dims());
auto out = outputs->at(i); auto out = outputs->at(i);
if (out != nullptr) { if (out != nullptr && out->initialized()) {
StridedNumelCopyWithAxis<T>(dev_ctx, axis, out->data<T>(), out_stride, StridedNumelCopyWithAxis<T>(dev_ctx, axis, out->data<T>(), out_stride,
input.data<T>() + input_offset, in_stride, input.data<T>() + input_offset, in_stride,
out_stride[axis]); out_stride[axis]);
......
...@@ -459,5 +459,24 @@ class API_TestDygraphSplit(unittest.TestCase): ...@@ -459,5 +459,24 @@ class API_TestDygraphSplit(unittest.TestCase):
self.assertTrue(np.allclose(ex_x2, x2_out)) self.assertTrue(np.allclose(ex_x2, x2_out))
class API_TestEmptySplit(unittest.TestCase):
def test_axis_input_empty_section(self):
with fluid.dygraph.guard():
input_1 = np.random.random([8, 6, 6]).astype("float32")
# input is a variable which shape is [8, 6, 6]
input = paddle.to_tensor(input_1)
x0, x1, x2 = paddle.split(input, num_or_sections=[5, 0, 3])
x0_out = x0.numpy()
x1_out = x1.numpy()
x2_out = x2.numpy()
ex_x0, ex_x1, ex_x2 = np.split(input_1, [
5,
5,
])
self.assertTrue(np.allclose(ex_x0, x0_out))
self.assertTrue(np.allclose(ex_x1, x1_out))
self.assertTrue(np.allclose(ex_x2, x2_out))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册