From 97847ae8cf200322f85c5db648cd9042171405f4 Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Fri, 9 Sep 2022 09:48:42 +0800 Subject: [PATCH] modify slice op Infershape (#45855) * modify slice infershape * code style * modify slice_unittest --- paddle/phi/kernels/funcs/slice_utils.h | 4 ++++ python/paddle/fluid/tests/unittests/test_slice_op.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/funcs/slice_utils.h b/paddle/phi/kernels/funcs/slice_utils.h index bfe024f45a..ed403c75db 100644 --- a/paddle/phi/kernels/funcs/slice_utils.h +++ b/paddle/phi/kernels/funcs/slice_utils.h @@ -117,6 +117,10 @@ inline phi::DDim GetSliceDims(const phi::DDim in_dims, continue; } + if (in_dims[axis] == -1) { + continue; + } + T start = starts[i]; T end = ends[i]; T step = steps == nullptr ? 1 : (*steps)[i]; diff --git a/python/paddle/fluid/tests/unittests/test_slice_op.py b/python/paddle/fluid/tests/unittests/test_slice_op.py index d660518f04..57864c62e4 100644 --- a/python/paddle/fluid/tests/unittests/test_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_slice_op.py @@ -784,7 +784,7 @@ class TestInferShape(unittest.TestCase): self.assertEqual(x.shape, (3, -1, 5)) out0 = paddle.slice(x, axes=[1], starts=[0], ends=[3]) - self.assertEqual(out0.shape, (3, 3, 5)) + self.assertEqual(out0.shape, (3, -1, 5)) def test_axis_less_than_zero(self): # Using paddle.disable_static will make other unittests fail. -- GitLab