From 0f1be6e050e1c0c2fbc643f9458353a9991b1bc6 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Thu, 2 Jun 2022 12:00:15 +0800 Subject: [PATCH] [Eager] first run accumulation node (#43134) * first run accumulation node --- paddle/fluid/eager/backward.cc | 67 +++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/eager/backward.cc b/paddle/fluid/eager/backward.cc index 63b899f6d6..9de647a21a 100644 --- a/paddle/fluid/eager/backward.cc +++ b/paddle/fluid/eager/backward.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/eager/backward.h" -#include +#include #include "paddle/fluid/eager/autograd_meta.h" #include "paddle/fluid/eager/grad_node_info.h" @@ -23,6 +23,7 @@ #include "paddle/fluid/platform/profiler/event_tracing.h" #include "glog/logging.h" +#include "paddle/fluid/eager/accumulation/accumulation_node.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" #include "paddle/phi/kernels/autotune/switch_autotune.h" @@ -100,19 +101,19 @@ class GeneralGrad { // make sure the path from root to target_node is ok std::unordered_set startup_ops; VLOG(6) << "Running in UpdateGraphInfo"; - std::queue queue; + std::deque queue; for (auto& target_nodes_inputmeta_pair : input_target_nodes_inputmeta_map_) { - queue.emplace(target_nodes_inputmeta_pair.first); + queue.push_back(target_nodes_inputmeta_pair.first); } while (!queue.empty()) { auto* target_node = queue.front(); - queue.pop(); + queue.pop_front(); if (!(depending_nodes_)[target_node].empty()) { auto precedding_nodes = (depending_nodes_)[target_node]; for (auto pre_nodes : precedding_nodes) { - queue.emplace(pre_nodes); + queue.push_back(pre_nodes); if (potential_stop_nodes_.find(pre_nodes) != potential_stop_nodes_.end()) { potential_stop_nodes_.erase(pre_nodes); @@ -144,20 +145,20 @@ class GeneralGrad { // Get Graph Info Betweent input target GradNode and outputs, // record depending_nodes_、potential_stop_nodes_、potential_startup_nodes_ - void GetGraphInfoBetweenTargets(const std::queue& init_queue) { + void GetGraphInfoBetweenTargets(const std::deque& init_queue) { VLOG(6) << "Runing In GetGraphInfoBetweenTargets"; // Calculate in_degree for each node std::unordered_map node_in_degree_map; // Copy nodes - std::queue queue = init_queue; + std::deque queue = init_queue; std::unordered_set visited; // Visit each node exactly once in any order while (!queue.empty()) { GradNodeBase* node = queue.front(); - queue.pop(); + queue.pop_front(); if (visited.count(node)) { continue; @@ -198,7 +199,7 @@ class GeneralGrad { // Record depending relationship (depending_nodes_)[next_node].emplace(node); - queue.push(next_node); + queue.push_back(next_node); } } } @@ -207,10 +208,10 @@ class GeneralGrad { UpdateGraphInfo(); } - void ModifyReadyQueue(std::queue* queue) { - std::queue tmp_queue; + void ModifyReadyQueue(std::deque* queue) { + std::deque tmp_queue; for (auto nodes : potential_startup_nodes_) { - tmp_queue.emplace(nodes); + tmp_queue.push_back(nodes); } tmp_queue.swap(*queue); } @@ -297,7 +298,7 @@ class GeneralGrad { void PreparedForGeneralGrad( const std::vector& inputs, const std::vector& no_grad_vars, - std::queue* queue, + std::deque* queue, const std::unordered_map>& node_input_buffers_dict) { @@ -366,14 +367,14 @@ class GeneralGrad { } void ReconstructBackwardGraph( - const std::queue& orig_init_queue) { - std::queue queue = orig_init_queue; + const std::deque& orig_init_queue) { + std::deque queue = orig_init_queue; std::unordered_set visited; // BFS and recursively copy the grad nodes while (!queue.empty()) { GradNodeBase* orig_node = queue.front(); - queue.pop(); + queue.pop_front(); if (visited.count(orig_node)) { continue; } @@ -417,7 +418,7 @@ class GeneralGrad { copied_edge.SetGradNode(copied_next_node); // Update BFS queue - queue.push(orig_next_node.get()); + queue.push_back(orig_next_node.get()); } } } @@ -449,20 +450,20 @@ class GeneralGrad { }; std::unordered_map getInDegreeMap( - const std::queue& init_queue) { + const std::deque& init_queue) { // Calculate in_degree for each node // We can completely remove this pass, if in_degree were set during forward // pass std::unordered_map node_in_degree_map; // Copy nodes - std::queue queue = init_queue; + std::deque queue = init_queue; std::unordered_set visited; // Visit each node exactly once in any order while (!queue.empty()) { GradNodeBase* node = queue.front(); - queue.pop(); + queue.pop_front(); if (visited.count(node)) { continue; @@ -490,7 +491,7 @@ std::unordered_map getInDegreeMap( if (!node_in_degree_map.count(next_node)) node_in_degree_map[next_node] = 0; node_in_degree_map[next_node]++; - queue.push(next_node); + queue.push_back(next_node); } } } @@ -548,8 +549,8 @@ std::vector RunBackward( /* --- Initialization --- */ // 1. Init queue with starting nodes // 2. Prepare initial input buffers - std::queue queue; - std::queue orig_queue; + std::deque queue; + std::deque orig_queue; std::unordered_map> node_input_buffers_dict; for (size_t i = 0; i < tensors.size(); i++) { @@ -582,7 +583,7 @@ std::vector RunBackward( GradNodeBase* grad_node = shared_grad_node.get(); if (is_general_grad) { // Save orig grad node - orig_queue.push(grad_node); + orig_queue.push_back(grad_node); // Replace grad_node with copied grad_node grad_node = GeneralGrad::Instance().CopyGradNode(shared_grad_node); @@ -625,7 +626,7 @@ std::vector RunBackward( } // Prepare queue, potential startup_nodes - queue.push(grad_node); + queue.push_back(grad_node); } if (is_general_grad) { @@ -663,10 +664,10 @@ std::vector RunBackward( paddle::platform::TracerEventType::Operator, 1); if (queue.size() > 1 && node_in_degree_map[node] != 0) { - queue.pop(); + queue.pop_front(); continue; } - queue.pop(); + queue.pop_front(); // Run node: This is where Hook happens auto node_input_buffer_iter = node_input_buffers_dict.find(node); @@ -798,11 +799,19 @@ std::vector RunBackward( bool is_potential_stop_node = GeneralGrad::Instance().GetPotentialStopNodes()->count(next_node); if (node_in_degree_map[next_node] == 0 && !is_potential_stop_node) { - queue.emplace(std::move(next_node)); + if (dynamic_cast(next_node)) { + queue.push_front(std::move(next_node)); + } else { + queue.push_back(std::move(next_node)); + } } } else { if (node_in_degree_map[next_node] == 0) { - queue.emplace(std::move(next_node)); + if (dynamic_cast(next_node)) { + queue.push_front(std::move(next_node)); + } else { + queue.push_back(std::move(next_node)); + } } } } -- GitLab