diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index fa0a29525658e50a9c4dcb6532f8f955a5168d74..3bac4046909e4ca51d34c9ebc0ea8ae07339e1db 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -172,7 +172,7 @@ FORWARD_FUNCTION_TEMPLATE = \ """ -NODE_CREATION_TEMPLATE = \ +FORWARD_BODY_TEMPLATE = \ """ // Get AutoGradMeta {} @@ -305,7 +305,6 @@ BUMP_INPLACE_VERSION_TEMPLATE = \ VLOG(3) << \"Tensor(\" << {}.name() << \") uses Inplace Strategy.\";\n """ - AMP_LOGIC_TEMPLATE = \ """ if (egr::Controller::Instance().GetAMPLevel() != paddle::imperative::AmpLevel::O0) {{ @@ -363,7 +362,7 @@ def GenerateCoreOpInfoDefinition(): ##################### ## Generator Class ## ##################### -class DygraphSingleFunctionGenerator(FunctionGeneratorBase): +class DygraphFunctionGeneratorBase(FunctionGeneratorBase): def __init__(self, forward_api_contents, grad_api_contents, namespace): self.forward_api_contents = forward_api_contents # Members from Parent: @@ -409,12 +408,6 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): self.backward_grad_outputs_map = { } #{ "name" : [type, fwd_position, orig_position] ...} - # Generated Results - self.forward_definition_str = "" - self.forward_declaration_str = "" - self.node_declaration_str = "" - self.node_definition_str = "" - def DygraphYamlValidationCheck(self): forward_api_contents = self.forward_api_contents grad_api_contents = self.grad_api_contents @@ -632,139 +625,237 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): f"Generated Backward Grad Output Map: {self.backward_grad_outputs_map}" ) - def GenerateNodeDeclaration(self): - forward_op_name = self.forward_api_name + def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced): + forward_api_name = self.forward_api_name + forward_inputs_position_map = self.forward_inputs_position_map + forward_outputs_position_map = self.forward_outputs_position_map + forward_attrs_list = self.forward_attrs_list backward_forward_inputs_map = self.backward_forward_inputs_map + backward_grad_inputs_map = self.backward_grad_inputs_map + backward_grad_outputs_map = self.backward_grad_outputs_map backward_attrs_list = self.backward_attrs_list - no_need_buffers = self.no_need_buffers + optional_inputs = self.optional_inputs + inplace_map = self.inplace_map if is_inplaced else {} - # SetTensorWrapper Methods & TensorWrapper Members - set_tensor_wrapper_methods_str = "" - tensor_wrapper_members_str = "" - clear_tensor_wrapper_str = "" - for tname, (ttype, is_fwd_input, - _) in backward_forward_inputs_map.items(): - no_need_buffer = "true" if tname in no_need_buffers else "false" - tensor_wrapper_name = GetSavedName(tname) + # Get Input AutoGradMeta + inputs_autograd_meta_list = [] + compute_require_grad_args_list = ["trace_backward"] + for name, (ttype, pos) in forward_inputs_position_map.items(): + input_autograd_meta_name = GetAutoGradMetaName(name) if IsPlainTensorType(ttype): - set_tensor_wrapper_methods_str += SET_PLAIN_TENSOR_WRAPPER_TEMPLATE.format( - tname, tname, tensor_wrapper_name, tname, no_need_buffer) - - tensor_wrapper_members_str += PLAIN_TENSOR_MEMBER_TEMPLATE.format( - tensor_wrapper_name) + input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});" + else: + assert IsVectorTensorType(ttype) + input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) + input_autograd_meta = f" std::vector {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n" + input_autograd_meta += f" std::vector* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" - clear_tensor_wrapper_str += CLEAR_TENSOR_WRAPPER_TEMPLATE.format( - tensor_wrapper_name) + inputs_autograd_meta_list.append(input_autograd_meta) + compute_require_grad_args_list.append(input_autograd_meta_name) + inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list) + compute_require_grad_args_str = ",".join(compute_require_grad_args_list) + # Get Output AutoGradMeta + outputs_autograd_meta_list = [] + pass_stop_gradient_args_list = ["false"] + num_fwd_outputs = len(forward_outputs_position_map.keys()) + for name, (rtype, pos) in forward_outputs_position_map.items(): + output_autograd_meta_name = GetAutoGradMetaName(name) + output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) + if num_fwd_outputs == 1: + if IsPlainTensorType(rtype): + output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);" + else: + assert IsVectorTensorType(rtype) + output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n" + output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" else: - assert IsVectorTensorType(ttype) - set_tensor_wrapper_methods_str += SET_VECTOR_TENSOR_WRAPPER_TEMPLATE.format( - tname, tname, tname, tensor_wrapper_name, no_need_buffer) + # Tuple api_result + if IsPlainTensorType(rtype): + output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));" + else: + assert IsVectorTensorType(rtype) + output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n" + output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" - tensor_wrapper_members_str += VECTOR_TENSOR_MEMBER_TEMPLATE.format( - tensor_wrapper_name) + outputs_autograd_meta_list.append(output_autograd_meta) + pass_stop_gradient_args_list.append(output_autograd_meta_name) - clear_tensor_wrapper_str += CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE.format( - tensor_wrapper_name) + # ComputeRequireGrad & PassStopGradient + outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list) + pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list) - # SetAttributes & Attribute Members - set_attribute_methods_str = "" - attribute_members_str = "" - for aname, atype, default_val, _ in backward_attrs_list: - saved_attr_name = GetSavedName(aname) - set_attribute_methods_str += SET_ATTR_METHOD_TEMPLATE.format( - aname, GetConstReference(atype), aname, saved_attr_name, aname) + # Check Inplace + check_inplace_str = "" + bump_inplace_version_str = "" + if is_inplaced: + for inplace_name in inplace_map.keys(): + inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name) + check_inplace_str += CHECK_INPLACE_TEMPLATE.format( + inplace_name, inplace_autograd_meta_name) + bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format( + inplace_name, inplace_name) - if default_val: - attribute_members_str += ATTRIBUTE_MEMBER_WITH_DEFAULT_TEMPLATE.format( - RemoveConstAndReference(atype), saved_attr_name, - default_val) + # Node Construction + num_backward_inputs = len(forward_outputs_position_map.keys()) + num_backward_outputs = len(forward_inputs_position_map.keys()) + grad_node_name = GetGradNodeName(forward_api_name) + + node_construction_str = f" auto grad_node = std::make_shared<{grad_node_name}>({num_backward_inputs}, {num_backward_outputs});" + + # SetAttributes + set_attributes_list = [] + forward_attrs_name_set = set() + for name, _, _, _ in forward_attrs_list: + forward_attrs_name_set.add(name) + + for name, _, default_val_attr, _ in backward_attrs_list: + if name in forward_attrs_name_set: + set_attributes = f" grad_node->SetAttribute{name}({name});" else: - attribute_members_str += ATTRIBUTE_MEMBER_TEMPLATE.format( - RemoveConstAndReference(atype), saved_attr_name) + set_attributes = f" grad_node->SetAttribute{name}({default_val_attr});" + set_attributes_list.append(set_attributes) + set_attributes_str = "\n".join(set_attributes_list) - grad_node_name = GetGradNodeName(forward_op_name) - self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format( - grad_node_name, grad_node_name, grad_node_name, grad_node_name, - grad_node_name, clear_tensor_wrapper_str, - set_tensor_wrapper_methods_str, set_attribute_methods_str, - tensor_wrapper_members_str, attribute_members_str) + # SetTensorWrappers + set_tensor_wrappers_list = [] + for name, (atype, is_fwd_input, + pos) in backward_forward_inputs_map.items(): + is_optional = (name in optional_inputs) - logging.info(f"Generated Node Declaration: {self.node_declaration_str}") + if is_fwd_input: + if is_optional: + set_tensor_wrappers = f" if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), true);" + else: + set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);" + else: + if num_fwd_outputs > 1: + # Aligned with forward output position + assert name in forward_outputs_position_map.keys( + ), AssertMessage(name, forward_outputs_position_map.keys()) + fwd_output_pos = forward_outputs_position_map[name][1] + tw_name = f"std::get<{fwd_output_pos}>(api_result)" + else: + tw_name = f"api_result" - def GenerateNodeDefinition(self): - namespace = self.namespace - forward_api_name = self.forward_api_name - backward_api_name = self.backward_api_name - backward_forward_inputs_map = self.backward_forward_inputs_map - backward_grad_inputs_map = self.backward_grad_inputs_map - backward_grad_outputs_map = self.backward_grad_outputs_map - backward_attrs_list = self.backward_attrs_list + if is_optional: + set_tensor_wrappers = f" if({tw_name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({tw_name}.get_ptr()), false);" + else: + set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({tw_name}, false);" + set_tensor_wrappers_list.append(set_tensor_wrappers) + set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list) - # Construct grad_api function args - # Order: TensorWrappers, GradTensors, Attributes - grad_api_args_len = len(backward_forward_inputs_map.keys()) + len( - backward_grad_inputs_map.keys()) + len(backward_attrs_list) - grad_api_args = ["" for i in range(grad_api_args_len)] - for name, (_, is_fwd_input, - grad_api_position), in backward_forward_inputs_map.items(): - tensor_wrapper_name = GetSavedName(name) + # SetGradOutMeta & SetEdges + set_grad_out_meta_list = [] + set_edges_list = [] + for name, (_, pos) in forward_inputs_position_map.items(): + input_autograd_meta_name = GetAutoGradMetaName(name) is_optional = (name in self.optional_inputs) if is_optional: - grad_api_args[ - grad_api_position] = f"egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, nullptr)" + set_grad_out_meta = f" if({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});" + set_edges = f" if({name}.get_ptr() != nullptr) grad_node->AddEdges({input_autograd_meta_name}, {pos});" else: - grad_api_args[ - grad_api_position] = f"egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr)" + set_grad_out_meta = f" grad_node->SetGradOutMeta({name}, {pos});" + set_edges = f" grad_node->AddEdges({input_autograd_meta_name}, {pos});" - for _, (ttype, fwd_position, - grad_api_position) in backward_grad_inputs_map.items(): - if IsPlainTensorType(ttype): - grad_api_args[ - grad_api_position] = f"hooked_grads[{fwd_position}][0]" - else: - assert IsVectorTensorType(ttype) - grad_api_args[ - grad_api_position] = f"hooked_grads[{fwd_position}]" + set_grad_out_meta_list.append(set_grad_out_meta) + set_edges_list.append(set_edges) + set_grad_out_meta_str = "\n".join(set_grad_out_meta_list) + set_edges_str = "\n".join(set_edges_list) - for name, _, _, grad_api_position in backward_attrs_list: - saved_attribute_name = GetSavedName(name) - grad_api_args[grad_api_position] = f"this->{saved_attribute_name}" - grad_api_args_str = ", ".join(grad_api_args) + # SetOutRank & SetHistory & SetGradInMeta + set_out_rank_list = [] + set_history_list = [] + set_grad_in_meta_list = [] + set_retain_grad_list = [] + num_outputs = len(forward_outputs_position_map.keys()) + for name, (_, pos) in forward_outputs_position_map.items(): + output_autograd_meta_name = GetAutoGradMetaName(name) + set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});" + set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);" - # Construct grad_api returns - num_bwd_outputs = len(backward_grad_outputs_map.keys()) - slot_num_bwd_outputs = len(self.forward_inputs_position_map.keys()) - returns_str = f"std::vector> returns({slot_num_bwd_outputs});\n" - for _, (ttype, fwd_position, - grad_api_position) in backward_grad_outputs_map.items(): - # Infer Grad API Return Type - if num_bwd_outputs == 1: - # Single tensor output, return as is - if IsPlainTensorType(ttype): - returns_str += "returns[0] = { grad_api_returns };\n" - else: - assert IsVectorTensorType(ttype) - returns_str += "returns[0] = grad_api_returns;\n" + if num_outputs == 1: + set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);" + set_grad_in_meta = f" grad_node->SetGradInMeta(api_result, {pos});" else: - # Rearrange output order accordingly - returns_str += f"returns[{fwd_position}] = grad_api_returns[{grad_api_position}];\n" - returns_str += f"if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n" - returns_str += f"return returns;\n" + set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(std::get<{pos}>(api_result));" + set_grad_in_meta = f" grad_node->SetGradInMeta(std::get<{pos}>(api_result), {pos});" + set_out_rank_list.append(set_out_rank) + set_history_list.append(set_history) + set_grad_in_meta_list.append(set_grad_in_meta) + set_retain_grad_list.append(set_retain_grad) - grad_node_name = GetGradNodeName(forward_api_name) + set_out_rank_str = "\n".join(set_out_rank_list) + set_history_str = "\n".join(set_history_list) + set_grad_in_meta_str = "\n".join(set_grad_in_meta_list) + set_retain_grad_str = "\n".join(set_retain_grad_list) - fill_zero_str = "" - if forward_api_name in ops_to_fill_zero_for_empty_grads: - fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n" + node_event_name = forward_api_name + " node_creation" + node_creation_event_str = f"paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n" - grad_api_namespace = f"paddle::experimental::{namespace}" + self.node_creation_str = FORWARD_BODY_TEMPLATE.format( + inputs_autograd_meta_str, compute_require_grad_args_str, + check_inplace_str, forward_call_str, bump_inplace_version_str, + node_creation_event_str, outputs_autograd_meta_str, + pass_stop_gradient_args_str, node_construction_str, + set_attributes_str, set_tensor_wrappers_str, set_grad_out_meta_str, + set_edges_str, set_out_rank_str, set_history_str, + set_grad_in_meta_str, set_retain_grad_str) - self.node_definition_str = FUNCTION_TEMPLATE.format( - grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace, - backward_api_name, grad_api_args_str, returns_str) + def run(self): + # Basic Validation Check + self.DygraphYamlValidationCheck() - logging.info(f"Generated Node Definition: {self.node_definition_str}") + ########################## + ## Parsing Raw Contents ## + ########################## + # Parse inplace_map + self.ParseInplaceInfo() + + # Parse no_need_buffer + self.ParseNoNeedBuffer() + + # Parse optional_inputs + self.ParseDispensable() + + # Parse intermediate_outputs + self.ParseIntermediate() + self.IntermediateValidationCheck() + + # Initialize backward_forward_str, backward_inputs_list, backward_attrs_list, backward_returns_list + self.CollectBackwardInfo() + + # Initialize forward_inputs_list, forward_attrs_list, forward_returns_list + self.CollectForwardInfoFromBackwardContents() + + # Initialize orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list + self.CollectOriginalForwardInfo() + + # Forwards Validation Check + self.ForwardsValidationCheck() + + ############################# + ## Process Parsed Contents ## + ############################# + # Initialize forward_inputs_position_map, forward_outputs_position_map + self.DetermineForwardPositionMap(self.forward_inputs_list, + self.forward_returns_list) + + # Initialize forward_inputs_position_map, forward_outputs_position_map + self.SlotNameMatching() + + # Backward Validation Check + self.BackwardValidationCheck() + + +class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): + def __init__(self, forward_api_contents, grad_api_contents, namespace): + DygraphFunctionGeneratorBase.__init__(self, forward_api_contents, + grad_api_contents, namespace) + + # Generated Results + self.forward_definition_str = "" + self.forward_declaration_str = "" def GenerateForwardDefinition(self, is_inplaced): namespace = self.namespace @@ -780,7 +871,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): backward_attrs_list = self.backward_attrs_list optional_inputs = self.optional_inputs intermediate_outputs = self.intermediate_outputs - inplace_map = self.inplace_map + inplace_map = self.inplace_map if is_inplaced else {} # Get Function Args num_inputs = len(forward_attrs_list) + len( @@ -918,182 +1009,6 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): logging.info( f"Generated Forward Declaration: {self.forward_declaration_str}") - def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced): - forward_api_name = self.forward_api_name - forward_inputs_position_map = self.forward_inputs_position_map - forward_outputs_position_map = self.forward_outputs_position_map - forward_attrs_list = self.forward_attrs_list - backward_forward_inputs_map = self.backward_forward_inputs_map - backward_grad_inputs_map = self.backward_grad_inputs_map - backward_grad_outputs_map = self.backward_grad_outputs_map - backward_attrs_list = self.backward_attrs_list - optional_inputs = self.optional_inputs - inplace_map = self.inplace_map - - # Get Input AutoGradMeta - inputs_autograd_meta_list = [] - compute_require_grad_args_list = ["trace_backward"] - for name, (ttype, pos) in forward_inputs_position_map.items(): - input_autograd_meta_name = GetAutoGradMetaName(name) - if IsPlainTensorType(ttype): - input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});" - else: - assert IsVectorTensorType(ttype) - input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) - input_autograd_meta = f" std::vector {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n" - input_autograd_meta += f" std::vector* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" - - inputs_autograd_meta_list.append(input_autograd_meta) - compute_require_grad_args_list.append(input_autograd_meta_name) - inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list) - compute_require_grad_args_str = ",".join(compute_require_grad_args_list) - - # Get Output AutoGradMeta - outputs_autograd_meta_list = [] - pass_stop_gradient_args_list = ["false"] - num_fwd_outputs = len(forward_outputs_position_map.keys()) - for name, (rtype, pos) in forward_outputs_position_map.items(): - output_autograd_meta_name = GetAutoGradMetaName(name) - output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) - if num_fwd_outputs == 1: - if IsPlainTensorType(rtype): - output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);" - else: - assert IsVectorTensorType(rtype) - output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n" - output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" - else: - # Tuple api_result - if IsPlainTensorType(rtype): - output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));" - else: - assert IsVectorTensorType(rtype) - output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n" - output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" - - outputs_autograd_meta_list.append(output_autograd_meta) - pass_stop_gradient_args_list.append(output_autograd_meta_name) - - # ComputeRequireGrad & PassStopGradient - outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list) - pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list) - - # Check Inplace - check_inplace_str = "" - bump_inplace_version_str = "" - if is_inplaced: - for inplace_name in inplace_map.keys(): - inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name) - check_inplace_str += CHECK_INPLACE_TEMPLATE.format( - inplace_name, inplace_autograd_meta_name) - bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format( - inplace_name, inplace_name) - - # Node Construction - num_backward_inputs = len(forward_outputs_position_map.keys()) - num_backward_outputs = len(forward_inputs_position_map.keys()) - grad_node_name = GetGradNodeName(forward_api_name) - - node_construction_str = f" auto grad_node = std::make_shared<{grad_node_name}>({num_backward_inputs}, {num_backward_outputs});" - - # SetAttributes - set_attributes_list = [] - forward_attrs_name_set = set() - for name, _, _, _ in forward_attrs_list: - forward_attrs_name_set.add(name) - - for name, _, default_val_attr, _ in backward_attrs_list: - if name in forward_attrs_name_set: - set_attributes = f" grad_node->SetAttribute{name}({name});" - else: - set_attributes = f" grad_node->SetAttribute{name}({default_val_attr});" - set_attributes_list.append(set_attributes) - set_attributes_str = "\n".join(set_attributes_list) - - # SetTensorWrappers - set_tensor_wrappers_list = [] - for name, (atype, is_fwd_input, - pos) in backward_forward_inputs_map.items(): - is_optional = (name in optional_inputs) - - if is_fwd_input: - if is_optional: - set_tensor_wrappers = f" if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), true);" - else: - set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);" - else: - if num_fwd_outputs > 1: - # Aligned with forward output position - assert name in forward_outputs_position_map.keys( - ), AssertMessage(name, forward_outputs_position_map.keys()) - fwd_output_pos = forward_outputs_position_map[name][1] - tw_name = f"std::get<{fwd_output_pos}>(api_result)" - else: - tw_name = f"api_result" - - if is_optional: - set_tensor_wrappers = f" if({tw_name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({tw_name}.get_ptr()), false);" - else: - set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({tw_name}, false);" - set_tensor_wrappers_list.append(set_tensor_wrappers) - set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list) - - # SetGradOutMeta & SetEdges - set_grad_out_meta_list = [] - set_edges_list = [] - for name, (_, pos) in forward_inputs_position_map.items(): - input_autograd_meta_name = GetAutoGradMetaName(name) - is_optional = (name in self.optional_inputs) - if is_optional: - set_grad_out_meta = f" if({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});" - set_edges = f" if({name}.get_ptr() != nullptr) grad_node->AddEdges({input_autograd_meta_name}, {pos});" - else: - set_grad_out_meta = f" grad_node->SetGradOutMeta({name}, {pos});" - set_edges = f" grad_node->AddEdges({input_autograd_meta_name}, {pos});" - set_grad_out_meta_list.append(set_grad_out_meta) - set_edges_list.append(set_edges) - set_grad_out_meta_str = "\n".join(set_grad_out_meta_list) - set_edges_str = "\n".join(set_edges_list) - - # SetOutRank & SetHistory & SetGradInMeta - set_out_rank_list = [] - set_history_list = [] - set_grad_in_meta_list = [] - set_retain_grad_list = [] - num_outputs = len(forward_outputs_position_map.keys()) - for name, (_, pos) in forward_outputs_position_map.items(): - output_autograd_meta_name = GetAutoGradMetaName(name) - set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});" - set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);" - - if num_outputs == 1: - set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);" - set_grad_in_meta = f" grad_node->SetGradInMeta(api_result, {pos});" - else: - set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(std::get<{pos}>(api_result));" - set_grad_in_meta = f" grad_node->SetGradInMeta(std::get<{pos}>(api_result), {pos});" - set_out_rank_list.append(set_out_rank) - set_history_list.append(set_history) - set_grad_in_meta_list.append(set_grad_in_meta) - set_retain_grad_list.append(set_retain_grad) - - set_out_rank_str = "\n".join(set_out_rank_list) - set_history_str = "\n".join(set_history_list) - set_grad_in_meta_str = "\n".join(set_grad_in_meta_list) - set_retain_grad_str = "\n".join(set_retain_grad_list) - - node_event_name = forward_api_name + " node_creation" - node_creation_event_str = f"paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n" - - self.node_creation_str = NODE_CREATION_TEMPLATE.format( - inputs_autograd_meta_str, compute_require_grad_args_str, - check_inplace_str, forward_call_str, bump_inplace_version_str, - node_creation_event_str, outputs_autograd_meta_str, - pass_stop_gradient_args_str, node_construction_str, - set_attributes_str, set_tensor_wrappers_str, set_grad_out_meta_str, - set_edges_str, set_out_rank_str, set_history_str, - set_grad_in_meta_str, set_retain_grad_str) - def GenerateInplacedForwardDygraphFunctions(self): # Inplaced Version Dygraph Function Generation forward_api_name = self.forward_api_name @@ -1139,60 +1054,168 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): core_ops_returns_info[final_state_fwd_api_name][pos] = name def run(self): - # Basic Validation Check - self.DygraphYamlValidationCheck() + super().run() - ########################## - ## Parsing Raw Contents ## - ########################## - # Parse inplace_map - self.ParseInplaceInfo() + ##################### + ## Code Generation ## + ##################### + self.GenerateForwardDefinition(is_inplaced=False) - # Parse no_need_buffer - self.ParseNoNeedBuffer() + self.UpdateCoreOpsInformation(is_inplaced=False) - # Parse optional_inputs - self.ParseDispensable() + self.GenerateInplacedForwardDygraphFunctions() - # Parse intermediate_outputs - self.ParseIntermediate() - self.IntermediateValidationCheck() - # Initialize backward_forward_str, backward_inputs_list, backward_attrs_list, backward_returns_list - self.CollectBackwardInfo() +class DygraphNodeGenerator(DygraphFunctionGeneratorBase): + def __init__(self, forward_api_contents, grad_api_contents, namespace): + DygraphFunctionGeneratorBase.__init__(self, forward_api_contents, + grad_api_contents, namespace) - # Initialize forward_inputs_list, forward_attrs_list, forward_returns_list - self.CollectForwardInfoFromBackwardContents() + # Generated Results + self.node_declaration_str = "" + self.node_definition_str = "" - # Initialize orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list - self.CollectOriginalForwardInfo() + def GenerateNodeDeclaration(self): + forward_op_name = self.forward_api_name + backward_forward_inputs_map = self.backward_forward_inputs_map + backward_attrs_list = self.backward_attrs_list + no_need_buffers = self.no_need_buffers - # Forwards Validation Check - self.ForwardsValidationCheck() + # SetTensorWrapper Methods & TensorWrapper Members + set_tensor_wrapper_methods_str = "" + tensor_wrapper_members_str = "" + clear_tensor_wrapper_str = "" + for tname, (ttype, is_fwd_input, + _) in backward_forward_inputs_map.items(): + no_need_buffer = "true" if tname in no_need_buffers else "false" + tensor_wrapper_name = GetSavedName(tname) + if IsPlainTensorType(ttype): + set_tensor_wrapper_methods_str += SET_PLAIN_TENSOR_WRAPPER_TEMPLATE.format( + tname, tname, tensor_wrapper_name, tname, no_need_buffer) - ############################# - ## Process Parsed Contents ## - ############################# - # Initialize forward_inputs_position_map, forward_outputs_position_map - self.DetermineForwardPositionMap(self.forward_inputs_list, - self.forward_returns_list) + tensor_wrapper_members_str += PLAIN_TENSOR_MEMBER_TEMPLATE.format( + tensor_wrapper_name) - # Initialize forward_inputs_position_map, forward_outputs_position_map - self.SlotNameMatching() + clear_tensor_wrapper_str += CLEAR_TENSOR_WRAPPER_TEMPLATE.format( + tensor_wrapper_name) - # Backward Validation Check - self.BackwardValidationCheck() + else: + assert IsVectorTensorType(ttype) + set_tensor_wrapper_methods_str += SET_VECTOR_TENSOR_WRAPPER_TEMPLATE.format( + tname, tname, tname, tensor_wrapper_name, no_need_buffer) + + tensor_wrapper_members_str += VECTOR_TENSOR_MEMBER_TEMPLATE.format( + tensor_wrapper_name) + + clear_tensor_wrapper_str += CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE.format( + tensor_wrapper_name) + + # SetAttributes & Attribute Members + set_attribute_methods_str = "" + attribute_members_str = "" + for aname, atype, default_val, _ in backward_attrs_list: + saved_attr_name = GetSavedName(aname) + set_attribute_methods_str += SET_ATTR_METHOD_TEMPLATE.format( + aname, GetConstReference(atype), aname, saved_attr_name, aname) + + if default_val: + attribute_members_str += ATTRIBUTE_MEMBER_WITH_DEFAULT_TEMPLATE.format( + RemoveConstAndReference(atype), saved_attr_name, + default_val) + else: + attribute_members_str += ATTRIBUTE_MEMBER_TEMPLATE.format( + RemoveConstAndReference(atype), saved_attr_name) + + grad_node_name = GetGradNodeName(forward_op_name) + self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format( + grad_node_name, grad_node_name, grad_node_name, grad_node_name, + grad_node_name, clear_tensor_wrapper_str, + set_tensor_wrapper_methods_str, set_attribute_methods_str, + tensor_wrapper_members_str, attribute_members_str) + + logging.info(f"Generated Node Declaration: {self.node_declaration_str}") + + def GenerateNodeDefinition(self): + namespace = self.namespace + forward_api_name = self.forward_api_name + backward_api_name = self.backward_api_name + backward_forward_inputs_map = self.backward_forward_inputs_map + backward_grad_inputs_map = self.backward_grad_inputs_map + backward_grad_outputs_map = self.backward_grad_outputs_map + backward_attrs_list = self.backward_attrs_list + + # Construct grad_api function args + # Order: TensorWrappers, GradTensors, Attributes + grad_api_args_len = len(backward_forward_inputs_map.keys()) + len( + backward_grad_inputs_map.keys()) + len(backward_attrs_list) + grad_api_args = ["" for i in range(grad_api_args_len)] + for name, (_, is_fwd_input, + grad_api_position), in backward_forward_inputs_map.items(): + tensor_wrapper_name = GetSavedName(name) + is_optional = (name in self.optional_inputs) + if is_optional: + grad_api_args[ + grad_api_position] = f"egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, nullptr)" + else: + grad_api_args[ + grad_api_position] = f"egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr)" + for _, (ttype, fwd_position, + grad_api_position) in backward_grad_inputs_map.items(): + if IsPlainTensorType(ttype): + grad_api_args[ + grad_api_position] = f"hooked_grads[{fwd_position}][0]" + else: + assert IsVectorTensorType(ttype) + grad_api_args[ + grad_api_position] = f"hooked_grads[{fwd_position}]" + + for name, _, _, grad_api_position in backward_attrs_list: + saved_attribute_name = GetSavedName(name) + grad_api_args[grad_api_position] = f"this->{saved_attribute_name}" + grad_api_args_str = ", ".join(grad_api_args) + + # Construct grad_api returns + num_bwd_outputs = len(backward_grad_outputs_map.keys()) + slot_num_bwd_outputs = len(self.forward_inputs_position_map.keys()) + returns_str = f"std::vector> returns({slot_num_bwd_outputs});\n" + for _, (ttype, fwd_position, + grad_api_position) in backward_grad_outputs_map.items(): + # Infer Grad API Return Type + if num_bwd_outputs == 1: + # Single tensor output, return as is + if IsPlainTensorType(ttype): + returns_str += "returns[0] = { grad_api_returns };\n" + else: + assert IsVectorTensorType(ttype) + returns_str += "returns[0] = grad_api_returns;\n" + else: + # Rearrange output order accordingly + returns_str += f"returns[{fwd_position}] = grad_api_returns[{grad_api_position}];\n" + returns_str += f"if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n" + returns_str += f"return returns;\n" + + grad_node_name = GetGradNodeName(forward_api_name) + + fill_zero_str = "" + if forward_api_name in ops_to_fill_zero_for_empty_grads: + fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n" + + grad_api_namespace = f"paddle::experimental::{namespace}" + + self.node_definition_str = FUNCTION_TEMPLATE.format( + grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace, + backward_api_name, grad_api_args_str, returns_str) + + logging.info(f"Generated Node Definition: {self.node_definition_str}") + + def run(self): + super().run() ##################### ## Code Generation ## ##################### self.GenerateNodeDeclaration() self.GenerateNodeDefinition() - self.GenerateForwardDefinition(is_inplaced=False) - - self.UpdateCoreOpsInformation(is_inplaced=False) - - self.GenerateInplacedForwardDygraphFunctions() class DygraphYamlGenerator(YamlGeneratorBase): @@ -1239,14 +1262,18 @@ class DygraphYamlGenerator(YamlGeneratorBase): forward_api_contents) if backward_api_contents is None: continue - d_generator = DygraphSingleFunctionGenerator( + function_generator = DygraphForwardFunctionGenerator( + forward_api_contents, backward_api_contents, namespace) + function_generator.run() + + node_generator = DygraphNodeGenerator( forward_api_contents, backward_api_contents, namespace) - d_generator.run() + node_generator.run() - self.forward_definition_str += d_generator.forward_definition_str + "\n" - self.forward_declaration_str += d_generator.forward_declaration_str + "\n" - self.node_declaration_str += d_generator.node_declaration_str + "\n" - self.node_definition_str += d_generator.node_definition_str + "\n" + self.forward_definition_str += function_generator.forward_definition_str + "\n" + self.forward_declaration_str += function_generator.forward_declaration_str + "\n" + self.node_declaration_str += node_generator.node_declaration_str + "\n" + self.node_definition_str += node_generator.node_definition_str + "\n" if len(namespace) > 0: if namespace.endswith("::"): diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py index 797fec9752fab67aa1d8529f2ae4658183d10dc4..63eb1ee46a822be651f6427f8544112f1ed56132 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py @@ -311,7 +311,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): dygraph_function_call_list[pos] = f"{name}" dygraph_function_call_str = ",".join(dygraph_function_call_list) - # Generate Python-C Function Definitions + # Generate Python-C Function Definitions if is_forward_only: fwd_function_name = FUNCTION_NAME_TEMPLATE.format( "paddle::experimental::", namespace, forward_api_name) @@ -337,21 +337,21 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): forward_api_name_prefix, forward_api_name, namespace, forward_api_name, forward_api_name) - if len(inplace_map) > 0: - assert len( - inplace_map - ) == 1, f"size of inplace_map must be 1, but inplace_map of \"{forward_api_name}\" op got {len(inplace_map)}" + if inplace_map: inplaced_forward_api_name = GetInplacedFunctionName( self.forward_api_name) - # Generate Python-C Function Definitions if is_forward_only: - fwd_function_name = FUNCTION_NAME_TEMPLATE.format( + inplaced_fwd_function_name = FUNCTION_NAME_TEMPLATE.format( "paddle::experimental::", namespace, inplaced_forward_api_name) - elif len(inplace_map) > 0: - fwd_function_name = FUNCTION_NAME_TEMPLATE.format( + else: + inplaced_fwd_function_name = FUNCTION_NAME_TEMPLATE.format( "::", namespace, GetForwardFunctionName(inplaced_forward_api_name)) + + assert len( + inplace_map + ) == 1, f"size of inplace_map must be 1, but inplace_map of \"{forward_api_name}\" op got {len(inplace_map)}" for inplace_input, inplace_output in inplace_map.items(): return_str = RETURN_INPLACE_PYOBJECT_TEMPLATE.format( inplaced_forward_api_name, inplace_input, @@ -361,7 +361,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): self.python_c_function_str += PYTHON_C_FUNCTION_TEMPLATE.format( inplaced_forward_api_name, pythonc_record_event_str, inplaced_forward_api_name, get_eager_tensor_str, - parse_attributes_str, fwd_function_name, + parse_attributes_str, inplaced_fwd_function_name, dygraph_function_call_str, return_str) # Generate Python-C Function Registration