未验证 提交 755a6c53 编写于 作者: W Wilber 提交者: GitHub

support register with attr (#40564)

* support register with attr

* add infrt_with_gpu macor
上级 35a5e8ee
...@@ -41,7 +41,37 @@ python3 ${PADDLE_ROOT}/python/paddle/utils/code_gen/wrapped_infermeta_gen.py \ ...@@ -41,7 +41,37 @@ python3 ${PADDLE_ROOT}/python/paddle/utils/code_gen/wrapped_infermeta_gen.py \
grep PD_REGISTER_INFER_META_FN ${temp_path}/generate.cc \ grep PD_REGISTER_INFER_META_FN ${temp_path}/generate.cc \
| awk -F "\(|,|::|\)" '{print $2, $4}' > ${temp_path}/wrap_info.txt | awk -F "\(|,|::|\)" '{print $2, $4}' > ${temp_path}/wrap_info.txt
#step 3: merge all infos
#step 3:get ir's attr_name.
ir_attr_name_info_file=`mktemp`
# phi_cpu attr
all_ir_name=`grep -Eo "PDTCPU_Kernel<.*\"" paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td | awk -v FS="<" '{gsub(/\"/,"");print $2}'`
for ir in $all_ir_name
do
attr_name=`grep "<\"$ir" -A 3 paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td | grep -Eo "Attr:.*)" \
| awk '{gsub(/F32Attr/,"");gsub(/F64Attr/,"");gsub(/StrAttr/,"");gsub(/BOOLAttr/,""); \
gsub(/SI1Attr/,"");gsub(/SI8Attr/,"");gsub(/SI16Attr/,"");gsub(/SI32Attr/,"");gsub(/SI64Attr/,""); \
gsub(/UI1Attr/,"");gsub(/UI8Attr/,"");gsub(/I16Attr/,"");gsub(/I32Attr/,"");gsub(/I64Attr/,""); \
gsub(/I1Attr/,"");gsub(/I8Attr/,"");gsub(/UI16Attr/,"");gsub(/UI32Attr/,"");gsub(/UI64Attr/,""); \
gsub(/Attr/,"");gsub(/\)/,""); \
gsub(/[,:]/,"");print $a}'`
echo phi_cpu.$ir $attr_name >> $ir_attr_name_info_file
done
# phi_gpu attr
all_ir_name=`grep -Eo "PDTGPU_Kernel<.*\"" paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td | awk -v FS="<" '{gsub(/\"/,"");print $2}'`
for ir in $all_ir_name
do
attr_name=`grep "<\"$ir" -A 3 paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td | grep -Eo "Attr:.*)" \
| awk '{gsub(/F32Attr/,"");gsub(/F64Attr/,"");gsub(/StrAttr/,"");gsub(/BOOLAttr/,""); \
gsub(/SI1Attr/,"");gsub(/SI8Attr/,"");gsub(/SI16Attr/,"");gsub(/SI32Attr/,"");gsub(/SI64Attr/,""); \
gsub(/UI1Attr/,"");gsub(/UI8Attr/,"");gsub(/I16Attr/,"");gsub(/I32Attr/,"");gsub(/I64Attr/,""); \
gsub(/I1Attr/,"");gsub(/I8Attr/,"");gsub(/UI16Attr/,"");gsub(/UI32Attr/,"");gsub(/UI64Attr/,""); \
gsub(/Attr/,"");gsub(/\)/,""); \
gsub(/[,:]/,"");print $a}'`
echo phi_gpu.$ir $attr_name >> $ir_attr_name_info_file
done
#step 4: merge all infos
# @input1 => phi kernel infomation : kernel_name kernel_key(GPU/CPU, precision, layout) # @input1 => phi kernel infomation : kernel_name kernel_key(GPU/CPU, precision, layout)
# @input2 => information from api.yaml : kernel_name kernel_function_name inferMeta_function_name # @input2 => information from api.yaml : kernel_name kernel_function_name inferMeta_function_name
# @input3 => information from wrapped_infermeta_gen : ensure the inferMeta function has # @input3 => information from wrapped_infermeta_gen : ensure the inferMeta function has
...@@ -50,4 +80,5 @@ python3 ${PADDLE_ROOT}/tools/infrt/get_phi_kernel_info.py \ ...@@ -50,4 +80,5 @@ python3 ${PADDLE_ROOT}/tools/infrt/get_phi_kernel_info.py \
--paddle_root_path ${PADDLE_ROOT} \ --paddle_root_path ${PADDLE_ROOT} \
--kernel_info_file $kernel_register_info_file \ --kernel_info_file $kernel_register_info_file \
--infermeta_wrap_file ${temp_path}/wrap_info.txt \ --infermeta_wrap_file ${temp_path}/wrap_info.txt \
--attr_info_file $ir_attr_name_info_file \
--generate_file ${PADDLE_ROOT}/paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launchers.cc --generate_file ${PADDLE_ROOT}/paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launchers.cc
...@@ -37,6 +37,8 @@ def parse_args(): ...@@ -37,6 +37,8 @@ def parse_args():
type=str, type=str,
required=True, required=True,
help="inferMeta wrap info file.") help="inferMeta wrap info file.")
parser.add_argument(
"--attr_info_file", type=str, required=True, help="attr info file.")
parser.add_argument( parser.add_argument(
"--generate_file", "--generate_file",
type=str, type=str,
...@@ -59,6 +61,23 @@ def get_kernel_info(file_path): ...@@ -59,6 +61,23 @@ def get_kernel_info(file_path):
return [l.strip() for l in cont] return [l.strip() for l in cont]
def get_attr_info(file_path):
"""
phi_gpu.argsort.float64.any $axisBool$descending
"""
ret = {}
with open(file_path, 'r') as f:
cont = f.readlines()
for l in cont:
datas = l.strip().split(' ')
if len(datas) == 2:
attrs = datas[1].split('$')
ret[datas[0]] = attrs[1:]
else:
ret[datas[0]] = None
return ret
def merge(infer_meta_data, kernel_data, wrap_data): def merge(infer_meta_data, kernel_data, wrap_data):
meta_map = {} meta_map = {}
for api in infer_meta_data: for api in infer_meta_data:
...@@ -114,14 +133,14 @@ namespace kernel { ...@@ -114,14 +133,14 @@ namespace kernel {
def gen_context(val): def gen_context(val):
if val == "CPU": if val == "CPU":
return "phi::CPUContext" return "phi::CPUContext", "phi_cpu"
# elif val == "GPU": elif val == "GPU":
# return "phi::GPUContext" return "phi::GPUContext", "phi_gpu"
# elif val == "XPU": # elif val == "XPU":
# return "phi::XPUContext" # return "phi::XPUContext", "phi_xpu"
else: else:
# raise Exception(f"Unknown context type {val}") # raise Exception(f"Unknown context type {val}")
return "" return "", ""
def gen_layout(val): def gen_layout(val):
...@@ -195,34 +214,53 @@ def gen_dtype(vals: List[str]): ...@@ -195,34 +214,53 @@ def gen_dtype(vals: List[str]):
return ir_dtypes, origin_dtypes return ir_dtypes, origin_dtypes
# TODO(wilber): Now only process CPUContext. # Note: Now only process CPUContext and GPUContext.
def gen_register_info(resources: List[List[str]]):
def gen_register_code_info(item: List[str], attr_data: Dict[str, List[str]]):
""" """
resources: [['add', 'CPU', 'ALL_LAYOUT', 'AddKernel', 'float', 'double', '...'(varaidic types), 'ElementwiseInferMeta'], ...] item: ['add', 'CPU', 'ALL_LAYOUT', 'AddKernel', 'float', 'double', '...'(varaidic types), 'ElementwiseInferMeta']
attr_data: {'phi_cpu.arg_min.float32.any': ['axisBool', 'keepdimsBool', 'flatten', 'dtype']}
""" """
res = "void RegisterInferShapeLaunchers(host_context::KernelRegistry* registry) {" ctx_name, ir_ctx_name = gen_context(item[1])
for item in resources:
# The output string is polluted by C++ macros, here the \ is removed
update_item = [v.strip('\\') for v in item]
ctx_name = gen_context(update_item[1])
if (ctx_name == ""): if (ctx_name == ""):
continue return ""
update_item[2] = gen_layout(update_item[2]) item[2] = gen_layout(item[2])
ir_dtypes, origin_dtypes = gen_dtype(update_item[4:-1]) ir_dtypes, origin_dtypes = gen_dtype(item[4:-1])
infer_shape_func = "&phi::" + update_item[-1] infer_shape_func = "&phi::" + item[-1]
if update_item[-1] == "unknown": res = ""
if item[-1] == "unknown":
# TODO(wilber): handle the unknown inferShape func. # TODO(wilber): handle the unknown inferShape func.
continue return ""
for ir_dtype, origin_dtype in zip(ir_dtypes, origin_dtypes): for ir_dtype, origin_dtype in zip(ir_dtypes, origin_dtypes):
kernel_func = gen_kernel_func(update_item[3], ctx_name, kernel_func = gen_kernel_func(item[3], ctx_name, origin_dtype)
origin_dtype) ir_name = ir_ctx_name + '.' + item[0].lower(
ir_name = 'phi_cpu.' + update_item[0].lower( ) + '.' + ir_dtype + '.' + item[2].lower()
) + '.' + ir_dtype + '.' + update_item[2].lower() if ir_name in attr_data.keys() and attr_data[ir_name] is not None:
attr_names = ', '.join(
["\"" + a + "\"" for a in attr_data[ir_name]])
res += f""" res += f"""
registry->AddKernel("{ir_name}",""" registry->AddKernelWithAttrs("{ir_name}","""
res += f"""
std::bind(&KernelLauncherFunc<decltype({kernel_func}),
{kernel_func},
decltype({infer_shape_func}),
{infer_shape_func}>,
KernelLauncher<decltype({kernel_func}),
{kernel_func},
decltype({infer_shape_func}),
{infer_shape_func}>(),
std::placeholders::_1),
{{{attr_names}}});
"""
else:
res += f"""
registry->AddKernel("{ir_name}","""
res += f""" res += f"""
std::bind(&KernelLauncherFunc<decltype({kernel_func}), std::bind(&KernelLauncherFunc<decltype({kernel_func}),
...@@ -236,18 +274,54 @@ def gen_register_info(resources: List[List[str]]): ...@@ -236,18 +274,54 @@ def gen_register_info(resources: List[List[str]]):
std::placeholders::_1)); std::placeholders::_1));
""" """
return res
def gen_register_info(resources: List[List[str]],
attr_data: Dict[str, List[str]]):
"""
resources: [['add', 'CPU', 'ALL_LAYOUT', 'AddKernel', 'float', 'double', '...'(varaidic types), 'ElementwiseInferMeta'], ...]
attr_data: {'phi_cpu.arg_min.float32.any': ['axisBool', 'keepdimsBool', 'flatten', 'dtype']}
"""
res = "void RegisterInferShapeLaunchers(host_context::KernelRegistry* registry) {"
# register cpu kernels.
for item in resources:
# The output string is polluted by C++ macros, here the \ is removed
update_item = [v.strip('\\') for v in item]
if update_item[1] != "CPU":
continue
code = gen_register_code_info(item, attr_data)
if (code == ""):
continue
res += code
# register gpu kernels.
res += "\n#ifdef INFRT_WITH_GPU"
for item in resources:
# The output string is polluted by C++ macros, here the \ is removed
update_item = [v.strip('\\') for v in item]
if update_item[1] != "GPU":
continue
code = gen_register_code_info(item, attr_data)
if (code == ""):
continue
res += code
res += "#endif // INFRT_WITH_GPU"
res += "\n}" res += "\n}"
return res return res
def gen_phi_kernel_register_code(resources: List[List[str]], def gen_phi_kernel_register_code(resources: List[List[str]],
attr_data: Dict[str, List[str]],
src_file_path: str): src_file_path: str):
source_file = open(src_file_path, 'w') source_file = open(src_file_path, 'w')
source_file.write(gen_warn_info()) source_file.write(gen_warn_info())
source_file.write(gen_include_headers()) source_file.write(gen_include_headers())
namespace = gen_namespace() namespace = gen_namespace()
source_file.write(namespace[0]) source_file.write(namespace[0])
source_file.write(gen_register_info(resources)) source_file.write(gen_register_info(resources, attr_data))
source_file.write(namespace[1]) source_file.write(namespace[1])
source_file.close() source_file.close()
...@@ -257,5 +331,6 @@ if __name__ == "__main__": ...@@ -257,5 +331,6 @@ if __name__ == "__main__":
infer_meta_data = get_api_yaml_info(args.paddle_root_path) infer_meta_data = get_api_yaml_info(args.paddle_root_path)
kernel_data = get_kernel_info(args.kernel_info_file) kernel_data = get_kernel_info(args.kernel_info_file)
info_meta_wrap_data = get_kernel_info(args.infermeta_wrap_file) info_meta_wrap_data = get_kernel_info(args.infermeta_wrap_file)
attr_data = get_attr_info(args.attr_info_file)
out = merge(infer_meta_data, kernel_data, info_meta_wrap_data) out = merge(infer_meta_data, kernel_data, info_meta_wrap_data)
gen_phi_kernel_register_code(out, args.generate_file) gen_phi_kernel_register_code(out, attr_data, args.generate_file)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册