diff --git a/paddle/fluid/operators/strided_slice_op.h b/paddle/fluid/operators/strided_slice_op.h index eaef9496a92dcf5ad0886d7ed06971579dc2270c..47714ebb806e9b0ac11e918351b0737a050c7b12 100644 --- a/paddle/fluid/operators/strided_slice_op.h +++ b/paddle/fluid/operators/strided_slice_op.h @@ -121,6 +121,7 @@ static void StridedSliceFunctor(int64_t* starts, int64_t* ends, // stride must not be zero if (starts[axis_index] < 0) { starts[axis_index] = starts[axis_index] + axis_size; + starts[axis_index] = std::max(starts[axis_index], 0); } if (ends[axis_index] < 0) { if (!(ends[axis_index] == -1 && @@ -139,11 +140,6 @@ static void StridedSliceFunctor(int64_t* starts, int64_t* ends, } } - if ((starts[axis_index] < 0) && (axis_size > 0)) { - starts[axis_index] += axis_size; - starts[axis_index] = std::max(starts[axis_index], 0); - } - if (strides[axis_index] < 0) { reverse_axis[axis_index] = 1; strides[axis_index] = -strides[axis_index]; diff --git a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py index 9d89c7cbe13971f401134fdb0562844577422504..e9be6b338fb86390bfa006d7fb4b6d5f34894d4d 100644 --- a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py @@ -216,6 +216,16 @@ class TestStrideSliceOp13(TestStrideSliceOp): self.infer_flags = [1, 1, 1, 1, 1] +class TestStrideSliceOp14(TestStrideSliceOp): + def initTestCase(self): + self.input = np.random.rand(4, 4, 4, 4) + self.axes = [1, 2, 3] + self.starts = [-5, 0, -7] + self.ends = [-1, 2, 4] + self.strides = [1, 1, 1] + self.infer_flags = [1, 1, 1] + + class TestStrideSliceOpBool(TestStrideSliceOp): def test_check_grad(self): pass