From edc9952ccbbfb6ed5b747589ae65f570fa218692 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Mon, 29 Aug 2022 14:19:23 +0800 Subject: [PATCH] [Eager] Pylayer set grad (#45452) * pylayer set has grad with create_graph --- paddle/fluid/eager/pylayer/py_layer_node.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/paddle/fluid/eager/pylayer/py_layer_node.cc b/paddle/fluid/eager/pylayer/py_layer_node.cc index 11e9d93da47..6fb78d20e8a 100644 --- a/paddle/fluid/eager/pylayer/py_layer_node.cc +++ b/paddle/fluid/eager/pylayer/py_layer_node.cc @@ -104,7 +104,10 @@ GradNodePyLayer::operator()( PADDLE_THROW(paddle::platform::errors::InvalidArgument( "Get backward function faild.")); } + bool need_grad_tmp = egr::Controller::Instance().HasGrad(); + egr::Controller::Instance().SetHasGrad(create_graph && need_grad_tmp); auto outputs = PyObject_CallObject(backward_fn, backward_args); + egr::Controller::Instance().SetHasGrad(need_grad_tmp); if (!outputs) { PADDLE_THROW(paddle::platform::errors::External( pybind11::detail::error_string().c_str())); -- GitLab