diff --git a/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py b/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py index bf10763e2ff5883ca3b574205084e940182b2d7c..9caa759a251fa2c9bccff3df137213a3b3f27519 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py +++ b/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py @@ -411,11 +411,14 @@ def ParseYamlCompositeInfo(string): pattern = fr'{fname}{wspace}\({wspace}{fargs}{wspace}\)' m = re.search(pattern, string) - composite_fun_info = [] - composite_fun_info.append(m.group(1)) + composite_fun_info = {} + composite_fun_info.update({"name": m.group(1)}) func_args = m.group(2).split(",") for fun_arg in func_args: - composite_fun_info.append(fun_arg.strip()) + if "args" in composite_fun_info: + composite_fun_info["args"].append(fun_arg.strip()) + else: + composite_fun_info.update({"args": [fun_arg.strip()]}) return composite_fun_info @@ -455,7 +458,9 @@ class FunctionGeneratorBase: # Special Op Attributes self.optional_inputs = [] # [name, ...] self.no_need_buffers = [] # [name, ...] - self.composite_func_info = [] # [func_name, input_name, ...] + self.composite_func_info = ( + {} + ) # {name: func_name, args: [input_name, ...]} self.intermediate_outputs = [] # [name, ...] self.forward_inplace_map = {} # {name : name, ...} diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index e03529bbe0970a8dd81e61c7197fdb450fa6ef22..414270a54456eebc8af4c37d18cc9f9132e8f298 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -876,7 +876,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): backward_attrs_list = self.backward_attrs_list optional_inputs = self.optional_inputs is_composite_grad_api = ( - False if self.composite_func_info == [] else True + False if self.composite_func_info == {} else True ) # Pass Stop Gradient Args @@ -1836,7 +1836,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): self.grad_api_contents, self.forward_apis_dict ) is_composite_grad_api = ( - False if self.composite_func_info == [] else True + False if self.composite_func_info == {} else True ) if next_node_generator is not None: @@ -1970,7 +1970,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): forward_api_name = self.forward_api_name backward_api_name = self.backward_api_name composite_grad_api_name = ( - self.composite_func_info[0] if is_composite_grad_api else None + self.composite_func_info["name"] if is_composite_grad_api else None ) backward_forward_inputs_map = self.backward_forward_inputs_map backward_grad_inputs_map = self.backward_grad_inputs_map @@ -2257,10 +2257,10 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): grad_function_call_str = f""" if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{ {indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str}); - VLOG(4) << paddle::string::Sprintf("composite api %s is called" , "{composite_grad_api_name}"); + VLOG(4) << "Composite api {composite_grad_api_name} is called "; }}else{{ {indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str}); - VLOG(4) << paddle::string::Sprintf("origin api %s is called" , "{backward_api_name}"); + VLOG(4) << "Fused api {backward_api_name} is called "; }} """ else: diff --git a/paddle/fluid/operators/generator/CMakeLists.txt b/paddle/fluid/operators/generator/CMakeLists.txt index 759a0fbaae2f58f74f22bec6d713840cf6004ca8..62c11faadaf209e0d2be499eff64aa174aa87d77 100644 --- a/paddle/fluid/operators/generator/CMakeLists.txt +++ b/paddle/fluid/operators/generator/CMakeLists.txt @@ -156,13 +156,6 @@ set(generated_static_files "${generated_static_argument_mapping_path}" "${generated_sparse_argument_mapping_path}") -set(generated_static_files - "${generated_op_path}" - "${generated_static_op_path}" - "${generated_sparse_ops_path}" - "${generated_argument_mapping_path}" - "${generated_static_argument_mapping_path}" - "${generated_sparse_argument_mapping_path}") foreach(generated_static_file ${generated_static_files}) if(EXISTS "${generated_static_file}.tmp" AND EXISTS "${generated_static_file}")