diff --git a/paddle/phi/kernels/funcs/slice_utils.h b/paddle/phi/kernels/funcs/slice_utils.h index bfe024f45a098b39f94e9514a6b717c2c8d45a55..ed403c75dbdc8a08d6f1f7a7aeb0875dfbd9aed2 100644 --- a/paddle/phi/kernels/funcs/slice_utils.h +++ b/paddle/phi/kernels/funcs/slice_utils.h @@ -117,6 +117,10 @@ inline phi::DDim GetSliceDims(const phi::DDim in_dims, continue; } + if (in_dims[axis] == -1) { + continue; + } + T start = starts[i]; T end = ends[i]; T step = steps == nullptr ? 1 : (*steps)[i]; diff --git a/python/paddle/fluid/tests/unittests/test_slice_op.py b/python/paddle/fluid/tests/unittests/test_slice_op.py index d660518f04e4c81b3497b7fb68d9fc567c14cb55..57864c62e4c78893b89980ea57bb1dd6a5dd0192 100644 --- a/python/paddle/fluid/tests/unittests/test_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_slice_op.py @@ -784,7 +784,7 @@ class TestInferShape(unittest.TestCase): self.assertEqual(x.shape, (3, -1, 5)) out0 = paddle.slice(x, axes=[1], starts=[0], ends=[3]) - self.assertEqual(out0.shape, (3, 3, 5)) + self.assertEqual(out0.shape, (3, -1, 5)) def test_axis_less_than_zero(self): # Using paddle.disable_static will make other unittests fail.