From 1046636b7ced00f55dd4fcaa76ffb64a2dec8756 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Thu, 30 Dec 2021 16:10:28 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90cherry-pick=E3=80=91Fix=20bug=20of=20t?= =?UTF-8?q?ensor=20slice=20(#37400,=20#38098)=20(#38593)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 本PR修复了以下两个tensor切片索引bug: 1.修复Tensor索引赋值调用set_value op出现的显存泄露问题,该问题主要是由Inplace策略的使用不当导致,本PR中已完成修复。 2.修复使用多个None类型索引时结果维度异常的问题 --- paddle/fluid/pybind/imperative.cc | 37 ++++--------------- .../tests/unittests/test_set_value_op.py | 8 ++++ .../fluid/tests/unittests/test_var_base.py | 4 +- .../fluid/tests/unittests/test_variable.py | 6 ++- python/paddle/fluid/variable_index.py | 9 +++-- 5 files changed, 28 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 60b97b76491..05037d2a1b9 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -541,7 +541,7 @@ static void ParseIndexingSlice( PADDLE_ENFORCE_LE(ell_count, 1, platform::errors::InvalidArgument( "An index can only have a single ellipsis ('...')")); - + int none_count = 0; for (int i = 0, dim = 0; i < size; ++i) { PyObject *slice_item = PyTuple_GetItem(index, i); @@ -587,7 +587,8 @@ static void ParseIndexingSlice( } else if (slice_item == Py_Ellipsis) { dim += rank - specified_dims; } else if (slice_item == Py_None) { - none_axes->push_back(dim); + none_axes->push_back(dim + none_count); + none_count++; } else if (PyList_Check(slice_item)) { *list_select_flag = true; PADDLE_ENFORCE_EQ( @@ -939,6 +940,11 @@ void BindImperative(py::module *m_ptr) { } }; + // NOTE(liym27): + // Increase the version of VarBase self because __setitem__ is an + // inplace operator for the VarBase self. + self->BumpInplaceVersion(); + // 1. Check argumnets bool parse_index = true; @@ -1106,10 +1112,6 @@ void BindImperative(py::module *m_ptr) { SetTensorFromPyArray(self_tensor, self_numpy, self_tensor->place(), false); } - // NOTE(liym27): - // Increase the version of VarBase self because __setitem__ is an - // inplace operator for the VarBase self. - self->BumpInplaceVersion(); }) .def("_getitem_index_not_tensor", [](std::shared_ptr &self, py::handle _index) { @@ -1183,29 +1185,6 @@ void BindImperative(py::module *m_ptr) { axis -= len; } - // Deal with cases that there are more than one - // prefix none index, For example: - // [None, None, :, :, None] - // the none_axes int the return of ParseIndexingSlice is: - // [0, 0, 2 ] - // according to the interface of "unsqueeze2", - // we should convert it to: - // [0, 0, 4 ] - int prefix_zero_cnt = 0; - for (const auto &axis : none_axes) { - if (axis == 0) { - prefix_zero_cnt++; - } else { - break; - } - } - if (prefix_zero_cnt > 0) { - int none_axes_num = static_cast(none_axes.size()); - for (int i = prefix_zero_cnt; i < none_axes_num; ++i) { - none_axes[i] += prefix_zero_cnt; - } - } - imperative::NameVarBaseMap ins = {{"X", {out}}}; framework::AttributeMap attrs = {{"axes", none_axes}}; auto new_out = std::shared_ptr( diff --git a/python/paddle/fluid/tests/unittests/test_set_value_op.py b/python/paddle/fluid/tests/unittests/test_set_value_op.py index e9809318cb3..057d1b590a0 100644 --- a/python/paddle/fluid/tests/unittests/test_set_value_op.py +++ b/python/paddle/fluid/tests/unittests/test_set_value_op.py @@ -408,6 +408,14 @@ class TestSetValueItemNone9(TestSetValueApi): self.data[None, :, 1, ..., None] = np.zeros(self.shape)[0, 0, :, None] +class TestSetValueItemNone10(TestSetValueApi): + def _call_setitem(self, x): + x[..., None, :, None] = np.zeros(self.shape)[..., None, :, None] + + def _get_answer(self): + self.data[..., None, :, None] = np.zeros(self.shape)[..., None, :, None] + + # 1.5 item is list or Tensor of bol class TestSetValueItemBool1(TestSetValueApi): def _call_setitem(self, x): diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index e4ef7428202..1dc48720a32 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -723,6 +723,7 @@ class TestVarBase(unittest.TestCase): var_tensor[None].numpy(), var_tensor[0, 0, None, 0, 0, None].numpy(), var_tensor[None, None, 0, ..., None].numpy(), + var_tensor[..., None, :, None].numpy(), var_tensor[0, 1:10:2, None, None, ...].numpy(), ] @@ -738,11 +739,12 @@ class TestVarBase(unittest.TestCase): np.array_equal(var[8], np_value[0, 0, None, 0, 0, None])) self.assertTrue( np.array_equal(var[9], np_value[None, None, 0, ..., None])) + self.assertTrue(np.array_equal(var[10], np_value[..., None, :, None])) # TODO(zyfncg) there is a bug of dimensions when slice step > 1 and # indexs has int type # self.assertTrue( - # np.array_equal(var[10], np_value[0, 1:10:2, None, None, ...])) + # np.array_equal(var[11], np_value[0, 1:10:2, None, None, ...])) def _test_bool_index(self): shape = (4, 2, 5, 64) diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index 2df336fbe8e..2eb3ecf7104 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -436,13 +436,15 @@ class TestVariableSlice(unittest.TestCase): out1 = x[0:, None] out2 = x[None, 1:] out3 = x[None] + out4 = x[..., None, :, None] - outs = [out0, out1, out2, out3] + outs = [out0, out1, out2, out3, out4] exe = paddle.static.Executor(place) result = exe.run(prog, fetch_list=outs) expected = [ - data[0:, None, 1:], data[0:, None], data[None, 1:], data[None] + data[0:, None, 1:], data[0:, None], data[None, 1:], data[None], + data[..., None, :, None] ] for i in range(len(outs)): self.assertEqual(outs[i].shape, expected[i].shape) diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 66be8f80594..19067b8ae12 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -204,7 +204,8 @@ def replace_ellipsis(var, item): # Remove Variable to skip bug when counting Ellipsis item_remove_var = [ - ele for ele in item if not isinstance(ele, (Variable, np.ndarray)) + ele for ele in item + if not isinstance(ele, (Variable, np.ndarray)) and ele is not None ] ell_count = item_remove_var.count(Ellipsis) if ell_count == 0: @@ -218,7 +219,7 @@ def replace_ellipsis(var, item): return item[:-1] else: item[ell_idx:ell_idx + 1] = [slice(None)] * ( - len(var.shape) - len(item) + 1) + len(var.shape) - len(item) + item.count(None) + 1) return item @@ -298,8 +299,8 @@ def _getitem_impl_(var, item): use_strided_slice = False item = replace_ndarray(item) - item, none_axes = replace_none(item) item = replace_ellipsis(var, item) + item, none_axes = replace_none(item) slice_info = SliceInfo() for dim, slice_item in enumerate(item): @@ -517,8 +518,8 @@ def _setitem_impl_(var, item, value): steps = [] item = replace_ndarray(item) - item, none_axes = replace_none(item) item = replace_ellipsis(var, item) + item, none_axes = replace_none(item) slice_info = SliceInfo() dim = 0 for _, slice_item in enumerate(item): -- GitLab