未验证 提交 86554d91 编写于 作者: Z Zhanlue Yang 提交者: GitHub

[DoubleGrad PR #1] Decoupled code generation logics for Dygraph...

[DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes (#40937)

* [Refactor] refactored eager_gen.py PR #2

* [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes

* Fixed minor issue
上级 0c024cb9
...@@ -311,7 +311,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -311,7 +311,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
dygraph_function_call_list[pos] = f"{name}" dygraph_function_call_list[pos] = f"{name}"
dygraph_function_call_str = ",".join(dygraph_function_call_list) dygraph_function_call_str = ",".join(dygraph_function_call_list)
# Generate Python-C Function Definitions # Generate Python-C Function Definitions
if is_forward_only: if is_forward_only:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format( fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"paddle::experimental::", namespace, forward_api_name) "paddle::experimental::", namespace, forward_api_name)
...@@ -337,21 +337,21 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -337,21 +337,21 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
forward_api_name_prefix, forward_api_name, namespace, forward_api_name_prefix, forward_api_name, namespace,
forward_api_name, forward_api_name) forward_api_name, forward_api_name)
if len(inplace_map) > 0: if inplace_map:
assert len(
inplace_map
) == 1, f"size of inplace_map must be 1, but inplace_map of \"{forward_api_name}\" op got {len(inplace_map)}"
inplaced_forward_api_name = GetInplacedFunctionName( inplaced_forward_api_name = GetInplacedFunctionName(
self.forward_api_name) self.forward_api_name)
# Generate Python-C Function Definitions
if is_forward_only: if is_forward_only:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format( inplaced_fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"paddle::experimental::", namespace, "paddle::experimental::", namespace,
inplaced_forward_api_name) inplaced_forward_api_name)
elif len(inplace_map) > 0: else:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format( inplaced_fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"::", namespace, "::", namespace,
GetForwardFunctionName(inplaced_forward_api_name)) GetForwardFunctionName(inplaced_forward_api_name))
assert len(
inplace_map
) == 1, f"size of inplace_map must be 1, but inplace_map of \"{forward_api_name}\" op got {len(inplace_map)}"
for inplace_input, inplace_output in inplace_map.items(): for inplace_input, inplace_output in inplace_map.items():
return_str = RETURN_INPLACE_PYOBJECT_TEMPLATE.format( return_str = RETURN_INPLACE_PYOBJECT_TEMPLATE.format(
inplaced_forward_api_name, inplace_input, inplaced_forward_api_name, inplace_input,
...@@ -361,7 +361,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -361,7 +361,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
self.python_c_function_str += PYTHON_C_FUNCTION_TEMPLATE.format( self.python_c_function_str += PYTHON_C_FUNCTION_TEMPLATE.format(
inplaced_forward_api_name, pythonc_record_event_str, inplaced_forward_api_name, pythonc_record_event_str,
inplaced_forward_api_name, get_eager_tensor_str, inplaced_forward_api_name, get_eager_tensor_str,
parse_attributes_str, fwd_function_name, parse_attributes_str, inplaced_fwd_function_name,
dygraph_function_call_str, return_str) dygraph_function_call_str, return_str)
# Generate Python-C Function Registration # Generate Python-C Function Registration
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册