From e12c9221d6bdd57f57022b1c14e9b3bc33ca46aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Mon, 6 Feb 2023 14:51:04 +0800 Subject: [PATCH] fix div 0 error of split (#49958) --- paddle/phi/infermeta/unary.cc | 4 ++++ python/paddle/fluid/tests/unittests/test_split_op.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 6b5cadc9d2..39ea06c89e 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -3654,6 +3654,10 @@ void SplitWithNumInferMeta(const MetaTensor& x, auto input_axis_dim = x.dims().at(axis_value); // step1: get formated sections std::vector sections_vec; + PADDLE_ENFORCE_NE( + num, + 0, + phi::errors::InvalidArgument("Attr(num_or_sections) should not be 0.")); PADDLE_ENFORCE_EQ(input_axis_dim % num, 0, phi::errors::InvalidArgument( diff --git a/python/paddle/fluid/tests/unittests/test_split_op.py b/python/paddle/fluid/tests/unittests/test_split_op.py index d250302165..4153e3e655 100644 --- a/python/paddle/fluid/tests/unittests/test_split_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_op.py @@ -358,6 +358,14 @@ class TestSplitOpError(unittest.TestCase): self.assertRaises(TypeError, test_axis_type_tensor) + with paddle.fluid.dygraph.guard(): + + def test_0_num_tensor(): + x = paddle.uniform([1, 1, 1], dtype='float32') + paddle.split(x, num_or_sections=0) + + self.assertRaises(ValueError, test_0_num_tensor) + class API_TestSplit(unittest.TestCase): def test_out(self): -- GitLab