未验证 提交 1840349a 编写于 作者: H huzhiqiang 提交者: GitHub

[Infrt] add skip method for inferShape codegen (#41014)

上级 cc52501e
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import json import json
import yaml
import sys import sys
import os import os
from get_compat_kernel_signature import get_compat_kernels_info from get_compat_kernel_signature import get_compat_kernels_info
...@@ -52,6 +53,28 @@ precision_type_converter = { ...@@ -52,6 +53,28 @@ precision_type_converter = {
kernel_types_info_file = "./kernels.json" kernel_types_info_file = "./kernels.json"
kernel_signature_info_file = "./kernel_signature.json" kernel_signature_info_file = "./kernel_signature.json"
skipped_phi_api_list_file = "./skipped_phi_api.json"
def get_skipped_kernel_list():
skiped_kernel_list = []
with open(skipped_phi_api_list_file, 'r') as f:
skiped_api_list = json.load(f)
infer_meta_data = get_api_yaml_info("../../")
for api in infer_meta_data:
if "kernel" not in api or "infer_meta" not in api:
continue
if api["api"] in skiped_api_list["phi_apis"]:
skiped_kernel_list.append(api["kernel"]["func"])
skiped_kernel_list += skiped_api_list["phi_kernels"]
return skiped_kernel_list
def get_api_yaml_info(file_path):
f = open(file_path + "/python/paddle/utils/code_gen/api.yaml", "r")
cont = f.read()
return yaml.load(cont, Loader=yaml.FullLoader)
def generate_kernel_name(op_name, place_str): def generate_kernel_name(op_name, place_str):
[target_, layout_, precision_] = place_str[1:-1].split(',') [target_, layout_, precision_] = place_str[1:-1].split(',')
...@@ -140,6 +163,10 @@ def generate_supported_kernel_list(load_dict): ...@@ -140,6 +163,10 @@ def generate_supported_kernel_list(load_dict):
if flag and op_name in kernel_attrs_names: if flag and op_name in kernel_attrs_names:
supported_kernels_list_.append(op_name) supported_kernels_list_.append(op_name)
supported_kernels_list_ = list(set(supported_kernels_list_)) supported_kernels_list_ = list(set(supported_kernels_list_))
skipped_kernel_list = get_skipped_kernel_list()
for skipped_kernel in skipped_kernel_list:
if skipped_kernel in skipped_kernel_list:
supported_kernels_list_.remove(skipped_kernel)
return supported_kernels_list_ return supported_kernels_list_
...@@ -250,6 +277,7 @@ def main(): ...@@ -250,6 +277,7 @@ def main():
cpu_registry_ = "" cpu_registry_ = ""
gpu_registry_ = "" gpu_registry_ = ""
supported_kernels = generate_supported_kernel_list(load_dict) supported_kernels = generate_supported_kernel_list(load_dict)
print("Supported kernels:") print("Supported kernels:")
print(supported_kernels) print(supported_kernels)
for op_name in load_dict: for op_name in load_dict:
......
...@@ -19,6 +19,23 @@ import json ...@@ -19,6 +19,23 @@ import json
import yaml import yaml
from typing import List, Dict, Any from typing import List, Dict, Any
skipped_phi_api_list_file = "/tools/infrt/skipped_phi_api.json"
api_yaml_file = "/python/paddle/utils/code_gen/api.yaml"
def get_skipped_kernel_list():
skiped_kernel_list = []
with open(skipped_phi_api_list_file, 'r') as f:
skiped_api_list = json.load(f)
infer_meta_data = get_api_yaml_info(api_yaml_file)
for api in infer_meta_data:
if "kernel" not in api or "infer_meta" not in api:
continue
if api["api"] in skiped_api_list["phi_apis"]:
skiped_kernel_list.append(api["kernel"]["func"])
skiped_kernel_list += skiped_api_list["phi_kernels"]
return skiped_kernel_list
def parse_args(): def parse_args():
parser = argparse.ArgumentParser("gather phi kernel and infermate info") parser = argparse.ArgumentParser("gather phi kernel and infermate info")
...@@ -50,7 +67,7 @@ def parse_args(): ...@@ -50,7 +67,7 @@ def parse_args():
def get_api_yaml_info(file_path): def get_api_yaml_info(file_path):
f = open(file_path + "/python/paddle/utils/code_gen/api.yaml", "r") f = open(file_path, "r")
cont = f.read() cont = f.read()
return yaml.load(cont, Loader=yaml.FullLoader) return yaml.load(cont, Loader=yaml.FullLoader)
...@@ -259,8 +276,11 @@ def gen_register_code_info(item: List[str], attr_data: Dict[str, List[str]]): ...@@ -259,8 +276,11 @@ def gen_register_code_info(item: List[str], attr_data: Dict[str, List[str]]):
# TODO(wilber): handle the unknown inferShape func. # TODO(wilber): handle the unknown inferShape func.
return "" return ""
skipped_kernel_list = get_skipped_kernel_list()
for ir_dtype, origin_dtype in zip(ir_dtypes, origin_dtypes): for ir_dtype, origin_dtype in zip(ir_dtypes, origin_dtypes):
kernel_func = gen_kernel_func(item[3], ctx_name, origin_dtype) kernel_func = gen_kernel_func(item[3], ctx_name, origin_dtype)
if item[0].lower() in skipped_kernel_list:
continue
ir_name = ir_ctx_name + '.' + item[0].lower( ir_name = ir_ctx_name + '.' + item[0].lower(
) + '.' + ir_dtype + '.' + item[2].lower() ) + '.' + ir_dtype + '.' + item[2].lower()
if ir_name in attr_data.keys() and attr_data[ir_name] is not None: if ir_name in attr_data.keys() and attr_data[ir_name] is not None:
...@@ -342,7 +362,9 @@ def gen_phi_kernel_register_code(resources: List[List[str]], ...@@ -342,7 +362,9 @@ def gen_phi_kernel_register_code(resources: List[List[str]],
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
infer_meta_data = get_api_yaml_info(args.paddle_root_path) skipped_phi_api_list_file = args.paddle_root_path + skipped_phi_api_list_file
api_yaml_file = args.paddle_root_path + api_yaml_file
infer_meta_data = get_api_yaml_info(api_yaml_file)
kernel_data = get_kernel_info(args.kernel_info_file) kernel_data = get_kernel_info(args.kernel_info_file)
info_meta_wrap_data = get_infermeta_info(args.infermeta_wrap_file) info_meta_wrap_data = get_infermeta_info(args.infermeta_wrap_file)
attr_data = get_attr_info(args.attr_info_file) attr_data = get_attr_info(args.attr_info_file)
......
{
"phi_apis":["conj"],
"phi_kernels":["equal_all"]
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册