From e31a0a508764cc23b539d1dc2f2511d69811736c Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 17 Aug 2022 19:48:12 +0800 Subject: [PATCH] refine eager_gen for amp (#45211) --- .../final_state_generator/eager_gen.py | 84 +++++++------------ 1 file changed, 29 insertions(+), 55 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index ff4824d78e..fdc7819e31 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -428,20 +428,6 @@ CHECK_NAN_AND_INF_TEMPLATE = \ """ if (FLAGS_check_nan_inf) {{ egr::CheckTensorHasNanOrInf("{}", {}); }} """ -# This list contains ops that do not need to generate amp logic -# All optimizer ops in this list -no_amp_list = [ - 'adam_', 'adam', 'adamw_', 'adamw', 'average_accumulates', - 'average_accumulates_', 'decayed_adagrad_', 'decayed_adagrad', - 'dgc_momentum_', 'dgc_momentum', 'distributed_fused_lamb_', - 'distributed_fused_lamb', 'dpsgd_', 'dpsgd', 'ftrl_', 'ftrl', 'lamb_', - 'lamb', 'lars_momentum_', 'lars_momentum', 'merged_adam_', 'merged_adam', - 'merged_momentum_', 'merged_momentum', 'momentum_', 'momentum', - 'proximal_adagrad_', 'proximal_adagrad', 'proximal_gd_', 'proximal_gd', - 'rmsprop_', 'rmsprop', 'sgd_', 'sgd', 'lamb_', 'lamb', 'assign_value_', - 'sparse_momentum_', 'sparse_momentum', 'full_' -] - inplace_optional_out_type_map = { "Tensor": "paddle::optional&", @@ -1126,19 +1112,17 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): returns_str = f"{returns_type_str}{{{returns_str}}}" # Node Creation Pre-Processing - # 1. Get Input AutoGradMeta - if not self.is_forward_only: + # 1. Get Input AutoGradMeta inputs_autograd_meta_list = [] compute_require_grad_args_list = ["trace_backward"] for name, (ttype, pos) in forward_inputs_position_map.items(): # Has corresponding grad output has_corresponding_grad_output = False - if not self.is_forward_only: - for _, (_, corresponding_pos, - _) in backward_grad_outputs_map.items(): - if pos == corresponding_pos: - has_corresponding_grad_output = True + for _, (_, corresponding_pos, + _) in backward_grad_outputs_map.items(): + if pos == corresponding_pos: + has_corresponding_grad_output = True if has_corresponding_grad_output or ( name in forward_inplace_map and forward_api_name not in inplace_check_blacklist) or self.is_forward_only: @@ -1159,8 +1143,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): compute_require_grad_args_str = ",".join( compute_require_grad_args_list) - # 2. Get Output AutoGradMeta - if not self.is_forward_only: + # 2. Get Output AutoGradMeta outputs_autograd_meta_list = [] num_fwd_outputs = len(forward_outputs_position_map.keys()) @@ -1186,25 +1169,22 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): outputs_autograd_meta_list.append(output_autograd_meta) outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list) - # 3. Check Inplace - check_inplace_str = "" - bump_inplace_version_str = "" - if is_inplaced: - for inplace_name in forward_inplace_map.keys(): - if forward_api_name not in inplace_check_blacklist: - inplace_autograd_meta_name = GetAutoGradMetaName( - inplace_name) - check_inplace_str += CHECK_INPLACE_TEMPLATE.format( - inplace_name, inplace_autograd_meta_name) - bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format( - inplace_name, inplace_name) - - # Node Creation - if not self.is_forward_only: + # 3. Check Inplace + check_inplace_str = "" + bump_inplace_version_str = "" + if is_inplaced: + for inplace_name in forward_inplace_map.keys(): + if forward_api_name not in inplace_check_blacklist: + inplace_autograd_meta_name = GetAutoGradMetaName( + inplace_name) + check_inplace_str += CHECK_INPLACE_TEMPLATE.format( + inplace_name, inplace_autograd_meta_name) + bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format( + inplace_name, inplace_name) + + # Node Creation self.GenerateNodeCreationCodes() node_creation_str = self.node_creation_str - else: - node_creation_str = "" dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);\n" forward_function_name = GetDygraphForwardFunctionName(forward_api_name) @@ -1230,7 +1210,15 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): amp_autocast_list_str, amp_call_str) # Generate forward_definition_str and forward_declaration_str - if not self.is_forward_only: + if self.is_forward_only: + if len(amp_tensors_vector_list) == 0: + amp_logic_str = "" + self.forward_definition_str += FORWARD_ONLY_FUNCTION_TEMPLATE.format( + returns_type_str, forward_function_name, + inputs_args_definition_str, dygraph_event_str, amp_logic_str, + forward_function_name, forward_call_str, get_outputs_str, + returns_str) + else: self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format( returns_type_str, forward_function_name, inputs_args_definition_str, dygraph_event_str, amp_logic_str, @@ -1239,20 +1227,6 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): outputs_autograd_meta_str, compute_require_grad_args_str, check_inplace_str, bump_inplace_version_str, node_creation_str, returns_str) - else: - if (len(amp_tensors_vector_list) > 0) and (self.forward_api_name - not in no_amp_list): - self.forward_definition_str += FORWARD_ONLY_FUNCTION_TEMPLATE.format( - returns_type_str, forward_function_name, - inputs_args_definition_str, dygraph_event_str, - amp_logic_str, forward_function_name, forward_call_str, - get_outputs_str, returns_str) - else: - self.forward_definition_str += FORWARD_ONLY_FUNCTION_TEMPLATE.format( - returns_type_str, forward_function_name, - inputs_args_definition_str, dygraph_event_str, " ", - forward_function_name, forward_call_str, get_outputs_str, - returns_str) self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n" -- GitLab