diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 777d68ea2f9ca7807a23f985486ec558a4dd36ba..e6b8239b44208837b815755f08efb248a8a57ac7 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -29,6 +29,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/scope_guard.h" #include "paddle/fluid/imperative/all_reduce.h" #include "paddle/fluid/imperative/amp_auto_cast.h" #include "paddle/fluid/imperative/basic_engine.h" @@ -424,7 +425,15 @@ static void ParseIndexingSlice( // We allow indexing by Integers, Slices, Ellipsis, None, tuples of those // types, and list of Bool and Integers. // wrap to tuple + + // NOTE(zhiqiu): PyTuple_Pack increases refcount. PyObject *index = !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index; + DEFINE_PADDLE_SCOPE_GUARD([index, _index]() { + if (!PyTuple_Check(_index)) { + Py_DECREF(index); + VLOG(4) << "Call Py_DECREF"; + } + }); PADDLE_ENFORCE_EQ( tensor->IsInitialized(), true, platform::errors::InvalidArgument("tensor has not been initialized")); @@ -550,8 +559,6 @@ static void ParseIndexingSlice( platform::errors::InvalidArgument( "Too many indices (%d) for tensor of dimension %d.", valid_indexs, rank)); - - if (!PyTuple_Check(_index)) Py_DecRef(index); } template @@ -811,11 +818,21 @@ void BindImperative(py::module *m_ptr) { .def("__setitem__", [](std::shared_ptr &self, py::handle _index, py::object &value_obj) { + VLOG(4) << "Call __setitem__"; + auto self_tensor = self->MutableVar()->GetMutable(); + // NOTE(zhiqiu): PyTuple_Pack increases refcount while PyTuple_New + // https://github.com/python/cpython/blob/24b63c695ae0a95b06379eaadace66735abac1e2/Objects/tupleobject.c#L251 PyObject *index_ptr = !PyTuple_Check(_index.ptr()) ? PyTuple_Pack(1, _index.ptr()) : _index.ptr(); + DEFINE_PADDLE_SCOPE_GUARD([index_ptr, &_index]() { + if (!PyTuple_Check(_index.ptr())) { + Py_DECREF(index_ptr); + VLOG(4) << "Call Py_DECREF"; + } + }); // 1. Check argumnets // 1.1 Check whether value obj is a tensor. bool value_is_tensor = true; @@ -826,6 +843,18 @@ void BindImperative(py::module *m_ptr) { value_is_tensor = false; } + auto is_tensor = [](py::handle var) { + if (!var.ptr() || var.ptr() == Py_None) { + return false; + } + try { + py::cast>(var); + return true; + } catch (py::cast_error &) { + return false; + } + }; + // 1.2 Check whether _index can be parsed. const int size = PyTuple_GET_SIZE(index_ptr); for (int dim = 0; dim < size; ++dim) { @@ -842,6 +871,7 @@ void BindImperative(py::module *m_ptr) { // TODO(liym27): Try not to call TensorToPyArray because it always // copys data to cpu place, which reduces performance. if (parse_index && value_is_tensor) { + VLOG(4) << "index is integer/slice/ellipsis and value is tensor"; std::vector axes, starts, ends, steps, decrease_axes, none_axes, infer_flags, list_select_idxs; // if index is a list, list_select_flag will be true @@ -850,7 +880,6 @@ void BindImperative(py::module *m_ptr) { &steps, &decrease_axes, &none_axes, &infer_flags, &list_select_idxs, &list_select_flag); - framework::AttributeMap attrs = { {"axes", axes}, {"starts", starts}, @@ -882,20 +911,43 @@ void BindImperative(py::module *m_ptr) { } } else { auto self_numpy = TensorToPyArray(*self_tensor); + VLOG(4) << "parse_index is false"; if (value_is_tensor) { + VLOG(4) << "value is tensor"; auto value = value_obj.cast>(); auto value_tensor = value->MutableVar()->GetMutable(); auto value_numpy = TensorToPyArray(*value_tensor); - - self_numpy[_index] = value_numpy; + if (is_tensor(_index)) { + VLOG(4) << "index is tensor"; + auto index_var = + py::cast>(_index); + auto index_tensor = index_var->MutableVar() + ->GetMutable(); + auto index_numpy = TensorToPyArray(*index_tensor); + self_numpy[index_numpy] = value_numpy; + } else { + VLOG(4) << "index is not tensor"; + self_numpy[_index] = value_numpy; + } SetTensorFromPyArray(self_tensor, self_numpy, self_tensor->place(), true); } else { - auto value_numpy = value_obj; - self_numpy[_index] = value_numpy; + VLOG(4) << "value is not tensor"; + if (is_tensor(_index)) { + VLOG(4) << "index is tensor"; + auto index_var = + py::cast>(_index); + auto index_tensor = index_var->MutableVar() + ->GetMutable(); + auto index_numpy = TensorToPyArray(*index_tensor); + self_numpy[index_numpy] = value_obj; + } else { + VLOG(4) << "index is not tensor"; + self_numpy[_index] = value_obj; + } SetTensorFromPyArray(self_tensor, self_numpy, self_tensor->place(), true); } @@ -907,6 +959,7 @@ void BindImperative(py::module *m_ptr) { }) .def("_getitem_index_not_tensor", [](std::shared_ptr &self, py::handle _index) { + VLOG(4) << "Call _getitem_index_not_tensor"; std::vector slice_axes, slice_starts, slice_ends, slice_strides, decrease_axis, none_axes, infer_flags, list_select_idxs;