未验证 提交 755a6c53 编写于 作者: W Wilber 提交者: GitHub

support register with attr (#40564)

* support register with attr

* add infrt_with_gpu macor
上级 35a5e8ee
......@@ -41,7 +41,37 @@ python3 ${PADDLE_ROOT}/python/paddle/utils/code_gen/wrapped_infermeta_gen.py \
grep PD_REGISTER_INFER_META_FN ${temp_path}/generate.cc \
| awk -F "\(|,|::|\)" '{print $2, $4}' > ${temp_path}/wrap_info.txt
#step 3: merge all infos
#step 3:get ir's attr_name.
ir_attr_name_info_file=`mktemp`
# phi_cpu attr
all_ir_name=`grep -Eo "PDTCPU_Kernel<.*\"" paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td | awk -v FS="<" '{gsub(/\"/,"");print $2}'`
for ir in $all_ir_name
do
attr_name=`grep "<\"$ir" -A 3 paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td | grep -Eo "Attr:.*)" \
| awk '{gsub(/F32Attr/,"");gsub(/F64Attr/,"");gsub(/StrAttr/,"");gsub(/BOOLAttr/,""); \
gsub(/SI1Attr/,"");gsub(/SI8Attr/,"");gsub(/SI16Attr/,"");gsub(/SI32Attr/,"");gsub(/SI64Attr/,""); \
gsub(/UI1Attr/,"");gsub(/UI8Attr/,"");gsub(/I16Attr/,"");gsub(/I32Attr/,"");gsub(/I64Attr/,""); \
gsub(/I1Attr/,"");gsub(/I8Attr/,"");gsub(/UI16Attr/,"");gsub(/UI32Attr/,"");gsub(/UI64Attr/,""); \
gsub(/Attr/,"");gsub(/\)/,""); \
gsub(/[,:]/,"");print $a}'`
echo phi_cpu.$ir $attr_name >> $ir_attr_name_info_file
done
# phi_gpu attr
all_ir_name=`grep -Eo "PDTGPU_Kernel<.*\"" paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td | awk -v FS="<" '{gsub(/\"/,"");print $2}'`
for ir in $all_ir_name
do
attr_name=`grep "<\"$ir" -A 3 paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td | grep -Eo "Attr:.*)" \
| awk '{gsub(/F32Attr/,"");gsub(/F64Attr/,"");gsub(/StrAttr/,"");gsub(/BOOLAttr/,""); \
gsub(/SI1Attr/,"");gsub(/SI8Attr/,"");gsub(/SI16Attr/,"");gsub(/SI32Attr/,"");gsub(/SI64Attr/,""); \
gsub(/UI1Attr/,"");gsub(/UI8Attr/,"");gsub(/I16Attr/,"");gsub(/I32Attr/,"");gsub(/I64Attr/,""); \
gsub(/I1Attr/,"");gsub(/I8Attr/,"");gsub(/UI16Attr/,"");gsub(/UI32Attr/,"");gsub(/UI64Attr/,""); \
gsub(/Attr/,"");gsub(/\)/,""); \
gsub(/[,:]/,"");print $a}'`
echo phi_gpu.$ir $attr_name >> $ir_attr_name_info_file
done
#step 4: merge all infos
# @input1 => phi kernel infomation : kernel_name kernel_key(GPU/CPU, precision, layout)
# @input2 => information from api.yaml : kernel_name kernel_function_name inferMeta_function_name
# @input3 => information from wrapped_infermeta_gen : ensure the inferMeta function has
......@@ -50,4 +80,5 @@ python3 ${PADDLE_ROOT}/tools/infrt/get_phi_kernel_info.py \
--paddle_root_path ${PADDLE_ROOT} \
--kernel_info_file $kernel_register_info_file \
--infermeta_wrap_file ${temp_path}/wrap_info.txt \
--attr_info_file $ir_attr_name_info_file \
--generate_file ${PADDLE_ROOT}/paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launchers.cc
......@@ -37,6 +37,8 @@ def parse_args():
type=str,
required=True,
help="inferMeta wrap info file.")
parser.add_argument(
"--attr_info_file", type=str, required=True, help="attr info file.")
parser.add_argument(
"--generate_file",
type=str,
......@@ -59,6 +61,23 @@ def get_kernel_info(file_path):
return [l.strip() for l in cont]
def get_attr_info(file_path):
"""
phi_gpu.argsort.float64.any $axisBool$descending
"""
ret = {}
with open(file_path, 'r') as f:
cont = f.readlines()
for l in cont:
datas = l.strip().split(' ')
if len(datas) == 2:
attrs = datas[1].split('$')
ret[datas[0]] = attrs[1:]
else:
ret[datas[0]] = None
return ret
def merge(infer_meta_data, kernel_data, wrap_data):
meta_map = {}
for api in infer_meta_data:
......@@ -114,14 +133,14 @@ namespace kernel {
def gen_context(val):
if val == "CPU":
return "phi::CPUContext"
# elif val == "GPU":
# return "phi::GPUContext"
return "phi::CPUContext", "phi_cpu"
elif val == "GPU":
return "phi::GPUContext", "phi_gpu"
# elif val == "XPU":
# return "phi::XPUContext"
# return "phi::XPUContext", "phi_xpu"
else:
# raise Exception(f"Unknown context type {val}")
return ""
return "", ""
def gen_layout(val):
......@@ -195,34 +214,53 @@ def gen_dtype(vals: List[str]):
return ir_dtypes, origin_dtypes
# TODO(wilber): Now only process CPUContext.
def gen_register_info(resources: List[List[str]]):
# Note: Now only process CPUContext and GPUContext.
def gen_register_code_info(item: List[str], attr_data: Dict[str, List[str]]):
"""
resources: [['add', 'CPU', 'ALL_LAYOUT', 'AddKernel', 'float', 'double', '...'(varaidic types), 'ElementwiseInferMeta'], ...]
item: ['add', 'CPU', 'ALL_LAYOUT', 'AddKernel', 'float', 'double', '...'(varaidic types), 'ElementwiseInferMeta']
attr_data: {'phi_cpu.arg_min.float32.any': ['axisBool', 'keepdimsBool', 'flatten', 'dtype']}
"""
res = "void RegisterInferShapeLaunchers(host_context::KernelRegistry* registry) {"
for item in resources:
# The output string is polluted by C++ macros, here the \ is removed
update_item = [v.strip('\\') for v in item]
ctx_name, ir_ctx_name = gen_context(item[1])
if (ctx_name == ""):
return ""
item[2] = gen_layout(item[2])
ir_dtypes, origin_dtypes = gen_dtype(item[4:-1])
infer_shape_func = "&phi::" + item[-1]
ctx_name = gen_context(update_item[1])
if (ctx_name == ""):
continue
update_item[2] = gen_layout(update_item[2])
ir_dtypes, origin_dtypes = gen_dtype(update_item[4:-1])
infer_shape_func = "&phi::" + update_item[-1]
res = ""
if update_item[-1] == "unknown":
# TODO(wilber): handle the unknown inferShape func.
continue
if item[-1] == "unknown":
# TODO(wilber): handle the unknown inferShape func.
return ""
for ir_dtype, origin_dtype in zip(ir_dtypes, origin_dtypes):
kernel_func = gen_kernel_func(item[3], ctx_name, origin_dtype)
ir_name = ir_ctx_name + '.' + item[0].lower(
) + '.' + ir_dtype + '.' + item[2].lower()
if ir_name in attr_data.keys() and attr_data[ir_name] is not None:
attr_names = ', '.join(
["\"" + a + "\"" for a in attr_data[ir_name]])
res += f"""
registry->AddKernelWithAttrs("{ir_name}","""
res += f"""
std::bind(&KernelLauncherFunc<decltype({kernel_func}),
{kernel_func},
decltype({infer_shape_func}),
{infer_shape_func}>,
KernelLauncher<decltype({kernel_func}),
{kernel_func},
decltype({infer_shape_func}),
{infer_shape_func}>(),
std::placeholders::_1),
{{{attr_names}}});
"""
for ir_dtype, origin_dtype in zip(ir_dtypes, origin_dtypes):
kernel_func = gen_kernel_func(update_item[3], ctx_name,
origin_dtype)
ir_name = 'phi_cpu.' + update_item[0].lower(
) + '.' + ir_dtype + '.' + update_item[2].lower()
else:
res += f"""
registry->AddKernel("{ir_name}","""
registry->AddKernel("{ir_name}","""
res += f"""
std::bind(&KernelLauncherFunc<decltype({kernel_func}),
......@@ -236,18 +274,54 @@ def gen_register_info(resources: List[List[str]]):
std::placeholders::_1));
"""
return res
def gen_register_info(resources: List[List[str]],
attr_data: Dict[str, List[str]]):
"""
resources: [['add', 'CPU', 'ALL_LAYOUT', 'AddKernel', 'float', 'double', '...'(varaidic types), 'ElementwiseInferMeta'], ...]
attr_data: {'phi_cpu.arg_min.float32.any': ['axisBool', 'keepdimsBool', 'flatten', 'dtype']}
"""
res = "void RegisterInferShapeLaunchers(host_context::KernelRegistry* registry) {"
# register cpu kernels.
for item in resources:
# The output string is polluted by C++ macros, here the \ is removed
update_item = [v.strip('\\') for v in item]
if update_item[1] != "CPU":
continue
code = gen_register_code_info(item, attr_data)
if (code == ""):
continue
res += code
# register gpu kernels.
res += "\n#ifdef INFRT_WITH_GPU"
for item in resources:
# The output string is polluted by C++ macros, here the \ is removed
update_item = [v.strip('\\') for v in item]
if update_item[1] != "GPU":
continue
code = gen_register_code_info(item, attr_data)
if (code == ""):
continue
res += code
res += "#endif // INFRT_WITH_GPU"
res += "\n}"
return res
def gen_phi_kernel_register_code(resources: List[List[str]],
attr_data: Dict[str, List[str]],
src_file_path: str):
source_file = open(src_file_path, 'w')
source_file.write(gen_warn_info())
source_file.write(gen_include_headers())
namespace = gen_namespace()
source_file.write(namespace[0])
source_file.write(gen_register_info(resources))
source_file.write(gen_register_info(resources, attr_data))
source_file.write(namespace[1])
source_file.close()
......@@ -257,5 +331,6 @@ if __name__ == "__main__":
infer_meta_data = get_api_yaml_info(args.paddle_root_path)
kernel_data = get_kernel_info(args.kernel_info_file)
info_meta_wrap_data = get_kernel_info(args.infermeta_wrap_file)
attr_data = get_attr_info(args.attr_info_file)
out = merge(infer_meta_data, kernel_data, info_meta_wrap_data)
gen_phi_kernel_register_code(out, args.generate_file)
gen_phi_kernel_register_code(out, attr_data, args.generate_file)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册