diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index 88fefb8eac99da13d921811f02430c8e2a78290d..056b9d79c84e2db5b0815639e4c2495b0a8867cd 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -714,7 +714,11 @@ std::vector split_impl(const Tensor& x, // Calculate the number of out tensors size_t out_number; if (num_or_sections.size() == 1) { - out_number = num_or_sections.GetData()[0]; + if (num_or_sections.GetData()[0] < 0) { + out_number = 1; + } else { + out_number = num_or_sections.GetData()[0]; + } } else { out_number = num_or_sections.size(); } diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index bdcef965be2585adc523835b9b18fdc645af3178..14f0951c3d26ea76f2b25fc13e09d3f7cc3dee63 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2947,7 +2947,7 @@ void SplitInferMeta(const MetaTensor& x, // step1: get formated sections std::vector sections; // num_or_sections is a number - if (num_or_sections_data.size() == 1) { + if (num_or_sections_data.size() == 1 && num_or_sections_data[0] > 0) { int num = num_or_sections_data.at(0); PADDLE_ENFORCE_EQ(input_axis_dim % num, diff --git a/python/paddle/fluid/tests/unittests/test_split_op.py b/python/paddle/fluid/tests/unittests/test_split_op.py index e3f72d7b41ca26bdc46ab49b256bf96a954b7dc4..4f438e26a7bb6fa1a0f87e944fe3338db5b89c7e 100644 --- a/python/paddle/fluid/tests/unittests/test_split_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_op.py @@ -503,6 +503,21 @@ class API_TestDygraphSplit(unittest.TestCase): self.assertTrue(np.allclose(ex_x1, x1_out)) self.assertTrue(np.allclose(ex_x2, x2_out)) + def func_negative_one_section(self): + with fluid.dygraph.guard(): + input_1 = np.random.random([4, 6, 6]).astype("int32") + # input is a variable which shape is [4, 6, 6] + input = paddle.to_tensor(input_1) + num1 = paddle.full(shape=[1], fill_value=1, dtype='int32') + x0 = paddle.split(input, num_or_sections=[-1], axis=num1) + x0_out = x0[0].numpy() + self.assertTrue(np.array_equal(x0_out, input.numpy())) + + def test_negative_one_section(self): + with _test_eager_guard(): + self.func_negative_one_section() + self.func_negative_one_section() + class API_TestEmptySplit(unittest.TestCase):