未验证 提交 ea1c9b89 编写于 作者: W wanghuancoder 提交者: GitHub

refine force syncbn (#52860)

上级 bd06be00
......@@ -372,32 +372,31 @@ std::vector<paddle::Tensor> RunBackward(
auto add_next_node_func = [&node_in_degree_map,
&queue](GradNodeBase* next_node) {
if (node_in_degree_map[next_node] == 0) {
if (dynamic_cast<egr::GradNodeAccumulation*>(next_node)) {
queue.push_front(std::move(next_node));
} else {
queue.push_back(std::move(next_node));
}
if (dynamic_cast<egr::GradNodeAccumulation*>(next_node)) {
queue.push_front(std::move(next_node));
} else {
queue.push_back(std::move(next_node));
}
};
if (force_sequential_nodes_set.count(next_node)) {
if (force_sequential_nodes_queue.front() == next_node) {
force_sequential_nodes_queue.pop_front();
add_next_node_func(next_node);
while (ready_force_sequential_nodes.count(
force_sequential_nodes_queue.front())) {
ready_force_sequential_nodes.erase(
force_sequential_nodes_queue.front());
add_next_node_func(force_sequential_nodes_queue.front());
if (node_in_degree_map[next_node] == 0) {
if (force_sequential_nodes_set.count(next_node)) {
if (force_sequential_nodes_queue.front() == next_node) {
force_sequential_nodes_queue.pop_front();
add_next_node_func(next_node);
while (ready_force_sequential_nodes.count(
force_sequential_nodes_queue.front())) {
ready_force_sequential_nodes.erase(
force_sequential_nodes_queue.front());
add_next_node_func(force_sequential_nodes_queue.front());
force_sequential_nodes_queue.pop_front();
}
} else {
ready_force_sequential_nodes.insert(next_node);
continue;
}
} else {
ready_force_sequential_nodes.insert(next_node);
continue;
add_next_node_func(next_node);
}
} else {
add_next_node_func(next_node);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册