From 9bb39d489972ad85eec43d6418619e9e5a2a22f0 Mon Sep 17 00:00:00 2001 From: Jiabin Yang <360788950@qq.com> Date: Tue, 7 Jun 2022 12:00:17 +0800 Subject: [PATCH] support prune (#43250) --- .../auto_code_generator/eager_generator.cc | 41 +++++++++++++------ 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 817a0de6e0c..73baf210158 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -1206,22 +1206,37 @@ static std::string GenerateGradNodeCreationContent( if (!input.duplicable()) { compute_require_grad_args += ", " + input_autograd_name; size_t input_position = fwd_inputs_name_pos_map.at(input_name); - - const char* SET_GRAD_OUT_META_TEMPLATE = - " grad_node->SetGradOutMeta(%s, %d);\n"; - grad_node_creation_str += - paddle::string::Sprintf(SET_GRAD_OUT_META_TEMPLATE, - LegalizeVarName(input_name), input_position); - + bool found_target_name = false; + for (const auto& iter : op_base_infos) { + const auto& grad_outs_slot_map = iter.GetGradOutsSlotnameMap(); + for (auto iter : grad_outs_slot_map) { + if ((!found_target_name) && (input_name == iter.second)) { + const char* SET_GRAD_OUT_META_TEMPLATE = + " grad_node->SetGradOutMeta(%s, %d);\n"; + grad_node_creation_str += paddle::string::Sprintf( + SET_GRAD_OUT_META_TEMPLATE, LegalizeVarName(input_name), + input_position); + found_target_name = true; + } + } + } } else { compute_require_grad_args += ", &" + input_autograd_name; size_t input_position = fwd_inputs_name_pos_map.at(input_name); - - const char* SET_GRAD_OUT_META_TEMPLATE = - " grad_node->SetGradOutMeta(%s, %d);\n"; - grad_node_creation_str += - paddle::string::Sprintf(SET_GRAD_OUT_META_TEMPLATE, - LegalizeVarName(input_name), input_position); + bool found_target_name = false; + for (const auto& iter : op_base_infos) { + const auto& grad_outs_slot_map = iter.GetGradOutsSlotnameMap(); + for (auto iter : grad_outs_slot_map) { + if ((!found_target_name) && (input_name == iter.second)) { + const char* SET_GRAD_OUT_META_TEMPLATE = + " grad_node->SetGradOutMeta(%s, %d);\n"; + grad_node_creation_str += paddle::string::Sprintf( + SET_GRAD_OUT_META_TEMPLATE, LegalizeVarName(input_name), + input_position); + found_target_name = true; + } + } + } } } -- GitLab