未验证 提交 86554d91 编写于 作者: Z Zhanlue Yang 提交者: GitHub

[DoubleGrad PR #1] Decoupled code generation logics for Dygraph...

[DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes (#40937)

* [Refactor] refactored eager_gen.py PR #2

* [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes

* Fixed minor issue
上级 0c024cb9
......@@ -311,7 +311,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
dygraph_function_call_list[pos] = f"{name}"
dygraph_function_call_str = ",".join(dygraph_function_call_list)
# Generate Python-C Function Definitions
# Generate Python-C Function Definitions
if is_forward_only:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"paddle::experimental::", namespace, forward_api_name)
......@@ -337,21 +337,21 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
forward_api_name_prefix, forward_api_name, namespace,
forward_api_name, forward_api_name)
if len(inplace_map) > 0:
assert len(
inplace_map
) == 1, f"size of inplace_map must be 1, but inplace_map of \"{forward_api_name}\" op got {len(inplace_map)}"
if inplace_map:
inplaced_forward_api_name = GetInplacedFunctionName(
self.forward_api_name)
# Generate Python-C Function Definitions
if is_forward_only:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
inplaced_fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"paddle::experimental::", namespace,
inplaced_forward_api_name)
elif len(inplace_map) > 0:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
else:
inplaced_fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"::", namespace,
GetForwardFunctionName(inplaced_forward_api_name))
assert len(
inplace_map
) == 1, f"size of inplace_map must be 1, but inplace_map of \"{forward_api_name}\" op got {len(inplace_map)}"
for inplace_input, inplace_output in inplace_map.items():
return_str = RETURN_INPLACE_PYOBJECT_TEMPLATE.format(
inplaced_forward_api_name, inplace_input,
......@@ -361,7 +361,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
self.python_c_function_str += PYTHON_C_FUNCTION_TEMPLATE.format(
inplaced_forward_api_name, pythonc_record_event_str,
inplaced_forward_api_name, get_eager_tensor_str,
parse_attributes_str, fwd_function_name,
parse_attributes_str, inplaced_fwd_function_name,
dygraph_function_call_str, return_str)
# Generate Python-C Function Registration
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册