From 4a675b7ad96f99cbb3c3e9ca8c613c7dd79dc7ee Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Fri, 9 Sep 2022 10:29:27 +0800 Subject: [PATCH] Enhance slice to support 0 shape Tensor (#45861) * Enhance slice to support 0 dims Tensor * Add UT --- paddle/fluid/pybind/eager_method.cc | 4 +++- .../unittests/dygraph_to_static/test_slice.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 0e8bf1d0f8..94756a41f8 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -759,8 +759,10 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self, decrease_axis, none_axes, infer_flags, list_select_idxs; // if index is a list, list_select_flag will be true bool list_select_flag = false; + // Note(0x45f): Using defined() instead of initialized() + // to support slice tensor which shape like [0, 0, 0]. PADDLE_ENFORCE_EQ( - self->tensor.initialized(), + self->tensor.defined(), true, platform::errors::InvalidArgument( "tensor %s has not been initialized, we can only slice initialized " diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py index 3a7a1dc1b0..6d8f0b1440 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py @@ -284,5 +284,23 @@ class TestPaddleStridedSlice(unittest.TestCase): np.testing.assert_array_equal(sl.numpy(), array_slice) +def slice_zero_shape_tensor(x): + y = x[1:2] + return y + + +class TestSliceZeroShapeTensor(unittest.TestCase): + + def test_slice(self): + paddle.disable_static() + x = paddle.ones([0, 0, 0, 0]) + y = slice_zero_shape_tensor(x) + np.testing.assert_equal(y.shape, [0, 0, 0, 0]) + + static_func = paddle.jit.to_static(slice_zero_shape_tensor) + y = static_func(x) + np.testing.assert_equal(y.shape, [0, 0, 0, 0]) + + if __name__ == '__main__': unittest.main() -- GitLab