diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 93a3137b9991036e161c7863e4acc4b6ccf5a711..509a727ade2c343e6d5107de0d5ebc448cdd09fa 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -221,6 +221,71 @@ static imperative::NameVarBaseMap ConvertToNameVarBaseMap( return result; } +static bool PyCheckInteger(PyObject *obj) { +#if PY_VERSION_HEX < 0x03000000 + return (PyLong_Check(obj) || PyInt_Check(obj)) && !PyBool_Check(obj); +#else + return PyLong_Check(obj) && !PyBool_Check(obj); +#endif +} + +// NOTE(zhiqiu): Revised version of PySlice_GetIndices. From: +// https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Objects/sliceobject.c#L103 +// Original PySlice_GetIndices return wrong result when +// slice_item contains long int, such as arr[:180L]. +// NOT sure why this happens !!! +// Besides, PySlice_GetIndices cannot raise error when float in slice item. +// So, I make a revised version of PySlice_GetIndices, named to +// _PySlice_GetIndices. Try to use _PySlice_Unpack which is more robust than +// PySlice_GetIndices in the future. +static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length, + Py_ssize_t *start, Py_ssize_t *stop, + Py_ssize_t *step) { + /* XXX support long ints */ + if (r->step == Py_None) { + *step = 1; + } else { + if (PyCheckInteger(r->step)) { + *step = PyLong_AsLong(r->step); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Currently, VarBase.__getitem__() only allows None or integers in " + "slice item, but received %s.", + std::string(Py_TYPE(r->step)->tp_name))); + } + } + if (r->start == Py_None) { + *start = *step < 0 ? length - 1 : 0; + } else { + if (PyCheckInteger(r->start)) { + *start = PyLong_AsLong(r->start); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Currently, VarBase.__getitem__() only allows None or integers in " + "slice item, but received %s.", + std::string(Py_TYPE(r->start)->tp_name))); + } + if (*start < 0) *start += length; + } + if (r->stop == Py_None) { + *stop = *step < 0 ? -1 : length; + } else { + if (PyCheckInteger(r->stop)) { + *stop = PyLong_AsLong(r->stop); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Currently, VarBase.__getitem__() only allows None or integers in " + "slice item, but received %s.", + std::string(Py_TYPE(r->stop)->tp_name))); + } + if (*stop < 0) *stop += length; + } + if (*stop > length) return -1; + if (*start >= length) return -1; + if (*step == 0) return -1; + return 0; +} + static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, std::vector *slice_axes, std::vector *slice_starts, @@ -245,16 +310,17 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, "too many indices (%d) for tensor of dimension %d", size, rank)); for (int dim = 0; dim < size; ++dim) { PyObject *slice_item = PyTuple_GetItem(index, dim); - PADDLE_ENFORCE_EQ( - PyNumber_Check(slice_item) || PySlice_Check(slice_item), true, - platform::errors::InvalidArgument( - "We allow indexing by Integers, Slices, and tuples of " - "these types, but received %s in %dth slice item", - std::string(Py_TYPE(slice_item)->tp_name), dim + 1)); + PADDLE_ENFORCE_EQ(PyCheckInteger(slice_item) || PySlice_Check(slice_item), + true, + platform::errors::InvalidArgument( + "Currently, VarBase.__getitem__() only allows " + "indexing by Integers, Slices, and tuples of " + "these types, but received %s in %dth slice item", + std::string(Py_TYPE(slice_item)->tp_name), dim + 1)); infer_flags->push_back(1); int dim_len = shape[dim]; - if (PyNumber_Check(slice_item)) { - // integer + if (PyCheckInteger(slice_item)) { + // integer, PyLong_AsLong supports both int and long int start = static_cast(PyLong_AsLong(slice_item)); start = start < 0 ? start + dim_len : start; slice_axes->push_back(dim); @@ -263,17 +329,15 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, slice_strides->push_back(1); decrease_axis->push_back(dim); } else { - // slice + // slice item Py_ssize_t start, end, step; -// The parameter type for the slice parameter was PySliceObject* before 3.2 -#if PY_VERSION_HEX >= 0x03020000 - PySlice_GetIndices(slice_item, dim_len, &start, &end, &step); -#else - PySlice_GetIndices(reinterpret_cast(slice_item), dim_len, - &start, &end, &step); -#endif + PySliceObject *p = reinterpret_cast(slice_item); + _PySlice_GetIndices(p, dim_len, &start, &end, &step); + // :: or : or 0:dim_len:1 - if (start == 0 && end == dim_len && step == 1) continue; + if (start == 0 && end == dim_len && step == 1) { + continue; + } slice_axes->push_back(dim); slice_starts->push_back(start); slice_ends->push_back(end); @@ -481,7 +545,6 @@ void BindImperative(py::module *m_ptr) { ParseIndexingSlice(tensor, _index.ptr(), &slice_axes, &slice_starts, &slice_ends, &slice_strides, &decrease_axis, &infer_flags); - // release gil and do tracing py::gil_scoped_release release; const auto &tracer = imperative::GetCurrentTracer(); @@ -621,8 +684,8 @@ void BindImperative(py::module *m_ptr) { [](imperative::VarBase &self, const imperative::detail::BackwardStrategy &bckst, const imperative::Tracer &tracer) { - // TODO(jiabin): when we impl more backward execution we can select - // them + // TODO(jiabin): when we impl more backward execution we can + // select them auto *engine = tracer.GetEngine(); engine->Init(&self, bckst); VLOG(3) << "Start backward"; diff --git a/python/paddle/fluid/tests/unittests/test_imperative_varbase_slice.py b/python/paddle/fluid/tests/unittests/test_imperative_varbase_slice.py new file mode 100644 index 0000000000000000000000000000000000000000..2c5f1aff1fa62e7bd94938ea3a081788a436e86f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_varbase_slice.py @@ -0,0 +1,50 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle.fluid as fluid + + +class TestImperativeVarBaseGetItem(unittest.TestCase): + def test_getitem_with_long(self): + with fluid.dygraph.guard(): + data = np.random.random((2, 80, 16128)).astype('float32') + var = fluid.dygraph.to_variable(data) + sliced = var[:, 10:, :var.shape[1]] # var.shape[1] is 80L here + self.assertEqual(sliced.shape, [2, 70, 80]) + + sliced = var[:, var.shape[0]:, var.shape[0]:var.shape[1]] + self.assertEqual(sliced.shape, [2, 78, 78]) + + def test_getitem_with_float(self): + def test_float_in_slice_item(): + with fluid.dygraph.guard(): + data = np.random.random((2, 80, 16128)).astype('float32') + var = fluid.dygraph.to_variable(data) + sliced = var[:, 1.1:, :var.shape[1]] + + self.assertRaises(Exception, test_float_in_slice_item) + + def test_float_in_index(): + with fluid.dygraph.guard(): + data = np.random.random((2, 80, 16128)).astype('float32') + var = fluid.dygraph.to_variable(data) + sliced = var[1.1] + + self.assertRaises(Exception, test_float_in_index) + + +if __name__ == '__main__': + unittest.main()