From 7f2b5be3c2f6b979f1875c59f2e91ce3926b4e80 Mon Sep 17 00:00:00 2001 From: WeiXin Date: Tue, 20 Jul 2021 09:26:15 +0800 Subject: [PATCH] change strided_slice when step<0. (#34205) * change strided_slice when step<0. * add unittest for paddle.strided_slice * polish unittest --- paddle/fluid/operators/strided_slice_op.h | 10 +++++-- .../unittests/dygraph_to_static/test_slice.py | 27 +++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/strided_slice_op.h b/paddle/fluid/operators/strided_slice_op.h index e61fb6d21ad..3c5fb869f68 100644 --- a/paddle/fluid/operators/strided_slice_op.h +++ b/paddle/fluid/operators/strided_slice_op.h @@ -141,8 +141,14 @@ static void StridedSliceFunctor(int64_t* starts, int64_t* ends, strides[axis_index] = -strides[axis_index]; if (starts[axis_index] > ends[axis_index]) { // swap the reverse - starts[axis_index] = starts[axis_index] + 1; - ends[axis_index] = ends[axis_index] + 1; + auto end_dim = dims[axis_index] - 1 < starts[axis_index] + ? dims[axis_index] - 1 + : starts[axis_index]; + auto offset = (end_dim - ends[axis_index]) % strides[axis_index]; + offset = offset == 0 ? strides[axis_index] : offset; + + starts[axis_index] = starts[axis_index] + offset; + ends[axis_index] = ends[axis_index] + offset; } std::swap(starts[axis_index], ends[axis_index]); } else { diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py index 67d3778bcc3..7b4a35a6a78 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py @@ -176,5 +176,32 @@ class TestSetValueWithLayerAndSave(unittest.TestCase): output_spec=None) +class TestPaddleStridedSlice(unittest.TestCase): + def test_compare_paddle_strided_slice_with_numpy(self): + paddle.disable_static() + array = np.arange(5) + pt = paddle.to_tensor(array) + + s1 = 3 + e1 = 1 + stride1 = -2 + sl = paddle.strided_slice( + pt, axes=[0, ], starts=[s1, ], ends=[e1, ], strides=[stride1, ]) + + self.assertTrue(array[s1:e1:stride1], sl) + + array = np.arange(6 * 6).reshape((6, 6)) + pt = paddle.to_tensor(array) + s2 = [8, -1] + e2 = [1, -5] + stride2 = [-2, -3] + sl = paddle.strided_slice( + pt, axes=[0, 1], starts=s2, ends=e2, strides=stride2) + + self.assertTrue( + np.array_equal(sl.numpy(), array[s2[0]:e2[0]:stride2[0], s2[1]:e2[ + 1]:stride2[1]])) + + if __name__ == '__main__': unittest.main() -- GitLab