diff --git a/paddle/fluid/operators/strided_slice_op.h b/paddle/fluid/operators/strided_slice_op.h index e61fb6d21ad470b2ba335e87c6065df20582ea24..3c5fb869f68f1ade7672f502fbf619f58e887174 100644 --- a/paddle/fluid/operators/strided_slice_op.h +++ b/paddle/fluid/operators/strided_slice_op.h @@ -141,8 +141,14 @@ static void StridedSliceFunctor(int64_t* starts, int64_t* ends, strides[axis_index] = -strides[axis_index]; if (starts[axis_index] > ends[axis_index]) { // swap the reverse - starts[axis_index] = starts[axis_index] + 1; - ends[axis_index] = ends[axis_index] + 1; + auto end_dim = dims[axis_index] - 1 < starts[axis_index] + ? dims[axis_index] - 1 + : starts[axis_index]; + auto offset = (end_dim - ends[axis_index]) % strides[axis_index]; + offset = offset == 0 ? strides[axis_index] : offset; + + starts[axis_index] = starts[axis_index] + offset; + ends[axis_index] = ends[axis_index] + offset; } std::swap(starts[axis_index], ends[axis_index]); } else { diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py index 67d3778bcc3876a25a42d8a512242f8221814a4a..7b4a35a6a7898d9c2da04e97e4dd29cf5e7608ad 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py @@ -176,5 +176,32 @@ class TestSetValueWithLayerAndSave(unittest.TestCase): output_spec=None) +class TestPaddleStridedSlice(unittest.TestCase): + def test_compare_paddle_strided_slice_with_numpy(self): + paddle.disable_static() + array = np.arange(5) + pt = paddle.to_tensor(array) + + s1 = 3 + e1 = 1 + stride1 = -2 + sl = paddle.strided_slice( + pt, axes=[0, ], starts=[s1, ], ends=[e1, ], strides=[stride1, ]) + + self.assertTrue(array[s1:e1:stride1], sl) + + array = np.arange(6 * 6).reshape((6, 6)) + pt = paddle.to_tensor(array) + s2 = [8, -1] + e2 = [1, -5] + stride2 = [-2, -3] + sl = paddle.strided_slice( + pt, axes=[0, 1], starts=s2, ends=e2, strides=stride2) + + self.assertTrue( + np.array_equal(sl.numpy(), array[s2[0]:e2[0]:stride2[0], s2[1]:e2[ + 1]:stride2[1]])) + + if __name__ == '__main__': unittest.main()