From 1840349ac4abaa2fd00b33688578e92b71f67096 Mon Sep 17 00:00:00 2001 From: huzhiqiang <912790387@qq.com> Date: Wed, 30 Mar 2022 07:39:46 +0800 Subject: [PATCH] [Infrt] add skip method for inferShape codegen (#41014) --- tools/infrt/generate_phi_kernel_dialect.py | 28 ++++++++++++++++++++++ tools/infrt/get_phi_kernel_info.py | 26 ++++++++++++++++++-- tools/infrt/skipped_phi_api.json | 4 ++++ 3 files changed, 56 insertions(+), 2 deletions(-) create mode 100644 tools/infrt/skipped_phi_api.json diff --git a/tools/infrt/generate_phi_kernel_dialect.py b/tools/infrt/generate_phi_kernel_dialect.py index 4ac8a2e127..0b67c6ba44 100644 --- a/tools/infrt/generate_phi_kernel_dialect.py +++ b/tools/infrt/generate_phi_kernel_dialect.py @@ -13,6 +13,7 @@ # limitations under the License. import json +import yaml import sys import os from get_compat_kernel_signature import get_compat_kernels_info @@ -52,6 +53,28 @@ precision_type_converter = { kernel_types_info_file = "./kernels.json" kernel_signature_info_file = "./kernel_signature.json" +skipped_phi_api_list_file = "./skipped_phi_api.json" + + +def get_skipped_kernel_list(): + skiped_kernel_list = [] + with open(skipped_phi_api_list_file, 'r') as f: + skiped_api_list = json.load(f) + infer_meta_data = get_api_yaml_info("../../") + for api in infer_meta_data: + if "kernel" not in api or "infer_meta" not in api: + continue + if api["api"] in skiped_api_list["phi_apis"]: + skiped_kernel_list.append(api["kernel"]["func"]) + skiped_kernel_list += skiped_api_list["phi_kernels"] + return skiped_kernel_list + + +def get_api_yaml_info(file_path): + f = open(file_path + "/python/paddle/utils/code_gen/api.yaml", "r") + cont = f.read() + return yaml.load(cont, Loader=yaml.FullLoader) + def generate_kernel_name(op_name, place_str): [target_, layout_, precision_] = place_str[1:-1].split(',') @@ -140,6 +163,10 @@ def generate_supported_kernel_list(load_dict): if flag and op_name in kernel_attrs_names: supported_kernels_list_.append(op_name) supported_kernels_list_ = list(set(supported_kernels_list_)) + skipped_kernel_list = get_skipped_kernel_list() + for skipped_kernel in skipped_kernel_list: + if skipped_kernel in skipped_kernel_list: + supported_kernels_list_.remove(skipped_kernel) return supported_kernels_list_ @@ -250,6 +277,7 @@ def main(): cpu_registry_ = "" gpu_registry_ = "" supported_kernels = generate_supported_kernel_list(load_dict) + print("Supported kernels:") print(supported_kernels) for op_name in load_dict: diff --git a/tools/infrt/get_phi_kernel_info.py b/tools/infrt/get_phi_kernel_info.py index 2d428adb1d..3fb40706e2 100644 --- a/tools/infrt/get_phi_kernel_info.py +++ b/tools/infrt/get_phi_kernel_info.py @@ -19,6 +19,23 @@ import json import yaml from typing import List, Dict, Any +skipped_phi_api_list_file = "/tools/infrt/skipped_phi_api.json" +api_yaml_file = "/python/paddle/utils/code_gen/api.yaml" + + +def get_skipped_kernel_list(): + skiped_kernel_list = [] + with open(skipped_phi_api_list_file, 'r') as f: + skiped_api_list = json.load(f) + infer_meta_data = get_api_yaml_info(api_yaml_file) + for api in infer_meta_data: + if "kernel" not in api or "infer_meta" not in api: + continue + if api["api"] in skiped_api_list["phi_apis"]: + skiped_kernel_list.append(api["kernel"]["func"]) + skiped_kernel_list += skiped_api_list["phi_kernels"] + return skiped_kernel_list + def parse_args(): parser = argparse.ArgumentParser("gather phi kernel and infermate info") @@ -50,7 +67,7 @@ def parse_args(): def get_api_yaml_info(file_path): - f = open(file_path + "/python/paddle/utils/code_gen/api.yaml", "r") + f = open(file_path, "r") cont = f.read() return yaml.load(cont, Loader=yaml.FullLoader) @@ -259,8 +276,11 @@ def gen_register_code_info(item: List[str], attr_data: Dict[str, List[str]]): # TODO(wilber): handle the unknown inferShape func. return "" + skipped_kernel_list = get_skipped_kernel_list() for ir_dtype, origin_dtype in zip(ir_dtypes, origin_dtypes): kernel_func = gen_kernel_func(item[3], ctx_name, origin_dtype) + if item[0].lower() in skipped_kernel_list: + continue ir_name = ir_ctx_name + '.' + item[0].lower( ) + '.' + ir_dtype + '.' + item[2].lower() if ir_name in attr_data.keys() and attr_data[ir_name] is not None: @@ -342,7 +362,9 @@ def gen_phi_kernel_register_code(resources: List[List[str]], if __name__ == "__main__": args = parse_args() - infer_meta_data = get_api_yaml_info(args.paddle_root_path) + skipped_phi_api_list_file = args.paddle_root_path + skipped_phi_api_list_file + api_yaml_file = args.paddle_root_path + api_yaml_file + infer_meta_data = get_api_yaml_info(api_yaml_file) kernel_data = get_kernel_info(args.kernel_info_file) info_meta_wrap_data = get_infermeta_info(args.infermeta_wrap_file) attr_data = get_attr_info(args.attr_info_file) diff --git a/tools/infrt/skipped_phi_api.json b/tools/infrt/skipped_phi_api.json new file mode 100644 index 0000000000..7e03e01d0f --- /dev/null +++ b/tools/infrt/skipped_phi_api.json @@ -0,0 +1,4 @@ +{ +"phi_apis":["conj"], +"phi_kernels":["equal_all"] +} -- GitLab