From de0cb38677ad0620f808a76da754ac044dd972a1 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 22 Nov 2021 16:57:32 +0800 Subject: [PATCH] fix bug of indexing tensor with None (#37400) --- paddle/fluid/pybind/imperative.cc | 28 ++----------------- .../tests/unittests/test_set_value_op.py | 8 ++++++ .../fluid/tests/unittests/test_var_base.py | 4 ++- .../fluid/tests/unittests/test_variable.py | 6 ++-- python/paddle/fluid/variable_index.py | 9 +++--- 5 files changed, 23 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 7f6e3644bc3..1423a89f5df 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -562,7 +562,7 @@ static void ParseIndexingSlice( PADDLE_ENFORCE_LE(ell_count, 1, platform::errors::InvalidArgument( "An index can only have a single ellipsis ('...')")); - + int none_count = 0; for (int i = 0, dim = 0; i < size; ++i) { PyObject *slice_item = PyTuple_GetItem(index, i); @@ -608,7 +608,8 @@ static void ParseIndexingSlice( } else if (slice_item == Py_Ellipsis) { dim += rank - specified_dims; } else if (slice_item == Py_None) { - none_axes->push_back(dim); + none_axes->push_back(dim + none_count); + none_count++; } else if (PyList_Check(slice_item)) { *list_select_flag = true; PADDLE_ENFORCE_EQ( @@ -1214,29 +1215,6 @@ void BindImperative(py::module *m_ptr) { axis -= len; } - // Deal with cases that there are more than one - // prefix none index, For example: - // [None, None, :, :, None] - // the none_axes int the return of ParseIndexingSlice is: - // [0, 0, 2 ] - // according to the interface of "unsqueeze2", - // we should convert it to: - // [0, 0, 4 ] - int prefix_zero_cnt = 0; - for (const auto &axis : none_axes) { - if (axis == 0) { - prefix_zero_cnt++; - } else { - break; - } - } - if (prefix_zero_cnt > 0) { - int none_axes_num = static_cast(none_axes.size()); - for (int i = prefix_zero_cnt; i < none_axes_num; ++i) { - none_axes[i] += prefix_zero_cnt; - } - } - imperative::NameVarBaseMap ins = {{"X", {out}}}; framework::AttributeMap attrs = {{"axes", none_axes}}; auto new_out = std::shared_ptr( diff --git a/python/paddle/fluid/tests/unittests/test_set_value_op.py b/python/paddle/fluid/tests/unittests/test_set_value_op.py index e9809318cb3..057d1b590a0 100644 --- a/python/paddle/fluid/tests/unittests/test_set_value_op.py +++ b/python/paddle/fluid/tests/unittests/test_set_value_op.py @@ -408,6 +408,14 @@ class TestSetValueItemNone9(TestSetValueApi): self.data[None, :, 1, ..., None] = np.zeros(self.shape)[0, 0, :, None] +class TestSetValueItemNone10(TestSetValueApi): + def _call_setitem(self, x): + x[..., None, :, None] = np.zeros(self.shape)[..., None, :, None] + + def _get_answer(self): + self.data[..., None, :, None] = np.zeros(self.shape)[..., None, :, None] + + # 1.5 item is list or Tensor of bol class TestSetValueItemBool1(TestSetValueApi): def _call_setitem(self, x): diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 34e679e752a..3e7b14aa99a 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -723,6 +723,7 @@ class TestVarBase(unittest.TestCase): var_tensor[None].numpy(), var_tensor[0, 0, None, 0, 0, None].numpy(), var_tensor[None, None, 0, ..., None].numpy(), + var_tensor[..., None, :, None].numpy(), var_tensor[0, 1:10:2, None, None, ...].numpy(), ] @@ -738,11 +739,12 @@ class TestVarBase(unittest.TestCase): np.array_equal(var[8], np_value[0, 0, None, 0, 0, None])) self.assertTrue( np.array_equal(var[9], np_value[None, None, 0, ..., None])) + self.assertTrue(np.array_equal(var[10], np_value[..., None, :, None])) # TODO(zyfncg) there is a bug of dimensions when slice step > 1 and # indexs has int type # self.assertTrue( - # np.array_equal(var[10], np_value[0, 1:10:2, None, None, ...])) + # np.array_equal(var[11], np_value[0, 1:10:2, None, None, ...])) def _test_bool_index(self): shape = (4, 2, 5, 64) diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index 2df336fbe8e..2eb3ecf7104 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -436,13 +436,15 @@ class TestVariableSlice(unittest.TestCase): out1 = x[0:, None] out2 = x[None, 1:] out3 = x[None] + out4 = x[..., None, :, None] - outs = [out0, out1, out2, out3] + outs = [out0, out1, out2, out3, out4] exe = paddle.static.Executor(place) result = exe.run(prog, fetch_list=outs) expected = [ - data[0:, None, 1:], data[0:, None], data[None, 1:], data[None] + data[0:, None, 1:], data[0:, None], data[None, 1:], data[None], + data[..., None, :, None] ] for i in range(len(outs)): self.assertEqual(outs[i].shape, expected[i].shape) diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 66be8f80594..19067b8ae12 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -204,7 +204,8 @@ def replace_ellipsis(var, item): # Remove Variable to skip bug when counting Ellipsis item_remove_var = [ - ele for ele in item if not isinstance(ele, (Variable, np.ndarray)) + ele for ele in item + if not isinstance(ele, (Variable, np.ndarray)) and ele is not None ] ell_count = item_remove_var.count(Ellipsis) if ell_count == 0: @@ -218,7 +219,7 @@ def replace_ellipsis(var, item): return item[:-1] else: item[ell_idx:ell_idx + 1] = [slice(None)] * ( - len(var.shape) - len(item) + 1) + len(var.shape) - len(item) + item.count(None) + 1) return item @@ -298,8 +299,8 @@ def _getitem_impl_(var, item): use_strided_slice = False item = replace_ndarray(item) - item, none_axes = replace_none(item) item = replace_ellipsis(var, item) + item, none_axes = replace_none(item) slice_info = SliceInfo() for dim, slice_item in enumerate(item): @@ -517,8 +518,8 @@ def _setitem_impl_(var, item, value): steps = [] item = replace_ndarray(item) - item, none_axes = replace_none(item) item = replace_ellipsis(var, item) + item, none_axes = replace_none(item) slice_info = SliceInfo() dim = 0 for _, slice_item in enumerate(item): -- GitLab