未验证 提交 14dd6388 编写于 作者: L Leo Chen 提交者: GitHub

fix bug of varbase.__getitem__, test=develop (#24642)

* fix bug of varbase.__getitem__, test=develop

* fix bug of float and other type, test=develop
上级 6ca44cba
......@@ -222,6 +222,71 @@ static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
return result;
}
static bool PyCheckInteger(PyObject *obj) {
#if PY_VERSION_HEX < 0x03000000
return (PyLong_Check(obj) || PyInt_Check(obj)) && !PyBool_Check(obj);
#else
return PyLong_Check(obj) && !PyBool_Check(obj);
#endif
}
// NOTE(zhiqiu): Revised version of PySlice_GetIndices. From:
// https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Objects/sliceobject.c#L103
// Original PySlice_GetIndices return wrong result when
// slice_item contains long int, such as arr[:180L].
// NOT sure why this happens !!!
// Besides, PySlice_GetIndices cannot raise error when float in slice item.
// So, I make a revised version of PySlice_GetIndices, named to
// _PySlice_GetIndices. Try to use _PySlice_Unpack which is more robust than
// PySlice_GetIndices in the future.
static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length,
Py_ssize_t *start, Py_ssize_t *stop,
Py_ssize_t *step) {
/* XXX support long ints */
if (r->step == Py_None) {
*step = 1;
} else {
if (PyCheckInteger(r->step)) {
*step = PyLong_AsLong(r->step);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Currently, VarBase.__getitem__() only allows None or integers in "
"slice item, but received %s.",
std::string(Py_TYPE(r->step)->tp_name)));
}
}
if (r->start == Py_None) {
*start = *step < 0 ? length - 1 : 0;
} else {
if (PyCheckInteger(r->start)) {
*start = PyLong_AsLong(r->start);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Currently, VarBase.__getitem__() only allows None or integers in "
"slice item, but received %s.",
std::string(Py_TYPE(r->start)->tp_name)));
}
if (*start < 0) *start += length;
}
if (r->stop == Py_None) {
*stop = *step < 0 ? -1 : length;
} else {
if (PyCheckInteger(r->stop)) {
*stop = PyLong_AsLong(r->stop);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Currently, VarBase.__getitem__() only allows None or integers in "
"slice item, but received %s.",
std::string(Py_TYPE(r->stop)->tp_name)));
}
if (*stop < 0) *stop += length;
}
if (*stop > length) return -1;
if (*start >= length) return -1;
if (*step == 0) return -1;
return 0;
}
static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
std::vector<int> *slice_axes,
std::vector<int> *slice_starts,
......@@ -246,16 +311,17 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
"too many indices (%d) for tensor of dimension %d", size, rank));
for (int dim = 0; dim < size; ++dim) {
PyObject *slice_item = PyTuple_GetItem(index, dim);
PADDLE_ENFORCE_EQ(
PyNumber_Check(slice_item) || PySlice_Check(slice_item), true,
PADDLE_ENFORCE_EQ(PyCheckInteger(slice_item) || PySlice_Check(slice_item),
true,
platform::errors::InvalidArgument(
"We allow indexing by Integers, Slices, and tuples of "
"Currently, VarBase.__getitem__() only allows "
"indexing by Integers, Slices, and tuples of "
"these types, but received %s in %dth slice item",
std::string(Py_TYPE(slice_item)->tp_name), dim + 1));
infer_flags->push_back(1);
int dim_len = shape[dim];
if (PyNumber_Check(slice_item)) {
// integer
if (PyCheckInteger(slice_item)) {
// integer, PyLong_AsLong supports both int and long
int start = static_cast<int>(PyLong_AsLong(slice_item));
auto s_t = start;
start = start < 0 ? start + dim_len : start;
......@@ -275,17 +341,15 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
slice_strides->push_back(1);
decrease_axis->push_back(dim);
} else {
// slice
// slice item
Py_ssize_t start, end, step;
// The parameter type for the slice parameter was PySliceObject* before 3.2
#if PY_VERSION_HEX >= 0x03020000
PySlice_GetIndices(slice_item, dim_len, &start, &end, &step);
#else
PySlice_GetIndices(reinterpret_cast<PySliceObject *>(slice_item), dim_len,
&start, &end, &step);
#endif
PySliceObject *p = reinterpret_cast<PySliceObject *>(slice_item);
_PySlice_GetIndices(p, dim_len, &start, &end, &step);
// :: or : or 0:dim_len:1
if (start == 0 && end == dim_len && step == 1) continue;
if (start == 0 && end == dim_len && step == 1) {
continue;
}
slice_axes->push_back(dim);
slice_starts->push_back(start);
slice_ends->push_back(end);
......@@ -493,7 +557,6 @@ void BindImperative(py::module *m_ptr) {
ParseIndexingSlice(tensor, _index.ptr(), &slice_axes,
&slice_starts, &slice_ends, &slice_strides,
&decrease_axis, &infer_flags);
// release gil and do tracing
py::gil_scoped_release release;
const auto &tracer = imperative::GetCurrentTracer();
......@@ -633,8 +696,8 @@ void BindImperative(py::module *m_ptr) {
[](imperative::VarBase &self,
const imperative::detail::BackwardStrategy &bckst,
const imperative::Tracer &tracer) {
// TODO(jiabin): when we impl more backward execution we can select
// them
// TODO(jiabin): when we impl more backward execution we can
// select them
auto *engine = tracer.GetEngine();
engine->Init(&self, bckst);
VLOG(3) << "Start backward";
......
......@@ -634,5 +634,34 @@ class TestSliceApiWithLoDTensorArray(unittest.TestCase):
self.assertTrue(np.array_equal(self.g_x2, np.ones_like(self.data)))
class TestImperativeVarBaseGetItem(unittest.TestCase):
def test_getitem_with_long(self):
with fluid.dygraph.guard():
data = np.random.random((2, 80, 16128)).astype('float32')
var = fluid.dygraph.to_variable(data)
sliced = var[:, 10:, :var.shape[1]] # var.shape[1] is 80L here
self.assertEqual(sliced.shape, [2, 70, 80])
sliced = var[:, var.shape[0]:, var.shape[0]:var.shape[1]]
self.assertEqual(sliced.shape, [2, 78, 78])
def test_getitem_with_float(self):
def test_float_in_slice_item():
with fluid.dygraph.guard():
data = np.random.random((2, 80, 16128)).astype('float32')
var = fluid.dygraph.to_variable(data)
sliced = var[:, 1.1:, :var.shape[1]]
self.assertRaises(Exception, test_float_in_slice_item)
def test_float_in_index():
with fluid.dygraph.guard():
data = np.random.random((2, 80, 16128)).astype('float32')
var = fluid.dygraph.to_variable(data)
sliced = var[1.1]
self.assertRaises(Exception, test_float_in_index)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册