From 9f588cc264802b71864642e1682880007d9bef0c Mon Sep 17 00:00:00 2001 From: WeiXin Date: Wed, 15 Sep 2021 15:39:17 +0800 Subject: [PATCH] support numpy.ndarray index. (#35748) * support numpy.ndarray index. * polish code. --- .../fluid/dygraph/varbase_patch_methods.py | 5 ++-- .../fluid/tests/unittests/test_var_base.py | 9 ++++++ python/paddle/fluid/variable_index.py | 28 ++++++++++++++----- 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index e39a86e961d..9d8b1500d5b 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -553,8 +553,9 @@ def monkey_patch_varbase(): or isinstance(slice_item.step, Variable): return True else: - if isinstance(slice_item, - Variable) and Variable.dtype != paddle.bool: + if isinstance( + slice_item, + (Variable, np.ndarray)) and Variable.dtype != paddle.bool: return True return False diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index addda7fb541..cbfb9860fa6 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -813,6 +813,11 @@ class TestVarBase(unittest.TestCase): [0., 0., 42., 42., 42., 0.]]) self.assertTrue(np.array_equal(res, exp)) + # case3: + row = np.array([0, 1, 2]) + col = np.array([2, 1, 3]) + self.assertTrue(np.array_equal(array[row, col], x[row, col].numpy())) + def test_slice(self): with fluid.dygraph.guard(): self._test_slice() @@ -834,6 +839,10 @@ class TestVarBase(unittest.TestCase): with self.assertRaises(IndexError): y = var[0 - self.shape[0] - 1] + with self.assertRaises(IndexError): + mask = np.array([1, 0, 1, 0], dtype=bool) + var[paddle.to_tensor([0, 1]), mask] + def test_var_base_to_np(self): with fluid.dygraph.guard(): var = fluid.dygraph.to_variable(self.array) diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 1b9a82ba85f..66be8f80594 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -67,6 +67,7 @@ class SliceInfo: def __init__(self): self.pre_shape = None self.indexes = [] + self.dtype = None def update(self, index): if is_list_tuple(index, int) or isinstance(index, ( @@ -75,6 +76,14 @@ class SliceInfo: if not isinstance(index, paddle.fluid.Variable): index = paddle.assign(index) + if self.dtype is None: + self.dtype = index.dtype + else: + if index.dtype != self.dtype: + raise IndexError( + "Data type of Tensor/List index should be same. The current data type is {}, but the previous data type is {}.". + format(index.dtype, self.dtype)) + self.indexes.append(index) if self.pre_shape is None: @@ -214,6 +223,16 @@ def replace_ellipsis(var, item): return item +def replace_ndarray(item): + new_item = [] + for slice_item in item: + if isinstance(slice_item, np.ndarray): + new_item.append(paddle.assign(slice_item)) + else: + new_item.append(slice_item) + return new_item + + def replace_none(item): new_item = [] none_axes = [] @@ -278,6 +297,7 @@ def _getitem_impl_(var, item): reverse_axes = [] use_strided_slice = False + item = replace_ndarray(item) item, none_axes = replace_none(item) item = replace_ellipsis(var, item) slice_info = SliceInfo() @@ -361,9 +381,6 @@ def _getitem_impl_(var, item): idx = assign(np.array(slice_item).astype("int32")) return index_select(var, index=idx, axis=0) - elif isinstance(slice_item, np.ndarray): - slice_info.update(slice_item) - continue elif isinstance(slice_item, (Variable)): if len(item) == 1: @@ -499,6 +516,7 @@ def _setitem_impl_(var, item, value): ends = [] steps = [] + item = replace_ndarray(item) item, none_axes = replace_none(item) item = replace_ellipsis(var, item) slice_info = SliceInfo() @@ -556,10 +574,6 @@ def _setitem_impl_(var, item, value): idx_tensor = assign(slice_item) return set_value_for_bool_tensor(var, idx_tensor, value) - elif isinstance(slice_item, np.ndarray): - slice_info.update(slice_item) - continue - elif isinstance(slice_item, Variable): if slice_item.dtype == core.VarDesc.VarType.BOOL: if len(item) != 1: -- GitLab