未验证 提交 3983c724 编写于 作者: Z Zhanlue Yang 提交者: GitHub

[DoubleGrad PR #2] Adjusted logics of GenerateNodeCreationCodes and...

[DoubleGrad PR #2] Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition (#41016)

* [Refactor] refactored eager_gen.py PR #2

* [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes

* Fixed minor issue

* Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition

* Fixed issues

* Fixed minor issue
上级 93a2f565
...@@ -162,9 +162,24 @@ FUNCTION_TEMPLATE = \ ...@@ -162,9 +162,24 @@ FUNCTION_TEMPLATE = \
FORWARD_FUNCTION_TEMPLATE = \ FORWARD_FUNCTION_TEMPLATE = \
""" """
{} {}({}) {{ {} {}({}) {{
{} // Dygraph Record Event
{} {}
{} // AMP Logic
{}
// Get Input AutoGradMeta
{}
// Forward API Call
{}
// Get Output AutoGradMeta
{}
bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
// Check Inplace & Bump Inplace Version
{}
{}
// Node Creation
{}
// Returns // Returns
return {}; return {};
...@@ -174,18 +189,8 @@ FORWARD_FUNCTION_TEMPLATE = \ ...@@ -174,18 +189,8 @@ FORWARD_FUNCTION_TEMPLATE = \
FORWARD_BODY_TEMPLATE = \ FORWARD_BODY_TEMPLATE = \
""" """
// Get AutoGradMeta
{}
bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
{}
// Forward API Call
{}
{}
{{
{}
{}
if(require_any_grad) {{ if(require_any_grad) {{
{}
egr::EagerUtils::PassStopGradient({}); egr::EagerUtils::PassStopGradient({});
// Node Construction // Node Construction
...@@ -203,7 +208,6 @@ FORWARD_BODY_TEMPLATE = \ ...@@ -203,7 +208,6 @@ FORWARD_BODY_TEMPLATE = \
{} {}
{} {}
}} }}
}}
""" """
NAMESPACE_WRAPPER_TEMPLATE = \ NAMESPACE_WRAPPER_TEMPLATE = \
...@@ -294,7 +298,6 @@ CORE_OPS_DECLARATION_TEMPLATE = \ ...@@ -294,7 +298,6 @@ CORE_OPS_DECLARATION_TEMPLATE = \
CHECK_INPLACE_TEMPLATE = \ CHECK_INPLACE_TEMPLATE = \
""" """
// Check Inplace
egr::EagerUtils::CheckInplace({}, {}, require_any_grad);\n egr::EagerUtils::CheckInplace({}, {}, require_any_grad);\n
""" """
...@@ -625,7 +628,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -625,7 +628,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
f"Generated Backward Grad Output Map: {self.backward_grad_outputs_map}" f"Generated Backward Grad Output Map: {self.backward_grad_outputs_map}"
) )
def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced): def GenerateNodeCreationCodes(self):
forward_api_name = self.forward_api_name forward_api_name = self.forward_api_name
forward_inputs_position_map = self.forward_inputs_position_map forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map forward_outputs_position_map = self.forward_outputs_position_map
...@@ -635,67 +638,14 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -635,67 +638,14 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
backward_grad_outputs_map = self.backward_grad_outputs_map backward_grad_outputs_map = self.backward_grad_outputs_map
backward_attrs_list = self.backward_attrs_list backward_attrs_list = self.backward_attrs_list
optional_inputs = self.optional_inputs optional_inputs = self.optional_inputs
inplace_map = self.inplace_map if is_inplaced else {}
# Get Input AutoGradMeta # Pass Stop Gradient Args
inputs_autograd_meta_list = []
compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name)
if IsPlainTensorType(ttype):
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name)
inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
compute_require_grad_args_str = ",".join(compute_require_grad_args_list)
# Get Output AutoGradMeta
outputs_autograd_meta_list = []
pass_stop_gradient_args_list = ["false"] pass_stop_gradient_args_list = ["false"]
num_fwd_outputs = len(forward_outputs_position_map.keys()) for name, (_, _) in forward_outputs_position_map.items():
for name, (rtype, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name) output_autograd_meta_name = GetAutoGradMetaName(name)
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
if num_fwd_outputs == 1:
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else:
# Tuple api_result
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
outputs_autograd_meta_list.append(output_autograd_meta)
pass_stop_gradient_args_list.append(output_autograd_meta_name) pass_stop_gradient_args_list.append(output_autograd_meta_name)
# ComputeRequireGrad & PassStopGradient
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list) pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list)
# Check Inplace
check_inplace_str = ""
bump_inplace_version_str = ""
if is_inplaced:
for inplace_name in inplace_map.keys():
inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name)
check_inplace_str += CHECK_INPLACE_TEMPLATE.format(
inplace_name, inplace_autograd_meta_name)
bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format(
inplace_name, inplace_name)
# Node Construction # Node Construction
num_backward_inputs = len(forward_outputs_position_map.keys()) num_backward_inputs = len(forward_outputs_position_map.keys())
num_backward_outputs = len(forward_inputs_position_map.keys()) num_backward_outputs = len(forward_inputs_position_map.keys())
...@@ -719,6 +669,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -719,6 +669,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
# SetTensorWrappers # SetTensorWrappers
set_tensor_wrappers_list = [] set_tensor_wrappers_list = []
num_fwd_outputs = len(forward_outputs_position_map.keys())
for name, (atype, is_fwd_input, for name, (atype, is_fwd_input,
pos) in backward_forward_inputs_map.items(): pos) in backward_forward_inputs_map.items():
is_optional = (name in optional_inputs) is_optional = (name in optional_inputs)
...@@ -794,13 +745,10 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -794,13 +745,10 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
node_creation_event_str = f"paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n" node_creation_event_str = f"paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n"
self.node_creation_str = FORWARD_BODY_TEMPLATE.format( self.node_creation_str = FORWARD_BODY_TEMPLATE.format(
inputs_autograd_meta_str, compute_require_grad_args_str, node_creation_event_str, pass_stop_gradient_args_str,
check_inplace_str, forward_call_str, bump_inplace_version_str, node_construction_str, set_attributes_str, set_tensor_wrappers_str,
node_creation_event_str, outputs_autograd_meta_str, set_grad_out_meta_str, set_edges_str, set_out_rank_str,
pass_stop_gradient_args_str, node_construction_str, set_history_str, set_grad_in_meta_str, set_retain_grad_str)
set_attributes_str, set_tensor_wrappers_str, set_grad_out_meta_str,
set_edges_str, set_out_rank_str, set_history_str,
set_grad_in_meta_str, set_retain_grad_str)
def run(self): def run(self):
# Basic Validation Check # Basic Validation Check
...@@ -973,7 +921,64 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -973,7 +921,64 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
returns_str = ", ".join(returns_list) returns_str = ", ".join(returns_list)
returns_str = f"std::make_tuple({returns_str})" returns_str = f"std::make_tuple({returns_str})"
self.GenerateNodeCreationCodes(forward_call_str, is_inplaced) # Node Creation Pre-Processing
# 1. Get Input AutoGradMeta
inputs_autograd_meta_list = []
compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name)
if IsPlainTensorType(ttype):
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name)
inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
compute_require_grad_args_str = ",".join(compute_require_grad_args_list)
# 2. Get Output AutoGradMeta
outputs_autograd_meta_list = []
num_fwd_outputs = len(forward_outputs_position_map.keys())
for name, (rtype, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
if num_fwd_outputs == 1:
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else:
# Tuple api_result
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
outputs_autograd_meta_list.append(output_autograd_meta)
# 3. ComputeRequireGrad & PassStopGradient
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
# 4. Check Inplace
check_inplace_str = ""
bump_inplace_version_str = ""
if is_inplaced:
for inplace_name in inplace_map.keys():
inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name)
check_inplace_str += CHECK_INPLACE_TEMPLATE.format(
inplace_name, inplace_autograd_meta_name)
bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format(
inplace_name, inplace_name)
self.GenerateNodeCreationCodes()
node_creation_str = self.node_creation_str node_creation_str = self.node_creation_str
dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);" dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);"
...@@ -1001,7 +1006,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1001,7 +1006,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format( self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name, inputs_args_definition_str, returns_type_str, forward_function_name, inputs_args_definition_str,
dygraph_event_str, amp_logic_str, node_creation_str, returns_str) dygraph_event_str, amp_logic_str, inputs_autograd_meta_str,
forward_call_str, outputs_autograd_meta_str,
compute_require_grad_args_str, check_inplace_str,
bump_inplace_version_str, node_creation_str, returns_str)
self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n" self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n"
logging.info( logging.info(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册