From c2329c8850bb38397d8a1b1373bbe28ef9edde58 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Thu, 21 May 2020 15:37:30 +0800 Subject: [PATCH] fix bug of varbase.__getitem__, test=develop (#24642) (#24646) * fix bug of varbase.__getitem__, test=develop * fix bug of float and other type, test=develop --- paddle/fluid/pybind/imperative.cc | 103 ++++++++++++++---- .../fluid/tests/unittests/test_slice_op.py | 29 +++++ 2 files changed, 112 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index e06eafa0cb..defb1edad9 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -222,6 +222,71 @@ static imperative::NameVarBaseMap ConvertToNameVarBaseMap( return result; } +static bool PyCheckInteger(PyObject *obj) { +#if PY_VERSION_HEX < 0x03000000 + return (PyLong_Check(obj) || PyInt_Check(obj)) && !PyBool_Check(obj); +#else + return PyLong_Check(obj) && !PyBool_Check(obj); +#endif +} + +// NOTE(zhiqiu): Revised version of PySlice_GetIndices. From: +// https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Objects/sliceobject.c#L103 +// Original PySlice_GetIndices return wrong result when +// slice_item contains long int, such as arr[:180L]. +// NOT sure why this happens !!! +// Besides, PySlice_GetIndices cannot raise error when float in slice item. +// So, I make a revised version of PySlice_GetIndices, named to +// _PySlice_GetIndices. Try to use _PySlice_Unpack which is more robust than +// PySlice_GetIndices in the future. +static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length, + Py_ssize_t *start, Py_ssize_t *stop, + Py_ssize_t *step) { + /* XXX support long ints */ + if (r->step == Py_None) { + *step = 1; + } else { + if (PyCheckInteger(r->step)) { + *step = PyLong_AsLong(r->step); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Currently, VarBase.__getitem__() only allows None or integers in " + "slice item, but received %s.", + std::string(Py_TYPE(r->step)->tp_name))); + } + } + if (r->start == Py_None) { + *start = *step < 0 ? length - 1 : 0; + } else { + if (PyCheckInteger(r->start)) { + *start = PyLong_AsLong(r->start); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Currently, VarBase.__getitem__() only allows None or integers in " + "slice item, but received %s.", + std::string(Py_TYPE(r->start)->tp_name))); + } + if (*start < 0) *start += length; + } + if (r->stop == Py_None) { + *stop = *step < 0 ? -1 : length; + } else { + if (PyCheckInteger(r->stop)) { + *stop = PyLong_AsLong(r->stop); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Currently, VarBase.__getitem__() only allows None or integers in " + "slice item, but received %s.", + std::string(Py_TYPE(r->stop)->tp_name))); + } + if (*stop < 0) *stop += length; + } + if (*stop > length) return -1; + if (*start >= length) return -1; + if (*step == 0) return -1; + return 0; +} + static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, std::vector *slice_axes, std::vector *slice_starts, @@ -246,16 +311,17 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, "too many indices (%d) for tensor of dimension %d", size, rank)); for (int dim = 0; dim < size; ++dim) { PyObject *slice_item = PyTuple_GetItem(index, dim); - PADDLE_ENFORCE_EQ( - PyNumber_Check(slice_item) || PySlice_Check(slice_item), true, - platform::errors::InvalidArgument( - "We allow indexing by Integers, Slices, and tuples of " - "these types, but received %s in %dth slice item", - std::string(Py_TYPE(slice_item)->tp_name), dim + 1)); + PADDLE_ENFORCE_EQ(PyCheckInteger(slice_item) || PySlice_Check(slice_item), + true, + platform::errors::InvalidArgument( + "Currently, VarBase.__getitem__() only allows " + "indexing by Integers, Slices, and tuples of " + "these types, but received %s in %dth slice item", + std::string(Py_TYPE(slice_item)->tp_name), dim + 1)); infer_flags->push_back(1); int dim_len = shape[dim]; - if (PyNumber_Check(slice_item)) { - // integer + if (PyCheckInteger(slice_item)) { + // integer, PyLong_AsLong supports both int and long int start = static_cast(PyLong_AsLong(slice_item)); auto s_t = start; start = start < 0 ? start + dim_len : start; @@ -275,17 +341,15 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, slice_strides->push_back(1); decrease_axis->push_back(dim); } else { - // slice + // slice item Py_ssize_t start, end, step; -// The parameter type for the slice parameter was PySliceObject* before 3.2 -#if PY_VERSION_HEX >= 0x03020000 - PySlice_GetIndices(slice_item, dim_len, &start, &end, &step); -#else - PySlice_GetIndices(reinterpret_cast(slice_item), dim_len, - &start, &end, &step); -#endif + PySliceObject *p = reinterpret_cast(slice_item); + _PySlice_GetIndices(p, dim_len, &start, &end, &step); + // :: or : or 0:dim_len:1 - if (start == 0 && end == dim_len && step == 1) continue; + if (start == 0 && end == dim_len && step == 1) { + continue; + } slice_axes->push_back(dim); slice_starts->push_back(start); slice_ends->push_back(end); @@ -493,7 +557,6 @@ void BindImperative(py::module *m_ptr) { ParseIndexingSlice(tensor, _index.ptr(), &slice_axes, &slice_starts, &slice_ends, &slice_strides, &decrease_axis, &infer_flags); - // release gil and do tracing py::gil_scoped_release release; const auto &tracer = imperative::GetCurrentTracer(); @@ -633,8 +696,8 @@ void BindImperative(py::module *m_ptr) { [](imperative::VarBase &self, const imperative::detail::BackwardStrategy &bckst, const imperative::Tracer &tracer) { - // TODO(jiabin): when we impl more backward execution we can select - // them + // TODO(jiabin): when we impl more backward execution we can + // select them auto *engine = tracer.GetEngine(); engine->Init(&self, bckst); VLOG(3) << "Start backward"; diff --git a/python/paddle/fluid/tests/unittests/test_slice_op.py b/python/paddle/fluid/tests/unittests/test_slice_op.py index 23ff09218b..4efd018723 100644 --- a/python/paddle/fluid/tests/unittests/test_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_slice_op.py @@ -609,5 +609,34 @@ class TestSliceApiWithLoDTensorArray(unittest.TestCase): self.assertTrue(np.array_equal(self.g_x2, np.zeros_like(self.data))) +class TestImperativeVarBaseGetItem(unittest.TestCase): + def test_getitem_with_long(self): + with fluid.dygraph.guard(): + data = np.random.random((2, 80, 16128)).astype('float32') + var = fluid.dygraph.to_variable(data) + sliced = var[:, 10:, :var.shape[1]] # var.shape[1] is 80L here + self.assertEqual(sliced.shape, [2, 70, 80]) + + sliced = var[:, var.shape[0]:, var.shape[0]:var.shape[1]] + self.assertEqual(sliced.shape, [2, 78, 78]) + + def test_getitem_with_float(self): + def test_float_in_slice_item(): + with fluid.dygraph.guard(): + data = np.random.random((2, 80, 16128)).astype('float32') + var = fluid.dygraph.to_variable(data) + sliced = var[:, 1.1:, :var.shape[1]] + + self.assertRaises(Exception, test_float_in_slice_item) + + def test_float_in_index(): + with fluid.dygraph.guard(): + data = np.random.random((2, 80, 16128)).astype('float32') + var = fluid.dygraph.to_variable(data) + sliced = var[1.1] + + self.assertRaises(Exception, test_float_in_index) + + if __name__ == '__main__': unittest.main() -- GitLab