未验证 提交 9f588cc2 编写于 作者: W WeiXin 提交者: GitHub

support numpy.ndarray index. (#35748)

* support numpy.ndarray index.

* polish code.
上级 46ec5b3e
......@@ -553,8 +553,9 @@ def monkey_patch_varbase():
or isinstance(slice_item.step, Variable):
return True
else:
if isinstance(slice_item,
Variable) and Variable.dtype != paddle.bool:
if isinstance(
slice_item,
(Variable, np.ndarray)) and Variable.dtype != paddle.bool:
return True
return False
......
......@@ -813,6 +813,11 @@ class TestVarBase(unittest.TestCase):
[0., 0., 42., 42., 42., 0.]])
self.assertTrue(np.array_equal(res, exp))
# case3:
row = np.array([0, 1, 2])
col = np.array([2, 1, 3])
self.assertTrue(np.array_equal(array[row, col], x[row, col].numpy()))
def test_slice(self):
with fluid.dygraph.guard():
self._test_slice()
......@@ -834,6 +839,10 @@ class TestVarBase(unittest.TestCase):
with self.assertRaises(IndexError):
y = var[0 - self.shape[0] - 1]
with self.assertRaises(IndexError):
mask = np.array([1, 0, 1, 0], dtype=bool)
var[paddle.to_tensor([0, 1]), mask]
def test_var_base_to_np(self):
with fluid.dygraph.guard():
var = fluid.dygraph.to_variable(self.array)
......
......@@ -67,6 +67,7 @@ class SliceInfo:
def __init__(self):
self.pre_shape = None
self.indexes = []
self.dtype = None
def update(self, index):
if is_list_tuple(index, int) or isinstance(index, (
......@@ -75,6 +76,14 @@ class SliceInfo:
if not isinstance(index, paddle.fluid.Variable):
index = paddle.assign(index)
if self.dtype is None:
self.dtype = index.dtype
else:
if index.dtype != self.dtype:
raise IndexError(
"Data type of Tensor/List index should be same. The current data type is {}, but the previous data type is {}.".
format(index.dtype, self.dtype))
self.indexes.append(index)
if self.pre_shape is None:
......@@ -214,6 +223,16 @@ def replace_ellipsis(var, item):
return item
def replace_ndarray(item):
new_item = []
for slice_item in item:
if isinstance(slice_item, np.ndarray):
new_item.append(paddle.assign(slice_item))
else:
new_item.append(slice_item)
return new_item
def replace_none(item):
new_item = []
none_axes = []
......@@ -278,6 +297,7 @@ def _getitem_impl_(var, item):
reverse_axes = []
use_strided_slice = False
item = replace_ndarray(item)
item, none_axes = replace_none(item)
item = replace_ellipsis(var, item)
slice_info = SliceInfo()
......@@ -361,9 +381,6 @@ def _getitem_impl_(var, item):
idx = assign(np.array(slice_item).astype("int32"))
return index_select(var, index=idx, axis=0)
elif isinstance(slice_item, np.ndarray):
slice_info.update(slice_item)
continue
elif isinstance(slice_item, (Variable)):
if len(item) == 1:
......@@ -499,6 +516,7 @@ def _setitem_impl_(var, item, value):
ends = []
steps = []
item = replace_ndarray(item)
item, none_axes = replace_none(item)
item = replace_ellipsis(var, item)
slice_info = SliceInfo()
......@@ -556,10 +574,6 @@ def _setitem_impl_(var, item, value):
idx_tensor = assign(slice_item)
return set_value_for_bool_tensor(var, idx_tensor, value)
elif isinstance(slice_item, np.ndarray):
slice_info.update(slice_item)
continue
elif isinstance(slice_item, Variable):
if slice_item.dtype == core.VarDesc.VarType.BOOL:
if len(item) != 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册