diff --git a/paddle/phi/api/yaml/api_compat.yaml b/paddle/phi/api/yaml/api_compat.yaml index 912d63a776b43ec04660b91617fca781ca3ea6ca..a4a5eea778a8e53d0d1709a466abc86a5661a11d 100644 --- a/paddle/phi/api/yaml/api_compat.yaml +++ b/paddle/phi/api/yaml/api_compat.yaml @@ -123,9 +123,8 @@ str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f, int workspace_size_MB = platform::GetDefaultConvWorkspaceSizeLimitMB()] -- api : diag - op_name : diag_v2 - grad_op_name : diag_v2_grad +- api : diag (diag_v2) + backward : diag_grad (diag_v2_grad) inputs : x : X outputs : diff --git a/paddle/phi/api/yaml/generator/generate_op.py b/paddle/phi/api/yaml/generator/generate_op.py index ac43db18e57c1825b55775f82b4c5bd8d3d8bb90..24f30323a935b2db64dd0e22f111f36f4ac6e9ce 100644 --- a/paddle/phi/api/yaml/generator/generate_op.py +++ b/paddle/phi/api/yaml/generator/generate_op.py @@ -56,18 +56,29 @@ def restruct_io(api): # replace name of op and params for OpMaker def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict): + + def get_api_and_op_name(api_item): + names = api_item.split('(') + if len(names) == 1: + return names[0].strip(), names[0].strip() + else: + return names[0].strip(), names[1].split(')')[0].strip() + for api_args in api_op_map: - if api_args['api'] not in forward_api_dict: + api_name, op_name = get_api_and_op_name(api_args['api']) + if api_name not in forward_api_dict: continue - forward_api_item = forward_api_dict[api_args['api']] + forward_api_item = forward_api_dict[api_name] has_backward = True if forward_api_item['backward'] else False if has_backward: backward_api_item = backward_api_dict[forward_api_item['backward']] - if 'op_name' in api_args: - forward_api_item['op_name'] = api_args['op_name'] - if 'grad_op_name' in api_args and has_backward: - forward_api_item['backward'] = api_args['grad_op_name'] - backward_api_item['op_name'] = api_args['grad_op_name'] + if api_name != op_name: + forward_api_item['op_name'] = op_name + if 'backward' in api_args and has_backward: + bw_api_name, bw_op_name = get_api_and_op_name( + api_args['backward'].split(',')[0]) + forward_api_item['backward'] = bw_op_name + backward_api_item['op_name'] = bw_op_name key_set = ['inputs', 'attrs', 'outputs'] args_map = {} diff --git a/paddle/phi/api/yaml/generator/ops_extra_info_gen.py b/paddle/phi/api/yaml/generator/ops_extra_info_gen.py index 675b889e4450cf14fa4eea736e390fe7ca688756..d7ece0d2a4563d8e45759445243caa9fdbffadd6 100644 --- a/paddle/phi/api/yaml/generator/ops_extra_info_gen.py +++ b/paddle/phi/api/yaml/generator/ops_extra_info_gen.py @@ -70,6 +70,13 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path): with open(api_compat_yaml_path, 'rt') as f: compat_apis = yaml.safe_load(f) + def get_op_name(api_item): + names = api_item.split('(') + if len(names) == 1: + return names[0].strip() + else: + return names[1].split(')')[0].strip() + extra_map_str_list = [] extra_checker_str_list = [] @@ -96,18 +103,19 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path): api_extra_attr_checkers = ",\n ".join( attr_checker_func_list) extra_map_str_list.append( - f"{{\"{api_compat_args['api']}\", {{ {api_extra_attr_map} }}}}" + f"{{\"{get_op_name(api_compat_args['api'])}\", {{ {api_extra_attr_map} }}}}" ) extra_checker_str_list.append( - f"{{\"{api_compat_args['api']}\", {{ {api_extra_attr_checkers} }}}}" + f"{{\"{get_op_name(api_compat_args['api'])}\", {{ {api_extra_attr_checkers} }}}}" ) if 'backward' in api_compat_args: - extra_map_str_list.append( - f"{{\"{api_compat_args['backward']}\", {{ {api_extra_attr_map} }}}}" - ) - extra_checker_str_list.append( - f"{{\"{api_compat_args['backward']}\", {{ {api_extra_attr_checkers} }}}}" - ) + for bw_item in api_compat_args['backward'].split(','): + bw_op_name = get_op_name(bw_item) + extra_map_str_list.append( + f"{{\"{bw_op_name}\", {{ {api_extra_attr_map} }}}}") + extra_checker_str_list.append( + f"{{\"{bw_op_name}\", {{ {api_extra_attr_checkers} }}}}" + ) ops_extra_info_file = open(ops_extra_info_path, 'w') ops_extra_info_file.write(