未验证 提交 0f1be6e0 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager] first run accumulation node (#43134)

* first run accumulation node
上级 ceb20406
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/eager/backward.h" #include "paddle/fluid/eager/backward.h"
#include <queue> #include <deque>
#include "paddle/fluid/eager/autograd_meta.h" #include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/grad_node_info.h" #include "paddle/fluid/eager/grad_node_info.h"
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/platform/profiler/event_tracing.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
#include "paddle/phi/kernels/autotune/switch_autotune.h" #include "paddle/phi/kernels/autotune/switch_autotune.h"
...@@ -100,19 +101,19 @@ class GeneralGrad { ...@@ -100,19 +101,19 @@ class GeneralGrad {
// make sure the path from root to target_node is ok // make sure the path from root to target_node is ok
std::unordered_set<GradNodeBase*> startup_ops; std::unordered_set<GradNodeBase*> startup_ops;
VLOG(6) << "Running in UpdateGraphInfo"; VLOG(6) << "Running in UpdateGraphInfo";
std::queue<GradNodeBase*> queue; std::deque<GradNodeBase*> queue;
for (auto& target_nodes_inputmeta_pair : for (auto& target_nodes_inputmeta_pair :
input_target_nodes_inputmeta_map_) { input_target_nodes_inputmeta_map_) {
queue.emplace(target_nodes_inputmeta_pair.first); queue.push_back(target_nodes_inputmeta_pair.first);
} }
while (!queue.empty()) { while (!queue.empty()) {
auto* target_node = queue.front(); auto* target_node = queue.front();
queue.pop(); queue.pop_front();
if (!(depending_nodes_)[target_node].empty()) { if (!(depending_nodes_)[target_node].empty()) {
auto precedding_nodes = (depending_nodes_)[target_node]; auto precedding_nodes = (depending_nodes_)[target_node];
for (auto pre_nodes : precedding_nodes) { for (auto pre_nodes : precedding_nodes) {
queue.emplace(pre_nodes); queue.push_back(pre_nodes);
if (potential_stop_nodes_.find(pre_nodes) != if (potential_stop_nodes_.find(pre_nodes) !=
potential_stop_nodes_.end()) { potential_stop_nodes_.end()) {
potential_stop_nodes_.erase(pre_nodes); potential_stop_nodes_.erase(pre_nodes);
...@@ -144,20 +145,20 @@ class GeneralGrad { ...@@ -144,20 +145,20 @@ class GeneralGrad {
// Get Graph Info Betweent input target GradNode and outputs, // Get Graph Info Betweent input target GradNode and outputs,
// record depending_nodes_、potential_stop_nodes_、potential_startup_nodes_ // record depending_nodes_、potential_stop_nodes_、potential_startup_nodes_
void GetGraphInfoBetweenTargets(const std::queue<GradNodeBase*>& init_queue) { void GetGraphInfoBetweenTargets(const std::deque<GradNodeBase*>& init_queue) {
VLOG(6) << "Runing In GetGraphInfoBetweenTargets"; VLOG(6) << "Runing In GetGraphInfoBetweenTargets";
// Calculate in_degree for each node // Calculate in_degree for each node
std::unordered_map<GradNodeBase*, int> node_in_degree_map; std::unordered_map<GradNodeBase*, int> node_in_degree_map;
// Copy nodes // Copy nodes
std::queue<GradNodeBase*> queue = init_queue; std::deque<GradNodeBase*> queue = init_queue;
std::unordered_set<GradNodeBase*> visited; std::unordered_set<GradNodeBase*> visited;
// Visit each node exactly once in any order // Visit each node exactly once in any order
while (!queue.empty()) { while (!queue.empty()) {
GradNodeBase* node = queue.front(); GradNodeBase* node = queue.front();
queue.pop(); queue.pop_front();
if (visited.count(node)) { if (visited.count(node)) {
continue; continue;
...@@ -198,7 +199,7 @@ class GeneralGrad { ...@@ -198,7 +199,7 @@ class GeneralGrad {
// Record depending relationship // Record depending relationship
(depending_nodes_)[next_node].emplace(node); (depending_nodes_)[next_node].emplace(node);
queue.push(next_node); queue.push_back(next_node);
} }
} }
} }
...@@ -207,10 +208,10 @@ class GeneralGrad { ...@@ -207,10 +208,10 @@ class GeneralGrad {
UpdateGraphInfo(); UpdateGraphInfo();
} }
void ModifyReadyQueue(std::queue<GradNodeBase*>* queue) { void ModifyReadyQueue(std::deque<GradNodeBase*>* queue) {
std::queue<GradNodeBase*> tmp_queue; std::deque<GradNodeBase*> tmp_queue;
for (auto nodes : potential_startup_nodes_) { for (auto nodes : potential_startup_nodes_) {
tmp_queue.emplace(nodes); tmp_queue.push_back(nodes);
} }
tmp_queue.swap(*queue); tmp_queue.swap(*queue);
} }
...@@ -297,7 +298,7 @@ class GeneralGrad { ...@@ -297,7 +298,7 @@ class GeneralGrad {
void PreparedForGeneralGrad( void PreparedForGeneralGrad(
const std::vector<paddle::experimental::Tensor>& inputs, const std::vector<paddle::experimental::Tensor>& inputs,
const std::vector<paddle::experimental::Tensor>& no_grad_vars, const std::vector<paddle::experimental::Tensor>& no_grad_vars,
std::queue<GradNodeBase*>* queue, std::deque<GradNodeBase*>* queue,
const std::unordered_map<GradNodeBase*, const std::unordered_map<GradNodeBase*,
std::unique_ptr<GradTensorHolder>>& std::unique_ptr<GradTensorHolder>>&
node_input_buffers_dict) { node_input_buffers_dict) {
...@@ -366,14 +367,14 @@ class GeneralGrad { ...@@ -366,14 +367,14 @@ class GeneralGrad {
} }
void ReconstructBackwardGraph( void ReconstructBackwardGraph(
const std::queue<GradNodeBase*>& orig_init_queue) { const std::deque<GradNodeBase*>& orig_init_queue) {
std::queue<GradNodeBase*> queue = orig_init_queue; std::deque<GradNodeBase*> queue = orig_init_queue;
std::unordered_set<GradNodeBase*> visited; std::unordered_set<GradNodeBase*> visited;
// BFS and recursively copy the grad nodes // BFS and recursively copy the grad nodes
while (!queue.empty()) { while (!queue.empty()) {
GradNodeBase* orig_node = queue.front(); GradNodeBase* orig_node = queue.front();
queue.pop(); queue.pop_front();
if (visited.count(orig_node)) { if (visited.count(orig_node)) {
continue; continue;
} }
...@@ -417,7 +418,7 @@ class GeneralGrad { ...@@ -417,7 +418,7 @@ class GeneralGrad {
copied_edge.SetGradNode(copied_next_node); copied_edge.SetGradNode(copied_next_node);
// Update BFS queue // Update BFS queue
queue.push(orig_next_node.get()); queue.push_back(orig_next_node.get());
} }
} }
} }
...@@ -449,20 +450,20 @@ class GeneralGrad { ...@@ -449,20 +450,20 @@ class GeneralGrad {
}; };
std::unordered_map<GradNodeBase*, int> getInDegreeMap( std::unordered_map<GradNodeBase*, int> getInDegreeMap(
const std::queue<GradNodeBase*>& init_queue) { const std::deque<GradNodeBase*>& init_queue) {
// Calculate in_degree for each node // Calculate in_degree for each node
// We can completely remove this pass, if in_degree were set during forward // We can completely remove this pass, if in_degree were set during forward
// pass // pass
std::unordered_map<GradNodeBase*, int> node_in_degree_map; std::unordered_map<GradNodeBase*, int> node_in_degree_map;
// Copy nodes // Copy nodes
std::queue<GradNodeBase*> queue = init_queue; std::deque<GradNodeBase*> queue = init_queue;
std::unordered_set<GradNodeBase*> visited; std::unordered_set<GradNodeBase*> visited;
// Visit each node exactly once in any order // Visit each node exactly once in any order
while (!queue.empty()) { while (!queue.empty()) {
GradNodeBase* node = queue.front(); GradNodeBase* node = queue.front();
queue.pop(); queue.pop_front();
if (visited.count(node)) { if (visited.count(node)) {
continue; continue;
...@@ -490,7 +491,7 @@ std::unordered_map<GradNodeBase*, int> getInDegreeMap( ...@@ -490,7 +491,7 @@ std::unordered_map<GradNodeBase*, int> getInDegreeMap(
if (!node_in_degree_map.count(next_node)) if (!node_in_degree_map.count(next_node))
node_in_degree_map[next_node] = 0; node_in_degree_map[next_node] = 0;
node_in_degree_map[next_node]++; node_in_degree_map[next_node]++;
queue.push(next_node); queue.push_back(next_node);
} }
} }
} }
...@@ -548,8 +549,8 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -548,8 +549,8 @@ std::vector<paddle::experimental::Tensor> RunBackward(
/* --- Initialization --- */ /* --- Initialization --- */
// 1. Init queue with starting nodes // 1. Init queue with starting nodes
// 2. Prepare initial input buffers // 2. Prepare initial input buffers
std::queue<GradNodeBase*> queue; std::deque<GradNodeBase*> queue;
std::queue<GradNodeBase*> orig_queue; std::deque<GradNodeBase*> orig_queue;
std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>> std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>>
node_input_buffers_dict; node_input_buffers_dict;
for (size_t i = 0; i < tensors.size(); i++) { for (size_t i = 0; i < tensors.size(); i++) {
...@@ -582,7 +583,7 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -582,7 +583,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
GradNodeBase* grad_node = shared_grad_node.get(); GradNodeBase* grad_node = shared_grad_node.get();
if (is_general_grad) { if (is_general_grad) {
// Save orig grad node // Save orig grad node
orig_queue.push(grad_node); orig_queue.push_back(grad_node);
// Replace grad_node with copied grad_node // Replace grad_node with copied grad_node
grad_node = GeneralGrad::Instance().CopyGradNode(shared_grad_node); grad_node = GeneralGrad::Instance().CopyGradNode(shared_grad_node);
...@@ -625,7 +626,7 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -625,7 +626,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
} }
// Prepare queue, potential startup_nodes // Prepare queue, potential startup_nodes
queue.push(grad_node); queue.push_back(grad_node);
} }
if (is_general_grad) { if (is_general_grad) {
...@@ -663,10 +664,10 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -663,10 +664,10 @@ std::vector<paddle::experimental::Tensor> RunBackward(
paddle::platform::TracerEventType::Operator, 1); paddle::platform::TracerEventType::Operator, 1);
if (queue.size() > 1 && node_in_degree_map[node] != 0) { if (queue.size() > 1 && node_in_degree_map[node] != 0) {
queue.pop(); queue.pop_front();
continue; continue;
} }
queue.pop(); queue.pop_front();
// Run node: This is where Hook happens // Run node: This is where Hook happens
auto node_input_buffer_iter = node_input_buffers_dict.find(node); auto node_input_buffer_iter = node_input_buffers_dict.find(node);
...@@ -798,11 +799,19 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -798,11 +799,19 @@ std::vector<paddle::experimental::Tensor> RunBackward(
bool is_potential_stop_node = bool is_potential_stop_node =
GeneralGrad::Instance().GetPotentialStopNodes()->count(next_node); GeneralGrad::Instance().GetPotentialStopNodes()->count(next_node);
if (node_in_degree_map[next_node] == 0 && !is_potential_stop_node) { if (node_in_degree_map[next_node] == 0 && !is_potential_stop_node) {
queue.emplace(std::move(next_node)); if (dynamic_cast<egr::GradNodeAccumulation*>(next_node)) {
queue.push_front(std::move(next_node));
} else {
queue.push_back(std::move(next_node));
}
} }
} else { } else {
if (node_in_degree_map[next_node] == 0) { if (node_in_degree_map[next_node] == 0) {
queue.emplace(std::move(next_node)); if (dynamic_cast<egr::GradNodeAccumulation*>(next_node)) {
queue.push_front(std::move(next_node));
} else {
queue.push_back(std::move(next_node));
}
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册