未验证 提交 facda828 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager grad] Refactor partial grad logic (#40693)

* Refactor partial_grad/backward logic

* Add DuplicateCheck and polish code

* Refactor partial_grad/backward more clearly

* Refactor GeneralGrad by SingleInstance
上级 cc853e95
......@@ -29,132 +29,330 @@
namespace egr {
std::unordered_map<GradNodeBase*, int> getInDegreeMap(
const std::queue<GradNodeBase*>& init_queue) {
// Calculate in_degree for each node
// We can completely remove this pass, if in_degree were set during forward
// pass
std::unordered_map<GradNodeBase*, int> node_in_degree_map;
/*
* GeneralGrad is Helpper class to implement custom grad operation between
* outputs and inputs.
*
* **/
class GeneralGrad {
public:
static GeneralGrad& Instance() { return *general_grad_; }
// Get inputs's / no_grad_vars's GradNodes and InputMeta Info
void GetTargetNodesInfo(
const std::vector<paddle::experimental::Tensor>& inputs,
bool is_no_grad_vars) {
std::string msg = is_no_grad_vars ? "no_grad_vars" : "inputs";
VLOG(6) << "Running in GetTargetNodesInfo.";
if (!inputs.empty()) {
VLOG(6) << msg << " are not empty.";
size_t num_inputs = inputs.size();
for (size_t i = 0; i < num_inputs; i++) {
AutogradMeta* auto_grad_meta =
EagerUtils::unsafe_autograd_meta(inputs[i]);
auto target_node = auto_grad_meta->GetMutableGradNode().get();
PADDLE_ENFORCE_NOT_NULL(target_node,
paddle::platform::errors::Fatal(
"There is no grad op for %s:[%d] or it's"
"stop_gradient=True.",
msg, i));
if (is_no_grad_vars) {
(no_grad_var_nodes_inputmeta_map)[target_node] = auto_grad_meta;
} else { // normal input
(input_target_nodes_inputmeta_map)[target_node] = auto_grad_meta;
}
}
}
}
// Copy nodes
std::queue<GradNodeBase*> queue = init_queue;
std::unordered_set<GradNodeBase*> visited;
size_t potential_startup_ops_cnt = queue.size();
size_t cnt = 0;
// Purify potential_startup_nodes, remove nodes those are the same as
// input_target_nodes
void PurifyPotentialStartUpNodes() {
VLOG(6) << "Running in PurifyPotentialStartUpNodes";
if (input_target_nodes_inputmeta_map.empty()) return;
std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
for (auto startup_op : potential_startup_nodes) {
auto iter = input_target_nodes_inputmeta_map.find(startup_op);
if (iter != input_target_nodes_inputmeta_map.end()) {
potential_startup_nodes_to_be_erased.emplace(iter->first);
}
}
if (!potential_startup_nodes_to_be_erased.empty()) {
for (auto nodes : potential_startup_nodes_to_be_erased) {
potential_startup_nodes.erase(nodes);
}
}
}
// Visit each node exactly once in any order
while (!queue.empty()) {
GradNodeBase* node = queue.front();
queue.pop();
// Remove some nodes those doesn't need to be
// stored in potential_stop_nodes、potential_startup_nodes
void UpdateGraphInfo() {
// Updated potential_sotp_nodes by depending_nodes,
// make sure the path from root to target_node is ok
std::unordered_set<GradNodeBase*> _startup_ops;
VLOG(6) << "Running in UpdateGraphInfo";
std::queue<GradNodeBase*> queue;
for (auto& target_nodes_inputmeta_pair : input_target_nodes_inputmeta_map) {
queue.emplace(target_nodes_inputmeta_pair.first);
}
if (cnt < potential_startup_ops_cnt) {
if (!node_in_degree_map.count(node)) {
node_in_degree_map[node] = 0;
while (!queue.empty()) {
auto* target_node = queue.front();
queue.pop();
if (!(depending_nodes)[target_node].empty()) {
auto precedding_nodes = (depending_nodes)[target_node];
for (auto pre_nodes : precedding_nodes) {
queue.emplace(pre_nodes);
if (potential_stop_nodes.find(pre_nodes) !=
potential_stop_nodes.end()) {
potential_stop_nodes.erase(pre_nodes);
}
}
} else { // startup_ops have no precedding nodes
VLOG(6) << "Emplace _startup_ops";
_startup_ops.emplace(target_node);
}
cnt += 1;
}
if (visited.count(node)) {
continue;
// Purify potential_startup_nodes again, remove some
// potential startup_nodes that unreach to input target nodes
if (!_startup_ops.empty()) {
std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
for (auto node : potential_startup_nodes) {
if (_startup_ops.count(node) == 0) {
VLOG(6) << "Set up potential_startup_nodes_to_be_erased";
potential_startup_nodes_to_be_erased.emplace(node);
}
}
if (!potential_startup_nodes_to_be_erased.empty()) {
for (auto node : potential_startup_nodes_to_be_erased) {
VLOG(6) << "Erase nodes in potential_startup_nodes_to_be_erased";
potential_startup_nodes.erase(node);
}
}
}
visited.insert(node);
}
PADDLE_ENFORCE_NOT_NULL(
node,
paddle::platform::errors::Fatal(
"We got null node when we traverse the backward graph, and this "
"should not happened please check your code and contact us."));
// Find and append next nodes
const std::vector<std::vector<Edge>>& edges = node->GetEdges();
for (const auto& edge_list : edges) {
for (const Edge& edge : edge_list) {
GradNodeBase* next_node = edge.GetMutableGradNode().get();
// Next node could be nullptr if it is leaf tensor with no
// AccumulationNode attached
// Or it could also originated from dispensable inputs
if (!next_node) continue;
// Get Graph Info Betweent input target GradNode and outputs,
// record depending_nodes、potential_stop_nodes、potential_startup_nodes
void GetGraphInfoBetweenTargets(const std::queue<GradNodeBase*>& init_queue) {
VLOG(6) << "Runing In GetGraphInfoBetweenTargets";
// Update in_degree
if (!node_in_degree_map.count(next_node))
node_in_degree_map[next_node] = 0;
node_in_degree_map[next_node]++;
queue.push(next_node);
// Calculate in_degree for each node
std::unordered_map<GradNodeBase*, int> node_in_degree_map;
// Copy nodes
std::queue<GradNodeBase*> queue = init_queue;
std::unordered_set<GradNodeBase*> visited;
// Visit each node exactly once in any order
while (!queue.empty()) {
GradNodeBase* node = queue.front();
queue.pop();
if (visited.count(node)) {
continue;
}
visited.insert(node);
// Check node is target_nodes or not, if node is not target_node,
// all the next_node will be marked in potential_stop_nodes
bool is_potential_stop_nodes =
input_target_nodes_inputmeta_map.count(node);
// Find and append next nodes
const std::vector<std::vector<Edge>>& edges = node->GetEdges();
for (const auto& edge_list : edges) {
for (const Edge& edge : edge_list) {
GradNodeBase* next_node = edge.GetMutableGradNode().get();
// Next node could be nullptr if it is leaf tensor with no
// AccumulationNode attached
// Or it could also originated from dispensable inputs
if (!next_node) continue;
// if node not in input_target_nodes,
// all the next_nodes of current node will be inserted to
// potential_stop_node
if (is_potential_stop_nodes) {
potential_stop_nodes.emplace(next_node);
}
// Update in_degree
if (!node_in_degree_map.count(next_node))
node_in_degree_map[next_node] = 0;
node_in_degree_map[next_node]++;
// Record depending relationship
(depending_nodes)[next_node].emplace(node);
queue.push(next_node);
}
}
}
// Update Graph Info, remove some nodes in
// potential_stop_nodes、potential_startup_nodes、
UpdateGraphInfo();
}
return node_in_degree_map;
}
// Remove some nodes those doesn't need to be
// stored in potential_stop_nodes、potential_startup_nodes
void UpdateGraphInfo(
std::unordered_map<GradNodeBase*, AutogradMeta*>*
target_nodes_inputmeta_map,
std::unordered_map<GradNodeBase*, std::unordered_set<GradNodeBase*>>*
depending_nodes,
std::unordered_set<GradNodeBase*>* potential_stop_nodes,
std::unordered_set<GradNodeBase*>* potential_startup_nodes) {
// Updated potential_sotp_nodes by depending_nodes,
// make sure the path from root to target_node is ok
std::unordered_set<GradNodeBase*> _startup_ops;
VLOG(6) << "Running in UpdateGraphInfo";
std::queue<GradNodeBase*> queue;
for (auto& target_nodes_inputmeta_pair : *target_nodes_inputmeta_map) {
queue.emplace(target_nodes_inputmeta_pair.first);
void ModifyReadyQueue(std::queue<GradNodeBase*>* queue) {
std::queue<GradNodeBase*> tmp_queue;
for (auto nodes : potential_startup_nodes) {
tmp_queue.emplace(nodes);
}
tmp_queue.swap(*queue);
}
while (!queue.empty()) {
auto* target_node = queue.front();
queue.pop();
if (!(*depending_nodes)[target_node].empty()) {
auto precedding_nodes = (*depending_nodes)[target_node];
for (auto pre_nodes : precedding_nodes) {
queue.emplace(pre_nodes);
if (potential_stop_nodes->find(pre_nodes) !=
potential_stop_nodes->end()) {
potential_stop_nodes->erase(pre_nodes);
// Set result for input target grad_var when potential_startup_nodes is empty
void SetResultForInputTargetVar(
const std::unordered_map<GradNodeBase*,
std::unique_ptr<GradTensorHolder>>&
node_input_buffers_dict) {
if (potential_startup_nodes.size() == 0) {
for (auto input_target_node : *GetInPutTargetNodesInputMetaMap()) {
// out rank_info of forward op
auto rank_info = input_target_node.second->OutRankInfo();
auto iter = node_input_buffers_dict.find(input_target_node.first);
if (iter != node_input_buffers_dict.end()) {
auto& target_result =
(iter->second)->Buffers()[rank_info.first][rank_info.second];
// save the target result
results_map[input_target_node.first] = target_result;
}
}
} else { // startup_ops have no precedding nodes
VLOG(6) << "Emplace _startup_ops";
_startup_ops.emplace(target_node);
}
}
// Purify potential_startup_nodes again, remove some
// potential startup_nodes that unreach to input target nodes
if (!_startup_ops.empty()) {
std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
for (auto node : *potential_startup_nodes) {
if (_startup_ops.count(node) == 0) {
VLOG(6) << "Set up potential_startup_nodes_to_be_erased";
potential_startup_nodes_to_be_erased.emplace(node);
}
// Set input target grad_var from node_input_buffer by inputmeta
void SetResultForInputTargetVar(GradTensorHolder input_buffers,
GradNodeBase* node) {
auto iter = GetInPutTargetNodesInputMetaMap()->find(node);
if (iter != GetInPutTargetNodesInputMetaMap()->end()) {
VLOG(6) << "Get target result by by inputmeta";
// out rank_info of forward op
auto rank_info = (iter->second)->OutRankInfo();
// rank_info is a pair, first means slot_id, second means rank.
auto& target_result =
input_buffers.Buffers()[rank_info.first][rank_info.second];
// save the target result
results_map[node] = target_result;
}
if (!potential_startup_nodes_to_be_erased.empty()) {
for (auto node : potential_startup_nodes_to_be_erased) {
VLOG(6) << "Erase nodes in potential_startup_nodes_to_be_erased";
potential_startup_nodes->erase(node);
}
std::vector<paddle::experimental::Tensor> GetResults(
const std::vector<paddle::experimental::Tensor>& inputs,
bool allow_unused, bool create_graph) {
VLOG(6) << "Running in GetResults";
if (inputs.empty()) return {};
std::vector<paddle::experimental::Tensor> results;
results.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
auto& input = inputs[i];
AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input);
auto target_node = auto_grad_meta->GetMutableGradNode().get();
auto iter = results_map.find(target_node);
if (iter != results_map.end()) {
// set StopGradient = !create_graph
AutogradMeta* tensor_auto_grad_meta =
EagerUtils::autograd_meta(&(iter->second));
tensor_auto_grad_meta->SetStopGradient(!create_graph);
results.emplace_back(iter->second);
} else {
PADDLE_ENFORCE_EQ(allow_unused, true,
paddle::platform::errors::InvalidArgument(
"The %d-th input does not appear in the backward "
"graph. Please check the input tensor or set "
"allow_unused=True to get None result.",
i));
results.emplace_back();
}
}
Clear();
return results;
}
void PreparedForGeneralGrad(
const std::vector<paddle::experimental::Tensor>& inputs,
const std::vector<paddle::experimental::Tensor>& no_grad_vars,
std::queue<GradNodeBase*>* queue,
const std::unordered_map<GradNodeBase*,
std::unique_ptr<GradTensorHolder>>&
node_input_buffers_dict) {
// Get no_grad_vars's GradNodes and InputMeta Info
GetTargetNodesInfo(no_grad_vars, true /* is_no_grad_vars */);
// Get inputs's GradNodes and InputMeta Info
GetTargetNodesInfo(inputs, false /* is_no_grad_vars */);
// Purify potential_startup_ops, remove those nodes that are the same as
// input_target_nodes
PurifyPotentialStartUpNodes();
// Get Graph Info Betweent input target gradnode and outputs
// Record the depending_nodes and
// potential_stop_nodes、potential_startup_nodes
GetGraphInfoBetweenTargets(*queue);
// Reset queue. Queue is empty only when
// 1.input equals to output. 2.input can not reach to output.
ModifyReadyQueue(queue);
// Set result for input target grad_var when queue is empty
if (queue->empty()) SetResultForInputTargetVar(node_input_buffers_dict);
}
bool IsPotentialStopNodes(GradNodeBase* node) {
return potential_stop_nodes.count(node);
}
std::unordered_map<GradNodeBase*, AutogradMeta*>*
GetNoGradVarNodesInputMetaMap() {
return &no_grad_var_nodes_inputmeta_map;
}
std::unordered_map<GradNodeBase*, AutogradMeta*>*
GetInPutTargetNodesInputMetaMap() {
return &input_target_nodes_inputmeta_map;
}
std::unordered_set<GradNodeBase*>* GetPotentialStopNodes() {
return &potential_stop_nodes;
}
std::unordered_set<GradNodeBase*>* GetPotentialStartupNodes() {
return &potential_startup_nodes;
}
void Clear() {
no_grad_var_nodes_inputmeta_map.clear();
input_target_nodes_inputmeta_map.clear();
potential_startup_nodes.clear();
potential_stop_nodes.clear();
depending_nodes.clear();
results_map.clear();
}
}
// Get Graph Info Betweent input target gradnode and outputs,
// record depending_nodes、 potential_stop_nodes、potential_startup_nodes
void GetGraphInfoBetweenTargets(
const std::queue<GradNodeBase*>& init_queue,
std::unordered_map<GradNodeBase*, AutogradMeta*>*
input_target_nodes_inputmeta_map,
std::unordered_map</*child node*/ GradNodeBase*,
/*father nodes*/ std::unordered_set<GradNodeBase*>>*
depending_nodes,
std::unordered_set<GradNodeBase*>* potential_stop_nodes,
std::unordered_set<GradNodeBase*>* potential_startup_nodes) {
if (input_target_nodes_inputmeta_map->empty()) return;
VLOG(6) << "Runing In GetGraphInfoBetweenTargets";
private:
GeneralGrad() = default;
static GeneralGrad* general_grad_;
// no_grad_vars's GradNode and GradNode's InputMeta.
std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
no_grad_var_nodes_inputmeta_map;
// inputs's GradNode and GradNode's InputMeta.
std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
input_target_nodes_inputmeta_map;
// Record all the potential startup_nodes, will be changed.
std::unordered_set<GradNodeBase*> potential_startup_nodes;
// Record all the potential stop nodes, will be changed.
std::unordered_set<GradNodeBase*> potential_stop_nodes;
std::unordered_map<GradNodeBase* /* next node */,
std::unordered_set<GradNodeBase*> /* pre nodes */>
depending_nodes;
std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map;
DISABLE_COPY_AND_ASSIGN(GeneralGrad);
};
std::unordered_map<GradNodeBase*, int> getInDegreeMap(
const std::queue<GradNodeBase*>& init_queue) {
// Calculate in_degree for each node
// We can completely remove this pass, if in_degree were set during forward
// pass
std::unordered_map<GradNodeBase*, int> node_in_degree_map;
// Copy nodes
......@@ -171,101 +369,30 @@ void GetGraphInfoBetweenTargets(
}
visited.insert(node);
// Check node is target_nodes or not, if node is not target_node,
// all the next_node will be marked in potential_stop_nodes
bool is_potential_stop_nodes =
input_target_nodes_inputmeta_map->count(node);
PADDLE_ENFORCE_NOT_NULL(
node,
paddle::platform::errors::Fatal(
"We got null node when we traverse the backward graph, and this "
"should not happened please check your code and contact us."));
// Find and append next nodes
const std::vector<std::vector<Edge>>& edges = node->GetEdges();
for (const auto& edge_list : edges) {
for (const Edge& edge : edge_list) {
GradNodeBase* next_node = edge.GetMutableGradNode().get();
// Next node could be nullptr if it is leaf tensor with no
// AccumulationNode attached
// Or it could also originated from dispensable inputs
if (!next_node) continue;
// if node not in input_target_nodes,
// all the next_nodes of current node will be inserted to
// potential_stop_node
if (is_potential_stop_nodes) {
potential_stop_nodes->emplace(next_node);
}
// Update in_degree
if (!node_in_degree_map.count(next_node))
node_in_degree_map[next_node] = 0;
node_in_degree_map[next_node]++;
// Record depending relationship
(*depending_nodes)[next_node].emplace(node);
queue.push(next_node);
}
}
}
// Update Graph Info, remove some stop_node in potential_stop_nodes
UpdateGraphInfo(input_target_nodes_inputmeta_map, depending_nodes,
potential_stop_nodes, potential_startup_nodes);
}
void GetTargetNodesInfo(const std::vector<paddle::experimental::Tensor>& inputs,
std::unordered_map<GradNodeBase*, AutogradMeta*>*
target_nodes_inputmeta_map) {
VLOG(6) << "Running in GetTargetNodesInfo";
if (!inputs.empty()) {
VLOG(6) << "Inputs are not empty";
size_t num_inputs = inputs.size();
for (size_t i = 0; i < num_inputs; i++) {
AutogradMeta* auto_grad_meta =
EagerUtils::unsafe_autograd_meta(inputs[i]);
auto target_node = auto_grad_meta->GetMutableGradNode().get();
PADDLE_ENFORCE_NOT_NULL(target_node,
paddle::platform::errors::Fatal(
"There is no grad op for input:%d or it's"
"stop_gradient=True",
i));
(*target_nodes_inputmeta_map)[target_node] = auto_grad_meta;
}
}
}
std::vector<paddle::experimental::Tensor> GetResults(
const std::vector<paddle::experimental::Tensor>& inputs,
std::unordered_map<GradNodeBase*, paddle::experimental::Tensor>*
results_map,
bool allow_unused, bool create_graph) {
VLOG(6) << "Running in GetResults";
if (inputs.empty()) return {};
std::vector<paddle::experimental::Tensor> results;
results.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
auto& input = inputs[i];
AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input);
auto target_node = auto_grad_meta->GetMutableGradNode().get();
auto iter = results_map->find(target_node);
if (iter != results_map->end()) {
// set StopGradient = !create_graph
AutogradMeta* tensor_auto_grad_meta =
EagerUtils::autograd_meta(&(iter->second));
tensor_auto_grad_meta->SetStopGradient(!create_graph);
results.emplace_back(iter->second);
} else {
PADDLE_ENFORCE_EQ(allow_unused, true,
paddle::platform::errors::InvalidArgument(
"The %d-th input does not appear in the backward "
"graph. Please check the input variable or set "
"allow_unused=True to get None result.",
i));
results.emplace_back();
}
}
return results;
return node_in_degree_map;
}
// Enforce GradNode has TensorWrappers as Input
......@@ -281,28 +408,23 @@ void EnforceGradNodeHasInput(GradNodeBase* node) {
node->name()));
}
// Purify potential_startup_nodes, remove nodes those are the same as
// input_target_nodes
void PurifyPotentialStartUpNodes(
std::unordered_set<GradNodeBase*>* potential_startup_nodes,
std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>*
input_target_nodes_inputmeta_map) {
VLOG(6) << "Running in PurifyPotentialStartUpNodes";
if (input_target_nodes_inputmeta_map->empty()) return;
std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
for (auto startup_op : *potential_startup_nodes) {
auto iter = input_target_nodes_inputmeta_map->find(startup_op);
if (iter != input_target_nodes_inputmeta_map->end()) {
potential_startup_nodes_to_be_erased.emplace(iter->first);
}
}
if (!potential_startup_nodes_to_be_erased.empty()) {
for (auto nodes : potential_startup_nodes_to_be_erased) {
potential_startup_nodes->erase(nodes);
}
void DuplicateCheck(const std::vector<paddle::experimental::Tensor>& inputs,
bool is_input) {
std::unordered_set<AutogradMeta*> visisted_ins;
std::string msg = is_input ? "inputs" : "outputs";
for (auto in : inputs) {
AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(in);
PADDLE_ENFORCE_EQ(
visisted_ins.count(auto_grad_meta), 0,
paddle::platform::errors::AlreadyExists(
"%s contain duplicate tensor %s, please check %s carefully.", msg,
in.name(), msg));
visisted_ins.insert(auto_grad_meta);
}
}
GeneralGrad* GeneralGrad::general_grad_ = new GeneralGrad();
std::vector<paddle::experimental::Tensor> RunBackward(
const std::vector<paddle::experimental::Tensor>& tensors, // output
const std::vector<paddle::experimental::Tensor>& grad_tensors,
......@@ -315,10 +437,8 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// *Inplace version check should perform at node-level
// *Cross-batch accumulation happens at forward pass
std::unordered_map<GradNodeBase*, AutogradMeta*>
no_grad_var_nodes_inputmeta_map;
// Get no_grad_vars's GradNodes and InputMeta Info
GetTargetNodesInfo(no_grad_vars, &no_grad_var_nodes_inputmeta_map);
// GeneralGrad
bool is_general_grad = !inputs.empty();
/* --- Initialization --- */
// 1. Init queue with starting nodes
......@@ -326,7 +446,6 @@ std::vector<paddle::experimental::Tensor> RunBackward(
std::queue<GradNodeBase*> queue;
std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>>
node_input_buffers_dict;
std::unordered_set<GradNodeBase*> potential_startup_nodes;
for (size_t i = 0; i < tensors.size(); i++) {
const paddle::experimental::Tensor& tensor = tensors[i];
......@@ -363,7 +482,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
paddle::platform::errors::Fatal(
"Detected size mismatch between tensors and grad_tensors"
"grad_tensors should either have "
"size = 0 or same size as tensors"));
"size = 0 or same size as tensors."));
// Feed given tensor if it's provided
VLOG(6) << "Fill grad input tensor " << i << "with give grad tensor";
......@@ -391,7 +510,9 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// Prepare queue, potential startup_nodes
queue.push(grad_node);
potential_startup_nodes.emplace(grad_node);
if (is_general_grad) {
GeneralGrad::Instance().GetPotentialStartupNodes()->emplace(grad_node);
}
}
VLOG(6) << "Update In degree Map for backward";
......@@ -399,56 +520,13 @@ std::vector<paddle::experimental::Tensor> RunBackward(
std::unordered_map<GradNodeBase*, int> node_in_degree_map =
getInDegreeMap(queue);
// Get input's GradNodes and InputMeta Info
std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
input_target_nodes_inputmeta_map;
GetTargetNodesInfo(inputs, &input_target_nodes_inputmeta_map);
// Purify potential_startup_ops, remove those nodes that are the same as
// input_target_nodes
PurifyPotentialStartUpNodes(&potential_startup_nodes,
&input_target_nodes_inputmeta_map);
// Get Graph Info Betweent input target gradnode and outputs
// Record the depending_nodes and potential_stop_nodes
std::unordered_map<GradNodeBase* /* child node */,
std::unordered_set<GradNodeBase*> /* father node */>
depending_nodes;
std::unordered_set<GradNodeBase*> potential_stop_nodes;
// std::unordered_set<GradNodeBase*> startup_ops;
GetGraphInfoBetweenTargets(queue, &input_target_nodes_inputmeta_map,
&depending_nodes, &potential_stop_nodes,
&potential_startup_nodes);
// ready_queue store all startup nodes
std::queue<GradNodeBase*> ready_queue;
// startup op's indegree should be 0
for (auto node : potential_startup_nodes) {
if (node_in_degree_map[node] == 0) {
ready_queue.emplace(node);
}
if (is_general_grad) {
// Prepare several vital preprocess for GeneralGrad
GeneralGrad::Instance().PreparedForGeneralGrad(inputs, no_grad_vars, &queue,
node_input_buffers_dict);
}
VLOG(1) << " startup_ops' size is :" << ready_queue.size();
std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map;
// read_queue is empty only when 1.input equals to output. 2.input can not
// reach to output.
if (ready_queue.size() == 0) {
for (auto input_target_node : input_target_nodes_inputmeta_map) {
// out rank_info of forward op
auto rank_info = input_target_node.second->OutRankInfo();
if (node_input_buffers_dict[input_target_node.first]) {
auto& target_result =
node_input_buffers_dict[input_target_node.first]
->Buffers()[rank_info.first][rank_info.second];
// save the target result
results_map[input_target_node.first] = target_result;
}
}
}
VLOG(6) << " startup_ops' size is :" << queue.size();
/* --- Topological Visit --- */
// 1. Pop queue
......@@ -458,53 +536,55 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// |- Prepare for next node
// 3. Update queue
VLOG(6) << "Run Backward";
while (!ready_queue.empty()) {
GradNodeBase* node = ready_queue.front();
while (!queue.empty()) {
GradNodeBase* node = queue.front();
VLOG(6) << "Running GradNode:" << node->name();
ready_queue.pop();
paddle::platform::RecordEvent node_record_event(
std::string(typeid(*node).name()) + " grad_node",
paddle::platform::TracerEventType::Operator, 1);
if (queue.size() > 1 && node_in_degree_map[node] != 0) {
queue.pop();
continue;
}
queue.pop();
// Run node: This is where Hook happens
PADDLE_ENFORCE(
node_input_buffers_dict.count(node),
paddle::platform::errors::Fatal(
"Unable to find next node in the GradTensorHolder \n"
"Trying to run Node without configuring its GradTensorHolder"));
"Trying to run Node without configuring its GradTensorHolder."));
std::unique_ptr<GradTensorHolder> node_input_buffer =
std::move(node_input_buffers_dict[node]);
// get target grad_var from node_input_buffer by inputmeta
if (input_target_nodes_inputmeta_map.find(node) !=
input_target_nodes_inputmeta_map.end()) {
VLOG(6) << "Get target result by by inputmeta";
// out rank_info of forward op
auto rank_info = input_target_nodes_inputmeta_map[node]->OutRankInfo();
// rank_info is a pair, first means slot_id, second means rank.
auto& target_result =
node_input_buffer->Buffers()[rank_info.first][rank_info.second];
// save the target result
results_map[node] = target_result;
// Set input target grad_var from node_input_buffer by inputmeta
if (!inputs.empty() && is_general_grad) {
GeneralGrad::Instance().SetResultForInputTargetVar(*node_input_buffer,
node);
}
// no_grad_vars
if (no_grad_var_nodes_inputmeta_map.find(node) !=
no_grad_var_nodes_inputmeta_map.end()) {
VLOG(6) << "Change the input buffer[slot][rank] by Zeros";
auto rank_info = no_grad_var_nodes_inputmeta_map[node]->OutRankInfo();
node_input_buffer->SetBufferSlotRankZeros(rank_info.first,
rank_info.second);
if (!no_grad_vars.empty() && is_general_grad) {
auto iter =
GeneralGrad::Instance().GetNoGradVarNodesInputMetaMap()->find(node);
if (iter !=
GeneralGrad::Instance().GetNoGradVarNodesInputMetaMap()->end()) {
VLOG(6) << "Change the input buffer[slot][rank] by Zeros";
auto rank_info = (iter->second)->OutRankInfo();
node_input_buffer->SetBufferSlotRankZeros(rank_info.first,
rank_info.second);
}
}
VLOG(6) << "Running GradNode:" << node->name();
// check input
// Check input
EnforceGradNodeHasInput(node);
VLOG(6) << "Run Backward Kernel with GradTensorHolder";
VLOG(6) << "Run Backward Kernel with GradTensorHolder.";
// Run Pre Backward Node and get outputs
std::vector<std::vector<paddle::experimental::Tensor>> grad_output_tensors =
(*node)(node_input_buffer->Buffers(), create_graph);
......@@ -587,23 +667,29 @@ std::vector<paddle::experimental::Tensor> RunBackward(
node_in_degree_map[next_node] >= 0,
paddle::platform::errors::Fatal(
"Detected in-degree value smaller than zero. For Node: %s"
"Node's in-degree cannot be negative",
"Node's in-degree cannot be negative.",
next_node->name()));
bool is_potential_stop_node = potential_stop_nodes.count(next_node);
if (node_in_degree_map[next_node] == 0 && !is_potential_stop_node) {
ready_queue.emplace(std::move(next_node));
if (is_general_grad) {
bool is_potential_stop_node =
GeneralGrad::Instance().GetPotentialStopNodes()->count(next_node);
if (node_in_degree_map[next_node] == 0 && !is_potential_stop_node) {
queue.emplace(std::move(next_node));
}
} else {
if (node_in_degree_map[next_node] == 0) {
queue.emplace(std::move(next_node));
}
}
}
}
}
return GetResults(inputs, &results_map, allow_unused, create_graph);
if (!is_general_grad) return {};
return GeneralGrad::Instance().GetResults(inputs, allow_unused, create_graph);
}
void Backward(
const std::vector<paddle::experimental::Tensor>& tensors, // output
const std::vector<paddle::experimental::Tensor>& tensors, // outputs
const std::vector<paddle::experimental::Tensor>& grad_tensors,
bool retain_graph) {
VLOG(6) << "Run in Backward";
......@@ -613,12 +699,16 @@ void Backward(
}
std::vector<paddle::experimental::Tensor> Grad(
const std::vector<paddle::experimental::Tensor>& tensors, // output
const std::vector<paddle::experimental::Tensor>& tensors, // outputs
const std::vector<paddle::experimental::Tensor>& inputs,
const std::vector<paddle::experimental::Tensor>& grad_tensors,
bool retain_graph, bool create_graph, bool only_inputs, bool allow_unused,
const std::vector<paddle::experimental::Tensor>& no_grad_vars) {
VLOG(6) << "Run in Grad";
DuplicateCheck(inputs, true /* is_input */);
DuplicateCheck(tensors, false /* is_input */);
return RunBackward(tensors, grad_tensors, retain_graph, create_graph, inputs,
allow_unused, no_grad_vars);
}
......
......@@ -116,6 +116,54 @@ class TestEagerGrad(TestCase):
self.func_simple_example_eager_grad_not_allow_unused()
self.func_simple_example_eager_grad_not_allow_unused()
def func_simple_example_eager_grad_duplicate_input(self):
np.random.seed(2021)
paddle.set_device('cpu')
np_x = np.random.random((3, 3))
np_y = np.random.random((3, 1))
np_z = np.random.random((3, 1))
x = paddle.to_tensor(np_x, dtype="float64", stop_gradient=False)
y = paddle.to_tensor(np_y, dtype="float64", stop_gradient=False)
z = paddle.to_tensor(np_z, dtype="float64", stop_gradient=False)
out_z = paddle.nn.functional.sigmoid(z)
out = paddle.matmul(x, y)
try:
# duplicate input will arise RuntimeError errors
dx = fluid.dygraph.grad(out, [x, x])
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("duplicate") > 0
def test_simple_example_eager_grad_duplicate_input(self):
with _test_eager_guard():
self.func_simple_example_eager_grad_duplicate_input()
self.func_simple_example_eager_grad_duplicate_input()
def func_simple_example_eager_grad_duplicate_output(self):
np.random.seed(2021)
paddle.set_device('cpu')
np_x = np.random.random((3, 3))
np_y = np.random.random((3, 1))
np_z = np.random.random((3, 1))
x = paddle.to_tensor(np_x, dtype="float64", stop_gradient=False)
y = paddle.to_tensor(np_y, dtype="float64", stop_gradient=False)
z = paddle.to_tensor(np_z, dtype="float64", stop_gradient=False)
out_z = paddle.nn.functional.sigmoid(z)
out = paddle.matmul(x, y)
try:
# duplicate output will arise RuntimeError errors
dx = fluid.dygraph.grad([out, out], [x])
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("duplicate") > 0
def test_simple_example_eager_grad_duplicate_output(self):
with _test_eager_guard():
self.func_simple_example_eager_grad_duplicate_output()
self.func_simple_example_eager_grad_duplicate_output()
class TestDygraphDoubleGrad(TestCase):
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册