From 9ce8cfcf04fd53f1aa57d8e08d82b39eed3aaf3f Mon Sep 17 00:00:00 2001 From: RedContritio Date: Wed, 1 Feb 2023 10:44:04 +0800 Subject: [PATCH] =?UTF-8?q?Fix=20UFA=E9=9D=9E=E6=B3=95=E5=9C=B0=E5=9D=80?= =?UTF-8?q?=E8=AE=BF=E9=97=AE(UFA=20illegal=20address=20access)=20of=20cas?= =?UTF-8?q?e4:=20paddle.unbind=20(#49995)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add axis check for unbind * add axis range check for unbind * update unittest and axis validation for unbind * add unittest invalid axis for unbind * restore axis extract for unbind --- paddle/phi/infermeta/unary.cc | 13 +++++++++++++ .../fluid/tests/unittests/test_unbind_op.py | 7 +++++++ python/paddle/tensor/manipulation.py | 15 ++++++++++----- 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 3b3202c291..eb05437ada 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -4253,7 +4253,20 @@ void UnbindInferMeta(const MetaTensor& x, std::vector outs) { auto in_dims = x.dims(); std::vector out_dim; + + PADDLE_ENFORCE_GE( + axis, + -in_dims.size(), + phi::errors::InvalidArgument( + "axis must be in range(%d, %d).", -in_dims.size(), in_dims.size())); + PADDLE_ENFORCE_LT( + axis, + in_dims.size(), + phi::errors::InvalidArgument( + "axis must be in range(%d, %d).", -in_dims.size(), in_dims.size())); + axis = axis < 0 ? in_dims.size() + axis : axis; + for (int i = 0; i < in_dims.size(); ++i) { if (i != axis) out_dim.push_back(in_dims[i]); } diff --git a/python/paddle/fluid/tests/unittests/test_unbind_op.py b/python/paddle/fluid/tests/unittests/test_unbind_op.py index 6ec82a96bc..8cafc1b5a8 100644 --- a/python/paddle/fluid/tests/unittests/test_unbind_op.py +++ b/python/paddle/fluid/tests/unittests/test_unbind_op.py @@ -25,6 +25,7 @@ from paddle.fluid import Program, program_guard class TestUnbind(unittest.TestCase): def test_unbind(self): + paddle.enable_static() x_1 = fluid.data(shape=[2, 3], dtype='float32', name='x_1') [out_0, out_1] = tensor.unbind(input=x_1, axis=0) @@ -59,6 +60,7 @@ class TestUnbind(unittest.TestCase): class TestLayersUnbind(unittest.TestCase): def test_layers_unbind(self): + paddle.enable_static() x_1 = fluid.data(shape=[2, 3], dtype='float32', name='x_1') [out_0, out_1] = paddle.unbind(input=x_1, axis=0) @@ -214,6 +216,11 @@ class TestUnbindAxisError(unittest.TestCase): self.assertRaises(TypeError, test_table_Variable) + def test_invalid_axis(): + tensor.unbind(input=x, axis=2) + + self.assertRaises(ValueError, test_invalid_axis) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 923e6923d6..b5308e6cee 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2755,14 +2755,19 @@ def unbind(input, axis=0): # x2.shape [3, 5] # x3.shape [3, 5] """ + if not isinstance(axis, (int)): + raise TypeError( + "The type of 'axis' must be int, but received %s." % (type(axis)) + ) + + if axis not in range(-input.ndim, input.ndim): + raise ValueError( + f'The axis must in range({-input.ndim}, {input.ndim}).' + ) + if in_dygraph_mode(): return _C_ops.unbind(input, axis) else: - if not isinstance(axis, (int)): - raise TypeError( - "The type of 'axis' must be int, but received %s." - % (type(axis)) - ) if isinstance(axis, np.generic): axis = np.asscalar(axis) input_shape = input.shape -- GitLab