diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 2b35545db1cd8ee5d6f9ce30a37a41029ff113bf..f2fcb3162081fb871117c39c2ed02d2334404bcc 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -3395,6 +3395,21 @@ void SliceRawInferMeta(const MetaTensor& input, } } + PADDLE_ENFORCE_EQ( + axes.size(), + starts_arr.size(), + phi::errors::InvalidArgument( + "The length of axes (%d) and length of starts (%d) should be same.", + axes.size(), + starts_arr.size())); + PADDLE_ENFORCE_EQ( + axes.size(), + ends_arr.size(), + phi::errors::InvalidArgument( + "The length of axes (%d) and length of ends (%d) should be same.", + axes.size(), + ends_arr.size())); + // 2.1 Check attrs. std::vector starts = starts_arr.GetData(); std::vector ends = ends_arr.GetData(); diff --git a/python/paddle/fluid/tests/unittests/test_slice_op.py b/python/paddle/fluid/tests/unittests/test_slice_op.py index 19aa669badf5c48bfc0720c1a053fd6f5bd50bde..157818e794301249bcd2ee5b7f7948dc39dad7c1 100644 --- a/python/paddle/fluid/tests/unittests/test_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_slice_op.py @@ -852,6 +852,27 @@ class TestInferShape(unittest.TestCase): paddle.slice(x, 0, starts, ends) +class TestSliceOpError(unittest.TestCase): + def test_dismatch_shape(self): + with fluid.dygraph.guard(): + with self.assertRaises(ValueError): + array = np.array([], dtype=np.float32) + x = paddle.to_tensor(np.reshape(array, [0]), dtype='float32') + paddle.slice(x, axes=[0], starts=[], ends=[]) + + with self.assertRaises(ValueError): + array = np.array([], dtype=np.float32) + x = paddle.to_tensor(np.reshape(array, [0]), dtype='float32') + paddle.slice(x, axes=[0], starts=[0], ends=[]) + + # if shape match, pass + array = np.array([], dtype=np.float32) + x = paddle.to_tensor(np.reshape(array, [0]), dtype='float32') + out = paddle.slice(x, axes=[0], starts=[0], ends=[0]) + self.assertEqual(out.numel(), 0) + # self.assertEqual(out.shape) + + @unittest.skipIf( not core.is_compiled_with_cuda(), "core is not compiled with CUDA" )