From 3f234db017162c632758459faed3cdf54571db37 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Mon, 8 Aug 2022 10:41:17 +0800 Subject: [PATCH] [Eager] fix split op in final state (#44952) --- paddle/phi/api/lib/api_custom_impl.cc | 6 +++++- paddle/phi/infermeta/unary.cc | 2 +- .../paddle/fluid/tests/unittests/test_split_op.py | 15 +++++++++++++++ 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index 88fefb8eac..056b9d79c8 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -714,7 +714,11 @@ std::vector split_impl(const Tensor& x, // Calculate the number of out tensors size_t out_number; if (num_or_sections.size() == 1) { - out_number = num_or_sections.GetData()[0]; + if (num_or_sections.GetData()[0] < 0) { + out_number = 1; + } else { + out_number = num_or_sections.GetData()[0]; + } } else { out_number = num_or_sections.size(); } diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index bdcef965be..14f0951c3d 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2947,7 +2947,7 @@ void SplitInferMeta(const MetaTensor& x, // step1: get formated sections std::vector sections; // num_or_sections is a number - if (num_or_sections_data.size() == 1) { + if (num_or_sections_data.size() == 1 && num_or_sections_data[0] > 0) { int num = num_or_sections_data.at(0); PADDLE_ENFORCE_EQ(input_axis_dim % num, diff --git a/python/paddle/fluid/tests/unittests/test_split_op.py b/python/paddle/fluid/tests/unittests/test_split_op.py index e3f72d7b41..4f438e26a7 100644 --- a/python/paddle/fluid/tests/unittests/test_split_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_op.py @@ -503,6 +503,21 @@ class API_TestDygraphSplit(unittest.TestCase): self.assertTrue(np.allclose(ex_x1, x1_out)) self.assertTrue(np.allclose(ex_x2, x2_out)) + def func_negative_one_section(self): + with fluid.dygraph.guard(): + input_1 = np.random.random([4, 6, 6]).astype("int32") + # input is a variable which shape is [4, 6, 6] + input = paddle.to_tensor(input_1) + num1 = paddle.full(shape=[1], fill_value=1, dtype='int32') + x0 = paddle.split(input, num_or_sections=[-1], axis=num1) + x0_out = x0[0].numpy() + self.assertTrue(np.array_equal(x0_out, input.numpy())) + + def test_negative_one_section(self): + with _test_eager_guard(): + self.func_negative_one_section() + self.func_negative_one_section() + class API_TestEmptySplit(unittest.TestCase): -- GitLab