diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 3b3202c2917251da1f368f8d8b60da78b9e83b6c..eb05437ada8a5e00cb684c8581b418721c1f460a 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -4253,7 +4253,20 @@ void UnbindInferMeta(const MetaTensor& x, std::vector outs) { auto in_dims = x.dims(); std::vector out_dim; + + PADDLE_ENFORCE_GE( + axis, + -in_dims.size(), + phi::errors::InvalidArgument( + "axis must be in range(%d, %d).", -in_dims.size(), in_dims.size())); + PADDLE_ENFORCE_LT( + axis, + in_dims.size(), + phi::errors::InvalidArgument( + "axis must be in range(%d, %d).", -in_dims.size(), in_dims.size())); + axis = axis < 0 ? in_dims.size() + axis : axis; + for (int i = 0; i < in_dims.size(); ++i) { if (i != axis) out_dim.push_back(in_dims[i]); } diff --git a/python/paddle/fluid/tests/unittests/test_unbind_op.py b/python/paddle/fluid/tests/unittests/test_unbind_op.py index 6ec82a96bc16534572d59eb81c690839681be557..8cafc1b5a8e1b7d65d3e7c57f2964de7958dbf59 100644 --- a/python/paddle/fluid/tests/unittests/test_unbind_op.py +++ b/python/paddle/fluid/tests/unittests/test_unbind_op.py @@ -25,6 +25,7 @@ from paddle.fluid import Program, program_guard class TestUnbind(unittest.TestCase): def test_unbind(self): + paddle.enable_static() x_1 = fluid.data(shape=[2, 3], dtype='float32', name='x_1') [out_0, out_1] = tensor.unbind(input=x_1, axis=0) @@ -59,6 +60,7 @@ class TestUnbind(unittest.TestCase): class TestLayersUnbind(unittest.TestCase): def test_layers_unbind(self): + paddle.enable_static() x_1 = fluid.data(shape=[2, 3], dtype='float32', name='x_1') [out_0, out_1] = paddle.unbind(input=x_1, axis=0) @@ -214,6 +216,11 @@ class TestUnbindAxisError(unittest.TestCase): self.assertRaises(TypeError, test_table_Variable) + def test_invalid_axis(): + tensor.unbind(input=x, axis=2) + + self.assertRaises(ValueError, test_invalid_axis) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 923e6923d6d63c66af55cb7f360411ce8bf88684..b5308e6cee63d15dbfda7508cd016cb768a0232c 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2755,14 +2755,19 @@ def unbind(input, axis=0): # x2.shape [3, 5] # x3.shape [3, 5] """ + if not isinstance(axis, (int)): + raise TypeError( + "The type of 'axis' must be int, but received %s." % (type(axis)) + ) + + if axis not in range(-input.ndim, input.ndim): + raise ValueError( + f'The axis must in range({-input.ndim}, {input.ndim}).' + ) + if in_dygraph_mode(): return _C_ops.unbind(input, axis) else: - if not isinstance(axis, (int)): - raise TypeError( - "The type of 'axis' must be int, but received %s." - % (type(axis)) - ) if isinstance(axis, np.generic): axis = np.asscalar(axis) input_shape = input.shape