From acde295cecfbe525c84a130c9b05fd23d802c3f9 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Thu, 28 Jul 2022 10:32:30 +0800 Subject: [PATCH] [Eager] refactor general_grad and fix some bugs (#44611) * refactor general_grad and fix some bugs * add TODO: support prune logic deeper --- paddle/fluid/eager/backward.cc | 510 +------------- paddle/fluid/eager/general_grad.h | 653 ++++++++++++++++++ paddle/fluid/eager/grad_node_info.h | 13 + .../unittests/test_imperative_double_grad.py | 40 ++ 4 files changed, 726 insertions(+), 490 deletions(-) create mode 100644 paddle/fluid/eager/general_grad.h diff --git a/paddle/fluid/eager/backward.cc b/paddle/fluid/eager/backward.cc index 3206b9e7cfa..7c7a09db2b1 100644 --- a/paddle/fluid/eager/backward.cc +++ b/paddle/fluid/eager/backward.cc @@ -14,461 +14,11 @@ #include "paddle/fluid/eager/backward.h" -#include - -#include "glog/logging.h" -#include "paddle/fluid/eager/accumulation/accumulation_node.h" -#include "paddle/fluid/eager/autograd_meta.h" -#include "paddle/fluid/eager/grad_node_info.h" -#include "paddle/fluid/eager/grad_tensor_holder.h" -#include "paddle/fluid/eager/utils.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/errors.h" -#include "paddle/fluid/platform/profiler.h" -#include "paddle/fluid/platform/profiler/event_tracing.h" +#include "paddle/fluid/eager/general_grad.h" #include "paddle/phi/kernels/autotune/switch_autotune.h" namespace egr { -/* - * GeneralGrad is Helpper class to implement custom grad operation between - * outputs and inputs. - * - * **/ -class GeneralGrad { - public: - static GeneralGrad& Instance() { return *general_grad_; } - - // Get inputs's / no_grad_vars's GradNodes and InputMeta Info - void GetTargetNodesInfo( - const std::vector& inputs, - bool is_no_grad_vars) { - std::string msg = is_no_grad_vars ? "no_grad_vars" : "inputs"; - VLOG(6) << "Running in GetTargetNodesInfo."; - if (!inputs.empty()) { - VLOG(6) << msg << " are not empty."; - size_t num_inputs = inputs.size(); - for (size_t i = 0; i < num_inputs; i++) { - AutogradMeta* auto_grad_meta = - EagerUtils::unsafe_autograd_meta(inputs[i]); - auto* target_node = auto_grad_meta->GetMutableGradNode().get(); - VLOG(8) << "Get no grad vars' grad_node: " << target_node->name() - << ", " << target_node << " with output rank info: " - << auto_grad_meta->OutRankInfo().first << ", " - << auto_grad_meta->OutRankInfo().second; - if (is_no_grad_vars) { - (no_grad_var_nodes_inputmeta_map_)[target_node] = auto_grad_meta; - continue; - } - if (orig_to_copied_node_mapping_.count(target_node)) { - target_node = orig_to_copied_node_mapping_[target_node].get(); - } else { - VLOG(6) << "Unable to find target node in " - "orig_to_copied_node_mapping_, likely indicating an " - "unused input"; - } - - PADDLE_ENFORCE_NOT_NULL(target_node, - paddle::platform::errors::Fatal( - "There is no grad op for %s:[%d] or it's" - "stop_gradient=True.", - msg, - i)); - // normal input - (input_target_nodes_inputmeta_map_)[target_node] = auto_grad_meta; - } - } - } - - // Purify potential_startup_nodes_, remove nodes those are the same as - // input_target_nodes - void PurifyPotentialStartUpNodes() { - VLOG(6) << "Running in PurifyPotentialStartUpNodes"; - if (input_target_nodes_inputmeta_map_.empty()) return; - std::unordered_set potential_startup_nodes_to_be_erased; - for (auto startup_op : potential_startup_nodes_) { - auto iter = input_target_nodes_inputmeta_map_.find(startup_op); - if (iter != input_target_nodes_inputmeta_map_.end()) { - potential_startup_nodes_to_be_erased.emplace(iter->first); - } - } - if (!potential_startup_nodes_to_be_erased.empty()) { - for (auto nodes : potential_startup_nodes_to_be_erased) { - potential_startup_nodes_.erase(nodes); - } - } - } - - // Remove some nodes those doesn't need to be - // stored in potential_stop_nodes_、potential_startup_nodes_ - void UpdateGraphInfo() { - // Updated potential_sotp_nodes by depending_nodes_, - // make sure the path from root to target_node is ok - std::unordered_set startup_ops; - VLOG(6) << "Running in UpdateGraphInfo"; - std::deque queue; - for (auto& target_nodes_inputmeta_pair : - input_target_nodes_inputmeta_map_) { - queue.push_back(target_nodes_inputmeta_pair.first); - } - - while (!queue.empty()) { - auto* target_node = queue.front(); - queue.pop_front(); - if (!(depending_nodes_)[target_node].empty()) { - auto precedding_nodes = (depending_nodes_)[target_node]; - for (auto pre_nodes : precedding_nodes) { - queue.push_back(pre_nodes); - if (potential_stop_nodes_.find(pre_nodes) != - potential_stop_nodes_.end()) { - potential_stop_nodes_.erase(pre_nodes); - } - } - } else { // startup_ops have no precedding nodes - VLOG(6) << "Emplace startup_ops"; - startup_ops.emplace(target_node); - } - } - // Purify potential_startup_nodes_ again, remove some - // potential startup_nodes that unreach to input target nodes - if (!startup_ops.empty()) { - std::unordered_set potential_startup_nodes_to_be_erased; - for (auto node : potential_startup_nodes_) { - if (startup_ops.count(node) == 0) { - VLOG(6) << "Set up potential_startup_nodes_to_be_erased"; - potential_startup_nodes_to_be_erased.emplace(node); - } - } - if (!potential_startup_nodes_to_be_erased.empty()) { - for (auto node : potential_startup_nodes_to_be_erased) { - VLOG(6) << "Erase nodes in potential_startup_nodes_to_be_erased"; - potential_startup_nodes_.erase(node); - } - } - } - } - - // Get Graph Info Betweent input target GradNode and outputs, - // record depending_nodes_、potential_stop_nodes_、potential_startup_nodes_ - void GetGraphInfoBetweenTargets(const std::deque& init_queue) { - VLOG(6) << "Runing In GetGraphInfoBetweenTargets"; - - // Calculate in_degree for each node - std::unordered_map node_in_degree_map; - - // Copy nodes - std::deque queue = init_queue; - std::unordered_set visited; - - // Visit each node exactly once in any order - while (!queue.empty()) { - GradNodeBase* node = queue.front(); - queue.pop_front(); - - if (visited.count(node)) { - continue; - } - visited.insert(node); - - // Check node is target_nodes or not, if node is not target_node, - // all the next_node will be marked in potential_stop_nodes_ - bool is_potential_stop_nodes = - input_target_nodes_inputmeta_map_.count(node); - - // Find and append next nodes - const paddle::small_vector, - kSlotSmallVectorSize>& metas = - node->OutputMeta(); - for (const auto& meta_list : metas) { - for (const GradSlotMeta& meta : meta_list) { - const auto& edge = meta.GetEdge(); - GradNodeBase* next_node = edge.GetMutableGradNode().get(); - - // Next node could be nullptr if it is leaf tensor with no - // AccumulationNode attached - // Or it could also originated from dispensable inputs - if (!next_node) continue; - - // if node not in input_target_nodes, - // all the next_nodes of current node will be inserted to - // potential_stop_node - if (is_potential_stop_nodes) { - potential_stop_nodes_.emplace(next_node); - } - - // Update in_degree - if (!node_in_degree_map.count(next_node)) { - node_in_degree_map[next_node] = 0; - } - node_in_degree_map[next_node]++; - - // Record depending relationship - (depending_nodes_)[next_node].emplace(node); - queue.push_back(next_node); - } - } - } - // Update Graph Info, remove some nodes in - // potential_stop_nodes_、potential_startup_nodes_、 - UpdateGraphInfo(); - } - - void ModifyReadyQueue(std::deque* queue) { - std::deque tmp_queue; - for (auto nodes : potential_startup_nodes_) { - tmp_queue.push_back(nodes); - } - tmp_queue.swap(*queue); - } - - // Set result for input target grad_var when potential_startup_nodes_ is empty - void SetResultForInputTargetVar( - const std::unordered_map>& - node_input_buffers_dict) { - if (potential_startup_nodes_.size() == 0) { - for (auto input_target_node : *GetInputTargetNodesInputMetaMap()) { - // out rank_info of forward op - auto rank_info = input_target_node.second->OutRankInfo(); - auto iter = node_input_buffers_dict.find(input_target_node.first); - if (iter != node_input_buffers_dict.end()) { - auto& target_result = - (iter->second)->Buffers()[rank_info.first][rank_info.second]; - // save the target result - results_map_[input_target_node.first] = target_result; - } - } - } - } - - // Set input target grad_var from node_input_buffer by inputmeta - void SetResultForInputTargetVar(GradTensorHolder input_buffers, - GradNodeBase* node) { - auto iter = GetInputTargetNodesInputMetaMap()->find(node); - if (iter != GetInputTargetNodesInputMetaMap()->end()) { - VLOG(6) << "Get target result by by inputmeta"; - // out rank_info of forward op - auto rank_info = (iter->second)->OutRankInfo(); - // rank_info is a pair, first means slot_id, second means rank. - auto& target_result = - input_buffers.Buffers()[rank_info.first][rank_info.second]; - // save the target result - results_map_[node] = target_result; - } - } - - std::vector GetResults( - const std::vector& inputs, - bool allow_unused, - bool create_graph) { - VLOG(6) << "Running in GetResults"; - if (inputs.empty()) return {}; - - std::vector results; - results.reserve(inputs.size()); - - for (size_t i = 0; i < inputs.size(); ++i) { - auto& input = inputs[i]; - AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input); - - auto* target_node = auto_grad_meta->GetMutableGradNode().get(); - if (orig_to_copied_node_mapping_.count(target_node)) { - target_node = orig_to_copied_node_mapping_[target_node].get(); - } else { - VLOG(6) << "Unable to find target node in " - "orig_to_copied_node_mapping_, likely indicating an unused " - "input"; - } - - auto iter = results_map_.find(target_node); - if (iter != results_map_.end()) { - // set StopGradient = !create_graph - AutogradMeta* tensor_auto_grad_meta = - EagerUtils::autograd_meta(&(iter->second)); - tensor_auto_grad_meta->SetStopGradient(!create_graph); - results.emplace_back(iter->second); - } else { - PADDLE_ENFORCE_EQ(allow_unused, - true, - paddle::platform::errors::InvalidArgument( - "The %d-th input does not appear in the backward " - "graph. Please check the input tensor or set " - "allow_unused=True to get None result.", - i)); - results.emplace_back(); - } - } - Clear(); - return results; - } - - void PreparedForGeneralGrad( - const std::vector& inputs, - const std::vector& no_grad_vars, - std::deque* queue, - const std::unordered_map>& - node_input_buffers_dict) { - // Get inputs's GradNodes and InputMeta Info - GetTargetNodesInfo(inputs, false /* is_no_grad_vars */); - // Purify potentialstartup_ops, remove those nodes that are the same as - // input_target_nodes - PurifyPotentialStartUpNodes(); - // Get Graph Info Betweent input target gradnode and outputs - // Record the depending_nodes_ and - // potential_stop_nodes_、potential_startup_nodes_ - GetGraphInfoBetweenTargets(*queue); - // Reset queue. Queue is empty only when - // 1.input equals to output. 2.input can not reach to output. - ModifyReadyQueue(queue); - // Set result for input target grad_var when queue is empty - if (queue->empty()) SetResultForInputTargetVar(node_input_buffers_dict); - } - - bool IsPotentialStopNodes(GradNodeBase* node) { - return potential_stop_nodes_.count(node); - } - - std::unordered_map* - GetNoGradVarNodesInputMetaMap() { - return &no_grad_var_nodes_inputmeta_map_; - } - - std::unordered_map* - GetInputTargetNodesInputMetaMap() { - return &input_target_nodes_inputmeta_map_; - } - - std::unordered_set* GetPotentialStopNodes() { - return &potential_stop_nodes_; - } - - std::unordered_set* GetPotentialStartupNodes() { - return &potential_startup_nodes_; - } - - void Clear() { - no_grad_var_nodes_inputmeta_map_.clear(); - input_target_nodes_inputmeta_map_.clear(); - potential_startup_nodes_.clear(); - potential_stop_nodes_.clear(); - depending_nodes_.clear(); - results_map_.clear(); - copied_grad_nodes_.clear(); - orig_to_copied_node_mapping_.clear(); - } - - GradNodeBase* CopyGradNode(const std::shared_ptr& orig_node) { - if (orig_to_copied_node_mapping_.count(orig_node.get())) { - return orig_to_copied_node_mapping_[orig_node.get()].get(); - } - std::shared_ptr copied_node = orig_node->Copy(); - - // Save node and update mapping - orig_to_copied_node_mapping_[orig_node.get()] = copied_node; - copied_grad_nodes_.push_back(copied_node); - - return copied_node.get(); - } - - void ReconstructBackwardGraph( - const std::deque& orig_init_queue) { - std::deque queue = orig_init_queue; - std::unordered_set visited; - - // BFS and recursively copy the grad nodes - while (!queue.empty()) { - GradNodeBase* orig_node = queue.front(); - queue.pop_front(); - if (visited.count(orig_node)) { - continue; - } - visited.insert(orig_node); - - PADDLE_ENFORCE( - orig_to_copied_node_mapping_.count(orig_node), - paddle::platform::errors::Fatal( - "Cannot reconstruct backward graph," - "unable to find copied target for certain grad node.")); - GradNodeBase* copied_node = orig_to_copied_node_mapping_[orig_node].get(); - - const paddle::small_vector, - kSlotSmallVectorSize>& orig_meta = - orig_node->OutputMeta(); - paddle::small_vector, kSlotSmallVectorSize>& - copied_edges = copied_node->MutableOutputMeta(); - for (size_t i = 0; i < orig_meta.size(); i++) { - for (size_t j = 0; j < orig_meta[i].size(); j++) { - const Edge& orig_edge = orig_meta[i][j].GetEdge(); - Edge& copied_edge = copied_edges[i][j].GetMutableEdge(); - - std::shared_ptr orig_next_node = - orig_edge.GetMutableGradNode(); - - if (no_grad_var_nodes_inputmeta_map_.count(orig_next_node.get()) && - (no_grad_var_nodes_inputmeta_map_[orig_next_node.get()] - ->OutRankInfo() == orig_edge.GetEdgeRankInfo())) { - VLOG(3) << "Get no grad edge from grad_node: " << orig_node->name() - << " : " << orig_node << " to:" << orig_next_node->name() - << ", " << orig_next_node.get() - << " with output rank info: " - << orig_edge.GetEdgeRankInfo().first << ", " - << orig_edge.GetEdgeRankInfo().second; - // Stop no grad var's preceding node - copied_node->MutableOutputMeta()[i][j].SetStopGradient(true); - copied_edge.Clear(); - continue; - } - if (!orig_next_node) continue; - - // Copy Next Node - std::shared_ptr copied_next_node; - if (orig_to_copied_node_mapping_.count(orig_next_node.get())) { - copied_next_node = - orig_to_copied_node_mapping_[orig_next_node.get()]; - - } else { - copied_next_node = orig_next_node->Copy(); - orig_to_copied_node_mapping_[orig_next_node.get()] = - copied_next_node; - copied_grad_nodes_.push_back(copied_next_node); - } - - // Update Edge's Grad Node - copied_edge.SetGradNode(copied_next_node); - - // Update BFS queue - queue.push_back(orig_next_node.get()); - } - } - } - } - - private: - GeneralGrad() = default; - static GeneralGrad* general_grad_; - // no_grad_vars's GradNode and GradNode's InputMeta. - std::unordered_map - no_grad_var_nodes_inputmeta_map_; - // inputs's GradNode and GradNode's InputMeta. - std::unordered_map - input_target_nodes_inputmeta_map_; - // Record all the potential startup_nodes, will be changed. - std::unordered_set potential_startup_nodes_; - // Record all the potential stop nodes, will be changed. - std::unordered_set potential_stop_nodes_; - std::unordered_map /* pre nodes */> - depending_nodes_; - std::unordered_map results_map_; - - std::vector> copied_grad_nodes_; - std::unordered_map> - orig_to_copied_node_mapping_; - - DISABLE_COPY_AND_ASSIGN(GeneralGrad); -}; - std::unordered_map getInDegreeMap( const std::deque& init_queue) { // Calculate in_degree for each node @@ -655,25 +205,17 @@ std::vector RunBackward( } if (is_general_grad) { - // Get no_grad_vars's GradNodes and InputMeta Info - GeneralGrad::Instance().GetTargetNodesInfo(no_grad_vars, - true /* is_no_grad_vars */); - // Copy Backward Graph - GeneralGrad::Instance().ReconstructBackwardGraph(orig_queue); + // Prepare several vital preprocess for GeneralGrad + GeneralGrad::Instance().PreparedForGeneralGrad( + inputs, no_grad_vars, orig_queue, &queue, node_input_buffers_dict); } - VLOG(3) << "Update In degree Map for backward"; + VLOG(6) << "Update In degree Map for backward"; // 3. Compute in_degree for each node std::unordered_map node_in_degree_map = getInDegreeMap(queue); - if (is_general_grad) { - // Prepare several vital preprocess for GeneralGrad - GeneralGrad::Instance().PreparedForGeneralGrad( - inputs, no_grad_vars, &queue, node_input_buffers_dict); - } - - VLOG(6) << " startup_ops' size is :" << queue.size(); + VLOG(3) << "Startup_ops's size is " << queue.size(); /* --- Topological Visit --- */ // 1. Pop queue @@ -685,7 +227,7 @@ std::vector RunBackward( VLOG(3) << "Run Backward"; while (!queue.empty()) { GradNodeBase* node = queue.front(); - VLOG(6) << "Running GradNode:" << node->name(); + VLOG(3) << "Running GradNode:" << node->name() << " addr:" << node; paddle::platform::RecordEvent node_record_event( std::string((*node).name()), @@ -710,12 +252,6 @@ std::vector RunBackward( std::unique_ptr node_input_buffer = std::move(node_input_buffer_iter->second); - // Set input target grad_var from node_input_buffer by inputmeta - if (!inputs.empty() && is_general_grad) { - GeneralGrad::Instance().SetResultForInputTargetVar(*node_input_buffer, - node); - } - // Check input EnforceGradNodeHasInput(node); @@ -726,6 +262,11 @@ std::vector RunBackward( grad_output_tensors = (*node)( node_input_buffer->Buffers(), create_graph, is_general_grad); + if (!inputs.empty() && is_general_grad) { + GeneralGrad::Instance().SetResultForEnddingNodes(grad_output_tensors, + node); + } + // retain_grad or not if (!retain_graph) { VLOG(6) @@ -757,8 +298,9 @@ std::vector RunBackward( // Since we make edge has as same rank as bwd outputs, we indexing them // with the same rank(i, j) auto next_node_shared = edge.GetMutableGradNode(); - VLOG(3) << "Found pending node: " << next_node_shared->name() << ": " - << next_node_shared.get(); + VLOG(3) << "Node: " << node->name() << " addr:" << node + << ", Found pending node: " << next_node_shared->name() + << " addr: " << next_node_shared.get(); // Next node could be nullptr if it is leaf tensor with no // AccumulationNode attached // Or it could also originated from dispensable inputs @@ -818,23 +360,11 @@ std::vector RunBackward( "Node's in-degree cannot be negative.", next_node->name())); - if (is_general_grad) { - bool is_potential_stop_node = - GeneralGrad::Instance().GetPotentialStopNodes()->count(next_node); - if (node_in_degree_map[next_node] == 0 && !is_potential_stop_node) { - if (dynamic_cast(next_node)) { - queue.push_front(std::move(next_node)); - } else { - queue.push_back(std::move(next_node)); - } - } - } else { - if (node_in_degree_map[next_node] == 0) { - if (dynamic_cast(next_node)) { - queue.push_front(std::move(next_node)); - } else { - queue.push_back(std::move(next_node)); - } + if (node_in_degree_map[next_node] == 0) { + if (dynamic_cast(next_node)) { + queue.push_front(std::move(next_node)); + } else { + queue.push_back(std::move(next_node)); } } } diff --git a/paddle/fluid/eager/general_grad.h b/paddle/fluid/eager/general_grad.h new file mode 100644 index 00000000000..554afcd8ccd --- /dev/null +++ b/paddle/fluid/eager/general_grad.h @@ -0,0 +1,653 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "glog/logging.h" +#include "paddle/fluid/eager/accumulation/accumulation_node.h" +#include "paddle/fluid/eager/api/utils/hook_utils.h" +#include "paddle/fluid/eager/autograd_meta.h" +#include "paddle/fluid/eager/grad_node_info.h" +#include "paddle/fluid/eager/grad_tensor_holder.h" +#include "paddle/fluid/eager/utils.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/errors.h" +#include "paddle/fluid/platform/profiler.h" +#include "paddle/fluid/platform/profiler/event_tracing.h" + +namespace egr { + +/* + * GeneralGrad is Helpper class to implement custom grad operation between + * outputs and inputs. + * + * **/ +class GeneralGrad { + public: + static GeneralGrad& Instance() { return *general_grad_; } + + // Get inputs's / no_grad_vars's GradNodes and InputMeta Info + void GetTargetNodesInfo( + const std::vector& inputs, + bool is_no_grad_vars) { + std::string msg = is_no_grad_vars ? "no_grad_vars" : "inputs"; + VLOG(6) << "Running in GetTargetNodesInfo."; + if (!inputs.empty()) { + VLOG(6) << msg << " are not empty."; + size_t num_inputs = inputs.size(); + for (size_t i = 0; i < num_inputs; i++) { + AutogradMeta* auto_grad_meta = + EagerUtils::unsafe_autograd_meta(inputs[i]); + auto* target_node = auto_grad_meta->GetMutableGradNode().get(); + + if (orig_to_copied_node_map_.count(target_node)) { + target_node = orig_to_copied_node_map_[target_node].get(); + } else { + VLOG(6) << "Unable to find target node in " + "orig_to_copied_node_map_, likely indicating an " + "unused input"; + } + + PADDLE_ENFORCE_NOT_NULL(target_node, + paddle::platform::errors::Fatal( + "There is no grad op for %s:[%d] or it's" + "stop_gradient=True.", + msg, + i)); + + if (is_no_grad_vars) { + (no_grad_var_nodes_inputmeta_map_)[target_node] = auto_grad_meta; + } else { + // normal input + (input_target_nodes_inputmeta_map_)[target_node] = auto_grad_meta; + } + } + } + } + + // Purify potential_startup_nodes_, remove nodes those are the same as + // input_target_nodes + void PurifyPotentialStartUpNodes() { + VLOG(6) << "Running in PurifyPotentialStartUpNodes"; + if (input_target_nodes_inputmeta_map_.empty()) return; + std::unordered_set potential_startup_nodes_to_be_erased; + for (auto startup_op : potential_startup_nodes_) { + auto iter = input_target_nodes_inputmeta_map_.find(startup_op); + if (iter != input_target_nodes_inputmeta_map_.end()) { + potential_startup_nodes_to_be_erased.emplace(iter->first); + } + } + if (!potential_startup_nodes_to_be_erased.empty()) { + for (auto nodes : potential_startup_nodes_to_be_erased) { + potential_startup_nodes_.erase(nodes); + } + } + } + + // Update Graph Info and remove some nodes those doesn't need to be + // stored in potential_startup_nodes_ + void UpdateGraphInfo() { + std::unordered_set startup_ops; + VLOG(6) << "Running in UpdateGraphInfo"; + std::deque queue; + for (auto& target_nodes_inputmeta_pair : + input_target_nodes_inputmeta_map_) { + queue.push_back(target_nodes_inputmeta_pair.first); + needed_nodes_.emplace(target_nodes_inputmeta_pair.first); + } + std::unordered_set visited; + std::unordered_set input_target_nodes_on_path; + while (!queue.empty()) { + auto* target_node = queue.front(); + queue.pop_front(); + if (visited.count(target_node)) { + continue; + } + visited.insert(target_node); + if (!(depending_nodes_)[target_node].empty()) { + auto precedding_nodes = (depending_nodes_)[target_node]; + for (auto pre_nodes : precedding_nodes) { + queue.push_back(pre_nodes); + needed_nodes_.emplace(pre_nodes); + if (IsInputTargetNodes(pre_nodes)) { + input_target_nodes_on_path.emplace(pre_nodes); + } + } + } else { // startup_ops have no precedding nodes + VLOG(6) << "Emplace startup_ops"; + startup_ops.emplace(target_node); + needed_nodes_.emplace(target_node); + } + } + + for (auto& target_nodes_inputmeta_pair : + input_target_nodes_inputmeta_map_) { + if (!input_target_nodes_on_path.count( + target_nodes_inputmeta_pair.first)) { + endding_nodes_.emplace(target_nodes_inputmeta_pair.first); + } + } + + // Purify potential_startup_nodes_ again, remove some + // potential startup nodes that unreach to input target nodes + if (!startup_ops.empty()) { + std::unordered_set potential_startup_nodes_to_be_erased; + for (auto node : potential_startup_nodes_) { + if (startup_ops.count(node) == 0) { + VLOG(6) << "Set up potential_startup_nodes_to_be_erased"; + potential_startup_nodes_to_be_erased.emplace(node); + } + } + if (!potential_startup_nodes_to_be_erased.empty()) { + for (auto node : potential_startup_nodes_to_be_erased) { + VLOG(6) << "Erase nodes in potential_startup_nodes_to_be_erased"; + potential_startup_nodes_.erase(node); + } + } + } + } + + // Get Graph Info Betweent input target GradNode and outputs, + // record depending_nodes_, potential_startup_nodes_ + void GetGraphInfoBetweenTargets(const std::deque& init_queue) { + VLOG(6) << "Runing In GetGraphInfoBetweenTargets"; + + // Calculate in_degree for each node + std::unordered_map node_in_degree_map; + + // Copy nodes + std::deque queue = init_queue; + std::unordered_set visited; + + // Visit each node exactly once in any order + while (!queue.empty()) { + GradNodeBase* node = queue.front(); + queue.pop_front(); + + if (visited.count(node)) { + continue; + } + visited.insert(node); + + // Find and append next nodes + const paddle::small_vector, + kSlotSmallVectorSize>& metas = + node->OutputMeta(); + for (const auto& meta_list : metas) { + for (const GradSlotMeta& meta : meta_list) { + const auto& edge = meta.GetEdge(); + GradNodeBase* next_node = edge.GetMutableGradNode().get(); + + // Next node could be nullptr if it is leaf tensor with no + // AccumulationNode attached + // Or it could also originated from dispensable inputs + if (!next_node) continue; + + // Update in_degree + if (!node_in_degree_map.count(next_node)) { + node_in_degree_map[next_node] = 0; + } + node_in_degree_map[next_node]++; + + // Record depending relationship + (depending_nodes_)[next_node].emplace(node); + queue.push_back(next_node); + } + } + } + } + + void ModifyReadyQueue(std::deque* queue) { + std::deque tmp_queue; + for (auto nodes : potential_startup_nodes_) { + tmp_queue.push_back(nodes); + } + tmp_queue.swap(*queue); + } + + // Set result for input target grad_var when potential_startup_nodes_ is empty + void SetResultForInputTargetVar( + const std::unordered_map>& + node_input_buffers_dict) { + if (potential_startup_nodes_.size() == 0) { + for (auto input_target_node : *GetInputTargetNodesInputMetaMap()) { + // out rank_info of forward op + auto rank_info = input_target_node.second->OutRankInfo(); + auto iter = node_input_buffers_dict.find(input_target_node.first); + if (iter != node_input_buffers_dict.end()) { + auto& target_result = + (iter->second)->Buffers()[rank_info.first][rank_info.second]; + // save the target result + results_map_[input_target_node.first] = + std::make_shared(target_result); + } + } + } + } + + void SetResultForEnddingNodes( + paddle::small_vector, + kSlotSmallVectorSize> grad_output, + GradNodeBase* node) { + if (IsEnddingNodes(node)) { + VLOG(6) << "Set result for endding_nodes_ with grad_output_tensors"; + results_map_[node] = + std::make_shared(grad_output[0][0]); + } + } + + std::shared_ptr FetchGradForTensor( + const paddle::experimental::Tensor& tensor, + egr::GradNodeBase* target_node) { + std::shared_ptr tmp{ + std::make_shared()}; + VLOG(6) + << "Running in FetchGradForTensor, prepare FetchGrad Hook for tensor: " + << tensor.name(); + auto hook = [tmp](const paddle::experimental::Tensor& t) { + auto tmp_grad = tmp.get(); + if (t.defined()) { + VLOG(6) << "Set impl for FetchGrad Hook for tensor: " << t.name(); + tmp_grad->set_impl(t.impl()); + tmp_grad->set_autograd_meta(t.mutable_autograd_meta()); + return t; + } else { + VLOG(6) << "Retain NULL paddle::experimental::Tensor in FetchGrad Hook"; + return paddle::experimental::Tensor(); + } + }; + + // Append to GradientHooks + auto rank_info = EagerUtils::unsafe_autograd_meta(tensor)->OutRankInfo(); + target_node->RegisterGradientHook( + rank_info.first, + rank_info.second, + std::move(std::make_shared(hook))); + return tmp; + } + + // Register Hook to fetch input's gradients, when input's grad node is not an + // endding node in backward graph. If input's grad node is an endding node in + // backward graph, use grad node's output as inputs' gradients and no need to + // register Hook. Please note that endding node must be GradNodeAccumulation + // after ModifyBackwardGraph function. + void RegisterFetchGradHook( + const std::vector& inputs) { + VLOG(6) << "Running in RegisterFetchGradHook."; + if (!inputs.empty()) { + size_t num_inputs = inputs.size(); + for (size_t i = 0; i < num_inputs; i++) { + AutogradMeta* auto_grad_meta = + EagerUtils::unsafe_autograd_meta(inputs[i]); + auto* target_node = auto_grad_meta->GetMutableGradNode().get(); + + if (dynamic_cast(target_node)) { + VLOG(6) + << "No need to call FetchGradForTensor for GradNodeAccumulation"; + continue; + } + + if (orig_to_copied_node_map_.count(target_node)) { + target_node = orig_to_copied_node_map_[target_node].get(); + if (copied_node_to_endding_node_map_.count(target_node)) { + VLOG(6) << "No need to call FetchGradForTensor for endding_nodes"; + continue; + } + } + + PADDLE_ENFORCE_NOT_NULL( + target_node, + paddle::platform::errors::Fatal( + "There is no grad op for inputs:[%d] or it's" + "stop_gradient=True.", + i)); + + if (!IsEnddingNodes(target_node)) { + // Fetch grad for tensor in target_node on path. + auto fetched_grad = FetchGradForTensor(inputs[i], target_node); + results_map_[target_node] = fetched_grad; + } + } + } + } + + void SetNodeToAccumulationNode(GradNodeBase* node) { + if (dynamic_cast(node)) return; + if (!(depending_nodes_)[node].empty()) { + auto precedding_nodes = (depending_nodes_)[node]; + for (auto pre_nodes : precedding_nodes) { + paddle::small_vector, kSlotSmallVectorSize>& + pre_nodes_edges = pre_nodes->MutableOutputMeta(); + for (size_t i = 0; i < pre_nodes_edges.size(); i++) { + for (size_t j = 0; j < pre_nodes_edges[i].size(); j++) { + auto edge_ = pre_nodes_edges[i][j].GetEdge(); + if (edge_.GetGradNode() == node) { + auto autograd_meta = egr::AutogradMeta(edge_); + Edge& pre_node_edge = pre_nodes_edges[i][j].GetMutableEdge(); + + if (copied_node_to_endding_node_map_.count(node)) { + pre_node_edge.SetGradNode( + copied_node_to_endding_node_map_[node]); + } else { + std::shared_ptr shared_grad_node_accumulation = + std::make_shared(&autograd_meta); + pre_node_edge.SetGradNode(shared_grad_node_accumulation); + copied_node_to_endding_node_map_[node] = + shared_grad_node_accumulation; + } + + auto* grad_node = pre_node_edge.GetGradNode(); + needed_nodes_.emplace(grad_node); + endding_nodes_.emplace(grad_node); + input_target_nodes_inputmeta_map_[grad_node] = + input_target_nodes_inputmeta_map_[node]; + + VLOG(6) + << node->name() << " (addr:" << node + << ") has been transformed to GradNodeAccumulation (addr: " + << grad_node << ")"; + + // Copy Hook func + if (node->GradientHooksRegistered()) { + VLOG(6) << "Copy hook func from node: " << node->name() + << " (addr: " << node + << ") to GradNodeAccumulation (addr: " << grad_node + << ")"; + grad_node->SetGradientHookFuntions( + node->GetGradientHookFuntions()); + } + } + } + } + } + } + } + + void ModifyBackwardGraph(std::deque* queue) { + std::deque queue_ = *queue; + std::unordered_set visited; + + while (!queue_.empty()) { + GradNodeBase* node = queue_.front(); + queue_.pop_front(); + + if (visited.count(node)) { + continue; + } + visited.insert(node); + + if (IsInputTargetNodes(node)) { + if (IsEnddingNodes(node)) { + SetNodeToAccumulationNode(node); + continue; + } + } + + paddle::small_vector, kSlotSmallVectorSize>& + meta = node->MutableOutputMeta(); + for (size_t i = 0; i < meta.size(); i++) { + for (size_t j = 0; j < meta[i].size(); j++) { + Edge& edge = meta[i][j].GetMutableEdge(); + std::shared_ptr next_node = edge.GetMutableGradNode(); + + if (!next_node) continue; + + if (no_grad_var_nodes_inputmeta_map_.count(next_node.get()) && + (no_grad_var_nodes_inputmeta_map_[next_node.get()] + ->OutRankInfo() == edge.GetEdgeRankInfo())) { + VLOG(3) << "Get no grad edge from grad_node: " << node->name() + << " : " << node << " to:" << next_node->name() << ", " + << next_node.get() << " with output rank info: " + << edge.GetEdgeRankInfo().first << ", " + << edge.GetEdgeRankInfo().second; + // no_grad_var's grad no need to be computed + meta[i][j].SetStopGradient(true); + edge.Clear(); + continue; + } + + // TODO(weilong): support prune logic deeper + + // Update BFS queue + queue_.push_back(next_node.get()); + } + } + } + } + + std::vector GetResults( + const std::vector& inputs, + bool allow_unused, + bool create_graph) { + VLOG(6) << "Running in GetResults"; + if (inputs.empty()) return {}; + + std::vector results; + results.reserve(inputs.size()); + + for (size_t i = 0; i < inputs.size(); ++i) { + auto& input = inputs[i]; + AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input); + + auto* target_node = auto_grad_meta->GetMutableGradNode().get(); + if (orig_to_copied_node_map_.count(target_node)) { + target_node = orig_to_copied_node_map_[target_node].get(); + if (copied_node_to_endding_node_map_.count(target_node)) { + target_node = copied_node_to_endding_node_map_[target_node].get(); + } + } else { + VLOG(6) << "Unable to find target node in " + "orig_to_copied_node_map_, likely indicating an unused " + "input"; + } + auto iter = results_map_.find(target_node); + if (iter != results_map_.end()) { + // set StopGradient = !create_graph + AutogradMeta* tensor_auto_grad_meta = + EagerUtils::autograd_meta(iter->second.get()); + tensor_auto_grad_meta->SetStopGradient(!create_graph); + results.emplace_back(*(iter->second.get())); + } else { + PADDLE_ENFORCE_EQ(allow_unused, + true, + paddle::platform::errors::InvalidArgument( + "The %d-th input does not appear in the backward " + "graph. Please check the input tensor or set " + "allow_unused=True to get None result.", + i)); + results.emplace_back(); + } + } + Clear(); + return results; + } + + bool IsNeededNodes(GradNodeBase* node) { return needed_nodes_.count(node); } + + bool IsEnddingNodes(GradNodeBase* node) { return endding_nodes_.count(node); } + + bool IsInputTargetNodes(GradNodeBase* node) { + auto iter = input_target_nodes_inputmeta_map_.find(node); + if (iter != input_target_nodes_inputmeta_map_.end()) { + return true; + } + return false; + } + + std::unordered_map* + GetNoGradVarNodesInputMetaMap() { + return &no_grad_var_nodes_inputmeta_map_; + } + + std::unordered_map* + GetInputTargetNodesInputMetaMap() { + return &input_target_nodes_inputmeta_map_; + } + + std::unordered_set* GetPotentialStartupNodes() { + return &potential_startup_nodes_; + } + + GradNodeBase* CopyGradNode(const std::shared_ptr& orig_node) { + if (orig_to_copied_node_map_.count(orig_node.get())) { + return orig_to_copied_node_map_[orig_node.get()].get(); + } + std::shared_ptr copied_node = orig_node->Copy(); + + // Save node and update mapping + orig_to_copied_node_map_[orig_node.get()] = copied_node; + copied_grad_nodes_.push_back(copied_node); + + return copied_node.get(); + } + + void CopyBackwardGraph(const std::deque& orig_init_queue) { + std::deque queue = orig_init_queue; + std::unordered_set visited; + + // BFS and recursively copy the grad nodes + while (!queue.empty()) { + GradNodeBase* orig_node = queue.front(); + queue.pop_front(); + if (visited.count(orig_node)) { + continue; + } + visited.insert(orig_node); + + PADDLE_ENFORCE( + orig_to_copied_node_map_.count(orig_node), + paddle::platform::errors::Fatal( + "Cannot copy backward graph," + "unable to find copied target for certain grad node.")); + GradNodeBase* copied_node = orig_to_copied_node_map_[orig_node].get(); + + const paddle::small_vector, + kSlotSmallVectorSize>& orig_meta = + orig_node->OutputMeta(); + paddle::small_vector, kSlotSmallVectorSize>& + copied_edges = copied_node->MutableOutputMeta(); + for (size_t i = 0; i < orig_meta.size(); i++) { + for (size_t j = 0; j < orig_meta[i].size(); j++) { + const Edge& orig_edge = orig_meta[i][j].GetEdge(); + Edge& copied_edge = copied_edges[i][j].GetMutableEdge(); + + std::shared_ptr orig_next_node = + orig_edge.GetMutableGradNode(); + if (!orig_next_node) continue; + + // Copy Next Node + std::shared_ptr copied_next_node; + if (orig_to_copied_node_map_.count(orig_next_node.get())) { + copied_next_node = orig_to_copied_node_map_[orig_next_node.get()]; + + } else { + copied_next_node = orig_next_node->Copy(); + orig_to_copied_node_map_[orig_next_node.get()] = copied_next_node; + copied_grad_nodes_.push_back(copied_next_node); + } + + // Update Edge's Grad Node + copied_edge.SetGradNode(copied_next_node); + + // Update BFS queue + queue.push_back(orig_next_node.get()); + } + } + } + } + + void PreparedForGeneralGrad( + const std::vector& inputs, + const std::vector& no_grad_vars, + const std::deque& orig_queue, + std::deque* queue, + const std::unordered_map>& + node_input_buffers_dict) { + // Copy Backward Graph + CopyBackwardGraph(orig_queue); + // Get no_grad_vars's GradNodes and InputMeta Info + GetTargetNodesInfo(no_grad_vars, true /* is_no_grad_vars */); + // Get inputs's GradNodes and InputMeta Info + GetTargetNodesInfo(inputs, false /* is_no_grad_vars */); + // Purify potentialstartup_ops, remove those nodes that are the same as + // input_target_nodes + PurifyPotentialStartUpNodes(); + // Get Graph Info Betweent input target gradnode and outputs + // Record the depending_nodes_ and potential_startup_nodes_ + GetGraphInfoBetweenTargets(*queue); + // Update Graph Info, remove some nodes in + // potential_startup_nodes_ + UpdateGraphInfo(); + // Reset queue. Queue is empty only when + // 1.input equals to output. 2.input can not reach to output. + ModifyReadyQueue(queue); + // Set result for input target grad_var when queue is empty + if (queue->empty()) { + SetResultForInputTargetVar(node_input_buffers_dict); + } else { + // TODO(wuweilong): Find a better design here. + ModifyBackwardGraph(queue); + // Register Hook to fetch input's gradients + RegisterFetchGradHook(inputs); + } + } + + void Clear() { + no_grad_var_nodes_inputmeta_map_.clear(); + input_target_nodes_inputmeta_map_.clear(); + potential_startup_nodes_.clear(); + depending_nodes_.clear(); + results_map_.clear(); + copied_grad_nodes_.clear(); + orig_to_copied_node_map_.clear(); + copied_node_to_endding_node_map_.clear(); + needed_nodes_.clear(); + endding_nodes_.clear(); + } + + private: + GeneralGrad() = default; + static GeneralGrad* general_grad_; + // no_grad_vars's GradNode and GradNode's InputMeta. + std::unordered_map + no_grad_var_nodes_inputmeta_map_; + // inputs's GradNode and GradNode's InputMeta. + std::unordered_map + input_target_nodes_inputmeta_map_; + // Record all the potential startup_nodes, will be changed. + std::unordered_set potential_startup_nodes_; + std::unordered_map /* pre nodes */> + depending_nodes_; + std::unordered_map> + results_map_; + + std::vector> copied_grad_nodes_; + std::unordered_map> + orig_to_copied_node_map_; + std::unordered_set needed_nodes_; + // Record which grad_node has been transformed to AccumulationNode + std::unordered_map> + copied_node_to_endding_node_map_; + std::unordered_set endding_nodes_; + + DISABLE_COPY_AND_ASSIGN(GeneralGrad); +}; + +} // namespace egr diff --git a/paddle/fluid/eager/grad_node_info.h b/paddle/fluid/eager/grad_node_info.h index 2f8ca2bb420..a65a044895a 100644 --- a/paddle/fluid/eager/grad_node_info.h +++ b/paddle/fluid/eager/grad_node_info.h @@ -253,6 +253,19 @@ class GradNodeBase { * **/ inline bool GradientHooksRegistered() { return !gradient_hooks_.empty(); } + std::map>> + GetGradientHookFuntions() { + VLOG(6) << "GetGradientHookFuntions "; + return gradient_hooks_; + } + + void SetGradientHookFuntions( + std::map>> + hooks) { + VLOG(6) << "SetGradientHookFuntions "; + gradient_hooks_ = hooks; + } + paddle::small_vector, kSlotSmallVectorSize> ApplyGradientHooks( diff --git a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py index 5e9374bac05..d80b708ebf2 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py @@ -166,6 +166,46 @@ class TestEagerGrad(TestCase): self.func_simple_example_eager_grad_duplicate_output() self.func_simple_example_eager_grad_duplicate_output() + def test_simple_example_eager_two_grad_output(self): + with _test_eager_guard(): + x1 = paddle.to_tensor([1.0, 2.0]) + x1.stop_gradient = False + x2 = paddle.to_tensor([1.0, 2.0]) + x2.stop_gradient = False + out1 = x1 * 2 + out2 = x2 * 2 + + dout2_record_by_hook = [] + + def record_hook(grad): + dout2_record_by_hook.append(grad) + + out2.register_hook(record_hook) + + out3 = paddle.multiply(out1, out2) + out4 = paddle.mean(out3) + egr_dout2, egr_dout3 = paddle.grad([out4], [out2, out3]) + + self.assertTrue( + np.array_equal(dout2_record_by_hook[0].numpy(), + np.array([1., 2.]))) + + x1 = paddle.to_tensor([1.0, 2.0]) + x1.stop_gradient = False + x2 = paddle.to_tensor([1.0, 2.0]) + x2.stop_gradient = False + out1 = x1 * 2 + out2 = x2 * 2 + + out3 = paddle.multiply(out1, out2) + out4 = paddle.mean(out3) + dout2, dout3 = paddle.grad([out4], [out2, out3]) + + self.assertEqual(dout2.stop_gradient, egr_dout2.stop_gradient) + self.assertEqual(dout3.stop_gradient, egr_dout3.stop_gradient) + self.assertTrue(np.array_equal(dout2.numpy(), egr_dout2.numpy())) + self.assertTrue(np.array_equal(dout3.numpy(), egr_dout3.numpy())) + class TestDygraphDoubleGrad(TestCase): -- GitLab