未验证 提交 03a2f187 编写于 作者: Q qiuwenbo 提交者: GitHub

解决 grad_fn next_functions api 接口导致内存异常的问题 - (#55627)

* [尝试] 给tensor增加一个属性, 这个属性是一个定值 1

* 暴露gradnode 并构建gradnode新的方法(用来测试)进行暴露给python python端可以访问

* 开发grad_fn、next_functions两个API 并暴露到python端- 做一些规范化处理

* 增加一个单元测试

* 优化 code-style

* 将单侧文件迁到正确的位置

* 优化 code-style

* 删除无用注释

* 解决 __main__ has no attribute

* 修改单侧文件

* 修改单侧脚本-temp

* 解决 grad_fn next_functions api 接口导致内存异常的问题

* 修改单测内容

* 解决 code-style 问题
上级 05a40691
...@@ -317,17 +317,16 @@ PyObject* tensor_properties_get_grad_fn(TensorObject* self, void* closure) { ...@@ -317,17 +317,16 @@ PyObject* tensor_properties_get_grad_fn(TensorObject* self, void* closure) {
if (meta) { if (meta) {
// Get the GradNode from meta // Get the GradNode from meta
auto grad_node = meta->GradNode(); // Convert GradNode to a Python object auto grad_node_ptr = meta->GetMutableGradNode();
// The conversion will depend on the structure of GradNode. if (!grad_node_ptr) {
if (!grad_node) {
Py_INCREF(Py_None); Py_INCREF(Py_None);
return Py_None; return Py_None;
} }
PyObject* py_grad_node = ToPyObject(grad_node); PyObject* py_grad_node = ToPyObject(grad_node_ptr);
return py_grad_node; return py_grad_node;
} else { } else {
// If meta does not exist, return an appropriate Python object (e.g., None // If meta does not exist, return an appropriate Python object (e.g., None
// or a special value). // or a special value).
......
...@@ -1006,10 +1006,9 @@ paddle::optional<paddle::Tensor> GetOptionalTensorFromArgs( ...@@ -1006,10 +1006,9 @@ paddle::optional<paddle::Tensor> GetOptionalTensorFromArgs(
} }
} }
PyObject* ToPyObject(egr::GradNodeBase* grad_node) { PyObject* ToPyObject(std::shared_ptr<egr::GradNodeBase> grad_node) {
py::object py_obj = py::cast(grad_node, py::return_value_policy::reference); py::object py_obj = py::cast(grad_node, py::return_value_policy::reference);
py::handle py_handle = py::handle(py_obj); PyObject* py_grad_node = py_obj.release().ptr();
PyObject* py_grad_node = py_handle.ptr();
Py_INCREF(py_grad_node); Py_INCREF(py_grad_node);
return py_grad_node; return py_grad_node;
} }
......
...@@ -126,7 +126,7 @@ PyObject* ToPyObject( ...@@ -126,7 +126,7 @@ PyObject* ToPyObject(
const std::unordered_map<std::string, std::vector<std::string>>& value); const std::unordered_map<std::string, std::vector<std::string>>& value);
PyObject* ToPyObject(const paddle::framework::Vocab& value); PyObject* ToPyObject(const paddle::framework::Vocab& value);
PyObject* ToPyObject(egr::GradNodeBase* grad_node); PyObject* ToPyObject(std::shared_ptr<egr::GradNodeBase> grad_node);
class PyTensorHook : public egr::TensorHook { class PyTensorHook : public egr::TensorHook {
public: public:
......
...@@ -778,12 +778,24 @@ PYBIND11_MODULE(libpaddle, m) { ...@@ -778,12 +778,24 @@ PYBIND11_MODULE(libpaddle, m) {
} }
}); });
py::class_<egr::GradNodeBase>(m, "GradNodeBase") py::class_<egr::GradNodeBase, std::shared_ptr<egr::GradNodeBase>>(
.def("name", &egr::GradNodeBase::name) m, "GradNodeBase")
.def_property_readonly("next_functions", .def("name",
&egr::GradNodeBase::NextFunctions) [](const std::shared_ptr<egr::GradNodeBase> &self) {
.def("input_meta", &egr::GradNodeBase::InputMeta) return self->name();
.def("output_meta", &egr::GradNodeBase::OutputMeta); })
.def_property_readonly(
"next_functions",
[](const std::shared_ptr<egr::GradNodeBase> &self) {
return self->NextFunctions();
})
.def("input_meta",
[](const std::shared_ptr<egr::GradNodeBase> &self) {
return self->InputMeta();
})
.def("output_meta", [](const std::shared_ptr<egr::GradNodeBase> &self) {
return self->OutputMeta();
});
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
m.def("cudnn_version", &platform::DnnVersion); m.def("cudnn_version", &platform::DnnVersion);
......
...@@ -83,6 +83,11 @@ class TestAnonmousSurvey(unittest.TestCase): ...@@ -83,6 +83,11 @@ class TestAnonmousSurvey(unittest.TestCase):
grad_fn_json (dict): grad_node_json of node grad_fn_json (dict): grad_node_json of node
""" """
self.assertEqual(grad_fn.name(), grad_fn_json["func_name"]) self.assertEqual(grad_fn.name(), grad_fn_json["func_name"])
# Recursively test other nodes
if hasattr(grad_fn, 'next_functions') and grad_fn.next_functions[0]:
next_funcs_json = grad_fn_json["next_funcs"]
for u in grad_fn.next_functions:
self.check_func(u, next_funcs_json[u.name()])
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册