未验证 提交 763b6d91 编写于 作者: L Leo Chen 提交者: GitHub

fix potential tensor leak in tensor.__setitem__ (#35013)

* fix index tensor leak in __setitem__

* fix another usage of PyTuple_Pack

* refine code

* refine code

* handle None index

* add Py_DecRef

* revert ut

* refine code

* merge develop

* use RAII

* follow comments
上级 4bfd0445
...@@ -29,6 +29,7 @@ limitations under the License. */ ...@@ -29,6 +29,7 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/scope_guard.h"
#include "paddle/fluid/imperative/all_reduce.h" #include "paddle/fluid/imperative/all_reduce.h"
#include "paddle/fluid/imperative/amp_auto_cast.h" #include "paddle/fluid/imperative/amp_auto_cast.h"
#include "paddle/fluid/imperative/basic_engine.h" #include "paddle/fluid/imperative/basic_engine.h"
...@@ -424,7 +425,15 @@ static void ParseIndexingSlice( ...@@ -424,7 +425,15 @@ static void ParseIndexingSlice(
// We allow indexing by Integers, Slices, Ellipsis, None, tuples of those // We allow indexing by Integers, Slices, Ellipsis, None, tuples of those
// types, and list of Bool and Integers. // types, and list of Bool and Integers.
// wrap to tuple // wrap to tuple
// NOTE(zhiqiu): PyTuple_Pack increases refcount.
PyObject *index = !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index; PyObject *index = !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index;
DEFINE_PADDLE_SCOPE_GUARD([index, _index]() {
if (!PyTuple_Check(_index)) {
Py_DECREF(index);
VLOG(4) << "Call Py_DECREF";
}
});
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
tensor->IsInitialized(), true, tensor->IsInitialized(), true,
platform::errors::InvalidArgument("tensor has not been initialized")); platform::errors::InvalidArgument("tensor has not been initialized"));
...@@ -550,8 +559,6 @@ static void ParseIndexingSlice( ...@@ -550,8 +559,6 @@ static void ParseIndexingSlice(
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Too many indices (%d) for tensor of dimension %d.", "Too many indices (%d) for tensor of dimension %d.",
valid_indexs, rank)); valid_indexs, rank));
if (!PyTuple_Check(_index)) Py_DecRef(index);
} }
template <typename P> template <typename P>
...@@ -811,11 +818,21 @@ void BindImperative(py::module *m_ptr) { ...@@ -811,11 +818,21 @@ void BindImperative(py::module *m_ptr) {
.def("__setitem__", .def("__setitem__",
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index, [](std::shared_ptr<imperative::VarBase> &self, py::handle _index,
py::object &value_obj) { py::object &value_obj) {
VLOG(4) << "Call __setitem__";
auto self_tensor = auto self_tensor =
self->MutableVar()->GetMutable<framework::LoDTensor>(); self->MutableVar()->GetMutable<framework::LoDTensor>();
// NOTE(zhiqiu): PyTuple_Pack increases refcount while PyTuple_New
// https://github.com/python/cpython/blob/24b63c695ae0a95b06379eaadace66735abac1e2/Objects/tupleobject.c#L251
PyObject *index_ptr = !PyTuple_Check(_index.ptr()) PyObject *index_ptr = !PyTuple_Check(_index.ptr())
? PyTuple_Pack(1, _index.ptr()) ? PyTuple_Pack(1, _index.ptr())
: _index.ptr(); : _index.ptr();
DEFINE_PADDLE_SCOPE_GUARD([index_ptr, &_index]() {
if (!PyTuple_Check(_index.ptr())) {
Py_DECREF(index_ptr);
VLOG(4) << "Call Py_DECREF";
}
});
// 1. Check argumnets // 1. Check argumnets
// 1.1 Check whether value obj is a tensor. // 1.1 Check whether value obj is a tensor.
bool value_is_tensor = true; bool value_is_tensor = true;
...@@ -826,6 +843,18 @@ void BindImperative(py::module *m_ptr) { ...@@ -826,6 +843,18 @@ void BindImperative(py::module *m_ptr) {
value_is_tensor = false; value_is_tensor = false;
} }
auto is_tensor = [](py::handle var) {
if (!var.ptr() || var.ptr() == Py_None) {
return false;
}
try {
py::cast<std::shared_ptr<imperative::VarBase>>(var);
return true;
} catch (py::cast_error &) {
return false;
}
};
// 1.2 Check whether _index can be parsed. // 1.2 Check whether _index can be parsed.
const int size = PyTuple_GET_SIZE(index_ptr); const int size = PyTuple_GET_SIZE(index_ptr);
for (int dim = 0; dim < size; ++dim) { for (int dim = 0; dim < size; ++dim) {
...@@ -842,6 +871,7 @@ void BindImperative(py::module *m_ptr) { ...@@ -842,6 +871,7 @@ void BindImperative(py::module *m_ptr) {
// TODO(liym27): Try not to call TensorToPyArray because it always // TODO(liym27): Try not to call TensorToPyArray because it always
// copys data to cpu place, which reduces performance. // copys data to cpu place, which reduces performance.
if (parse_index && value_is_tensor) { if (parse_index && value_is_tensor) {
VLOG(4) << "index is integer/slice/ellipsis and value is tensor";
std::vector<int> axes, starts, ends, steps, decrease_axes, std::vector<int> axes, starts, ends, steps, decrease_axes,
none_axes, infer_flags, list_select_idxs; none_axes, infer_flags, list_select_idxs;
// if index is a list, list_select_flag will be true // if index is a list, list_select_flag will be true
...@@ -850,7 +880,6 @@ void BindImperative(py::module *m_ptr) { ...@@ -850,7 +880,6 @@ void BindImperative(py::module *m_ptr) {
&steps, &decrease_axes, &none_axes, &steps, &decrease_axes, &none_axes,
&infer_flags, &list_select_idxs, &infer_flags, &list_select_idxs,
&list_select_flag); &list_select_flag);
framework::AttributeMap attrs = { framework::AttributeMap attrs = {
{"axes", axes}, {"axes", axes},
{"starts", starts}, {"starts", starts},
...@@ -882,20 +911,43 @@ void BindImperative(py::module *m_ptr) { ...@@ -882,20 +911,43 @@ void BindImperative(py::module *m_ptr) {
} }
} else { } else {
auto self_numpy = TensorToPyArray(*self_tensor); auto self_numpy = TensorToPyArray(*self_tensor);
VLOG(4) << "parse_index is false";
if (value_is_tensor) { if (value_is_tensor) {
VLOG(4) << "value is tensor";
auto value = auto value =
value_obj.cast<std::shared_ptr<imperative::VarBase>>(); value_obj.cast<std::shared_ptr<imperative::VarBase>>();
auto value_tensor = auto value_tensor =
value->MutableVar()->GetMutable<framework::LoDTensor>(); value->MutableVar()->GetMutable<framework::LoDTensor>();
auto value_numpy = TensorToPyArray(*value_tensor); auto value_numpy = TensorToPyArray(*value_tensor);
if (is_tensor(_index)) {
VLOG(4) << "index is tensor";
auto index_var =
py::cast<std::shared_ptr<imperative::VarBase>>(_index);
auto index_tensor = index_var->MutableVar()
->GetMutable<framework::LoDTensor>();
auto index_numpy = TensorToPyArray(*index_tensor);
self_numpy[index_numpy] = value_numpy;
} else {
VLOG(4) << "index is not tensor";
self_numpy[_index] = value_numpy; self_numpy[_index] = value_numpy;
}
SetTensorFromPyArray(self_tensor, self_numpy, SetTensorFromPyArray(self_tensor, self_numpy,
self_tensor->place(), true); self_tensor->place(), true);
} else { } else {
auto value_numpy = value_obj; VLOG(4) << "value is not tensor";
self_numpy[_index] = value_numpy; if (is_tensor(_index)) {
VLOG(4) << "index is tensor";
auto index_var =
py::cast<std::shared_ptr<imperative::VarBase>>(_index);
auto index_tensor = index_var->MutableVar()
->GetMutable<framework::LoDTensor>();
auto index_numpy = TensorToPyArray(*index_tensor);
self_numpy[index_numpy] = value_obj;
} else {
VLOG(4) << "index is not tensor";
self_numpy[_index] = value_obj;
}
SetTensorFromPyArray(self_tensor, self_numpy, SetTensorFromPyArray(self_tensor, self_numpy,
self_tensor->place(), true); self_tensor->place(), true);
} }
...@@ -907,6 +959,7 @@ void BindImperative(py::module *m_ptr) { ...@@ -907,6 +959,7 @@ void BindImperative(py::module *m_ptr) {
}) })
.def("_getitem_index_not_tensor", .def("_getitem_index_not_tensor",
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index) { [](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
VLOG(4) << "Call _getitem_index_not_tensor";
std::vector<int> slice_axes, slice_starts, slice_ends, std::vector<int> slice_axes, slice_starts, slice_ends,
slice_strides, decrease_axis, none_axes, infer_flags, slice_strides, decrease_axis, none_axes, infer_flags,
list_select_idxs; list_select_idxs;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册