diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 7551c51d6476f08b5f14b1f96c3dc1c702cf9fa7..03b20c535af978f66fda28ef62ab0b11c950c84c 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1321,6 +1321,27 @@ void GatherInferMeta(const MetaTensor& x, auto input_dim = x.dims(); auto axis_v = axis.to(); + if (axis_v < 0) axis_v += input_dim.size(); + + PADDLE_ENFORCE_GE( + axis_v, + (0 - input_dim.size()), + phi::errors::OutOfRange( + "Attr(axis) is out of range, It's expected " + "to be in range of [%d, %d]. But received Attr(axis) = %d.", + -input_dim.size(), + input_dim.size() - 1, + axis_v)); + PADDLE_ENFORCE_LT( + axis_v, + input_dim.size(), + phi::errors::OutOfRange( + "Attr(axis) is out of range, It's expected " + "to be in range of [%d, %d]. But received Attr(axis) = %d.", + -input_dim.size(), + input_dim.size() - 1, + axis_v)); + if (index_dims.size() == 0) { // 0D index will decrease the dimension if (input_dim.size() == 1) { diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index 2f2538769a3b38e5a750d0dabfa035c13fa89a8a..5844d7b51da143b840486b974d8dc1bac843a5ba 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -385,6 +385,29 @@ class TestGathertError(unittest.TestCase): self.assertRaises(TypeError, test_index_type) + def test_error3(self): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + + shape = [8, 9, 6] + x = paddle.fluid.data(shape=shape, dtype='int32', name='x') + axis = paddle.fluid.data(shape=[1], dtype='int32', name='axis') + index = paddle.fluid.data(shape=shape, dtype='int32', name='index') + index_float = paddle.fluid.data( + shape=shape, dtype='float32', name='index_float' + ) + + def test_axis_minsize(): + paddle.gather(x, index, axis=-1) + + self.assertRaises(ValueError, test_axis_minsize) + + def test_axis_maxsize(): + paddle.gather(x, index, axis=512) + + self.assertRaises(ValueError, test_axis_maxsize) + class TestCheckOutType(unittest.TestCase): def test_out_type(self):