From 669853f5c47162a2d7909d776a1171a58f9f0158 Mon Sep 17 00:00:00 2001 From: WeiXin Date: Fri, 27 Aug 2021 11:16:25 +0800 Subject: [PATCH] Polish the error message of paddle.slice. (#35179) * polish the error message of paddle.slice. * polish code. --- paddle/fluid/operators/slice_op.cc | 2 +- python/paddle/fluid/layers/nn.py | 16 ++++++++ .../fluid/tests/unittests/test_slice_op.py | 39 +++++++++++++++++++ 3 files changed, 56 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index 01daba7c07..a55959385f 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -113,7 +113,7 @@ class SliceOp : public framework::OperatorWithKernel { } ctx->SetOutputDim("Out", out_dims); - if (axes[0] != 0) { + if (axes.size() > 0 && axes[0] != 0) { ctx->ShareLoD("Input", /*->*/ "Out"); } } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index bd7ecfeee6..59dfec005d 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -10975,6 +10975,22 @@ def slice(input, axes, starts, ends): attrs = () starts_tensor = None ends_tensor = None + + if isinstance(axes, (list, tuple)): + if len(axes) == 0: + raise ValueError( + "Input axes should not be an empty list/tuple.") + for i in range(len(axes)): + if axes[i] < 0: + axes[i] = max(0, axes[i] + len(input.shape)) + else: + axes[i] = min(len(input.shape) - 1, axes[i]) + + else: + raise ValueError( + "Input axes must be a python list or tuple, but reveived {}". + format(type(axes))) + infer_flags = list(1 for i in range(len(axes))) if isinstance(starts, (list, tuple)): diff --git a/python/paddle/fluid/tests/unittests/test_slice_op.py b/python/paddle/fluid/tests/unittests/test_slice_op.py index b83478a5b8..f69993c52a 100644 --- a/python/paddle/fluid/tests/unittests/test_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_slice_op.py @@ -694,6 +694,45 @@ class TestInferShape(unittest.TestCase): out0 = paddle.slice(x, axes=[1], starts=[0], ends=[3]) self.assertEqual(out0.shape, (3, 3, 5)) + def test_axis_less_than_zero(self): + + # Using paddle.disable_static will make other unittests fail. + with fluid.dygraph.guard(): + x_arr = np.arange(0, 24, dtype=np.float32).reshape([2, 3, 4]) + x = paddle.to_tensor(x_arr) + + pp_slice = paddle.slice(x, [100, ], [0], [1]) + np_slice = x_arr[:, :, 0:1] + self.assertTrue(np.array_equal(pp_slice, np_slice)) + + pp_slice = paddle.slice(x, [-100, ], [0], [1]) + np_slice = x_arr[0:1] + self.assertTrue(np.array_equal(pp_slice, np_slice)) + + x_arr = np.array([], dtype=np.float32) + x = paddle.to_tensor(np.reshape(x_arr, (0, 0, 0))) + + starts = paddle.to_tensor( + np.reshape( + np.array( + [], dtype=np.int32), (0, ))) + ends = paddle.to_tensor( + np.reshape( + np.array( + [], dtype=np.int32), (0, ))) + + with self.assertRaises(ValueError): + paddle.slice(x, [-1000000], starts, ends) + + with self.assertRaises(ValueError): + paddle.slice(x, [1000000], starts, ends) + + with self.assertRaises(ValueError): + paddle.slice(x, [], starts, ends) + + with self.assertRaises(ValueError): + paddle.slice(x, 0, starts, ends) + @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") -- GitLab