From bd8f998b69a3474502df28ceecd31603bd2a4f5f Mon Sep 17 00:00:00 2001 From: Charles-hit <56987902+Charles-hit@users.noreply.github.com> Date: Fri, 9 Sep 2022 18:50:52 +0800 Subject: [PATCH] Fix split bug in static mode (#45906) * fix split bug in static mode * modify code style * modify code style * add unit test for split --- paddle/phi/infermeta/unary.cc | 22 +++++++++++++------ .../fluid/tests/unittests/test_split_op.py | 15 +++++++++++++ 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 006f77132f..39db2579ec 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -3205,10 +3205,14 @@ void SplitInferMeta(const MetaTensor& x, // fill out dims with -1 if ((sections.FromTensor() && !config.is_runtime) || axis_value == -1 || (axis_value >= 0 && x.dims().at(axis_value) <= 0)) { - std::vector out_dims( - sections_data.size(), - phi::make_ddim(std::vector(x.dims().size(), -1))); - + std::vector out_dims; + if ((sections.FromTensor() && !config.is_runtime) || axis_value == -1) { + out_dims = std::vector( + sections_data.size(), + phi::make_ddim(std::vector(x.dims().size(), -1))); + } else { + out_dims = std::vector(sections_data.size(), x.dims()); + } for (size_t i = 0; i < sections_data.size(); ++i) { if (axis_value != 0) { // Only pass LoD when not spliting along the first dim. @@ -3293,9 +3297,13 @@ void SplitWithNumInferMeta(const MetaTensor& x, int axis_value = GetSplitAxisValue(x, axis, config); // fill out dims with -1 if (axis_value == -1 || (axis_value >= 0 && x.dims().at(axis_value) <= 0)) { - std::vector out_dims( - num, phi::make_ddim(std::vector(x.dims().size(), -1))); - + std::vector out_dims; + if (axis_value == -1) { + out_dims = std::vector( + num, phi::make_ddim(std::vector(x.dims().size(), -1))); + } else { + out_dims = std::vector(num, x.dims()); + } for (int i = 0; i < num; ++i) { if (axis_value != 0) { // Only pass LoD when not spliting along the first dim. diff --git a/python/paddle/fluid/tests/unittests/test_split_op.py b/python/paddle/fluid/tests/unittests/test_split_op.py index c31169feed..37ea0d429c 100644 --- a/python/paddle/fluid/tests/unittests/test_split_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_op.py @@ -441,6 +441,21 @@ class API_TestSplit5(unittest.TestCase): np.testing.assert_allclose(ex_out, re, rtol=1e-05) +class API_TestSplit6(unittest.TestCase): + + def test_out(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.layers.data('data', shape=[-1, 10], dtype='float64') + x0, x1 = paddle.split(data, num_or_sections=[1, 1], axis=0) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input1 = np.random.random([2, 10]).astype('float64') + r0, r1 = exe.run(feed={"data": input1}, fetch_list=[x0, x1]) + ex_x0, ex_x1 = np.split(input1, (1, ), axis=0) + np.testing.assert_allclose(ex_x0, r0, rtol=1e-05) + np.testing.assert_allclose(ex_x1, r1, rtol=1e-05) + + class API_TestDygraphFluidSplit(unittest.TestCase): def test_out1(self): -- GitLab