未验证 提交 facda828 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager grad] Refactor partial grad logic (#40693)

* Refactor partial_grad/backward logic

* Add DuplicateCheck and polish code

* Refactor partial_grad/backward more clearly

* Refactor GeneralGrad by SingleInstance
上级 cc853e95
...@@ -29,90 +29,83 @@ ...@@ -29,90 +29,83 @@
namespace egr { namespace egr {
std::unordered_map<GradNodeBase*, int> getInDegreeMap( /*
const std::queue<GradNodeBase*>& init_queue) { * GeneralGrad is Helpper class to implement custom grad operation between
// Calculate in_degree for each node * outputs and inputs.
// We can completely remove this pass, if in_degree were set during forward *
// pass * **/
std::unordered_map<GradNodeBase*, int> node_in_degree_map; class GeneralGrad {
public:
// Copy nodes static GeneralGrad& Instance() { return *general_grad_; }
std::queue<GradNodeBase*> queue = init_queue;
std::unordered_set<GradNodeBase*> visited; // Get inputs's / no_grad_vars's GradNodes and InputMeta Info
size_t potential_startup_ops_cnt = queue.size(); void GetTargetNodesInfo(
size_t cnt = 0; const std::vector<paddle::experimental::Tensor>& inputs,
bool is_no_grad_vars) {
// Visit each node exactly once in any order std::string msg = is_no_grad_vars ? "no_grad_vars" : "inputs";
while (!queue.empty()) { VLOG(6) << "Running in GetTargetNodesInfo.";
GradNodeBase* node = queue.front(); if (!inputs.empty()) {
queue.pop(); VLOG(6) << msg << " are not empty.";
size_t num_inputs = inputs.size();
if (cnt < potential_startup_ops_cnt) { for (size_t i = 0; i < num_inputs; i++) {
if (!node_in_degree_map.count(node)) { AutogradMeta* auto_grad_meta =
node_in_degree_map[node] = 0; EagerUtils::unsafe_autograd_meta(inputs[i]);
auto target_node = auto_grad_meta->GetMutableGradNode().get();
PADDLE_ENFORCE_NOT_NULL(target_node,
paddle::platform::errors::Fatal(
"There is no grad op for %s:[%d] or it's"
"stop_gradient=True.",
msg, i));
if (is_no_grad_vars) {
(no_grad_var_nodes_inputmeta_map)[target_node] = auto_grad_meta;
} else { // normal input
(input_target_nodes_inputmeta_map)[target_node] = auto_grad_meta;
}
} }
cnt += 1;
} }
if (visited.count(node)) {
continue;
} }
visited.insert(node);
PADDLE_ENFORCE_NOT_NULL(
node,
paddle::platform::errors::Fatal(
"We got null node when we traverse the backward graph, and this "
"should not happened please check your code and contact us."));
// Find and append next nodes
const std::vector<std::vector<Edge>>& edges = node->GetEdges();
for (const auto& edge_list : edges) {
for (const Edge& edge : edge_list) {
GradNodeBase* next_node = edge.GetMutableGradNode().get();
// Next node could be nullptr if it is leaf tensor with no
// AccumulationNode attached
// Or it could also originated from dispensable inputs
if (!next_node) continue;
// Update in_degree // Purify potential_startup_nodes, remove nodes those are the same as
if (!node_in_degree_map.count(next_node)) // input_target_nodes
node_in_degree_map[next_node] = 0; void PurifyPotentialStartUpNodes() {
node_in_degree_map[next_node]++; VLOG(6) << "Running in PurifyPotentialStartUpNodes";
queue.push(next_node); if (input_target_nodes_inputmeta_map.empty()) return;
std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
for (auto startup_op : potential_startup_nodes) {
auto iter = input_target_nodes_inputmeta_map.find(startup_op);
if (iter != input_target_nodes_inputmeta_map.end()) {
potential_startup_nodes_to_be_erased.emplace(iter->first);
}
}
if (!potential_startup_nodes_to_be_erased.empty()) {
for (auto nodes : potential_startup_nodes_to_be_erased) {
potential_startup_nodes.erase(nodes);
} }
} }
} }
return node_in_degree_map;
}
// Remove some nodes those doesn't need to be // Remove some nodes those doesn't need to be
// stored in potential_stop_nodes、potential_startup_nodes // stored in potential_stop_nodes、potential_startup_nodes
void UpdateGraphInfo( void UpdateGraphInfo() {
std::unordered_map<GradNodeBase*, AutogradMeta*>*
target_nodes_inputmeta_map,
std::unordered_map<GradNodeBase*, std::unordered_set<GradNodeBase*>>*
depending_nodes,
std::unordered_set<GradNodeBase*>* potential_stop_nodes,
std::unordered_set<GradNodeBase*>* potential_startup_nodes) {
// Updated potential_sotp_nodes by depending_nodes, // Updated potential_sotp_nodes by depending_nodes,
// make sure the path from root to target_node is ok // make sure the path from root to target_node is ok
std::unordered_set<GradNodeBase*> _startup_ops; std::unordered_set<GradNodeBase*> _startup_ops;
VLOG(6) << "Running in UpdateGraphInfo"; VLOG(6) << "Running in UpdateGraphInfo";
std::queue<GradNodeBase*> queue; std::queue<GradNodeBase*> queue;
for (auto& target_nodes_inputmeta_pair : *target_nodes_inputmeta_map) { for (auto& target_nodes_inputmeta_pair : input_target_nodes_inputmeta_map) {
queue.emplace(target_nodes_inputmeta_pair.first); queue.emplace(target_nodes_inputmeta_pair.first);
} }
while (!queue.empty()) { while (!queue.empty()) {
auto* target_node = queue.front(); auto* target_node = queue.front();
queue.pop(); queue.pop();
if (!(*depending_nodes)[target_node].empty()) { if (!(depending_nodes)[target_node].empty()) {
auto precedding_nodes = (*depending_nodes)[target_node]; auto precedding_nodes = (depending_nodes)[target_node];
for (auto pre_nodes : precedding_nodes) { for (auto pre_nodes : precedding_nodes) {
queue.emplace(pre_nodes); queue.emplace(pre_nodes);
if (potential_stop_nodes->find(pre_nodes) != if (potential_stop_nodes.find(pre_nodes) !=
potential_stop_nodes->end()) { potential_stop_nodes.end()) {
potential_stop_nodes->erase(pre_nodes); potential_stop_nodes.erase(pre_nodes);
} }
} }
} else { // startup_ops have no precedding nodes } else { // startup_ops have no precedding nodes
...@@ -124,7 +117,7 @@ void UpdateGraphInfo( ...@@ -124,7 +117,7 @@ void UpdateGraphInfo(
// potential startup_nodes that unreach to input target nodes // potential startup_nodes that unreach to input target nodes
if (!_startup_ops.empty()) { if (!_startup_ops.empty()) {
std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased; std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
for (auto node : *potential_startup_nodes) { for (auto node : potential_startup_nodes) {
if (_startup_ops.count(node) == 0) { if (_startup_ops.count(node) == 0) {
VLOG(6) << "Set up potential_startup_nodes_to_be_erased"; VLOG(6) << "Set up potential_startup_nodes_to_be_erased";
potential_startup_nodes_to_be_erased.emplace(node); potential_startup_nodes_to_be_erased.emplace(node);
...@@ -133,25 +126,15 @@ void UpdateGraphInfo( ...@@ -133,25 +126,15 @@ void UpdateGraphInfo(
if (!potential_startup_nodes_to_be_erased.empty()) { if (!potential_startup_nodes_to_be_erased.empty()) {
for (auto node : potential_startup_nodes_to_be_erased) { for (auto node : potential_startup_nodes_to_be_erased) {
VLOG(6) << "Erase nodes in potential_startup_nodes_to_be_erased"; VLOG(6) << "Erase nodes in potential_startup_nodes_to_be_erased";
potential_startup_nodes->erase(node); potential_startup_nodes.erase(node);
}
} }
} }
} }
}
// Get Graph Info Betweent input target gradnode and outputs,
// record depending_nodes、 potential_stop_nodes、potential_startup_nodes
void GetGraphInfoBetweenTargets(
const std::queue<GradNodeBase*>& init_queue,
std::unordered_map<GradNodeBase*, AutogradMeta*>*
input_target_nodes_inputmeta_map,
std::unordered_map</*child node*/ GradNodeBase*,
/*father nodes*/ std::unordered_set<GradNodeBase*>>*
depending_nodes,
std::unordered_set<GradNodeBase*>* potential_stop_nodes,
std::unordered_set<GradNodeBase*>* potential_startup_nodes) {
if (input_target_nodes_inputmeta_map->empty()) return;
// Get Graph Info Betweent input target GradNode and outputs,
// record depending_nodes、potential_stop_nodes、potential_startup_nodes
void GetGraphInfoBetweenTargets(const std::queue<GradNodeBase*>& init_queue) {
VLOG(6) << "Runing In GetGraphInfoBetweenTargets"; VLOG(6) << "Runing In GetGraphInfoBetweenTargets";
// Calculate in_degree for each node // Calculate in_degree for each node
...@@ -174,7 +157,7 @@ void GetGraphInfoBetweenTargets( ...@@ -174,7 +157,7 @@ void GetGraphInfoBetweenTargets(
// Check node is target_nodes or not, if node is not target_node, // Check node is target_nodes or not, if node is not target_node,
// all the next_node will be marked in potential_stop_nodes // all the next_node will be marked in potential_stop_nodes
bool is_potential_stop_nodes = bool is_potential_stop_nodes =
input_target_nodes_inputmeta_map->count(node); input_target_nodes_inputmeta_map.count(node);
// Find and append next nodes // Find and append next nodes
const std::vector<std::vector<Edge>>& edges = node->GetEdges(); const std::vector<std::vector<Edge>>& edges = node->GetEdges();
...@@ -191,7 +174,7 @@ void GetGraphInfoBetweenTargets( ...@@ -191,7 +174,7 @@ void GetGraphInfoBetweenTargets(
// all the next_nodes of current node will be inserted to // all the next_nodes of current node will be inserted to
// potential_stop_node // potential_stop_node
if (is_potential_stop_nodes) { if (is_potential_stop_nodes) {
potential_stop_nodes->emplace(next_node); potential_stop_nodes.emplace(next_node);
} }
// Update in_degree // Update in_degree
...@@ -200,42 +183,62 @@ void GetGraphInfoBetweenTargets( ...@@ -200,42 +183,62 @@ void GetGraphInfoBetweenTargets(
node_in_degree_map[next_node]++; node_in_degree_map[next_node]++;
// Record depending relationship // Record depending relationship
(*depending_nodes)[next_node].emplace(node); (depending_nodes)[next_node].emplace(node);
queue.push(next_node); queue.push(next_node);
} }
} }
} }
// Update Graph Info, remove some stop_node in potential_stop_nodes // Update Graph Info, remove some nodes in
UpdateGraphInfo(input_target_nodes_inputmeta_map, depending_nodes, // potential_stop_nodes、potential_startup_nodes、
potential_stop_nodes, potential_startup_nodes); UpdateGraphInfo();
} }
void GetTargetNodesInfo(const std::vector<paddle::experimental::Tensor>& inputs, void ModifyReadyQueue(std::queue<GradNodeBase*>* queue) {
std::unordered_map<GradNodeBase*, AutogradMeta*>* std::queue<GradNodeBase*> tmp_queue;
target_nodes_inputmeta_map) { for (auto nodes : potential_startup_nodes) {
VLOG(6) << "Running in GetTargetNodesInfo"; tmp_queue.emplace(nodes);
if (!inputs.empty()) { }
VLOG(6) << "Inputs are not empty"; tmp_queue.swap(*queue);
size_t num_inputs = inputs.size(); }
for (size_t i = 0; i < num_inputs; i++) {
AutogradMeta* auto_grad_meta =
EagerUtils::unsafe_autograd_meta(inputs[i]);
auto target_node = auto_grad_meta->GetMutableGradNode().get();
PADDLE_ENFORCE_NOT_NULL(target_node, // Set result for input target grad_var when potential_startup_nodes is empty
paddle::platform::errors::Fatal( void SetResultForInputTargetVar(
"There is no grad op for input:%d or it's" const std::unordered_map<GradNodeBase*,
"stop_gradient=True", std::unique_ptr<GradTensorHolder>>&
i)); node_input_buffers_dict) {
(*target_nodes_inputmeta_map)[target_node] = auto_grad_meta; if (potential_startup_nodes.size() == 0) {
for (auto input_target_node : *GetInPutTargetNodesInputMetaMap()) {
// out rank_info of forward op
auto rank_info = input_target_node.second->OutRankInfo();
auto iter = node_input_buffers_dict.find(input_target_node.first);
if (iter != node_input_buffers_dict.end()) {
auto& target_result =
(iter->second)->Buffers()[rank_info.first][rank_info.second];
// save the target result
results_map[input_target_node.first] = target_result;
}
}
} }
} }
}
std::vector<paddle::experimental::Tensor> GetResults( // Set input target grad_var from node_input_buffer by inputmeta
void SetResultForInputTargetVar(GradTensorHolder input_buffers,
GradNodeBase* node) {
auto iter = GetInPutTargetNodesInputMetaMap()->find(node);
if (iter != GetInPutTargetNodesInputMetaMap()->end()) {
VLOG(6) << "Get target result by by inputmeta";
// out rank_info of forward op
auto rank_info = (iter->second)->OutRankInfo();
// rank_info is a pair, first means slot_id, second means rank.
auto& target_result =
input_buffers.Buffers()[rank_info.first][rank_info.second];
// save the target result
results_map[node] = target_result;
}
}
std::vector<paddle::experimental::Tensor> GetResults(
const std::vector<paddle::experimental::Tensor>& inputs, const std::vector<paddle::experimental::Tensor>& inputs,
std::unordered_map<GradNodeBase*, paddle::experimental::Tensor>*
results_map,
bool allow_unused, bool create_graph) { bool allow_unused, bool create_graph) {
VLOG(6) << "Running in GetResults"; VLOG(6) << "Running in GetResults";
if (inputs.empty()) return {}; if (inputs.empty()) return {};
...@@ -248,8 +251,8 @@ std::vector<paddle::experimental::Tensor> GetResults( ...@@ -248,8 +251,8 @@ std::vector<paddle::experimental::Tensor> GetResults(
AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input); AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input);
auto target_node = auto_grad_meta->GetMutableGradNode().get(); auto target_node = auto_grad_meta->GetMutableGradNode().get();
auto iter = results_map->find(target_node); auto iter = results_map.find(target_node);
if (iter != results_map->end()) { if (iter != results_map.end()) {
// set StopGradient = !create_graph // set StopGradient = !create_graph
AutogradMeta* tensor_auto_grad_meta = AutogradMeta* tensor_auto_grad_meta =
EagerUtils::autograd_meta(&(iter->second)); EagerUtils::autograd_meta(&(iter->second));
...@@ -259,13 +262,137 @@ std::vector<paddle::experimental::Tensor> GetResults( ...@@ -259,13 +262,137 @@ std::vector<paddle::experimental::Tensor> GetResults(
PADDLE_ENFORCE_EQ(allow_unused, true, PADDLE_ENFORCE_EQ(allow_unused, true,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The %d-th input does not appear in the backward " "The %d-th input does not appear in the backward "
"graph. Please check the input variable or set " "graph. Please check the input tensor or set "
"allow_unused=True to get None result.", "allow_unused=True to get None result.",
i)); i));
results.emplace_back(); results.emplace_back();
} }
} }
Clear();
return results; return results;
}
void PreparedForGeneralGrad(
const std::vector<paddle::experimental::Tensor>& inputs,
const std::vector<paddle::experimental::Tensor>& no_grad_vars,
std::queue<GradNodeBase*>* queue,
const std::unordered_map<GradNodeBase*,
std::unique_ptr<GradTensorHolder>>&
node_input_buffers_dict) {
// Get no_grad_vars's GradNodes and InputMeta Info
GetTargetNodesInfo(no_grad_vars, true /* is_no_grad_vars */);
// Get inputs's GradNodes and InputMeta Info
GetTargetNodesInfo(inputs, false /* is_no_grad_vars */);
// Purify potential_startup_ops, remove those nodes that are the same as
// input_target_nodes
PurifyPotentialStartUpNodes();
// Get Graph Info Betweent input target gradnode and outputs
// Record the depending_nodes and
// potential_stop_nodes、potential_startup_nodes
GetGraphInfoBetweenTargets(*queue);
// Reset queue. Queue is empty only when
// 1.input equals to output. 2.input can not reach to output.
ModifyReadyQueue(queue);
// Set result for input target grad_var when queue is empty
if (queue->empty()) SetResultForInputTargetVar(node_input_buffers_dict);
}
bool IsPotentialStopNodes(GradNodeBase* node) {
return potential_stop_nodes.count(node);
}
std::unordered_map<GradNodeBase*, AutogradMeta*>*
GetNoGradVarNodesInputMetaMap() {
return &no_grad_var_nodes_inputmeta_map;
}
std::unordered_map<GradNodeBase*, AutogradMeta*>*
GetInPutTargetNodesInputMetaMap() {
return &input_target_nodes_inputmeta_map;
}
std::unordered_set<GradNodeBase*>* GetPotentialStopNodes() {
return &potential_stop_nodes;
}
std::unordered_set<GradNodeBase*>* GetPotentialStartupNodes() {
return &potential_startup_nodes;
}
void Clear() {
no_grad_var_nodes_inputmeta_map.clear();
input_target_nodes_inputmeta_map.clear();
potential_startup_nodes.clear();
potential_stop_nodes.clear();
depending_nodes.clear();
results_map.clear();
}
private:
GeneralGrad() = default;
static GeneralGrad* general_grad_;
// no_grad_vars's GradNode and GradNode's InputMeta.
std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
no_grad_var_nodes_inputmeta_map;
// inputs's GradNode and GradNode's InputMeta.
std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
input_target_nodes_inputmeta_map;
// Record all the potential startup_nodes, will be changed.
std::unordered_set<GradNodeBase*> potential_startup_nodes;
// Record all the potential stop nodes, will be changed.
std::unordered_set<GradNodeBase*> potential_stop_nodes;
std::unordered_map<GradNodeBase* /* next node */,
std::unordered_set<GradNodeBase*> /* pre nodes */>
depending_nodes;
std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map;
DISABLE_COPY_AND_ASSIGN(GeneralGrad);
};
std::unordered_map<GradNodeBase*, int> getInDegreeMap(
const std::queue<GradNodeBase*>& init_queue) {
// Calculate in_degree for each node
// We can completely remove this pass, if in_degree were set during forward
// pass
std::unordered_map<GradNodeBase*, int> node_in_degree_map;
// Copy nodes
std::queue<GradNodeBase*> queue = init_queue;
std::unordered_set<GradNodeBase*> visited;
// Visit each node exactly once in any order
while (!queue.empty()) {
GradNodeBase* node = queue.front();
queue.pop();
if (visited.count(node)) {
continue;
}
visited.insert(node);
PADDLE_ENFORCE_NOT_NULL(
node,
paddle::platform::errors::Fatal(
"We got null node when we traverse the backward graph, and this "
"should not happened please check your code and contact us."));
// Find and append next nodes
const std::vector<std::vector<Edge>>& edges = node->GetEdges();
for (const auto& edge_list : edges) {
for (const Edge& edge : edge_list) {
GradNodeBase* next_node = edge.GetMutableGradNode().get();
// Next node could be nullptr if it is leaf tensor with no
// AccumulationNode attached
// Or it could also originated from dispensable inputs
if (!next_node) continue;
// Update in_degree
if (!node_in_degree_map.count(next_node))
node_in_degree_map[next_node] = 0;
node_in_degree_map[next_node]++;
queue.push(next_node);
}
}
}
return node_in_degree_map;
} }
// Enforce GradNode has TensorWrappers as Input // Enforce GradNode has TensorWrappers as Input
...@@ -281,28 +408,23 @@ void EnforceGradNodeHasInput(GradNodeBase* node) { ...@@ -281,28 +408,23 @@ void EnforceGradNodeHasInput(GradNodeBase* node) {
node->name())); node->name()));
} }
// Purify potential_startup_nodes, remove nodes those are the same as void DuplicateCheck(const std::vector<paddle::experimental::Tensor>& inputs,
// input_target_nodes bool is_input) {
void PurifyPotentialStartUpNodes( std::unordered_set<AutogradMeta*> visisted_ins;
std::unordered_set<GradNodeBase*>* potential_startup_nodes, std::string msg = is_input ? "inputs" : "outputs";
std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>* for (auto in : inputs) {
input_target_nodes_inputmeta_map) { AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(in);
VLOG(6) << "Running in PurifyPotentialStartUpNodes"; PADDLE_ENFORCE_EQ(
if (input_target_nodes_inputmeta_map->empty()) return; visisted_ins.count(auto_grad_meta), 0,
std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased; paddle::platform::errors::AlreadyExists(
for (auto startup_op : *potential_startup_nodes) { "%s contain duplicate tensor %s, please check %s carefully.", msg,
auto iter = input_target_nodes_inputmeta_map->find(startup_op); in.name(), msg));
if (iter != input_target_nodes_inputmeta_map->end()) { visisted_ins.insert(auto_grad_meta);
potential_startup_nodes_to_be_erased.emplace(iter->first);
}
}
if (!potential_startup_nodes_to_be_erased.empty()) {
for (auto nodes : potential_startup_nodes_to_be_erased) {
potential_startup_nodes->erase(nodes);
}
} }
} }
GeneralGrad* GeneralGrad::general_grad_ = new GeneralGrad();
std::vector<paddle::experimental::Tensor> RunBackward( std::vector<paddle::experimental::Tensor> RunBackward(
const std::vector<paddle::experimental::Tensor>& tensors, // output const std::vector<paddle::experimental::Tensor>& tensors, // output
const std::vector<paddle::experimental::Tensor>& grad_tensors, const std::vector<paddle::experimental::Tensor>& grad_tensors,
...@@ -315,10 +437,8 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -315,10 +437,8 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// *Inplace version check should perform at node-level // *Inplace version check should perform at node-level
// *Cross-batch accumulation happens at forward pass // *Cross-batch accumulation happens at forward pass
std::unordered_map<GradNodeBase*, AutogradMeta*> // GeneralGrad
no_grad_var_nodes_inputmeta_map; bool is_general_grad = !inputs.empty();
// Get no_grad_vars's GradNodes and InputMeta Info
GetTargetNodesInfo(no_grad_vars, &no_grad_var_nodes_inputmeta_map);
/* --- Initialization --- */ /* --- Initialization --- */
// 1. Init queue with starting nodes // 1. Init queue with starting nodes
...@@ -326,7 +446,6 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -326,7 +446,6 @@ std::vector<paddle::experimental::Tensor> RunBackward(
std::queue<GradNodeBase*> queue; std::queue<GradNodeBase*> queue;
std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>> std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>>
node_input_buffers_dict; node_input_buffers_dict;
std::unordered_set<GradNodeBase*> potential_startup_nodes;
for (size_t i = 0; i < tensors.size(); i++) { for (size_t i = 0; i < tensors.size(); i++) {
const paddle::experimental::Tensor& tensor = tensors[i]; const paddle::experimental::Tensor& tensor = tensors[i];
...@@ -363,7 +482,7 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -363,7 +482,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
"Detected size mismatch between tensors and grad_tensors" "Detected size mismatch between tensors and grad_tensors"
"grad_tensors should either have " "grad_tensors should either have "
"size = 0 or same size as tensors")); "size = 0 or same size as tensors."));
// Feed given tensor if it's provided // Feed given tensor if it's provided
VLOG(6) << "Fill grad input tensor " << i << "with give grad tensor"; VLOG(6) << "Fill grad input tensor " << i << "with give grad tensor";
...@@ -391,7 +510,9 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -391,7 +510,9 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// Prepare queue, potential startup_nodes // Prepare queue, potential startup_nodes
queue.push(grad_node); queue.push(grad_node);
potential_startup_nodes.emplace(grad_node); if (is_general_grad) {
GeneralGrad::Instance().GetPotentialStartupNodes()->emplace(grad_node);
}
} }
VLOG(6) << "Update In degree Map for backward"; VLOG(6) << "Update In degree Map for backward";
...@@ -399,56 +520,13 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -399,56 +520,13 @@ std::vector<paddle::experimental::Tensor> RunBackward(
std::unordered_map<GradNodeBase*, int> node_in_degree_map = std::unordered_map<GradNodeBase*, int> node_in_degree_map =
getInDegreeMap(queue); getInDegreeMap(queue);
// Get input's GradNodes and InputMeta Info if (is_general_grad) {
std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */> // Prepare several vital preprocess for GeneralGrad
input_target_nodes_inputmeta_map; GeneralGrad::Instance().PreparedForGeneralGrad(inputs, no_grad_vars, &queue,
GetTargetNodesInfo(inputs, &input_target_nodes_inputmeta_map); node_input_buffers_dict);
// Purify potential_startup_ops, remove those nodes that are the same as
// input_target_nodes
PurifyPotentialStartUpNodes(&potential_startup_nodes,
&input_target_nodes_inputmeta_map);
// Get Graph Info Betweent input target gradnode and outputs
// Record the depending_nodes and potential_stop_nodes
std::unordered_map<GradNodeBase* /* child node */,
std::unordered_set<GradNodeBase*> /* father node */>
depending_nodes;
std::unordered_set<GradNodeBase*> potential_stop_nodes;
// std::unordered_set<GradNodeBase*> startup_ops;
GetGraphInfoBetweenTargets(queue, &input_target_nodes_inputmeta_map,
&depending_nodes, &potential_stop_nodes,
&potential_startup_nodes);
// ready_queue store all startup nodes
std::queue<GradNodeBase*> ready_queue;
// startup op's indegree should be 0
for (auto node : potential_startup_nodes) {
if (node_in_degree_map[node] == 0) {
ready_queue.emplace(node);
}
} }
VLOG(1) << " startup_ops' size is :" << ready_queue.size(); VLOG(6) << " startup_ops' size is :" << queue.size();
std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map;
// read_queue is empty only when 1.input equals to output. 2.input can not
// reach to output.
if (ready_queue.size() == 0) {
for (auto input_target_node : input_target_nodes_inputmeta_map) {
// out rank_info of forward op
auto rank_info = input_target_node.second->OutRankInfo();
if (node_input_buffers_dict[input_target_node.first]) {
auto& target_result =
node_input_buffers_dict[input_target_node.first]
->Buffers()[rank_info.first][rank_info.second];
// save the target result
results_map[input_target_node.first] = target_result;
}
}
}
/* --- Topological Visit --- */ /* --- Topological Visit --- */
// 1. Pop queue // 1. Pop queue
...@@ -458,53 +536,55 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -458,53 +536,55 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// |- Prepare for next node // |- Prepare for next node
// 3. Update queue // 3. Update queue
VLOG(6) << "Run Backward"; VLOG(6) << "Run Backward";
while (!ready_queue.empty()) { while (!queue.empty()) {
GradNodeBase* node = ready_queue.front(); GradNodeBase* node = queue.front();
VLOG(6) << "Running GradNode:" << node->name(); VLOG(6) << "Running GradNode:" << node->name();
ready_queue.pop();
paddle::platform::RecordEvent node_record_event( paddle::platform::RecordEvent node_record_event(
std::string(typeid(*node).name()) + " grad_node", std::string(typeid(*node).name()) + " grad_node",
paddle::platform::TracerEventType::Operator, 1); paddle::platform::TracerEventType::Operator, 1);
if (queue.size() > 1 && node_in_degree_map[node] != 0) {
queue.pop();
continue;
}
queue.pop();
// Run node: This is where Hook happens // Run node: This is where Hook happens
PADDLE_ENFORCE( PADDLE_ENFORCE(
node_input_buffers_dict.count(node), node_input_buffers_dict.count(node),
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
"Unable to find next node in the GradTensorHolder \n" "Unable to find next node in the GradTensorHolder \n"
"Trying to run Node without configuring its GradTensorHolder")); "Trying to run Node without configuring its GradTensorHolder."));
std::unique_ptr<GradTensorHolder> node_input_buffer = std::unique_ptr<GradTensorHolder> node_input_buffer =
std::move(node_input_buffers_dict[node]); std::move(node_input_buffers_dict[node]);
// get target grad_var from node_input_buffer by inputmeta // Set input target grad_var from node_input_buffer by inputmeta
if (input_target_nodes_inputmeta_map.find(node) != if (!inputs.empty() && is_general_grad) {
input_target_nodes_inputmeta_map.end()) { GeneralGrad::Instance().SetResultForInputTargetVar(*node_input_buffer,
VLOG(6) << "Get target result by by inputmeta"; node);
// out rank_info of forward op
auto rank_info = input_target_nodes_inputmeta_map[node]->OutRankInfo();
// rank_info is a pair, first means slot_id, second means rank.
auto& target_result =
node_input_buffer->Buffers()[rank_info.first][rank_info.second];
// save the target result
results_map[node] = target_result;
} }
// no_grad_vars // no_grad_vars
if (no_grad_var_nodes_inputmeta_map.find(node) != if (!no_grad_vars.empty() && is_general_grad) {
no_grad_var_nodes_inputmeta_map.end()) { auto iter =
GeneralGrad::Instance().GetNoGradVarNodesInputMetaMap()->find(node);
if (iter !=
GeneralGrad::Instance().GetNoGradVarNodesInputMetaMap()->end()) {
VLOG(6) << "Change the input buffer[slot][rank] by Zeros"; VLOG(6) << "Change the input buffer[slot][rank] by Zeros";
auto rank_info = no_grad_var_nodes_inputmeta_map[node]->OutRankInfo(); auto rank_info = (iter->second)->OutRankInfo();
node_input_buffer->SetBufferSlotRankZeros(rank_info.first, node_input_buffer->SetBufferSlotRankZeros(rank_info.first,
rank_info.second); rank_info.second);
} }
}
VLOG(6) << "Running GradNode:" << node->name(); VLOG(6) << "Running GradNode:" << node->name();
// check input // Check input
EnforceGradNodeHasInput(node); EnforceGradNodeHasInput(node);
VLOG(6) << "Run Backward Kernel with GradTensorHolder"; VLOG(6) << "Run Backward Kernel with GradTensorHolder.";
// Run Pre Backward Node and get outputs // Run Pre Backward Node and get outputs
std::vector<std::vector<paddle::experimental::Tensor>> grad_output_tensors = std::vector<std::vector<paddle::experimental::Tensor>> grad_output_tensors =
(*node)(node_input_buffer->Buffers(), create_graph); (*node)(node_input_buffer->Buffers(), create_graph);
...@@ -587,23 +667,29 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -587,23 +667,29 @@ std::vector<paddle::experimental::Tensor> RunBackward(
node_in_degree_map[next_node] >= 0, node_in_degree_map[next_node] >= 0,
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
"Detected in-degree value smaller than zero. For Node: %s" "Detected in-degree value smaller than zero. For Node: %s"
"Node's in-degree cannot be negative", "Node's in-degree cannot be negative.",
next_node->name())); next_node->name()));
bool is_potential_stop_node = potential_stop_nodes.count(next_node); if (is_general_grad) {
bool is_potential_stop_node =
GeneralGrad::Instance().GetPotentialStopNodes()->count(next_node);
if (node_in_degree_map[next_node] == 0 && !is_potential_stop_node) { if (node_in_degree_map[next_node] == 0 && !is_potential_stop_node) {
ready_queue.emplace(std::move(next_node)); queue.emplace(std::move(next_node));
}
} else {
if (node_in_degree_map[next_node] == 0) {
queue.emplace(std::move(next_node));
} }
} }
} }
} }
}
return GetResults(inputs, &results_map, allow_unused, create_graph); if (!is_general_grad) return {};
return GeneralGrad::Instance().GetResults(inputs, allow_unused, create_graph);
} }
void Backward( void Backward(
const std::vector<paddle::experimental::Tensor>& tensors, // output const std::vector<paddle::experimental::Tensor>& tensors, // outputs
const std::vector<paddle::experimental::Tensor>& grad_tensors, const std::vector<paddle::experimental::Tensor>& grad_tensors,
bool retain_graph) { bool retain_graph) {
VLOG(6) << "Run in Backward"; VLOG(6) << "Run in Backward";
...@@ -613,12 +699,16 @@ void Backward( ...@@ -613,12 +699,16 @@ void Backward(
} }
std::vector<paddle::experimental::Tensor> Grad( std::vector<paddle::experimental::Tensor> Grad(
const std::vector<paddle::experimental::Tensor>& tensors, // output const std::vector<paddle::experimental::Tensor>& tensors, // outputs
const std::vector<paddle::experimental::Tensor>& inputs, const std::vector<paddle::experimental::Tensor>& inputs,
const std::vector<paddle::experimental::Tensor>& grad_tensors, const std::vector<paddle::experimental::Tensor>& grad_tensors,
bool retain_graph, bool create_graph, bool only_inputs, bool allow_unused, bool retain_graph, bool create_graph, bool only_inputs, bool allow_unused,
const std::vector<paddle::experimental::Tensor>& no_grad_vars) { const std::vector<paddle::experimental::Tensor>& no_grad_vars) {
VLOG(6) << "Run in Grad"; VLOG(6) << "Run in Grad";
DuplicateCheck(inputs, true /* is_input */);
DuplicateCheck(tensors, false /* is_input */);
return RunBackward(tensors, grad_tensors, retain_graph, create_graph, inputs, return RunBackward(tensors, grad_tensors, retain_graph, create_graph, inputs,
allow_unused, no_grad_vars); allow_unused, no_grad_vars);
} }
......
...@@ -116,6 +116,54 @@ class TestEagerGrad(TestCase): ...@@ -116,6 +116,54 @@ class TestEagerGrad(TestCase):
self.func_simple_example_eager_grad_not_allow_unused() self.func_simple_example_eager_grad_not_allow_unused()
self.func_simple_example_eager_grad_not_allow_unused() self.func_simple_example_eager_grad_not_allow_unused()
def func_simple_example_eager_grad_duplicate_input(self):
np.random.seed(2021)
paddle.set_device('cpu')
np_x = np.random.random((3, 3))
np_y = np.random.random((3, 1))
np_z = np.random.random((3, 1))
x = paddle.to_tensor(np_x, dtype="float64", stop_gradient=False)
y = paddle.to_tensor(np_y, dtype="float64", stop_gradient=False)
z = paddle.to_tensor(np_z, dtype="float64", stop_gradient=False)
out_z = paddle.nn.functional.sigmoid(z)
out = paddle.matmul(x, y)
try:
# duplicate input will arise RuntimeError errors
dx = fluid.dygraph.grad(out, [x, x])
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("duplicate") > 0
def test_simple_example_eager_grad_duplicate_input(self):
with _test_eager_guard():
self.func_simple_example_eager_grad_duplicate_input()
self.func_simple_example_eager_grad_duplicate_input()
def func_simple_example_eager_grad_duplicate_output(self):
np.random.seed(2021)
paddle.set_device('cpu')
np_x = np.random.random((3, 3))
np_y = np.random.random((3, 1))
np_z = np.random.random((3, 1))
x = paddle.to_tensor(np_x, dtype="float64", stop_gradient=False)
y = paddle.to_tensor(np_y, dtype="float64", stop_gradient=False)
z = paddle.to_tensor(np_z, dtype="float64", stop_gradient=False)
out_z = paddle.nn.functional.sigmoid(z)
out = paddle.matmul(x, y)
try:
# duplicate output will arise RuntimeError errors
dx = fluid.dygraph.grad([out, out], [x])
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("duplicate") > 0
def test_simple_example_eager_grad_duplicate_output(self):
with _test_eager_guard():
self.func_simple_example_eager_grad_duplicate_output()
self.func_simple_example_eager_grad_duplicate_output()
class TestDygraphDoubleGrad(TestCase): class TestDygraphDoubleGrad(TestCase):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册