diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 817a0de6e0ca9594d6e9e09d41538071def1b47f..73baf210158332ffdfab11d5de775eb9fedee767 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -1206,22 +1206,37 @@ static std::string GenerateGradNodeCreationContent( if (!input.duplicable()) { compute_require_grad_args += ", " + input_autograd_name; size_t input_position = fwd_inputs_name_pos_map.at(input_name); - - const char* SET_GRAD_OUT_META_TEMPLATE = - " grad_node->SetGradOutMeta(%s, %d);\n"; - grad_node_creation_str += - paddle::string::Sprintf(SET_GRAD_OUT_META_TEMPLATE, - LegalizeVarName(input_name), input_position); - + bool found_target_name = false; + for (const auto& iter : op_base_infos) { + const auto& grad_outs_slot_map = iter.GetGradOutsSlotnameMap(); + for (auto iter : grad_outs_slot_map) { + if ((!found_target_name) && (input_name == iter.second)) { + const char* SET_GRAD_OUT_META_TEMPLATE = + " grad_node->SetGradOutMeta(%s, %d);\n"; + grad_node_creation_str += paddle::string::Sprintf( + SET_GRAD_OUT_META_TEMPLATE, LegalizeVarName(input_name), + input_position); + found_target_name = true; + } + } + } } else { compute_require_grad_args += ", &" + input_autograd_name; size_t input_position = fwd_inputs_name_pos_map.at(input_name); - - const char* SET_GRAD_OUT_META_TEMPLATE = - " grad_node->SetGradOutMeta(%s, %d);\n"; - grad_node_creation_str += - paddle::string::Sprintf(SET_GRAD_OUT_META_TEMPLATE, - LegalizeVarName(input_name), input_position); + bool found_target_name = false; + for (const auto& iter : op_base_infos) { + const auto& grad_outs_slot_map = iter.GetGradOutsSlotnameMap(); + for (auto iter : grad_outs_slot_map) { + if ((!found_target_name) && (input_name == iter.second)) { + const char* SET_GRAD_OUT_META_TEMPLATE = + " grad_node->SetGradOutMeta(%s, %d);\n"; + grad_node_creation_str += paddle::string::Sprintf( + SET_GRAD_OUT_META_TEMPLATE, LegalizeVarName(input_name), + input_position); + found_target_name = true; + } + } + } } }