未验证 提交 6c067e09 编写于 作者: Z zhangbo9674 提交者: GitHub

support weakref for eager tensor (#41769) (#41797)

上级 76d5483a
...@@ -709,6 +709,8 @@ int TensorInit(PyObject* self, PyObject* args, PyObject* kwargs) { ...@@ -709,6 +709,8 @@ int TensorInit(PyObject* self, PyObject* args, PyObject* kwargs) {
} }
static void TensorDealloc(TensorObject* self) { static void TensorDealloc(TensorObject* self) {
if (self->weakrefs != NULL)
PyObject_ClearWeakRefs(reinterpret_cast<PyObject*>(self));
self->tensor.~Tensor(); self->tensor.~Tensor();
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
} }
...@@ -739,6 +741,7 @@ void BindEager(pybind11::module* module) { ...@@ -739,6 +741,7 @@ void BindEager(pybind11::module* module) {
type->tp_getset = variable_properties; type->tp_getset = variable_properties;
type->tp_init = TensorInit; type->tp_init = TensorInit;
type->tp_new = TensorNew; type->tp_new = TensorNew;
type->tp_weaklistoffset = offsetof(TensorObject, weakrefs);
Py_INCREF(&PyBaseObject_Type); Py_INCREF(&PyBaseObject_Type);
type->tp_base = reinterpret_cast<PyTypeObject*>(&PyBaseObject_Type); type->tp_base = reinterpret_cast<PyTypeObject*>(&PyBaseObject_Type);
type->tp_flags |= type->tp_flags |=
......
...@@ -22,6 +22,8 @@ namespace pybind { ...@@ -22,6 +22,8 @@ namespace pybind {
typedef struct { typedef struct {
PyObject_HEAD paddle::experimental::Tensor tensor; PyObject_HEAD paddle::experimental::Tensor tensor;
// Weak references
PyObject* weakrefs;
} TensorObject; } TensorObject;
typedef struct { typedef struct {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册