diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 3bac4046909e4ca51d34c9ebc0ea8ae07339e1db..f5ae3da91dadc6cabc29af972027455b00ebdc09 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -162,9 +162,24 @@ FUNCTION_TEMPLATE = \ FORWARD_FUNCTION_TEMPLATE = \ """ {} {}({}) {{ - {} - {} - {} + // Dygraph Record Event +{} + // AMP Logic +{} + + // Get Input AutoGradMeta +{} + // Forward API Call +{} + // Get Output AutoGradMeta +{} + bool trace_backward = egr::Controller::Instance().HasGrad(); + bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({}); + // Check Inplace & Bump Inplace Version +{} +{} + // Node Creation +{} // Returns return {}; @@ -174,18 +189,8 @@ FORWARD_FUNCTION_TEMPLATE = \ FORWARD_BODY_TEMPLATE = \ """ - // Get AutoGradMeta -{} - bool trace_backward = egr::Controller::Instance().HasGrad(); - bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({}); -{} - // Forward API Call - {} -{} - {{ -{} -{} if(require_any_grad) {{ +{} egr::EagerUtils::PassStopGradient({}); // Node Construction @@ -203,7 +208,6 @@ FORWARD_BODY_TEMPLATE = \ {} {} }} - }} """ NAMESPACE_WRAPPER_TEMPLATE = \ @@ -294,7 +298,6 @@ CORE_OPS_DECLARATION_TEMPLATE = \ CHECK_INPLACE_TEMPLATE = \ """ - // Check Inplace egr::EagerUtils::CheckInplace({}, {}, require_any_grad);\n """ @@ -625,7 +628,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): f"Generated Backward Grad Output Map: {self.backward_grad_outputs_map}" ) - def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced): + def GenerateNodeCreationCodes(self): forward_api_name = self.forward_api_name forward_inputs_position_map = self.forward_inputs_position_map forward_outputs_position_map = self.forward_outputs_position_map @@ -635,67 +638,14 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): backward_grad_outputs_map = self.backward_grad_outputs_map backward_attrs_list = self.backward_attrs_list optional_inputs = self.optional_inputs - inplace_map = self.inplace_map if is_inplaced else {} - # Get Input AutoGradMeta - inputs_autograd_meta_list = [] - compute_require_grad_args_list = ["trace_backward"] - for name, (ttype, pos) in forward_inputs_position_map.items(): - input_autograd_meta_name = GetAutoGradMetaName(name) - if IsPlainTensorType(ttype): - input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});" - else: - assert IsVectorTensorType(ttype) - input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) - input_autograd_meta = f" std::vector {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n" - input_autograd_meta += f" std::vector* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" - - inputs_autograd_meta_list.append(input_autograd_meta) - compute_require_grad_args_list.append(input_autograd_meta_name) - inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list) - compute_require_grad_args_str = ",".join(compute_require_grad_args_list) - - # Get Output AutoGradMeta - outputs_autograd_meta_list = [] + # Pass Stop Gradient Args pass_stop_gradient_args_list = ["false"] - num_fwd_outputs = len(forward_outputs_position_map.keys()) - for name, (rtype, pos) in forward_outputs_position_map.items(): + for name, (_, _) in forward_outputs_position_map.items(): output_autograd_meta_name = GetAutoGradMetaName(name) - output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) - if num_fwd_outputs == 1: - if IsPlainTensorType(rtype): - output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);" - else: - assert IsVectorTensorType(rtype) - output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n" - output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" - else: - # Tuple api_result - if IsPlainTensorType(rtype): - output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));" - else: - assert IsVectorTensorType(rtype) - output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n" - output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" - - outputs_autograd_meta_list.append(output_autograd_meta) pass_stop_gradient_args_list.append(output_autograd_meta_name) - - # ComputeRequireGrad & PassStopGradient - outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list) pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list) - # Check Inplace - check_inplace_str = "" - bump_inplace_version_str = "" - if is_inplaced: - for inplace_name in inplace_map.keys(): - inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name) - check_inplace_str += CHECK_INPLACE_TEMPLATE.format( - inplace_name, inplace_autograd_meta_name) - bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format( - inplace_name, inplace_name) - # Node Construction num_backward_inputs = len(forward_outputs_position_map.keys()) num_backward_outputs = len(forward_inputs_position_map.keys()) @@ -719,6 +669,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): # SetTensorWrappers set_tensor_wrappers_list = [] + num_fwd_outputs = len(forward_outputs_position_map.keys()) for name, (atype, is_fwd_input, pos) in backward_forward_inputs_map.items(): is_optional = (name in optional_inputs) @@ -794,13 +745,10 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): node_creation_event_str = f"paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n" self.node_creation_str = FORWARD_BODY_TEMPLATE.format( - inputs_autograd_meta_str, compute_require_grad_args_str, - check_inplace_str, forward_call_str, bump_inplace_version_str, - node_creation_event_str, outputs_autograd_meta_str, - pass_stop_gradient_args_str, node_construction_str, - set_attributes_str, set_tensor_wrappers_str, set_grad_out_meta_str, - set_edges_str, set_out_rank_str, set_history_str, - set_grad_in_meta_str, set_retain_grad_str) + node_creation_event_str, pass_stop_gradient_args_str, + node_construction_str, set_attributes_str, set_tensor_wrappers_str, + set_grad_out_meta_str, set_edges_str, set_out_rank_str, + set_history_str, set_grad_in_meta_str, set_retain_grad_str) def run(self): # Basic Validation Check @@ -973,7 +921,64 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): returns_str = ", ".join(returns_list) returns_str = f"std::make_tuple({returns_str})" - self.GenerateNodeCreationCodes(forward_call_str, is_inplaced) + # Node Creation Pre-Processing + # 1. Get Input AutoGradMeta + inputs_autograd_meta_list = [] + compute_require_grad_args_list = ["trace_backward"] + for name, (ttype, pos) in forward_inputs_position_map.items(): + input_autograd_meta_name = GetAutoGradMetaName(name) + if IsPlainTensorType(ttype): + input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});" + else: + assert IsVectorTensorType(ttype) + input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) + input_autograd_meta = f" std::vector {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n" + input_autograd_meta += f" std::vector* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" + + inputs_autograd_meta_list.append(input_autograd_meta) + compute_require_grad_args_list.append(input_autograd_meta_name) + inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list) + compute_require_grad_args_str = ",".join(compute_require_grad_args_list) + + # 2. Get Output AutoGradMeta + outputs_autograd_meta_list = [] + num_fwd_outputs = len(forward_outputs_position_map.keys()) + for name, (rtype, pos) in forward_outputs_position_map.items(): + output_autograd_meta_name = GetAutoGradMetaName(name) + output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) + if num_fwd_outputs == 1: + if IsPlainTensorType(rtype): + output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);" + else: + assert IsVectorTensorType(rtype) + output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n" + output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" + else: + # Tuple api_result + if IsPlainTensorType(rtype): + output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));" + else: + assert IsVectorTensorType(rtype) + output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n" + output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" + + outputs_autograd_meta_list.append(output_autograd_meta) + + # 3. ComputeRequireGrad & PassStopGradient + outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list) + + # 4. Check Inplace + check_inplace_str = "" + bump_inplace_version_str = "" + if is_inplaced: + for inplace_name in inplace_map.keys(): + inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name) + check_inplace_str += CHECK_INPLACE_TEMPLATE.format( + inplace_name, inplace_autograd_meta_name) + bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format( + inplace_name, inplace_name) + + self.GenerateNodeCreationCodes() node_creation_str = self.node_creation_str dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);" @@ -1001,7 +1006,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format( returns_type_str, forward_function_name, inputs_args_definition_str, - dygraph_event_str, amp_logic_str, node_creation_str, returns_str) + dygraph_event_str, amp_logic_str, inputs_autograd_meta_str, + forward_call_str, outputs_autograd_meta_str, + compute_require_grad_args_str, check_inplace_str, + bump_inplace_version_str, node_creation_str, returns_str) self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n" logging.info(