diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 2a5b158d315c39d7aeaa11cb1d20875a77f60666..ca60c8ee9a40509e3af9ec4bbb44109f551925d3 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -646,96 +646,6 @@ static void PurifyGradNodeGenerationInfo(const proto::OpProto& op_proto, } } -static void PurifyGradOpProto( - const proto::OpProto& op_proto, - std::map* grad_outs_slotname_map, - std::map* grad_ins_fwd_slotname_map, - std::map* grad_ins_grad_slotname_map, - std::map>>* - grad_ins, - std::map>>* - grad_outs) { - // Op Name - const std::string op_name = op_proto.type(); - - // Handle dispensable inputs - for (const proto::OpProto::Var& input : op_proto.inputs()) { - std::string input_name = input.name(); - - // Delete dispensable tensor unless specified in op_ins_map - if (input.dispensable()) { - if (!op_ins_map.count(op_name) || - !op_ins_map[op_name].count(input_name)) { - VLOG(6) << "Removing Dispensable Input: " << input_name; - - // grad_outs_slotname_map - auto grad_outs_slotname_map_purified = *grad_outs_slotname_map; - for (const auto& iter : *grad_outs_slotname_map) { - const std::string& grad_output_name = iter.first; - const std::string& matched_input_name = iter.second; - if (matched_input_name == input_name) { - grad_outs_slotname_map_purified.erase(grad_output_name); - - PADDLE_ENFORCE( - grad_outs->count(grad_output_name) > 0, - paddle::platform::errors::Fatal( - "Unable to find gradient output name in grad_outs.")); - // grad_outs - grad_outs->erase(grad_output_name); - } - } - *grad_outs_slotname_map = grad_outs_slotname_map_purified; - - // grad_ins_fwd_slotname_map: output as tensorwrapper - if (grad_ins_fwd_slotname_map->count(input_name)) - grad_ins_fwd_slotname_map->erase(input_name); - - // grad_ins: output as tensorwrapper - if (grad_ins->count(input_name)) grad_ins->erase(input_name); - } - } - } - - for (const proto::OpProto::Var& output : op_proto.outputs()) { - std::string output_name = output.name(); - - // Delete dispensable tensor unless specified in op_outs_map - if (output.dispensable()) { - if (!op_outs_map.count(op_name) || - !op_outs_map[op_name].count(output_name)) { - VLOG(6) << "Removing Dispensable Output: " << output_name; - - // grad_ins_grad_slotname_map - auto grad_ins_grad_slotname_map_purified = *grad_ins_grad_slotname_map; - for (const auto& iter : *grad_ins_grad_slotname_map) { - const std::string& grad_input_name = iter.first; - const std::string& matched_output_name = iter.second; - if (matched_output_name == output_name) { - grad_ins_grad_slotname_map_purified.erase(grad_input_name); - - PADDLE_ENFORCE( - grad_ins->count(grad_input_name) > 0, - paddle::platform::errors::Fatal( - "Unable to find gradient input name in grad_ins.")); - // grad_ins - grad_ins->erase(grad_input_name); - } - } - *grad_ins_grad_slotname_map = grad_ins_grad_slotname_map_purified; - - // grad_ins_fwd_slotname_map: output as tensorwrapper - if (grad_ins_fwd_slotname_map->count(output_name)) - grad_ins_fwd_slotname_map->erase(output_name); - - // grad_ins: output as tensorwrapper - if (grad_ins->count(output_name)) grad_ins->erase(output_name); - } - } - } -} - /* -------------------------------- */ /* --------- Collect Info --------- */ /* -------------------------------- */ @@ -980,6 +890,13 @@ static std::string GenerateGradNodeCreationContent( get_autograd_meta_str += paddle::string::Sprintf( GET_MULTI_AUTOGRAD_META_TEMPLATE, input_autograd_name, input_name); + } else if (input.dispensable()) { + const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE = + " egr::AutogradMeta* %s = " + "egr::EagerUtils::nullable_autograd_meta(%s);\n"; + get_autograd_meta_str += paddle::string::Sprintf( + GET_SINGLE_AUTOGRAD_META_TEMPLATE, input_autograd_name, input_name); + } else { const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE = " egr::AutogradMeta& %s = " @@ -1068,17 +985,36 @@ static std::string GenerateGradNodeCreationContent( for (const proto::OpProto::Var& input : in_vars) { const std::string& input_name = input.name(); const std::string& input_autograd_name = "p_autograd_" + input_name; - compute_require_grad_args += ", &" + input_autograd_name; - size_t input_position = fwd_inputs_name_pos_map.at(input_name); - const char* SET_GRAD_OUT_META_TEMPLATE = - " grad_node->SetGradOutMeta(%s, %d);\n"; - grad_node_creation_str += paddle::string::Sprintf( - SET_GRAD_OUT_META_TEMPLATE, input_autograd_name, input_position); + if (input.dispensable() && !input.duplicable()) { + compute_require_grad_args += ", " + input_autograd_name; + size_t input_position = fwd_inputs_name_pos_map.at(input_name); - const char* ADD_EDGES_TEMPLATE = " grad_node->AddEdges(%s, %d);\n"; - grad_node_creation_str += paddle::string::Sprintf( - ADD_EDGES_TEMPLATE, input_autograd_name, input_position); + const char* SET_GRAD_OUT_META_TEMPLATE = + " if(%s) grad_node->SetGradOutMeta(*%s, %d);\n"; + grad_node_creation_str += paddle::string::Sprintf( + SET_GRAD_OUT_META_TEMPLATE, input_autograd_name, input_autograd_name, + input_position); + + const char* ADD_EDGES_TEMPLATE = + " if(%s) grad_node->AddEdges(*%s, %d);\n"; + grad_node_creation_str += + paddle::string::Sprintf(ADD_EDGES_TEMPLATE, input_autograd_name, + input_autograd_name, input_position); + + } else { + compute_require_grad_args += ", &" + input_autograd_name; + size_t input_position = fwd_inputs_name_pos_map.at(input_name); + + const char* SET_GRAD_OUT_META_TEMPLATE = + " grad_node->SetGradOutMeta(%s, %d);\n"; + grad_node_creation_str += paddle::string::Sprintf( + SET_GRAD_OUT_META_TEMPLATE, input_autograd_name, input_position); + + const char* ADD_EDGES_TEMPLATE = " grad_node->AddEdges(%s, %d);\n"; + grad_node_creation_str += paddle::string::Sprintf( + ADD_EDGES_TEMPLATE, input_autograd_name, input_position); + } } // [GradOpNode] SetGradInMeta @@ -1188,6 +1124,7 @@ static std::pair GenerateForwardFunctionContents( for (const proto::OpProto::Var& input : in_vars) { const std::string& input_name = input.name(); size_t input_position = fwd_inputs_name_pos_map.at(input_name); + if (input.duplicable()) { const char* FWD_INS_ARG_TEMPLATE = "const std::vector& %s"; @@ -1198,6 +1135,9 @@ static std::pair GenerateForwardFunctionContents( input_args_str_list[input_position] = paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name); } + + if (input.dispensable()) continue; + const char* FWD_INS_CONTENT_TEMPLATE = "{ \"%s\", egr::EagerUtils::SyncToVars(%s) },"; ins_contents_str += paddle::string::Sprintf(FWD_INS_CONTENT_TEMPLATE, @@ -1222,6 +1162,26 @@ static std::pair GenerateForwardFunctionContents( generated_function_body += ins_map_str; generated_function_body += "\n"; + // Handle Dispensable Inputs + for (const proto::OpProto::Var& input : in_vars) { + const std::string& input_name = input.name(); + if (input.dispensable()) { + if (input.duplicable()) { + const char* FWD_INS_CONTENT_TEMPLATE = + " if(%s.size() > 0) " + "ins[\"%s\"] = egr::EagerUtils::SyncToVars(%s)\n;"; + generated_function_body += paddle::string::Sprintf( + FWD_INS_CONTENT_TEMPLATE, input_name, input_name, input_name); + } else { + const char* FWD_INS_CONTENT_TEMPLATE = + " if(%s.initialized()) " + "ins[\"%s\"] = egr::EagerUtils::SyncToVars(%s)\n;"; + generated_function_body += paddle::string::Sprintf( + FWD_INS_CONTENT_TEMPLATE, input_name, input_name, input_name); + } + } + } + VLOG(6) << "Generated Ins Map"; // [Generation] Get Outs Map diff --git a/paddle/fluid/eager/backward.cc b/paddle/fluid/eager/backward.cc index 9288c23a34a1496e117d3bf3b3bc597004c9da6a..bee7124b55cd9d1226d4ead9a44acb927a743848 100644 --- a/paddle/fluid/eager/backward.cc +++ b/paddle/fluid/eager/backward.cc @@ -53,6 +53,12 @@ std::unordered_map getInDegreeMap( for (const auto& edge_list : edges) { for (const Edge& edge : edge_list) { GradNodeBase* next_node = edge.GetMutableGradNode().get(); + + // Next node could be nullptr if it is leaf tensor with no + // AccumulationNode attached + // Or it could also originated from dispensable inputs + if (!next_node) continue; + // Update in_degree if (!node_in_degree_map.count(next_node)) node_in_degree_map[next_node] = 0; @@ -91,11 +97,6 @@ void RunBackward(const std::vector& tensors, // Get target GradNodeBase from target tensors GradNodeBase* grad_node = auto_grad_meta->GetMutableGradNode().get(); - PADDLE_ENFORCE(grad_node, - paddle::platform::errors::Fatal( - "Detected null grad_node." - "Grad Node is nullptr for grad input tensor %d", - i)); // Prepare GradTensorHolder if (!node_input_buffers_dict.count(grad_node)) { VLOG(6) << "Create Value for grad input tensor " << i; @@ -186,6 +187,11 @@ void RunBackward(const std::vector& tensors, } GradNodeBase* next_node = edge.GetMutableGradNode().get(); + // Next node could be nullptr if it is leaf tensor with no + // AccumulationNode attached + // Or it could also originated from dispensable inputs + if (!next_node) continue; + if (!node_input_buffers_dict.count(next_node)) { node_input_buffers_dict[next_node] = std::make_unique(next_node->InputMeta()); diff --git a/paddle/fluid/eager/utils.cc b/paddle/fluid/eager/utils.cc index be06bf9eb344ba3c15a5fc4bdf21099772562cb0..a27595468725b2572b4024dc5886631a751adcb7 100644 --- a/paddle/fluid/eager/utils.cc +++ b/paddle/fluid/eager/utils.cc @@ -56,6 +56,14 @@ std::vector EagerUtils::unsafe_autograd_meta( return metas; } +AutogradMeta* EagerUtils::nullable_autograd_meta( + const egr::EagerTensor& target) { + auto* p_autograd_meta = target.get_autograd_meta(); + if (!p_autograd_meta) return nullptr; + + return static_cast(p_autograd_meta); +} + std::vector EagerUtils::multi_autograd_meta( std::vector* targets) { std::vector ret; diff --git a/paddle/fluid/eager/utils.h b/paddle/fluid/eager/utils.h index 03f922e5bf9ba9759762e14ede2c998918016d9a..851f665bbabe6bdcfe3f6d3bb7a62c4348887111 100644 --- a/paddle/fluid/eager/utils.h +++ b/paddle/fluid/eager/utils.h @@ -56,6 +56,9 @@ class ComputeRequireGradIter : public IterHelper { private: void visit(AutogradMeta* element) override { + // Dispensable Tensors feeds in nullptr autograd_meta + if (!element) return; + bool stop_gradient = element->StopGradient(); if (!stop_gradient) require_grad_ = true; } @@ -112,6 +115,7 @@ class EagerUtils { static void SetOutRankWithSlot(AutogradMeta* target, size_t slot_id); // This method will return an AutogradMeta pointer unsafely. + static AutogradMeta* nullable_autograd_meta(const egr::EagerTensor& target); static AutogradMeta* unsafe_autograd_meta(const egr::EagerTensor& target); static std::vector unsafe_autograd_meta( const std::vector& targets);