From 4700a08e99d232d2597a135ec655252f4a29cdd6 Mon Sep 17 00:00:00 2001 From: Jiabin Yang <360788950@qq.com> Date: Tue, 31 May 2022 13:41:44 +0800 Subject: [PATCH] Support backward prune for eager intermidiate (#43111) * support is empty * fix error * fix code error * change to fake empty * using fake empty first * using fake empty first * Support backward prune in fluid --- .../auto_code_generator/eager_generator.cc | 67 ++++++++++++------- 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 521b952a4df..3a9bac833d5 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -2189,7 +2189,6 @@ static std::string GenerateSingleOpBase( } VLOG(6) << "Generated Ins Map"; - // [Generation] Get Outs Map std::string outs_contents_str = ""; for (auto iter : grad_outs) { @@ -2238,9 +2237,12 @@ static std::string GenerateSingleOpBase( size_t grads_position = fwd_outputs_name_pos_map.at(fwd_name); const char* GRAD_OUTS_CONTENT_TEMPLATE = - "{ \"%s\", egr::EagerUtils::TrySyncToVars(hooked_grads[%d]) },"; + " if((!out_metas[%d].empty()) && " + "(!(out_metas[%d][0].IsStopGradient()))){ \n %s.insert({ \"%s\", " + "egr::EagerUtils::TrySyncToVars(hooked_grads[%d])});} \n "; outs_contents_str += paddle::string::Sprintf( - GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, grads_position); + GRAD_OUTS_CONTENT_TEMPLATE, grads_position, grads_position, + outs_name, grad_output_name, grads_position); } else { if (dispensable_input_name_set.count(fwd_name) && @@ -2251,18 +2253,20 @@ static std::string GenerateSingleOpBase( if (duplicable_input_name_set.count(fwd_name) && !is_op_base_per_duplicable_input) { const char* GRAD_OUTS_CONTENT_TEMPLATE = - "{ \"%s\", egr::EagerUtils::CreateVars( " - "this->OutputMeta()[%d].size() ) },"; + " if(!out_metas[%d].empty()){ %s.insert({ \"%s\", " + "egr::EagerUtils::CreateVars(out_metas[%d].size())});} \n "; outs_contents_str += paddle::string::Sprintf( - GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, fwd_input_position); + GRAD_OUTS_CONTENT_TEMPLATE, fwd_input_position, outs_name, + grad_output_name, fwd_input_position); } else { const char* GRAD_OUTS_CONTENT_TEMPLATE = - "{ \"%s\", " + " if((!out_metas[%d].empty()) && " + "(!(out_metas[%d][0].IsStopGradient()))){ %s.insert({ \"%s\", " "{std::make_shared(egr::Controller::Instance(" - ")." - "GenerateUniqueName())}},"; + ").GenerateUniqueName())}});} \n "; outs_contents_str += paddle::string::Sprintf( - GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name); + GRAD_OUTS_CONTENT_TEMPLATE, fwd_input_position, + fwd_input_position, outs_name, grad_output_name); } } } else { @@ -2272,16 +2276,15 @@ static std::string GenerateSingleOpBase( grad_output_name)); } } - if (outs_contents_str.size() > 0) - outs_contents_str.pop_back(); // // Remove trailing "," const char* BWD_OUTS_MAP_TEMPLATE = " std::map>> %s = { " - "%s };\n"; - std::string outs_map_str = paddle::string::Sprintf( - BWD_OUTS_MAP_TEMPLATE, outs_name, outs_contents_str); + "std::vector>> %s;\n"; + std::string outs_map_str = + paddle::string::Sprintf(BWD_OUTS_MAP_TEMPLATE, outs_name); + generated_grad_function_body += outs_map_str; + generated_grad_function_body += outs_contents_str; generated_grad_function_body += "\n"; for (auto iter : grad_outs) { const std::string& grad_output_name = iter.first; @@ -2296,18 +2299,23 @@ static std::string GenerateSingleOpBase( !is_op_base_per_duplicable_input) { size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name); const char* DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE = - " if(%s.size() > 0) %s[\"%s\"] = egr::EagerUtils::CreateVars( " - "this->OutputMeta()[%d].size() );\n"; + " if((%s.size() > 0) && (!out_metas[%d].empty()) && " + "(!out_metas[%d][0].IsStopGradient())) %s[\"%s\"] = " + "egr::EagerUtils::CreateVars( " + "out_metas[%d].size() );\n"; generated_grad_function_body += paddle::string::Sprintf( DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE, fwd_name, outs_name, grad_output_name, fwd_input_position); } else { + size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name); const char* DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE = - " if(%s.defined()) %s[\"%s\"] = " + " if(%s.defined() && (!out_metas[%d].empty()) && " + "(!out_metas[%d][0].IsStopGradient())) %s[\"%s\"] = " "{std::make_shared(egr::Controller::" "Instance().GenerateUniqueName())};\n"; generated_grad_function_body += paddle::string::Sprintf( - DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE, fwd_name, outs_name, + DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE, fwd_name, + fwd_input_position, fwd_input_position, outs_name, grad_output_name); } } @@ -2387,16 +2395,20 @@ static std::string GenerateSingleOpBase( size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name); if (!is_op_base_per_duplicable_input) { const char* BWD_OUTPUT_TEMPLATE = - " outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n"; + " if (%s.find(\"%s\") != %s.end()) { outputs[%d] = " + "egr::EagerUtils::GetOutputs(%s[\"%s\"]); }\n"; outputs_str += paddle::string::Sprintf( - BWD_OUTPUT_TEMPLATE, fwd_input_position, outs_name, grad_out_name); + BWD_OUTPUT_TEMPLATE, outs_name, grad_out_name, outs_name, + fwd_input_position, outs_name, grad_out_name); } else { const char* BWD_OUTPUT_TEMPLATE = " " + "if (%s.find(\"%s\") != %s.end()) { " "outputs[0].emplace_back(egr::EagerUtils::GetOutputs(%s[\"%s\"])[0]" - ");\n"; + "); }\n"; outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, outs_name, - grad_out_name); + grad_out_name, outs_name, + outs_name, grad_out_name); } num_appended_outputs++; } else { @@ -2415,9 +2427,11 @@ static std::string GenerateSingleOpBase( if (fwd_outputs_name_pos_map.count(fwd_name)) { const char* BWD_OUTPUT_TEMPLATE = - " outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n"; + " if (%s.find(\"%s\") != %s.end()) { outputs[%d] = " + "egr::EagerUtils::GetOutputs(%s[\"%s\"]); }\n"; outputs_str += paddle::string::Sprintf( - BWD_OUTPUT_TEMPLATE, num_appended_outputs, outs_name, grad_out_name); + BWD_OUTPUT_TEMPLATE, outs_name, grad_out_name, outs_name, + num_appended_outputs, outs_name, grad_out_name); num_appended_outputs++; } } @@ -2550,6 +2564,7 @@ static std::string GenerateGradNodeCCContents( " paddle::small_vector, " "egr::kSlotSmallVectorSize> hooked_grads = " "GradNode%s::ApplyGradientHooks(grads);\n" + " const auto& out_metas = OutputMeta();\n" " paddle::small_vector, " "egr::kSlotSmallVectorSize> outputs(%d);\n" " %s\n" -- GitLab