未验证 提交 3f234db0 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] fix split op in final state (#44952)

上级 f9165878
......@@ -714,7 +714,11 @@ std::vector<Tensor> split_impl(const Tensor& x,
// Calculate the number of out tensors
size_t out_number;
if (num_or_sections.size() == 1) {
if (num_or_sections.GetData()[0] < 0) {
out_number = 1;
} else {
out_number = num_or_sections.GetData()[0];
}
} else {
out_number = num_or_sections.size();
}
......
......@@ -2947,7 +2947,7 @@ void SplitInferMeta(const MetaTensor& x,
// step1: get formated sections
std::vector<int64_t> sections;
// num_or_sections is a number
if (num_or_sections_data.size() == 1) {
if (num_or_sections_data.size() == 1 && num_or_sections_data[0] > 0) {
int num = num_or_sections_data.at(0);
PADDLE_ENFORCE_EQ(input_axis_dim % num,
......
......@@ -503,6 +503,21 @@ class API_TestDygraphSplit(unittest.TestCase):
self.assertTrue(np.allclose(ex_x1, x1_out))
self.assertTrue(np.allclose(ex_x2, x2_out))
def func_negative_one_section(self):
with fluid.dygraph.guard():
input_1 = np.random.random([4, 6, 6]).astype("int32")
# input is a variable which shape is [4, 6, 6]
input = paddle.to_tensor(input_1)
num1 = paddle.full(shape=[1], fill_value=1, dtype='int32')
x0 = paddle.split(input, num_or_sections=[-1], axis=num1)
x0_out = x0[0].numpy()
self.assertTrue(np.array_equal(x0_out, input.numpy()))
def test_negative_one_section(self):
with _test_eager_guard():
self.func_negative_one_section()
self.func_negative_one_section()
class API_TestEmptySplit(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册