diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 006f77132f0119fd38772f5b9efa1ff2a0cd0f52..39db2579ecb2323acc47aad72c56cdd4d3fe203b 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -3205,10 +3205,14 @@ void SplitInferMeta(const MetaTensor& x, // fill out dims with -1 if ((sections.FromTensor() && !config.is_runtime) || axis_value == -1 || (axis_value >= 0 && x.dims().at(axis_value) <= 0)) { - std::vector out_dims( - sections_data.size(), - phi::make_ddim(std::vector(x.dims().size(), -1))); - + std::vector out_dims; + if ((sections.FromTensor() && !config.is_runtime) || axis_value == -1) { + out_dims = std::vector( + sections_data.size(), + phi::make_ddim(std::vector(x.dims().size(), -1))); + } else { + out_dims = std::vector(sections_data.size(), x.dims()); + } for (size_t i = 0; i < sections_data.size(); ++i) { if (axis_value != 0) { // Only pass LoD when not spliting along the first dim. @@ -3293,9 +3297,13 @@ void SplitWithNumInferMeta(const MetaTensor& x, int axis_value = GetSplitAxisValue(x, axis, config); // fill out dims with -1 if (axis_value == -1 || (axis_value >= 0 && x.dims().at(axis_value) <= 0)) { - std::vector out_dims( - num, phi::make_ddim(std::vector(x.dims().size(), -1))); - + std::vector out_dims; + if (axis_value == -1) { + out_dims = std::vector( + num, phi::make_ddim(std::vector(x.dims().size(), -1))); + } else { + out_dims = std::vector(num, x.dims()); + } for (int i = 0; i < num; ++i) { if (axis_value != 0) { // Only pass LoD when not spliting along the first dim. diff --git a/python/paddle/fluid/tests/unittests/test_split_op.py b/python/paddle/fluid/tests/unittests/test_split_op.py index c31169feedbdd53583abd305872dbf11064e23fd..37ea0d429ca6726a259ff1084fc9cd422aacd3bb 100644 --- a/python/paddle/fluid/tests/unittests/test_split_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_op.py @@ -441,6 +441,21 @@ class API_TestSplit5(unittest.TestCase): np.testing.assert_allclose(ex_out, re, rtol=1e-05) +class API_TestSplit6(unittest.TestCase): + + def test_out(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.layers.data('data', shape=[-1, 10], dtype='float64') + x0, x1 = paddle.split(data, num_or_sections=[1, 1], axis=0) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input1 = np.random.random([2, 10]).astype('float64') + r0, r1 = exe.run(feed={"data": input1}, fetch_list=[x0, x1]) + ex_x0, ex_x1 = np.split(input1, (1, ), axis=0) + np.testing.assert_allclose(ex_x0, r0, rtol=1e-05) + np.testing.assert_allclose(ex_x1, r1, rtol=1e-05) + + class API_TestDygraphFluidSplit(unittest.TestCase): def test_out1(self):