未验证 提交 bf80664c 编写于 作者: I Infinity_lee 提交者: GitHub

fix stackoverflow case13 gather (#50243)

上级 fb228c4a
......@@ -1321,6 +1321,27 @@ void GatherInferMeta(const MetaTensor& x,
auto input_dim = x.dims();
auto axis_v = axis.to<int>();
if (axis_v < 0) axis_v += input_dim.size();
PADDLE_ENFORCE_GE(
axis_v,
(0 - input_dim.size()),
phi::errors::OutOfRange(
"Attr(axis) is out of range, It's expected "
"to be in range of [%d, %d]. But received Attr(axis) = %d.",
-input_dim.size(),
input_dim.size() - 1,
axis_v));
PADDLE_ENFORCE_LT(
axis_v,
input_dim.size(),
phi::errors::OutOfRange(
"Attr(axis) is out of range, It's expected "
"to be in range of [%d, %d]. But received Attr(axis) = %d.",
-input_dim.size(),
input_dim.size() - 1,
axis_v));
if (index_dims.size() == 0) {
// 0D index will decrease the dimension
if (input_dim.size() == 1) {
......
......@@ -385,6 +385,29 @@ class TestGathertError(unittest.TestCase):
self.assertRaises(TypeError, test_index_type)
def test_error3(self):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
shape = [8, 9, 6]
x = paddle.fluid.data(shape=shape, dtype='int32', name='x')
axis = paddle.fluid.data(shape=[1], dtype='int32', name='axis')
index = paddle.fluid.data(shape=shape, dtype='int32', name='index')
index_float = paddle.fluid.data(
shape=shape, dtype='float32', name='index_float'
)
def test_axis_minsize():
paddle.gather(x, index, axis=-1)
self.assertRaises(ValueError, test_axis_minsize)
def test_axis_maxsize():
paddle.gather(x, index, axis=512)
self.assertRaises(ValueError, test_axis_maxsize)
class TestCheckOutType(unittest.TestCase):
def test_out_type(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册