未验证 提交 9f588cc2 编写于 作者: W WeiXin 提交者: GitHub

support numpy.ndarray index. (#35748)

* support numpy.ndarray index.

* polish code.
上级 46ec5b3e
...@@ -553,8 +553,9 @@ def monkey_patch_varbase(): ...@@ -553,8 +553,9 @@ def monkey_patch_varbase():
or isinstance(slice_item.step, Variable): or isinstance(slice_item.step, Variable):
return True return True
else: else:
if isinstance(slice_item, if isinstance(
Variable) and Variable.dtype != paddle.bool: slice_item,
(Variable, np.ndarray)) and Variable.dtype != paddle.bool:
return True return True
return False return False
......
...@@ -813,6 +813,11 @@ class TestVarBase(unittest.TestCase): ...@@ -813,6 +813,11 @@ class TestVarBase(unittest.TestCase):
[0., 0., 42., 42., 42., 0.]]) [0., 0., 42., 42., 42., 0.]])
self.assertTrue(np.array_equal(res, exp)) self.assertTrue(np.array_equal(res, exp))
# case3:
row = np.array([0, 1, 2])
col = np.array([2, 1, 3])
self.assertTrue(np.array_equal(array[row, col], x[row, col].numpy()))
def test_slice(self): def test_slice(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
self._test_slice() self._test_slice()
...@@ -834,6 +839,10 @@ class TestVarBase(unittest.TestCase): ...@@ -834,6 +839,10 @@ class TestVarBase(unittest.TestCase):
with self.assertRaises(IndexError): with self.assertRaises(IndexError):
y = var[0 - self.shape[0] - 1] y = var[0 - self.shape[0] - 1]
with self.assertRaises(IndexError):
mask = np.array([1, 0, 1, 0], dtype=bool)
var[paddle.to_tensor([0, 1]), mask]
def test_var_base_to_np(self): def test_var_base_to_np(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
var = fluid.dygraph.to_variable(self.array) var = fluid.dygraph.to_variable(self.array)
......
...@@ -67,6 +67,7 @@ class SliceInfo: ...@@ -67,6 +67,7 @@ class SliceInfo:
def __init__(self): def __init__(self):
self.pre_shape = None self.pre_shape = None
self.indexes = [] self.indexes = []
self.dtype = None
def update(self, index): def update(self, index):
if is_list_tuple(index, int) or isinstance(index, ( if is_list_tuple(index, int) or isinstance(index, (
...@@ -75,6 +76,14 @@ class SliceInfo: ...@@ -75,6 +76,14 @@ class SliceInfo:
if not isinstance(index, paddle.fluid.Variable): if not isinstance(index, paddle.fluid.Variable):
index = paddle.assign(index) index = paddle.assign(index)
if self.dtype is None:
self.dtype = index.dtype
else:
if index.dtype != self.dtype:
raise IndexError(
"Data type of Tensor/List index should be same. The current data type is {}, but the previous data type is {}.".
format(index.dtype, self.dtype))
self.indexes.append(index) self.indexes.append(index)
if self.pre_shape is None: if self.pre_shape is None:
...@@ -214,6 +223,16 @@ def replace_ellipsis(var, item): ...@@ -214,6 +223,16 @@ def replace_ellipsis(var, item):
return item return item
def replace_ndarray(item):
new_item = []
for slice_item in item:
if isinstance(slice_item, np.ndarray):
new_item.append(paddle.assign(slice_item))
else:
new_item.append(slice_item)
return new_item
def replace_none(item): def replace_none(item):
new_item = [] new_item = []
none_axes = [] none_axes = []
...@@ -278,6 +297,7 @@ def _getitem_impl_(var, item): ...@@ -278,6 +297,7 @@ def _getitem_impl_(var, item):
reverse_axes = [] reverse_axes = []
use_strided_slice = False use_strided_slice = False
item = replace_ndarray(item)
item, none_axes = replace_none(item) item, none_axes = replace_none(item)
item = replace_ellipsis(var, item) item = replace_ellipsis(var, item)
slice_info = SliceInfo() slice_info = SliceInfo()
...@@ -361,9 +381,6 @@ def _getitem_impl_(var, item): ...@@ -361,9 +381,6 @@ def _getitem_impl_(var, item):
idx = assign(np.array(slice_item).astype("int32")) idx = assign(np.array(slice_item).astype("int32"))
return index_select(var, index=idx, axis=0) return index_select(var, index=idx, axis=0)
elif isinstance(slice_item, np.ndarray):
slice_info.update(slice_item)
continue
elif isinstance(slice_item, (Variable)): elif isinstance(slice_item, (Variable)):
if len(item) == 1: if len(item) == 1:
...@@ -499,6 +516,7 @@ def _setitem_impl_(var, item, value): ...@@ -499,6 +516,7 @@ def _setitem_impl_(var, item, value):
ends = [] ends = []
steps = [] steps = []
item = replace_ndarray(item)
item, none_axes = replace_none(item) item, none_axes = replace_none(item)
item = replace_ellipsis(var, item) item = replace_ellipsis(var, item)
slice_info = SliceInfo() slice_info = SliceInfo()
...@@ -556,10 +574,6 @@ def _setitem_impl_(var, item, value): ...@@ -556,10 +574,6 @@ def _setitem_impl_(var, item, value):
idx_tensor = assign(slice_item) idx_tensor = assign(slice_item)
return set_value_for_bool_tensor(var, idx_tensor, value) return set_value_for_bool_tensor(var, idx_tensor, value)
elif isinstance(slice_item, np.ndarray):
slice_info.update(slice_item)
continue
elif isinstance(slice_item, Variable): elif isinstance(slice_item, Variable):
if slice_item.dtype == core.VarDesc.VarType.BOOL: if slice_item.dtype == core.VarDesc.VarType.BOOL:
if len(item) != 1: if len(item) != 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册