diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 8cea16f770631ee309ba073c4b9d9f989f4c110a..e08f1769bef48e4808aef345cadb973612983887 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -3075,27 +3075,40 @@ void RepeatInterleaveInferMeta(const MetaTensor& x, MetaTensor* out) { const auto& input_dim = x.dims(); auto output_dim = phi::vectorize(input_dim); + auto n_dim = dim; - PADDLE_ENFORCE_EQ( - dim < input_dim.size() && dim >= (0 - input_dim.size()), - true, + if (n_dim < 0) n_dim += input_dim.size(); + + PADDLE_ENFORCE_LT( + dim, + input_dim.size(), phi::errors::OutOfRange( "Attr(dim) is out of range, It's expected " - "to be in range of [-%d, %d]. But received Attr(dim) = %d.", - input_dim.size(), + "to be in range of [%d, %d]. But received Attr(dim) = %d.", + -input_dim.size(), input_dim.size() - 1, dim)); - PADDLE_ENFORCE_EQ( - repeats > 0, - true, + PADDLE_ENFORCE_GE( + dim, + (0 - input_dim.size()), + phi::errors::OutOfRange( + "Attr(dim) is out of range, It's expected " + "to be in range of [%d, %d]. But received Attr(dim) = %d.", + -input_dim.size(), + input_dim.size() - 1, + dim)); + + PADDLE_ENFORCE_GT( + repeats, + 0, phi::errors::InvalidArgument("repeats should be larger than zero")); - PADDLE_ENFORCE_NE(out, - nullptr, - phi::errors::InvalidArgument( - "repeat_interleave's output tensor can't be nullptr")); + PADDLE_ENFORCE_NOT_NULL( + out, + phi::errors::InvalidArgument( + "repeat_interleave's output tensor can't be nullptr")); - output_dim[dim] = input_dim[dim] * repeats; + output_dim[n_dim] = input_dim[n_dim] * repeats; out->set_dims(phi::make_ddim(output_dim)); out->share_lod(x); out->set_dtype(x.dtype()); diff --git a/python/paddle/fluid/tests/unittests/test_repeat_interleave_op.py b/python/paddle/fluid/tests/unittests/test_repeat_interleave_op.py index 90877a3047e2c9f52ce8a58a9d93675f57156d8f..4b5272c5a4bdfdd5ba4521c8da7ad2d961a02475 100644 --- a/python/paddle/fluid/tests/unittests/test_repeat_interleave_op.py +++ b/python/paddle/fluid/tests/unittests/test_repeat_interleave_op.py @@ -188,6 +188,26 @@ class TestIndexSelectAPI(unittest.TestCase): expect_out = np.repeat(self.data_zero_dim_x, repeats) np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05) + # case 4 negative axis: + with program_guard(Program(), Program()): + x = paddle.static.data(name='x', shape=[-1, 4], dtype='float32') + x.desc.set_need_check_feed(False) + index = paddle.static.data( + name='repeats_', + shape=[4], + dtype='int32', + ) + index.desc.set_need_check_feed(False) + z = paddle.repeat_interleave(x, index, axis=-1) + exe = fluid.Executor(fluid.CPUPlace()) + (res,) = exe.run( + feed={'x': self.data_x, 'repeats_': self.data_index}, + fetch_list=[z.name], + return_numpy=False, + ) + expect_out = np.repeat(self.data_x, self.data_index, axis=-1) + np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05) + def test_dygraph_api(self): self.input_data() # case axis none