未验证 提交 0f1be6e0 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager] first run accumulation node (#43134)

* first run accumulation node
上级 ceb20406
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/eager/backward.h"
#include <queue>
#include <deque>
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/grad_node_info.h"
......@@ -23,6 +23,7 @@
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "glog/logging.h"
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/phi/kernels/autotune/switch_autotune.h"
......@@ -100,19 +101,19 @@ class GeneralGrad {
// make sure the path from root to target_node is ok
std::unordered_set<GradNodeBase*> startup_ops;
VLOG(6) << "Running in UpdateGraphInfo";
std::queue<GradNodeBase*> queue;
std::deque<GradNodeBase*> queue;
for (auto& target_nodes_inputmeta_pair :
input_target_nodes_inputmeta_map_) {
queue.emplace(target_nodes_inputmeta_pair.first);
queue.push_back(target_nodes_inputmeta_pair.first);
}
while (!queue.empty()) {
auto* target_node = queue.front();
queue.pop();
queue.pop_front();
if (!(depending_nodes_)[target_node].empty()) {
auto precedding_nodes = (depending_nodes_)[target_node];
for (auto pre_nodes : precedding_nodes) {
queue.emplace(pre_nodes);
queue.push_back(pre_nodes);
if (potential_stop_nodes_.find(pre_nodes) !=
potential_stop_nodes_.end()) {
potential_stop_nodes_.erase(pre_nodes);
......@@ -144,20 +145,20 @@ class GeneralGrad {
// Get Graph Info Betweent input target GradNode and outputs,
// record depending_nodes_、potential_stop_nodes_、potential_startup_nodes_
void GetGraphInfoBetweenTargets(const std::queue<GradNodeBase*>& init_queue) {
void GetGraphInfoBetweenTargets(const std::deque<GradNodeBase*>& init_queue) {
VLOG(6) << "Runing In GetGraphInfoBetweenTargets";
// Calculate in_degree for each node
std::unordered_map<GradNodeBase*, int> node_in_degree_map;
// Copy nodes
std::queue<GradNodeBase*> queue = init_queue;
std::deque<GradNodeBase*> queue = init_queue;
std::unordered_set<GradNodeBase*> visited;
// Visit each node exactly once in any order
while (!queue.empty()) {
GradNodeBase* node = queue.front();
queue.pop();
queue.pop_front();
if (visited.count(node)) {
continue;
......@@ -198,7 +199,7 @@ class GeneralGrad {
// Record depending relationship
(depending_nodes_)[next_node].emplace(node);
queue.push(next_node);
queue.push_back(next_node);
}
}
}
......@@ -207,10 +208,10 @@ class GeneralGrad {
UpdateGraphInfo();
}
void ModifyReadyQueue(std::queue<GradNodeBase*>* queue) {
std::queue<GradNodeBase*> tmp_queue;
void ModifyReadyQueue(std::deque<GradNodeBase*>* queue) {
std::deque<GradNodeBase*> tmp_queue;
for (auto nodes : potential_startup_nodes_) {
tmp_queue.emplace(nodes);
tmp_queue.push_back(nodes);
}
tmp_queue.swap(*queue);
}
......@@ -297,7 +298,7 @@ class GeneralGrad {
void PreparedForGeneralGrad(
const std::vector<paddle::experimental::Tensor>& inputs,
const std::vector<paddle::experimental::Tensor>& no_grad_vars,
std::queue<GradNodeBase*>* queue,
std::deque<GradNodeBase*>* queue,
const std::unordered_map<GradNodeBase*,
std::unique_ptr<GradTensorHolder>>&
node_input_buffers_dict) {
......@@ -366,14 +367,14 @@ class GeneralGrad {
}
void ReconstructBackwardGraph(
const std::queue<GradNodeBase*>& orig_init_queue) {
std::queue<GradNodeBase*> queue = orig_init_queue;
const std::deque<GradNodeBase*>& orig_init_queue) {
std::deque<GradNodeBase*> queue = orig_init_queue;
std::unordered_set<GradNodeBase*> visited;
// BFS and recursively copy the grad nodes
while (!queue.empty()) {
GradNodeBase* orig_node = queue.front();
queue.pop();
queue.pop_front();
if (visited.count(orig_node)) {
continue;
}
......@@ -417,7 +418,7 @@ class GeneralGrad {
copied_edge.SetGradNode(copied_next_node);
// Update BFS queue
queue.push(orig_next_node.get());
queue.push_back(orig_next_node.get());
}
}
}
......@@ -449,20 +450,20 @@ class GeneralGrad {
};
std::unordered_map<GradNodeBase*, int> getInDegreeMap(
const std::queue<GradNodeBase*>& init_queue) {
const std::deque<GradNodeBase*>& init_queue) {
// Calculate in_degree for each node
// We can completely remove this pass, if in_degree were set during forward
// pass
std::unordered_map<GradNodeBase*, int> node_in_degree_map;
// Copy nodes
std::queue<GradNodeBase*> queue = init_queue;
std::deque<GradNodeBase*> queue = init_queue;
std::unordered_set<GradNodeBase*> visited;
// Visit each node exactly once in any order
while (!queue.empty()) {
GradNodeBase* node = queue.front();
queue.pop();
queue.pop_front();
if (visited.count(node)) {
continue;
......@@ -490,7 +491,7 @@ std::unordered_map<GradNodeBase*, int> getInDegreeMap(
if (!node_in_degree_map.count(next_node))
node_in_degree_map[next_node] = 0;
node_in_degree_map[next_node]++;
queue.push(next_node);
queue.push_back(next_node);
}
}
}
......@@ -548,8 +549,8 @@ std::vector<paddle::experimental::Tensor> RunBackward(
/* --- Initialization --- */
// 1. Init queue with starting nodes
// 2. Prepare initial input buffers
std::queue<GradNodeBase*> queue;
std::queue<GradNodeBase*> orig_queue;
std::deque<GradNodeBase*> queue;
std::deque<GradNodeBase*> orig_queue;
std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>>
node_input_buffers_dict;
for (size_t i = 0; i < tensors.size(); i++) {
......@@ -582,7 +583,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
GradNodeBase* grad_node = shared_grad_node.get();
if (is_general_grad) {
// Save orig grad node
orig_queue.push(grad_node);
orig_queue.push_back(grad_node);
// Replace grad_node with copied grad_node
grad_node = GeneralGrad::Instance().CopyGradNode(shared_grad_node);
......@@ -625,7 +626,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
}
// Prepare queue, potential startup_nodes
queue.push(grad_node);
queue.push_back(grad_node);
}
if (is_general_grad) {
......@@ -663,10 +664,10 @@ std::vector<paddle::experimental::Tensor> RunBackward(
paddle::platform::TracerEventType::Operator, 1);
if (queue.size() > 1 && node_in_degree_map[node] != 0) {
queue.pop();
queue.pop_front();
continue;
}
queue.pop();
queue.pop_front();
// Run node: This is where Hook happens
auto node_input_buffer_iter = node_input_buffers_dict.find(node);
......@@ -798,11 +799,19 @@ std::vector<paddle::experimental::Tensor> RunBackward(
bool is_potential_stop_node =
GeneralGrad::Instance().GetPotentialStopNodes()->count(next_node);
if (node_in_degree_map[next_node] == 0 && !is_potential_stop_node) {
queue.emplace(std::move(next_node));
if (dynamic_cast<egr::GradNodeAccumulation*>(next_node)) {
queue.push_front(std::move(next_node));
} else {
queue.push_back(std::move(next_node));
}
}
} else {
if (node_in_degree_map[next_node] == 0) {
queue.emplace(std::move(next_node));
if (dynamic_cast<egr::GradNodeAccumulation*>(next_node)) {
queue.push_front(std::move(next_node));
} else {
queue.push_back(std::move(next_node));
}
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册