diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 3edd13ccd597f10ccd4541bf7fd21e1fe16e6dbc..521b952a4dfcd77c7b12c569ae98fa1db45034ab 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -231,6 +231,15 @@ class GradNodeGenerationInfo { return &no_need_buffer_ins_; } + const std::unordered_map& GetBackwardInplaceMap() + const { + return backward_inplace_map_; + } + std::unordered_map* + GetMutableBackwardInplaceMap() { + return &backward_inplace_map_; + } + private: std::string op_base_type_; std::map grad_outs_slotname_map_; @@ -244,6 +253,7 @@ class GradNodeGenerationInfo { grad_outs_; paddle::framework::AttributeMap grad_attrs_; std::unordered_set no_need_buffer_ins_; + std::unordered_map backward_inplace_map_; }; public: @@ -979,6 +989,12 @@ static bool CollectGradInformationFromOpInfo( *(*op_base_infos)[index].GetMutableNoNeedBufferInputs() = inferer(g_ins, g_outs, *op_base_grad_attrs); } + + auto& infer_backward_inplace = op_base.Info().infer_inplace_; + if (infer_backward_inplace) { + *(*op_base_infos)[index].GetMutableBackwardInplaceMap() = + infer_backward_inplace(true); + } } /* ------ Slot Name Matching ---- */ @@ -1005,7 +1021,7 @@ static std::string GenerateGradNodeCreationContent( const ForwardGenerationInfo& fwd_info, const GradNodeGenerationInfo& bwd_info, const std::string& trace_op_body_str, - std::map inplace_map = {}) { + std::map forward_inplace_map = {}) { VLOG(6) << "Generating GradNode Creation codes"; const std::string& op_type = fwd_info.GetOpType(); @@ -1045,8 +1061,10 @@ static std::string GenerateGradNodeCreationContent( } else { // In inplace op, the case where output is duplicable is not considered. // Replace output directly with input in inplace op. - if (!inplace_map.empty() && inplace_map.count(output_name)) { - auto inplace_input_name = LegalizeVarName(inplace_map[output_name]); + if (!forward_inplace_map.empty() && + forward_inplace_map.count(output_name)) { + auto inplace_input_name = + LegalizeVarName(forward_inplace_map[output_name]); const std::string& inplace_input_autograd_name = "p_autograd_" + inplace_input_name; const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE = @@ -1103,12 +1121,12 @@ static std::string GenerateGradNodeCreationContent( // check inplace input to avoid inplace operations on leaf nodes with // stop_gradient=False. std::string check_inplace_str = ""; - if (!inplace_map.empty()) { + if (!forward_inplace_map.empty()) { const char* CHECKING_INPLACE_TEMPLATE = " // Check Inplace\n" " egr::EagerUtils::CheckInplace(%s, p_autograd_%s, " "require_any_grad);\n"; - for (auto& inplace_pair : inplace_map) { + for (auto& inplace_pair : forward_inplace_map) { std::string inplace_name = LegalizeVarName(inplace_pair.second); check_inplace_str += paddle::string::Sprintf(CHECKING_INPLACE_TEMPLATE, inplace_name, inplace_name); @@ -1161,8 +1179,9 @@ static std::string GenerateGradNodeCreationContent( const char* SET_TENSOR_WRAPPER_TEMPLATE = " grad_node->SetTensorWrapper%s(%s);\n"; // Replace output directly with input in inplace op. - if (!inplace_map.empty() && inplace_map.count(tensor_wrapper_name)) { - auto inplace_input_name = inplace_map[tensor_wrapper_name]; + if (!forward_inplace_map.empty() && + forward_inplace_map.count(tensor_wrapper_name)) { + auto inplace_input_name = forward_inplace_map[tensor_wrapper_name]; grad_node_creation_str += paddle::string::Sprintf( SET_TENSOR_WRAPPER_TEMPLATE, LegalizeVarName(tensor_wrapper_name), LegalizeVarName(inplace_input_name)); @@ -1213,8 +1232,9 @@ static std::string GenerateGradNodeCreationContent( for (const proto::OpProto::Var& output : out_vars) { const std::string& output_name = output.name(); // Replace output directly with input in inplace op. - if (!inplace_map.empty() && inplace_map.count(output_name)) { - auto inplace_input_name = inplace_map[output_name]; + if (!forward_inplace_map.empty() && + forward_inplace_map.count(output_name)) { + auto inplace_input_name = forward_inplace_map[output_name]; const std::string& inplace_input_autograd_name = "p_autograd_" + LegalizeVarName(inplace_input_name); size_t output_position = fwd_outputs_name_pos_map.at(output_name); @@ -1345,7 +1365,7 @@ static std::string GenerateGradNodeCreationContent( static std::pair GenerateForwardFunctionContents( const ForwardGenerationInfo& fwd_info, const GradNodeGenerationInfo& bwd_info, - std::map inplace_map = {}) { + std::map forward_inplace_map = {}) { /* --- Process Forward Info ---*/ const std::string& op_type = fwd_info.GetOpType(); const std::unordered_map& fwd_inputs_name_pos_map = @@ -1434,8 +1454,8 @@ static std::pair GenerateForwardFunctionContents( // inplace tensor can't be const const char* FWD_INS_ARG_TEMPLATE; bool flag_find_input_name = false; - if (!inplace_map.empty()) { - for (auto& inplace_pair : inplace_map) { + if (!forward_inplace_map.empty()) { + for (auto& inplace_pair : forward_inplace_map) { if (inplace_pair.second == input_name) { flag_find_input_name = true; FWD_INS_ARG_TEMPLATE = "paddle::experimental::Tensor& %s"; @@ -1605,15 +1625,16 @@ static std::pair GenerateForwardFunctionContents( } core_ops_args_info[op_type].push_back(output_name); - } else if (!inplace_map.empty() && inplace_map.count(output_name)) { + } else if (!forward_inplace_map.empty() && + forward_inplace_map.count(output_name)) { // In inplace op, replace the output with the input directly. PADDLE_ENFORCE_NE( - inplace_map[output_name], "", + forward_inplace_map[output_name], "", paddle::platform::errors::InvalidArgument( "Inplace op %s has no input corresponding to output %s.", op_type, output_name)); const char* FWD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", ins[\"%s\"] },"; - auto inplace_input_name = inplace_map[output_name]; + auto inplace_input_name = forward_inplace_map[output_name]; outs_contents_str += paddle::string::Sprintf( FWD_OUTS_CONTENT_TEMPLATE, output_name, inplace_input_name); @@ -1651,7 +1672,7 @@ static std::pair GenerateForwardFunctionContents( if (inplace_mapping_str.size() > 0) inplace_mapping_str.pop_back(); // Remove trailing "," - if ((op_type != "cast") && (inplace_map.empty())) { + if ((op_type != "cast") && (forward_inplace_map.empty())) { VLOG(6) << "Generating Dygraph Forward AMP"; const char* AMP_LOGIC_CONTEXT = " if (egr::Controller::Instance().GetAMPLevel() != " @@ -1743,7 +1764,7 @@ static std::pair GenerateForwardFunctionContents( VLOG(6) << "Generated Outs Map"; // [Generation] Apply View Strategy (Tensor) - if (inplace_map.empty() && view_op_map.count(op_type)) { + if (forward_inplace_map.empty() && view_op_map.count(op_type)) { const char* HANDLE_VIEW_BETWEEN_INPUT_AND_OUTPUT = " if (ins.count(\"%s\") && outs.count(\"%s\")) {\n" " egr::EagerUtils::HandleViewBetweenInputAndOutput(ins[\"%s\"][0], " @@ -1852,10 +1873,11 @@ static std::pair GenerateForwardFunctionContents( output_varname, output_var_args_name); } } else { - if (!inplace_map.empty() && inplace_map.count(output_name)) { + if (!forward_inplace_map.empty() && + forward_inplace_map.count(output_name)) { // Modify meta info of inplace tensor. // Bump inplace version of inplace tensor. - auto inplace_input_name = inplace_map[output_name]; + auto inplace_input_name = forward_inplace_map[output_name]; const char* FWD_OUT_TENSOR_TEMPLATE = " egr::EagerUtils::GetOutput(outs[\"%s\"][0], &%s);\n" " %s.bump_inplace_version();\n" @@ -1878,10 +1900,11 @@ static std::pair GenerateForwardFunctionContents( return_types[return_position] = "paddle::experimental::Tensor"; } - if (!inplace_map.empty() && inplace_map.count(output_name)) { + if (!forward_inplace_map.empty() && + forward_inplace_map.count(output_name)) { // Replace output directly with input in inplace op. return_contents[return_position] = - LegalizeVarName(inplace_map[output_name]); + LegalizeVarName(forward_inplace_map[output_name]); } else { return_contents[return_position] = output_varname; } @@ -1903,7 +1926,7 @@ static std::pair GenerateForwardFunctionContents( // If GradNode needs to be generated, pass `trace_op_body_str` // into `GenerateGradNodeCreationContent`. std::string grad_node_creation_body_str = GenerateGradNodeCreationContent( - fwd_info, bwd_info, trace_op_body_str, inplace_map); + fwd_info, bwd_info, trace_op_body_str, forward_inplace_map); generated_function_body += grad_node_creation_body_str; generated_function_body += "\n"; @@ -1960,7 +1983,7 @@ static std::pair GenerateForwardFunctionContents( // [Generation] Get Full Function std::string function_name; - if (inplace_map.empty()) { + if (forward_inplace_map.empty()) { function_name = op_type + "_dygraph_function"; } else { // change function_name for inplace op. @@ -2013,6 +2036,7 @@ static std::string GenerateSingleOpBase( std::vector>>& grad_outs, const paddle::framework::AttributeMap& grad_attrs, + const std::unordered_map& backward_inplace_map, bool is_op_base_per_duplicable_input, size_t* outs_size) { std::string generated_grad_function_body = ""; @@ -2029,6 +2053,23 @@ static std::string GenerateSingleOpBase( for (const auto& in : in_vars) { if (in.duplicable()) duplicable_input_name_set.insert(in.name()); } + const char* CHECK_BACKWARD_INPLACE_TEMPLATE = + " // Check backward inplace info\n" + " bool %s = false;\n" + " %s\n" + " if (%s.initialized()) {\n" + " VLOG(10) << %s.name() << \"(%s) use_count: \" << " + "%s.impl().use_count();\n" + " if (%s.impl().use_count() == 1 || (%s.impl().use_count() == 2 && " + "%s.impl().get() == %s.impl().get())) {\n" + " %s = true;\n" + " }\n" + " }\n"; + const std::string& can_be_inplaced_name = + "can_be_inplaced" + std::to_string(*outs_size); + const std::string& bwd_inplace_input_name = + "backward_inplace_tensor" + std::to_string(*outs_size); + bool process_backward_inplace = false; std::string ins_contents_str = ""; for (auto iter : grad_ins) { const std::string& grad_input_name = iter.first; @@ -2051,7 +2092,26 @@ static std::string GenerateSingleOpBase( ins_contents_str += paddle::string::Sprintf(GRAD_INS_FWD_CONTENT_TEMPLATE, grad_input_name, struct_fwd_input_name); - + if (!backward_inplace_map.empty() && + backward_inplace_map.count(grad_input_name)) { + process_backward_inplace = true; + const char* GRAD_INS_FWD_TENSOR_WRAPPER_TEMPLATE = + "auto %s = egr::EagerUtils::RecoverTensorWrapper(&this->%s);"; + std::string tensor_wrapper_str = paddle::string::Sprintf( + GRAD_INS_FWD_TENSOR_WRAPPER_TEMPLATE, bwd_inplace_input_name, + struct_fwd_input_name); + const char* GRAD_INS_FWD_TENSOR_TEMPLATE = + "(&this->%s)->get_intermidiate_tensor()"; + std::string tensor_wrapper_intermidiate_tensor_str = + paddle::string::Sprintf(GRAD_INS_FWD_TENSOR_TEMPLATE, + struct_fwd_input_name); + generated_grad_function_body += paddle::string::Sprintf( + CHECK_BACKWARD_INPLACE_TEMPLATE, can_be_inplaced_name, + tensor_wrapper_str, bwd_inplace_input_name, bwd_inplace_input_name, + grad_input_name, bwd_inplace_input_name, bwd_inplace_input_name, + bwd_inplace_input_name, bwd_inplace_input_name, + tensor_wrapper_intermidiate_tensor_str, can_be_inplaced_name); + } } else if (grad_ins_grad_slotname_map.count(grad_input_name)) { // Fwd Tensor's Grad size_t fwd_output_position = fwd_outputs_name_pos_map.at( @@ -2060,7 +2120,24 @@ static std::string GenerateSingleOpBase( "{ \"%s\", egr::EagerUtils::TrySyncToVars(hooked_grads[%d]) },"; ins_contents_str += paddle::string::Sprintf( GRAD_INS_GRAD_CONTENT_TEMPLATE, grad_input_name, fwd_output_position); - + if (!backward_inplace_map.empty() && + backward_inplace_map.count(grad_input_name)) { + process_backward_inplace = true; + const char* GRAD_INS_HOOKED_GRAD_TEMPLATE = + "auto& %s = hooked_grads[%d][0];"; + std::string hooked_grads_tensor_str = paddle::string::Sprintf( + GRAD_INS_HOOKED_GRAD_TEMPLATE, bwd_inplace_input_name, + fwd_output_position); + const char* GRAD_INS_GRAD_TENSOR_TEMPLATE = "grads[%d][0]"; + std::string grads_tensor_str = paddle::string::Sprintf( + GRAD_INS_GRAD_TENSOR_TEMPLATE, fwd_output_position); + generated_grad_function_body += paddle::string::Sprintf( + CHECK_BACKWARD_INPLACE_TEMPLATE, can_be_inplaced_name, + hooked_grads_tensor_str, bwd_inplace_input_name, + bwd_inplace_input_name, grad_input_name, bwd_inplace_input_name, + bwd_inplace_input_name, bwd_inplace_input_name, + bwd_inplace_input_name, grads_tensor_str, can_be_inplaced_name); + } } else { PADDLE_THROW(platform::errors::Fatal( "Detected mismatched slot names." @@ -2245,6 +2322,27 @@ static std::string GenerateSingleOpBase( VLOG(6) << "Generated Outs Map"; + // [Generation] Process Backward Inplace + if (process_backward_inplace) { + const char* HANDLE_BACKWARD_INPLACE_BETWEEN_INPUT_AND_OUTPUT = + " if (%s && %s.count(\"%s\") && %s.count(\"%s\")) {\n" + " egr::EagerUtils::HandleViewBetweenInputAndOutput(%s[\"%s\"][0], " + "%s[\"%s\"][0]);\n" + " };\n"; + std::string backward_inplace_map_str = ""; + for (auto iter : backward_inplace_map) { + std::string backward_inplace_input_name = iter.first; + std::string backward_inplace_output_name = iter.second; + backward_inplace_map_str += paddle::string::Sprintf( + HANDLE_BACKWARD_INPLACE_BETWEEN_INPUT_AND_OUTPUT, + can_be_inplaced_name, ins_name, backward_inplace_input_name, + outs_name, backward_inplace_output_name, ins_name, + backward_inplace_input_name, outs_name, backward_inplace_output_name); + } + generated_grad_function_body += backward_inplace_map_str; + VLOG(6) << "Process Backward Inplace"; + } + // [Generation] Get Attrs Map const char* ATTRS_TEMPLATE = " auto& %s = this->attr_map_;\n"; std::string grad_attrs_str = @@ -2428,13 +2526,15 @@ static std::string GenerateGradNodeCCContents( const auto& grad_ins = op_base_info.GetGradIns(); const auto& grad_outs = op_base_info.GetGradOuts(); const auto& grad_attrs = op_base_info.GetGradAttrs(); + const auto& backward_inplace_map = op_base_info.GetBackwardInplaceMap(); const std::string& op_base_type = op_base_info.GetOpBaseType(); generated_grad_function_body += GenerateSingleOpBase( fwd_op_type, op_base_type, fwd_inputs_name_pos_map, fwd_outputs_name_pos_map, in_vars, grad_ins_fwd_slotname_map, grad_ins_grad_slotname_map, grad_outs_slotname_map, grad_ins, grad_outs, - grad_attrs, is_op_base_per_duplicable_input, &outs_size); + grad_attrs, backward_inplace_map, is_op_base_per_duplicable_input, + &outs_size); } if (is_op_base_per_duplicable_input) { @@ -2847,19 +2947,20 @@ static void DygraphCodeGeneration(const std::string& output_dir) { auto& infer_inplace = paddle::framework::OpInfoMap::Instance().Get(op_type).infer_inplace_; - std::map inplace_map; + std::map forward_inplace_map; // Inplace Function Generator. // `sum` op has duplicate input. Don't consider adding inplace strategy // for `sum` in temporary. if (infer_inplace && !special_inplace_op_set.count(op_type)) { auto in_to_outs = infer_inplace(true); for (auto& inplace_pair : in_to_outs) { - inplace_map[inplace_pair.second] = inplace_pair.first; + forward_inplace_map[inplace_pair.second] = inplace_pair.first; } VLOG(6) << "-------- GenerateInplaceForwardFunctionContents -------"; std::pair inplace_body_and_declaration = - GenerateForwardFunctionContents(fwd_info, bwd_info, inplace_map); + GenerateForwardFunctionContents(fwd_info, bwd_info, + forward_inplace_map); fwd_function_str += inplace_body_and_declaration.first + "\n";