未验证 提交 66682be0 编写于 作者: R RedContritio 提交者: GitHub

Fix 堆栈溢出 (stack overflow) of case9: paddle.repeat_interleave (#49982)

* support negative index in repeat_interleave

* add unittest
上级 baf96a12
...@@ -3075,27 +3075,40 @@ void RepeatInterleaveInferMeta(const MetaTensor& x, ...@@ -3075,27 +3075,40 @@ void RepeatInterleaveInferMeta(const MetaTensor& x,
MetaTensor* out) { MetaTensor* out) {
const auto& input_dim = x.dims(); const auto& input_dim = x.dims();
auto output_dim = phi::vectorize(input_dim); auto output_dim = phi::vectorize(input_dim);
auto n_dim = dim;
PADDLE_ENFORCE_EQ( if (n_dim < 0) n_dim += input_dim.size();
dim < input_dim.size() && dim >= (0 - input_dim.size()),
true, PADDLE_ENFORCE_LT(
dim,
input_dim.size(),
phi::errors::OutOfRange( phi::errors::OutOfRange(
"Attr(dim) is out of range, It's expected " "Attr(dim) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(dim) = %d.", "to be in range of [%d, %d]. But received Attr(dim) = %d.",
input_dim.size(), -input_dim.size(),
input_dim.size() - 1, input_dim.size() - 1,
dim)); dim));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_GE(
repeats > 0, dim,
true, (0 - input_dim.size()),
phi::errors::OutOfRange(
"Attr(dim) is out of range, It's expected "
"to be in range of [%d, %d]. But received Attr(dim) = %d.",
-input_dim.size(),
input_dim.size() - 1,
dim));
PADDLE_ENFORCE_GT(
repeats,
0,
phi::errors::InvalidArgument("repeats should be larger than zero")); phi::errors::InvalidArgument("repeats should be larger than zero"));
PADDLE_ENFORCE_NE(out, PADDLE_ENFORCE_NOT_NULL(
nullptr, out,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"repeat_interleave's output tensor can't be nullptr")); "repeat_interleave's output tensor can't be nullptr"));
output_dim[dim] = input_dim[dim] * repeats; output_dim[n_dim] = input_dim[n_dim] * repeats;
out->set_dims(phi::make_ddim(output_dim)); out->set_dims(phi::make_ddim(output_dim));
out->share_lod(x); out->share_lod(x);
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
......
...@@ -188,6 +188,26 @@ class TestIndexSelectAPI(unittest.TestCase): ...@@ -188,6 +188,26 @@ class TestIndexSelectAPI(unittest.TestCase):
expect_out = np.repeat(self.data_zero_dim_x, repeats) expect_out = np.repeat(self.data_zero_dim_x, repeats)
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05) np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
# case 4 negative axis:
with program_guard(Program(), Program()):
x = paddle.static.data(name='x', shape=[-1, 4], dtype='float32')
x.desc.set_need_check_feed(False)
index = paddle.static.data(
name='repeats_',
shape=[4],
dtype='int32',
)
index.desc.set_need_check_feed(False)
z = paddle.repeat_interleave(x, index, axis=-1)
exe = fluid.Executor(fluid.CPUPlace())
(res,) = exe.run(
feed={'x': self.data_x, 'repeats_': self.data_index},
fetch_list=[z.name],
return_numpy=False,
)
expect_out = np.repeat(self.data_x, self.data_index, axis=-1)
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
def test_dygraph_api(self): def test_dygraph_api(self):
self.input_data() self.input_data()
# case axis none # case axis none
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册