From b47fb7648c84721808fe7452be96d7a92b98c648 Mon Sep 17 00:00:00 2001 From: TeslaZhao Date: Fri, 21 Jan 2022 15:37:27 +0800 Subject: [PATCH] Keep strided_slice op behavior consistent with slice op when starts input is less than -rank (#39066) --- paddle/fluid/operators/strided_slice_op.h | 6 +----- .../fluid/tests/unittests/test_strided_slice_op.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/strided_slice_op.h b/paddle/fluid/operators/strided_slice_op.h index eaef9496a92..47714ebb806 100644 --- a/paddle/fluid/operators/strided_slice_op.h +++ b/paddle/fluid/operators/strided_slice_op.h @@ -121,6 +121,7 @@ static void StridedSliceFunctor(int64_t* starts, int64_t* ends, // stride must not be zero if (starts[axis_index] < 0) { starts[axis_index] = starts[axis_index] + axis_size; + starts[axis_index] = std::max(starts[axis_index], 0); } if (ends[axis_index] < 0) { if (!(ends[axis_index] == -1 && @@ -139,11 +140,6 @@ static void StridedSliceFunctor(int64_t* starts, int64_t* ends, } } - if ((starts[axis_index] < 0) && (axis_size > 0)) { - starts[axis_index] += axis_size; - starts[axis_index] = std::max(starts[axis_index], 0); - } - if (strides[axis_index] < 0) { reverse_axis[axis_index] = 1; strides[axis_index] = -strides[axis_index]; diff --git a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py index 9d89c7cbe13..e9be6b338fb 100644 --- a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py @@ -216,6 +216,16 @@ class TestStrideSliceOp13(TestStrideSliceOp): self.infer_flags = [1, 1, 1, 1, 1] +class TestStrideSliceOp14(TestStrideSliceOp): + def initTestCase(self): + self.input = np.random.rand(4, 4, 4, 4) + self.axes = [1, 2, 3] + self.starts = [-5, 0, -7] + self.ends = [-1, 2, 4] + self.strides = [1, 1, 1] + self.infer_flags = [1, 1, 1] + + class TestStrideSliceOpBool(TestStrideSliceOp): def test_check_grad(self): pass -- GitLab