未验证 提交 e12c9221 编写于 作者: 张春乔 提交者: GitHub

fix div 0 error of split (#49958)

上级 6908550e
......@@ -3654,6 +3654,10 @@ void SplitWithNumInferMeta(const MetaTensor& x,
auto input_axis_dim = x.dims().at(axis_value);
// step1: get formated sections
std::vector<int64_t> sections_vec;
PADDLE_ENFORCE_NE(
num,
0,
phi::errors::InvalidArgument("Attr(num_or_sections) should not be 0."));
PADDLE_ENFORCE_EQ(input_axis_dim % num,
0,
phi::errors::InvalidArgument(
......
......@@ -358,6 +358,14 @@ class TestSplitOpError(unittest.TestCase):
self.assertRaises(TypeError, test_axis_type_tensor)
with paddle.fluid.dygraph.guard():
def test_0_num_tensor():
x = paddle.uniform([1, 1, 1], dtype='float32')
paddle.split(x, num_or_sections=0)
self.assertRaises(ValueError, test_0_num_tensor)
class API_TestSplit(unittest.TestCase):
def test_out(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册