From ff4bdac31b5b6b1f4ea801f157c98e63b40ec750 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 13 Aug 2021 19:14:59 +0800 Subject: [PATCH] fix a bug of slice by none index (#34877) --- paddle/fluid/pybind/imperative.cc | 23 +++++++++++++++++++ .../fluid/tests/unittests/test_var_base.py | 5 +++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index b540d459c26..0b6af3b5423 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -921,6 +921,29 @@ void BindImperative(py::module *m_ptr) { axis -= len; } + // Deal with cases that there are more than one + // prefix none index, For example: + // [None, None, :, :, None] + // the none_axes int the return of ParseIndexingSlice is: + // [0, 0, 2 ] + // according to the interface of "unsqueeze2", + // we should convert it to: + // [0, 0, 4 ] + int prefix_zero_cnt = 0; + for (const auto &axis : none_axes) { + if (axis == 0) { + prefix_zero_cnt++; + } else { + break; + } + } + if (prefix_zero_cnt > 0) { + int none_axes_num = static_cast(none_axes.size()); + for (int i = prefix_zero_cnt; i < none_axes_num; ++i) { + none_axes[i] += prefix_zero_cnt; + } + } + imperative::NameVarBaseMap ins = {{"X", {out}}}; framework::AttributeMap attrs = {{"axes", none_axes}}; auto new_out = std::shared_ptr( diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 4b52cfceabf..cdf34c27c0a 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -711,6 +711,7 @@ class TestVarBase(unittest.TestCase): var_tensor[None, 2, None, 1].numpy(), var_tensor[None].numpy(), var_tensor[0, 0, None, 0, 0, None].numpy(), + var_tensor[None, None, 0, ..., None].numpy(), var_tensor[0, 1:10:2, None, None, ...].numpy(), ] @@ -724,11 +725,13 @@ class TestVarBase(unittest.TestCase): self.assertTrue(np.array_equal(var[7], np_value[None])) self.assertTrue( np.array_equal(var[8], np_value[0, 0, None, 0, 0, None])) + self.assertTrue( + np.array_equal(var[9], np_value[None, None, 0, ..., None])) # TODO(zyfncg) there is a bug of dimensions when slice step > 1 and # indexs has int type # self.assertTrue( - # np.array_equal(var[9], np_value[0, 1:10:2, None, None, ...])) + # np.array_equal(var[10], np_value[0, 1:10:2, None, None, ...])) def _test_for_var(self): np_value = np.random.random((30, 100, 100)).astype('float32') -- GitLab