From 0111a3549f434d23f68ec8d26066361ddf0ca525 Mon Sep 17 00:00:00 2001 From: Ryan <44900829+DrRyanHuang@users.noreply.github.com> Date: Mon, 30 Jan 2023 16:24:50 +0800 Subject: [PATCH] [Fix IndexError] add unstack axis check (#49943) * add unstack axis check * IndexErr -> ValueError * add static select --- .../fluid/tests/unittests/test_unstack_op.py | 24 +++++++++++++++++++ python/paddle/tensor/manipulation.py | 4 ++++ 2 files changed, 28 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_unstack_op.py b/python/paddle/fluid/tests/unittests/test_unstack_op.py index 1dda05fb0a6..745e14983a5 100755 --- a/python/paddle/fluid/tests/unittests/test_unstack_op.py +++ b/python/paddle/fluid/tests/unittests/test_unstack_op.py @@ -84,5 +84,29 @@ class TestStackOp6(TestUnStackOpBase): self.axis = 2 +class TestUnstackZeroInputOp(unittest.TestCase): + def unstack_zero_input_static(self): + + paddle.enable_static() + + array = np.array([], dtype=np.float32) + x = paddle.to_tensor(np.reshape(array, [0]), dtype='float32') + paddle.unstack(x, axis=1) + + def unstack_zero_input_dynamic(self): + + array = np.array([], dtype=np.float32) + x = paddle.to_tensor(np.reshape(array, [0]), dtype='float32') + paddle.unstack(x, axis=1) + + def test_type_error(self): + paddle.disable_static() + + self.assertRaises(ValueError, self.unstack_zero_input_dynamic) + self.assertRaises(ValueError, self.unstack_zero_input_static) + + paddle.disable_static() + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index bdd903ee8f1..923e6923d6d 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -543,6 +543,10 @@ def unstack(x, axis=0, num=None): y = paddle.unstack(x, axis=1) # unstack with second axis, which results 3 tensors with shape=[2, 5] """ + if not (-x.ndim <= axis < x.ndim): + raise ValueError( + '`axis` must be in the range [-{0}, {0})'.format(x.ndim) + ) if in_dygraph_mode(): if num is None: num = x.shape[axis] -- GitLab