From 60c5adaa6804360113c8e94cc15ee7884d0319f1 Mon Sep 17 00:00:00 2001 From: WeiXin Date: Mon, 6 Sep 2021 10:50:31 +0800 Subject: [PATCH] support numpy dtype and polish code of list index. (#35404) * support numpy dtype and polish code of list index. * polish code. --- paddle/fluid/pybind/imperative.cc | 17 ++++++--- .../fluid/dygraph/varbase_patch_methods.py | 34 ++++++++++++++---- .../fluid/tests/unittests/test_var_base.py | 36 +++++++++++++++++++ 3 files changed, 76 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 7f44afabf25..67f9f8b203d 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -331,7 +331,14 @@ GetVarBaseListFromPyHandle(const py::handle &handle) { return result; } - +static bool IsNumpyType(PyObject *obj) { + // It is not a good way to judge the type of obj by its type'name. Maybe using + // `PyArray_IsScalar` will be better. However, this interface cannot be used + // by including pybind11, and it needs to compile with numpy. + auto type_name = std::string(Py_TYPE(obj)->tp_name); + return type_name == "numpy.int64" || type_name == "numpy.longlong" || + type_name == "numpy.int32" || type_name == "numpy.int16"; +} static imperative::NameVarBaseMap ConvertToNameVarBaseMap( const PyNameVarBaseMap &map) { imperative::NameVarBaseMap result; @@ -372,7 +379,7 @@ static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length, if (r->step == Py_None) { *step = 1; } else { - if (PyCheckInteger(r->step)) { + if (PyCheckInteger(r->step) || IsNumpyType(r->step)) { *step = PyLong_AsLong(r->step); } else { PADDLE_THROW(platform::errors::InvalidArgument( @@ -384,7 +391,7 @@ static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length, if (r->start == Py_None) { *start = *step < 0 ? length - 1 : 0; } else { - if (PyCheckInteger(r->start)) { + if (PyCheckInteger(r->start) || IsNumpyType(r->start)) { *start = PyLong_AsLong(r->start); } else { PADDLE_THROW(platform::errors::InvalidArgument( @@ -398,7 +405,7 @@ static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length, if (r->stop == Py_None) { *stop = *step < 0 ? -1 : length; } else { - if (PyCheckInteger(r->stop)) { + if (PyCheckInteger(r->stop) || IsNumpyType(r->stop)) { *stop = PyLong_AsLong(r->stop); } else { PADDLE_THROW(platform::errors::InvalidArgument( @@ -456,7 +463,7 @@ static void ParseIndexingSlice( infer_flags->push_back(1); int dim_len = shape[dim]; - if (PyCheckInteger(slice_item)) { + if (PyCheckInteger(slice_item) || IsNumpyType(slice_item)) { // integer, PyLong_AsLong supports both int and long int start = static_cast(PyLong_AsLong(slice_item)); auto s_t = start; diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index c42a2a5943d..102dcd43622 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -544,7 +544,7 @@ def monkey_patch_varbase(): return array def contain_tensor(item): - if not isinstance(item, tuple): + if not isinstance(item, (tuple, list)): item = [item] for slice_item in item: @@ -554,20 +554,21 @@ def monkey_patch_varbase(): or isinstance(slice_item.step, Variable): return True else: - if isinstance(slice_item, Variable): + if isinstance(slice_item, + Variable) and Variable.dtype != paddle.bool: return True return False def __getitem__(self, item): def is_list_tuple(index, contain_type): def _is_list_tuple(item): - if not (isinstance(item, (list, tuple)) or - type(item) == contain_type): - return False if isinstance(item, (tuple, list)): for s in item: if not _is_list_tuple(s): return False + else: + if type(item) != contain_type: + return False return True if not isinstance(index, (tuple, list)): @@ -599,7 +600,28 @@ def monkey_patch_varbase(): return False - if contain_tensor_or_list(item): + def is_combine_index(item): + var_type = None + item_type = None + if isinstance(item, (tuple, list)): + for slice_item in item: + if item_type is None: + item_type = type(slice_item) + else: + if type(slice_item) != item_type: + return True + + if isinstance(slice_item, Variable): + if var_type is None: + var_type = slice_item.dtype + else: + if var_type != slice_item.dtype: + return True + return False + + return False + + if contain_tensor_or_list(item) and not is_combine_index(item): # To reuse code with static graph, # Call _setitem_impl_ when item contains tensor or list. return _setitem_impl_(self, item, value) diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index c94316c7482..addda7fb541 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -779,6 +779,40 @@ class TestVarBase(unittest.TestCase): for i, e in enumerate(w): self.assertTrue(np.array_equal(e.numpy(), np_value[i])) + def _test_numpy_index(self): + array = np.arange(120).reshape([4, 5, 6]) + t = paddle.to_tensor(array) + self.assertTrue(np.array_equal(t[np.longlong(0)].numpy(), array[0])) + self.assertTrue( + np.array_equal(t[np.longlong(0):np.longlong(4):np.longlong(2)] + .numpy(), array[0:4:2])) + self.assertTrue(np.array_equal(t[np.int64(0)].numpy(), array[0])) + self.assertTrue( + np.array_equal(t[np.int32(1):np.int32(4):np.int32(2)].numpy(), + array[1:4:2])) + self.assertTrue( + np.array_equal(t[np.int16(0):np.int16(4):np.int16(2)].numpy(), + array[0:4:2])) + + def _test_list_index(self): + # case1: + array = np.arange(120).reshape([6, 5, 4]) + x = paddle.to_tensor(array) + py_idx = [[0, 2, 0, 1, 3], [0, 0, 1, 2, 0]] + idx = [paddle.to_tensor(py_idx[0]), paddle.to_tensor(py_idx[1])] + self.assertTrue(np.array_equal(x[idx].numpy(), array[py_idx])) + self.assertTrue(np.array_equal(x[py_idx].numpy(), array[py_idx])) + # case2: + tensor_x = paddle.to_tensor( + np.zeros(12).reshape(2, 6).astype(np.float32)) + tensor_y1 = paddle.zeros([1]) + 2 + tensor_y2 = paddle.zeros([1]) + 5 + tensor_x[:, tensor_y1:tensor_y2] = 42 + res = tensor_x.numpy() + exp = np.array([[0., 0., 42., 42., 42., 0.], + [0., 0., 42., 42., 42., 0.]]) + self.assertTrue(np.array_equal(res, exp)) + def test_slice(self): with fluid.dygraph.guard(): self._test_slice() @@ -787,6 +821,8 @@ class TestVarBase(unittest.TestCase): self._test_for_getitem_ellipsis_index() self._test_none_index() self._test_bool_index() + self._test_numpy_index() + self._test_list_index() var = fluid.dygraph.to_variable(self.array) self.assertTrue(np.array_equal(var[1, :].numpy(), self.array[1, :])) -- GitLab