未验证 提交 763b6d91 编写于 作者: L Leo Chen 提交者: GitHub

fix potential tensor leak in tensor.__setitem__ (#35013)

* fix index tensor leak in __setitem__

* fix another usage of PyTuple_Pack

* refine code

* refine code

* handle None index

* add Py_DecRef

* revert ut

* refine code

* merge develop

* use RAII

* follow comments
上级 4bfd0445
......@@ -29,6 +29,7 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/framework/scope_guard.h"
#include "paddle/fluid/imperative/all_reduce.h"
#include "paddle/fluid/imperative/amp_auto_cast.h"
#include "paddle/fluid/imperative/basic_engine.h"
......@@ -424,7 +425,15 @@ static void ParseIndexingSlice(
// We allow indexing by Integers, Slices, Ellipsis, None, tuples of those
// types, and list of Bool and Integers.
// wrap to tuple
// NOTE(zhiqiu): PyTuple_Pack increases refcount.
PyObject *index = !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index;
DEFINE_PADDLE_SCOPE_GUARD([index, _index]() {
if (!PyTuple_Check(_index)) {
Py_DECREF(index);
VLOG(4) << "Call Py_DECREF";
}
});
PADDLE_ENFORCE_EQ(
tensor->IsInitialized(), true,
platform::errors::InvalidArgument("tensor has not been initialized"));
......@@ -550,8 +559,6 @@ static void ParseIndexingSlice(
platform::errors::InvalidArgument(
"Too many indices (%d) for tensor of dimension %d.",
valid_indexs, rank));
if (!PyTuple_Check(_index)) Py_DecRef(index);
}
template <typename P>
......@@ -811,11 +818,21 @@ void BindImperative(py::module *m_ptr) {
.def("__setitem__",
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index,
py::object &value_obj) {
VLOG(4) << "Call __setitem__";
auto self_tensor =
self->MutableVar()->GetMutable<framework::LoDTensor>();
// NOTE(zhiqiu): PyTuple_Pack increases refcount while PyTuple_New
// https://github.com/python/cpython/blob/24b63c695ae0a95b06379eaadace66735abac1e2/Objects/tupleobject.c#L251
PyObject *index_ptr = !PyTuple_Check(_index.ptr())
? PyTuple_Pack(1, _index.ptr())
: _index.ptr();
DEFINE_PADDLE_SCOPE_GUARD([index_ptr, &_index]() {
if (!PyTuple_Check(_index.ptr())) {
Py_DECREF(index_ptr);
VLOG(4) << "Call Py_DECREF";
}
});
// 1. Check argumnets
// 1.1 Check whether value obj is a tensor.
bool value_is_tensor = true;
......@@ -826,6 +843,18 @@ void BindImperative(py::module *m_ptr) {
value_is_tensor = false;
}
auto is_tensor = [](py::handle var) {
if (!var.ptr() || var.ptr() == Py_None) {
return false;
}
try {
py::cast<std::shared_ptr<imperative::VarBase>>(var);
return true;
} catch (py::cast_error &) {
return false;
}
};
// 1.2 Check whether _index can be parsed.
const int size = PyTuple_GET_SIZE(index_ptr);
for (int dim = 0; dim < size; ++dim) {
......@@ -842,6 +871,7 @@ void BindImperative(py::module *m_ptr) {
// TODO(liym27): Try not to call TensorToPyArray because it always
// copys data to cpu place, which reduces performance.
if (parse_index && value_is_tensor) {
VLOG(4) << "index is integer/slice/ellipsis and value is tensor";
std::vector<int> axes, starts, ends, steps, decrease_axes,
none_axes, infer_flags, list_select_idxs;
// if index is a list, list_select_flag will be true
......@@ -850,7 +880,6 @@ void BindImperative(py::module *m_ptr) {
&steps, &decrease_axes, &none_axes,
&infer_flags, &list_select_idxs,
&list_select_flag);
framework::AttributeMap attrs = {
{"axes", axes},
{"starts", starts},
......@@ -882,20 +911,43 @@ void BindImperative(py::module *m_ptr) {
}
} else {
auto self_numpy = TensorToPyArray(*self_tensor);
VLOG(4) << "parse_index is false";
if (value_is_tensor) {
VLOG(4) << "value is tensor";
auto value =
value_obj.cast<std::shared_ptr<imperative::VarBase>>();
auto value_tensor =
value->MutableVar()->GetMutable<framework::LoDTensor>();
auto value_numpy = TensorToPyArray(*value_tensor);
if (is_tensor(_index)) {
VLOG(4) << "index is tensor";
auto index_var =
py::cast<std::shared_ptr<imperative::VarBase>>(_index);
auto index_tensor = index_var->MutableVar()
->GetMutable<framework::LoDTensor>();
auto index_numpy = TensorToPyArray(*index_tensor);
self_numpy[index_numpy] = value_numpy;
} else {
VLOG(4) << "index is not tensor";
self_numpy[_index] = value_numpy;
}
SetTensorFromPyArray(self_tensor, self_numpy,
self_tensor->place(), true);
} else {
auto value_numpy = value_obj;
self_numpy[_index] = value_numpy;
VLOG(4) << "value is not tensor";
if (is_tensor(_index)) {
VLOG(4) << "index is tensor";
auto index_var =
py::cast<std::shared_ptr<imperative::VarBase>>(_index);
auto index_tensor = index_var->MutableVar()
->GetMutable<framework::LoDTensor>();
auto index_numpy = TensorToPyArray(*index_tensor);
self_numpy[index_numpy] = value_obj;
} else {
VLOG(4) << "index is not tensor";
self_numpy[_index] = value_obj;
}
SetTensorFromPyArray(self_tensor, self_numpy,
self_tensor->place(), true);
}
......@@ -907,6 +959,7 @@ void BindImperative(py::module *m_ptr) {
})
.def("_getitem_index_not_tensor",
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
VLOG(4) << "Call _getitem_index_not_tensor";
std::vector<int> slice_axes, slice_starts, slice_ends,
slice_strides, decrease_axis, none_axes, infer_flags,
list_select_idxs;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册