未验证 提交 7a78a571 编写于 作者: W wanghuancoder 提交者: GitHub

fix force sync bug in paddle.grad (#52779)

上级 2a24a6bb
...@@ -113,7 +113,6 @@ std::vector<paddle::Tensor> RunBackward( ...@@ -113,7 +113,6 @@ std::vector<paddle::Tensor> RunBackward(
std::queue<GradNodeBase*> force_sequential_nodes_forward_queue = std::queue<GradNodeBase*> force_sequential_nodes_forward_queue =
egr::Controller::Instance().GetForceSequentialNodes(); egr::Controller::Instance().GetForceSequentialNodes();
egr::Controller::Instance().ClearForceSequentialNodes();
std::deque<GradNodeBase*> force_sequential_nodes_queue; std::deque<GradNodeBase*> force_sequential_nodes_queue;
std::set<GradNodeBase*> force_sequential_nodes_set; std::set<GradNodeBase*> force_sequential_nodes_set;
std::set<GradNodeBase*> ready_force_sequential_nodes; std::set<GradNodeBase*> ready_force_sequential_nodes;
...@@ -421,6 +420,7 @@ void Backward(const std::vector<paddle::Tensor>& tensors, // outputs ...@@ -421,6 +420,7 @@ void Backward(const std::vector<paddle::Tensor>& tensors, // outputs
VLOG(3) << "Run in Backward"; VLOG(3) << "Run in Backward";
paddle::platform::RecordEvent backward_record_event( paddle::platform::RecordEvent backward_record_event(
"backward", paddle::platform::TracerEventType::UserDefined, 1); "backward", paddle::platform::TracerEventType::UserDefined, 1);
egr::Controller::Instance().ClearForceSequentialNodes();
RunBackward(tensors, grad_tensors, retain_graph); RunBackward(tensors, grad_tensors, retain_graph);
phi::autotune::AutoTuneStatus::Instance().Update(); phi::autotune::AutoTuneStatus::Instance().Update();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册