未验证 提交 e31a0a50 编写于 作者: Z zyfncg 提交者: GitHub

refine eager_gen for amp (#45211)

上级 e51ea538
......@@ -428,20 +428,6 @@ CHECK_NAN_AND_INF_TEMPLATE = \
""" if (FLAGS_check_nan_inf) {{ egr::CheckTensorHasNanOrInf("{}", {}); }}
"""
# This list contains ops that do not need to generate amp logic
# All optimizer ops in this list
no_amp_list = [
'adam_', 'adam', 'adamw_', 'adamw', 'average_accumulates',
'average_accumulates_', 'decayed_adagrad_', 'decayed_adagrad',
'dgc_momentum_', 'dgc_momentum', 'distributed_fused_lamb_',
'distributed_fused_lamb', 'dpsgd_', 'dpsgd', 'ftrl_', 'ftrl', 'lamb_',
'lamb', 'lars_momentum_', 'lars_momentum', 'merged_adam_', 'merged_adam',
'merged_momentum_', 'merged_momentum', 'momentum_', 'momentum',
'proximal_adagrad_', 'proximal_adagrad', 'proximal_gd_', 'proximal_gd',
'rmsprop_', 'rmsprop', 'sgd_', 'sgd', 'lamb_', 'lamb', 'assign_value_',
'sparse_momentum_', 'sparse_momentum', 'full_'
]
inplace_optional_out_type_map = {
"Tensor":
"paddle::optional<paddle::experimental::Tensor>&",
......@@ -1126,15 +1112,13 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
returns_str = f"{returns_type_str}{{{returns_str}}}"
# Node Creation Pre-Processing
# 1. Get Input AutoGradMeta
if not self.is_forward_only:
# 1. Get Input AutoGradMeta
inputs_autograd_meta_list = []
compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos) in forward_inputs_position_map.items():
# Has corresponding grad output
has_corresponding_grad_output = False
if not self.is_forward_only:
for _, (_, corresponding_pos,
_) in backward_grad_outputs_map.items():
if pos == corresponding_pos:
......@@ -1160,7 +1144,6 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
compute_require_grad_args_list)
# 2. Get Output AutoGradMeta
if not self.is_forward_only:
outputs_autograd_meta_list = []
num_fwd_outputs = len(forward_outputs_position_map.keys())
......@@ -1200,11 +1183,8 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
inplace_name, inplace_name)
# Node Creation
if not self.is_forward_only:
self.GenerateNodeCreationCodes()
node_creation_str = self.node_creation_str
else:
node_creation_str = ""
dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);\n"
forward_function_name = GetDygraphForwardFunctionName(forward_api_name)
......@@ -1230,7 +1210,15 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
amp_autocast_list_str, amp_call_str)
# Generate forward_definition_str and forward_declaration_str
if not self.is_forward_only:
if self.is_forward_only:
if len(amp_tensors_vector_list) == 0:
amp_logic_str = ""
self.forward_definition_str += FORWARD_ONLY_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name,
inputs_args_definition_str, dygraph_event_str, amp_logic_str,
forward_function_name, forward_call_str, get_outputs_str,
returns_str)
else:
self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name,
inputs_args_definition_str, dygraph_event_str, amp_logic_str,
......@@ -1239,20 +1227,6 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
outputs_autograd_meta_str, compute_require_grad_args_str,
check_inplace_str, bump_inplace_version_str, node_creation_str,
returns_str)
else:
if (len(amp_tensors_vector_list) > 0) and (self.forward_api_name
not in no_amp_list):
self.forward_definition_str += FORWARD_ONLY_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name,
inputs_args_definition_str, dygraph_event_str,
amp_logic_str, forward_function_name, forward_call_str,
get_outputs_str, returns_str)
else:
self.forward_definition_str += FORWARD_ONLY_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name,
inputs_args_definition_str, dygraph_event_str, " ",
forward_function_name, forward_call_str, get_outputs_str,
returns_str)
self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册