From f98828540cf17d66067ddf1fdf8c7eeb4fa4ec6d Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 2 Sep 2022 15:16:42 +0800 Subject: [PATCH] Adjust the rule of configure in api_compat.yaml (#45672) * set use_cudnn=true for conv2d * refine the config rule of api_compat --- paddle/phi/api/yaml/api_compat.yaml | 5 ++-- paddle/phi/api/yaml/generator/generate_op.py | 25 +++++++++++++------ .../api/yaml/generator/ops_extra_info_gen.py | 24 ++++++++++++------ 3 files changed, 36 insertions(+), 18 deletions(-) diff --git a/paddle/phi/api/yaml/api_compat.yaml b/paddle/phi/api/yaml/api_compat.yaml index 912d63a776..a4a5eea778 100644 --- a/paddle/phi/api/yaml/api_compat.yaml +++ b/paddle/phi/api/yaml/api_compat.yaml @@ -123,9 +123,8 @@ str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f, int workspace_size_MB = platform::GetDefaultConvWorkspaceSizeLimitMB()] -- api : diag - op_name : diag_v2 - grad_op_name : diag_v2_grad +- api : diag (diag_v2) + backward : diag_grad (diag_v2_grad) inputs : x : X outputs : diff --git a/paddle/phi/api/yaml/generator/generate_op.py b/paddle/phi/api/yaml/generator/generate_op.py index ac43db18e5..24f30323a9 100644 --- a/paddle/phi/api/yaml/generator/generate_op.py +++ b/paddle/phi/api/yaml/generator/generate_op.py @@ -56,18 +56,29 @@ def restruct_io(api): # replace name of op and params for OpMaker def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict): + + def get_api_and_op_name(api_item): + names = api_item.split('(') + if len(names) == 1: + return names[0].strip(), names[0].strip() + else: + return names[0].strip(), names[1].split(')')[0].strip() + for api_args in api_op_map: - if api_args['api'] not in forward_api_dict: + api_name, op_name = get_api_and_op_name(api_args['api']) + if api_name not in forward_api_dict: continue - forward_api_item = forward_api_dict[api_args['api']] + forward_api_item = forward_api_dict[api_name] has_backward = True if forward_api_item['backward'] else False if has_backward: backward_api_item = backward_api_dict[forward_api_item['backward']] - if 'op_name' in api_args: - forward_api_item['op_name'] = api_args['op_name'] - if 'grad_op_name' in api_args and has_backward: - forward_api_item['backward'] = api_args['grad_op_name'] - backward_api_item['op_name'] = api_args['grad_op_name'] + if api_name != op_name: + forward_api_item['op_name'] = op_name + if 'backward' in api_args and has_backward: + bw_api_name, bw_op_name = get_api_and_op_name( + api_args['backward'].split(',')[0]) + forward_api_item['backward'] = bw_op_name + backward_api_item['op_name'] = bw_op_name key_set = ['inputs', 'attrs', 'outputs'] args_map = {} diff --git a/paddle/phi/api/yaml/generator/ops_extra_info_gen.py b/paddle/phi/api/yaml/generator/ops_extra_info_gen.py index 675b889e44..d7ece0d2a4 100644 --- a/paddle/phi/api/yaml/generator/ops_extra_info_gen.py +++ b/paddle/phi/api/yaml/generator/ops_extra_info_gen.py @@ -70,6 +70,13 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path): with open(api_compat_yaml_path, 'rt') as f: compat_apis = yaml.safe_load(f) + def get_op_name(api_item): + names = api_item.split('(') + if len(names) == 1: + return names[0].strip() + else: + return names[1].split(')')[0].strip() + extra_map_str_list = [] extra_checker_str_list = [] @@ -96,18 +103,19 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path): api_extra_attr_checkers = ",\n ".join( attr_checker_func_list) extra_map_str_list.append( - f"{{\"{api_compat_args['api']}\", {{ {api_extra_attr_map} }}}}" + f"{{\"{get_op_name(api_compat_args['api'])}\", {{ {api_extra_attr_map} }}}}" ) extra_checker_str_list.append( - f"{{\"{api_compat_args['api']}\", {{ {api_extra_attr_checkers} }}}}" + f"{{\"{get_op_name(api_compat_args['api'])}\", {{ {api_extra_attr_checkers} }}}}" ) if 'backward' in api_compat_args: - extra_map_str_list.append( - f"{{\"{api_compat_args['backward']}\", {{ {api_extra_attr_map} }}}}" - ) - extra_checker_str_list.append( - f"{{\"{api_compat_args['backward']}\", {{ {api_extra_attr_checkers} }}}}" - ) + for bw_item in api_compat_args['backward'].split(','): + bw_op_name = get_op_name(bw_item) + extra_map_str_list.append( + f"{{\"{bw_op_name}\", {{ {api_extra_attr_map} }}}}") + extra_checker_str_list.append( + f"{{\"{bw_op_name}\", {{ {api_extra_attr_checkers} }}}}" + ) ops_extra_info_file = open(ops_extra_info_path, 'w') ops_extra_info_file.write( -- GitLab