From 03a2f1878cc37efabe55e3dbdf9c08f80019c0e1 Mon Sep 17 00:00:00 2001 From: qiuwenbo Date: Tue, 25 Jul 2023 10:34:28 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A7=A3=E5=86=B3=20grad=5Ffn=20next=5Ffunctio?= =?UTF-8?q?ns=20api=20=E6=8E=A5=E5=8F=A3=E5=AF=BC=E8=87=B4=E5=86=85?= =?UTF-8?q?=E5=AD=98=E5=BC=82=E5=B8=B8=E7=9A=84=E9=97=AE=E9=A2=98=20-=20?= =?UTF-8?q?=20(#55627)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [尝试] 给tensor增加一个属性, 这个属性是一个定值 1 * 暴露gradnode 并构建gradnode新的方法(用来测试)进行暴露给python python端可以访问 * 开发grad_fn、next_functions两个API 并暴露到python端- 做一些规范化处理 * 增加一个单元测试 * 优化 code-style * 将单侧文件迁到正确的位置 * 优化 code-style * 删除无用注释 * 解决 __main__ has no attribute * 修改单侧文件 * 修改单侧脚本-temp * 解决 grad_fn next_functions api 接口导致内存异常的问题 * 修改单测内容 * 解决 code-style 问题 --- paddle/fluid/pybind/eager_properties.cc | 9 ++++--- paddle/fluid/pybind/eager_utils.cc | 5 ++-- paddle/fluid/pybind/eager_utils.h | 2 +- paddle/fluid/pybind/pybind.cc | 24 ++++++++++++++----- .../test_grad_fn_and_next_functions.py | 5 ++++ 5 files changed, 30 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/pybind/eager_properties.cc b/paddle/fluid/pybind/eager_properties.cc index 42d53ad7bee..2a7692ee99b 100644 --- a/paddle/fluid/pybind/eager_properties.cc +++ b/paddle/fluid/pybind/eager_properties.cc @@ -317,17 +317,16 @@ PyObject* tensor_properties_get_grad_fn(TensorObject* self, void* closure) { if (meta) { // Get the GradNode from meta - auto grad_node = meta->GradNode(); // Convert GradNode to a Python object - // The conversion will depend on the structure of GradNode. - - if (!grad_node) { + auto grad_node_ptr = meta->GetMutableGradNode(); + if (!grad_node_ptr) { Py_INCREF(Py_None); return Py_None; } - PyObject* py_grad_node = ToPyObject(grad_node); + PyObject* py_grad_node = ToPyObject(grad_node_ptr); return py_grad_node; + } else { // If meta does not exist, return an appropriate Python object (e.g., None // or a special value). diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 8dfc7cfc8e4..ee270042f41 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -1006,10 +1006,9 @@ paddle::optional GetOptionalTensorFromArgs( } } -PyObject* ToPyObject(egr::GradNodeBase* grad_node) { +PyObject* ToPyObject(std::shared_ptr grad_node) { py::object py_obj = py::cast(grad_node, py::return_value_policy::reference); - py::handle py_handle = py::handle(py_obj); - PyObject* py_grad_node = py_handle.ptr(); + PyObject* py_grad_node = py_obj.release().ptr(); Py_INCREF(py_grad_node); return py_grad_node; } diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index 1fb53a3b9f7..f50ec9395b2 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -126,7 +126,7 @@ PyObject* ToPyObject( const std::unordered_map>& value); PyObject* ToPyObject(const paddle::framework::Vocab& value); -PyObject* ToPyObject(egr::GradNodeBase* grad_node); +PyObject* ToPyObject(std::shared_ptr grad_node); class PyTensorHook : public egr::TensorHook { public: diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index d55cab98b1e..504e1adf225 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -778,12 +778,24 @@ PYBIND11_MODULE(libpaddle, m) { } }); - py::class_(m, "GradNodeBase") - .def("name", &egr::GradNodeBase::name) - .def_property_readonly("next_functions", - &egr::GradNodeBase::NextFunctions) - .def("input_meta", &egr::GradNodeBase::InputMeta) - .def("output_meta", &egr::GradNodeBase::OutputMeta); + py::class_>( + m, "GradNodeBase") + .def("name", + [](const std::shared_ptr &self) { + return self->name(); + }) + .def_property_readonly( + "next_functions", + [](const std::shared_ptr &self) { + return self->NextFunctions(); + }) + .def("input_meta", + [](const std::shared_ptr &self) { + return self->InputMeta(); + }) + .def("output_meta", [](const std::shared_ptr &self) { + return self->OutputMeta(); + }); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) m.def("cudnn_version", &platform::DnnVersion); diff --git a/test/legacy_test/test_grad_fn_and_next_functions.py b/test/legacy_test/test_grad_fn_and_next_functions.py index 54647750012..531cdfa98a0 100644 --- a/test/legacy_test/test_grad_fn_and_next_functions.py +++ b/test/legacy_test/test_grad_fn_and_next_functions.py @@ -83,6 +83,11 @@ class TestAnonmousSurvey(unittest.TestCase): grad_fn_json (dict): grad_node_json of node """ self.assertEqual(grad_fn.name(), grad_fn_json["func_name"]) + # Recursively test other nodes + if hasattr(grad_fn, 'next_functions') and grad_fn.next_functions[0]: + next_funcs_json = grad_fn_json["next_funcs"] + for u in grad_fn.next_functions: + self.check_func(u, next_funcs_json[u.name()]) if __name__ == "__main__": -- GitLab