未验证 提交 a227ae2b 编写于 作者: J Jiabin Yang 提交者: GitHub

Fix prim paddle c make (#49643)

* proto type of composite grad in paddle

* proto type of composite grad in paddle

* refactor composite api with phi

* fix compile error

* support static graph code-gen for squeeze op

* generate static graph code of unsqueeze

* refine op name

* fix compile error

* add extra output in op_compat

* remove debug log

* fix clang compile error

* support prim switch flag

* support prim switch flag

* fix dygraph error

* merge develop

* add code_gen

* add necessary files without codegen

* fix code_gen bug

* add deps

* modify igmnore

* add ignore

* delete std cout

* add composite logic for backward.py

* add tanh first order grad composite

* support enable_prim flag for static graph

* throw expection when both GrapOpMaker and GradCompOpMaker not been registered

* reorganize the directory of prim api tests

* fix windows error

* add eager_utils

* add eager_utils

* modify code gen

* add composite parse

* add unittest for get_grad_op_desc

* code optimize

* fix static test on windows

* support generate static graph code for imag and real op

* fix windows compile error in test_static_prim

* merge develop

* disable test eager in inference

* prim code gen

* disable eager compile in inference

* rm other file

* rm gitignore file

* code_style

* add eager test

* code_style

* merge develop

* remove useless files

* modify static test

* support bool flag from singlton

* merge develop

* recover git ignore

* fix conflict

* recover git ignore for generated op

* fix test compile error

* remove some tests

* add python test

* fix some name issue

* add composite code gen

* modify backward yaml

* fix static composite grad maker code gen

* remove addtional files

* add some static funcs unit test

* fix some bugs

* fix composite grad maker register code gen

* optimize some functions

* remove duplicated cmake

* fix cmake and codegen problem
Co-authored-by: Nzyfncg <zhangyunfei07@baidu.com>
Co-authored-by: Nwangruting <wangruting@baidu.com>
Co-authored-by: Ncxxly <chenxx_id@163.com>
Co-authored-by: Ncharles-hit <wanghao107@baidu.com>
Co-authored-by: Nxiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com>
上级 39210ed0
...@@ -411,11 +411,14 @@ def ParseYamlCompositeInfo(string): ...@@ -411,11 +411,14 @@ def ParseYamlCompositeInfo(string):
pattern = fr'{fname}{wspace}\({wspace}{fargs}{wspace}\)' pattern = fr'{fname}{wspace}\({wspace}{fargs}{wspace}\)'
m = re.search(pattern, string) m = re.search(pattern, string)
composite_fun_info = [] composite_fun_info = {}
composite_fun_info.append(m.group(1)) composite_fun_info.update({"name": m.group(1)})
func_args = m.group(2).split(",") func_args = m.group(2).split(",")
for fun_arg in func_args: for fun_arg in func_args:
composite_fun_info.append(fun_arg.strip()) if "args" in composite_fun_info:
composite_fun_info["args"].append(fun_arg.strip())
else:
composite_fun_info.update({"args": [fun_arg.strip()]})
return composite_fun_info return composite_fun_info
...@@ -455,7 +458,9 @@ class FunctionGeneratorBase: ...@@ -455,7 +458,9 @@ class FunctionGeneratorBase:
# Special Op Attributes # Special Op Attributes
self.optional_inputs = [] # [name, ...] self.optional_inputs = [] # [name, ...]
self.no_need_buffers = [] # [name, ...] self.no_need_buffers = [] # [name, ...]
self.composite_func_info = [] # [func_name, input_name, ...] self.composite_func_info = (
{}
) # {name: func_name, args: [input_name, ...]}
self.intermediate_outputs = [] # [name, ...] self.intermediate_outputs = [] # [name, ...]
self.forward_inplace_map = {} # {name : name, ...} self.forward_inplace_map = {} # {name : name, ...}
......
...@@ -876,7 +876,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -876,7 +876,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
backward_attrs_list = self.backward_attrs_list backward_attrs_list = self.backward_attrs_list
optional_inputs = self.optional_inputs optional_inputs = self.optional_inputs
is_composite_grad_api = ( is_composite_grad_api = (
False if self.composite_func_info == [] else True False if self.composite_func_info == {} else True
) )
# Pass Stop Gradient Args # Pass Stop Gradient Args
...@@ -1836,7 +1836,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1836,7 +1836,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
self.grad_api_contents, self.forward_apis_dict self.grad_api_contents, self.forward_apis_dict
) )
is_composite_grad_api = ( is_composite_grad_api = (
False if self.composite_func_info == [] else True False if self.composite_func_info == {} else True
) )
if next_node_generator is not None: if next_node_generator is not None:
...@@ -1970,7 +1970,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1970,7 +1970,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
forward_api_name = self.forward_api_name forward_api_name = self.forward_api_name
backward_api_name = self.backward_api_name backward_api_name = self.backward_api_name
composite_grad_api_name = ( composite_grad_api_name = (
self.composite_func_info[0] if is_composite_grad_api else None self.composite_func_info["name"] if is_composite_grad_api else None
) )
backward_forward_inputs_map = self.backward_forward_inputs_map backward_forward_inputs_map = self.backward_forward_inputs_map
backward_grad_inputs_map = self.backward_grad_inputs_map backward_grad_inputs_map = self.backward_grad_inputs_map
...@@ -2257,10 +2257,10 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -2257,10 +2257,10 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
grad_function_call_str = f""" grad_function_call_str = f"""
if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{ if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{
{indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str}); {indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str});
VLOG(4) << paddle::string::Sprintf("composite api %s is called" , "{composite_grad_api_name}"); VLOG(4) << "Composite api {composite_grad_api_name} is called ";
}}else{{ }}else{{
{indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str}); {indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str});
VLOG(4) << paddle::string::Sprintf("origin api %s is called" , "{backward_api_name}"); VLOG(4) << "Fused api {backward_api_name} is called ";
}} }}
""" """
else: else:
......
...@@ -156,13 +156,6 @@ set(generated_static_files ...@@ -156,13 +156,6 @@ set(generated_static_files
"${generated_static_argument_mapping_path}" "${generated_static_argument_mapping_path}"
"${generated_sparse_argument_mapping_path}") "${generated_sparse_argument_mapping_path}")
set(generated_static_files
"${generated_op_path}"
"${generated_static_op_path}"
"${generated_sparse_ops_path}"
"${generated_argument_mapping_path}"
"${generated_static_argument_mapping_path}"
"${generated_sparse_argument_mapping_path}")
foreach(generated_static_file ${generated_static_files}) foreach(generated_static_file ${generated_static_files})
if(EXISTS "${generated_static_file}.tmp" AND EXISTS if(EXISTS "${generated_static_file}.tmp" AND EXISTS
"${generated_static_file}") "${generated_static_file}")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册