diff --git a/paddle/fluid/eager/pylayer/py_layer_node.cc b/paddle/fluid/eager/pylayer/py_layer_node.cc index 11e9d93da478f62f699b194e020880d284317ba8..6fb78d20e8a8bb2934b2f6bd45d6d1491f2b7b62 100644 --- a/paddle/fluid/eager/pylayer/py_layer_node.cc +++ b/paddle/fluid/eager/pylayer/py_layer_node.cc @@ -104,7 +104,10 @@ GradNodePyLayer::operator()( PADDLE_THROW(paddle::platform::errors::InvalidArgument( "Get backward function faild.")); } + bool need_grad_tmp = egr::Controller::Instance().HasGrad(); + egr::Controller::Instance().SetHasGrad(create_graph && need_grad_tmp); auto outputs = PyObject_CallObject(backward_fn, backward_args); + egr::Controller::Instance().SetHasGrad(need_grad_tmp); if (!outputs) { PADDLE_THROW(paddle::platform::errors::External( pybind11::detail::error_string().c_str()));