未验证 提交 edc9952c 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager] Pylayer set grad (#45452)

* pylayer set has grad with create_graph
上级 b93b710a
...@@ -104,7 +104,10 @@ GradNodePyLayer::operator()( ...@@ -104,7 +104,10 @@ GradNodePyLayer::operator()(
PADDLE_THROW(paddle::platform::errors::InvalidArgument( PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Get backward function faild.")); "Get backward function faild."));
} }
bool need_grad_tmp = egr::Controller::Instance().HasGrad();
egr::Controller::Instance().SetHasGrad(create_graph && need_grad_tmp);
auto outputs = PyObject_CallObject(backward_fn, backward_args); auto outputs = PyObject_CallObject(backward_fn, backward_args);
egr::Controller::Instance().SetHasGrad(need_grad_tmp);
if (!outputs) { if (!outputs) {
PADDLE_THROW(paddle::platform::errors::External( PADDLE_THROW(paddle::platform::errors::External(
pybind11::detail::error_string().c_str())); pybind11::detail::error_string().c_str()));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册