diff --git a/tools/infrt/get_phi_kernel_function.sh b/tools/infrt/get_phi_kernel_function.sh index 3b9f4b7273500f23d67a3062a2d4ee367c0b473b..6b2586d40819b9e25eef823dff59687114664197 100644 --- a/tools/infrt/get_phi_kernel_function.sh +++ b/tools/infrt/get_phi_kernel_function.sh @@ -41,7 +41,37 @@ python3 ${PADDLE_ROOT}/python/paddle/utils/code_gen/wrapped_infermeta_gen.py \ grep PD_REGISTER_INFER_META_FN ${temp_path}/generate.cc \ | awk -F "\(|,|::|\)" '{print $2, $4}' > ${temp_path}/wrap_info.txt -#step 3: merge all infos + +#step 3:get ir's attr_name. +ir_attr_name_info_file=`mktemp` +# phi_cpu attr +all_ir_name=`grep -Eo "PDTCPU_Kernel<.*\"" paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td | awk -v FS="<" '{gsub(/\"/,"");print $2}'` +for ir in $all_ir_name +do + attr_name=`grep "<\"$ir" -A 3 paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td | grep -Eo "Attr:.*)" \ + | awk '{gsub(/F32Attr/,"");gsub(/F64Attr/,"");gsub(/StrAttr/,"");gsub(/BOOLAttr/,""); \ + gsub(/SI1Attr/,"");gsub(/SI8Attr/,"");gsub(/SI16Attr/,"");gsub(/SI32Attr/,"");gsub(/SI64Attr/,""); \ + gsub(/UI1Attr/,"");gsub(/UI8Attr/,"");gsub(/I16Attr/,"");gsub(/I32Attr/,"");gsub(/I64Attr/,""); \ + gsub(/I1Attr/,"");gsub(/I8Attr/,"");gsub(/UI16Attr/,"");gsub(/UI32Attr/,"");gsub(/UI64Attr/,""); \ + gsub(/Attr/,"");gsub(/\)/,""); \ + gsub(/[,:]/,"");print $a}'` + echo phi_cpu.$ir $attr_name >> $ir_attr_name_info_file +done +# phi_gpu attr +all_ir_name=`grep -Eo "PDTGPU_Kernel<.*\"" paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td | awk -v FS="<" '{gsub(/\"/,"");print $2}'` +for ir in $all_ir_name +do + attr_name=`grep "<\"$ir" -A 3 paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td | grep -Eo "Attr:.*)" \ + | awk '{gsub(/F32Attr/,"");gsub(/F64Attr/,"");gsub(/StrAttr/,"");gsub(/BOOLAttr/,""); \ + gsub(/SI1Attr/,"");gsub(/SI8Attr/,"");gsub(/SI16Attr/,"");gsub(/SI32Attr/,"");gsub(/SI64Attr/,""); \ + gsub(/UI1Attr/,"");gsub(/UI8Attr/,"");gsub(/I16Attr/,"");gsub(/I32Attr/,"");gsub(/I64Attr/,""); \ + gsub(/I1Attr/,"");gsub(/I8Attr/,"");gsub(/UI16Attr/,"");gsub(/UI32Attr/,"");gsub(/UI64Attr/,""); \ + gsub(/Attr/,"");gsub(/\)/,""); \ + gsub(/[,:]/,"");print $a}'` + echo phi_gpu.$ir $attr_name >> $ir_attr_name_info_file +done + +#step 4: merge all infos # @input1 => phi kernel infomation : kernel_name kernel_key(GPU/CPU, precision, layout) # @input2 => information from api.yaml : kernel_name kernel_function_name inferMeta_function_name # @input3 => information from wrapped_infermeta_gen : ensure the inferMeta function has @@ -50,4 +80,5 @@ python3 ${PADDLE_ROOT}/tools/infrt/get_phi_kernel_info.py \ --paddle_root_path ${PADDLE_ROOT} \ --kernel_info_file $kernel_register_info_file \ --infermeta_wrap_file ${temp_path}/wrap_info.txt \ + --attr_info_file $ir_attr_name_info_file \ --generate_file ${PADDLE_ROOT}/paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launchers.cc diff --git a/tools/infrt/get_phi_kernel_info.py b/tools/infrt/get_phi_kernel_info.py index 774f6cd6bf3648a0de7a34e01e893d212bce9770..85ad585cdefa9cbb4ac8d029e699af4d5ffaeaf7 100644 --- a/tools/infrt/get_phi_kernel_info.py +++ b/tools/infrt/get_phi_kernel_info.py @@ -37,6 +37,8 @@ def parse_args(): type=str, required=True, help="inferMeta wrap info file.") + parser.add_argument( + "--attr_info_file", type=str, required=True, help="attr info file.") parser.add_argument( "--generate_file", type=str, @@ -59,6 +61,23 @@ def get_kernel_info(file_path): return [l.strip() for l in cont] +def get_attr_info(file_path): + """ + phi_gpu.argsort.float64.any $axisBool$descending + """ + ret = {} + with open(file_path, 'r') as f: + cont = f.readlines() + for l in cont: + datas = l.strip().split(' ') + if len(datas) == 2: + attrs = datas[1].split('$') + ret[datas[0]] = attrs[1:] + else: + ret[datas[0]] = None + return ret + + def merge(infer_meta_data, kernel_data, wrap_data): meta_map = {} for api in infer_meta_data: @@ -114,14 +133,14 @@ namespace kernel { def gen_context(val): if val == "CPU": - return "phi::CPUContext" - # elif val == "GPU": - # return "phi::GPUContext" + return "phi::CPUContext", "phi_cpu" + elif val == "GPU": + return "phi::GPUContext", "phi_gpu" # elif val == "XPU": - # return "phi::XPUContext" + # return "phi::XPUContext", "phi_xpu" else: # raise Exception(f"Unknown context type {val}") - return "" + return "", "" def gen_layout(val): @@ -195,34 +214,53 @@ def gen_dtype(vals: List[str]): return ir_dtypes, origin_dtypes -# TODO(wilber): Now only process CPUContext. -def gen_register_info(resources: List[List[str]]): +# Note: Now only process CPUContext and GPUContext. + + +def gen_register_code_info(item: List[str], attr_data: Dict[str, List[str]]): """ - resources: [['add', 'CPU', 'ALL_LAYOUT', 'AddKernel', 'float', 'double', '...'(varaidic types), 'ElementwiseInferMeta'], ...] + item: ['add', 'CPU', 'ALL_LAYOUT', 'AddKernel', 'float', 'double', '...'(varaidic types), 'ElementwiseInferMeta'] + attr_data: {'phi_cpu.arg_min.float32.any': ['axisBool', 'keepdimsBool', 'flatten', 'dtype']} """ - res = "void RegisterInferShapeLaunchers(host_context::KernelRegistry* registry) {" - for item in resources: - # The output string is polluted by C++ macros, here the \ is removed - update_item = [v.strip('\\') for v in item] + ctx_name, ir_ctx_name = gen_context(item[1]) + if (ctx_name == ""): + return "" + item[2] = gen_layout(item[2]) + ir_dtypes, origin_dtypes = gen_dtype(item[4:-1]) + infer_shape_func = "&phi::" + item[-1] - ctx_name = gen_context(update_item[1]) - if (ctx_name == ""): - continue - update_item[2] = gen_layout(update_item[2]) - ir_dtypes, origin_dtypes = gen_dtype(update_item[4:-1]) - infer_shape_func = "&phi::" + update_item[-1] + res = "" - if update_item[-1] == "unknown": - # TODO(wilber): handle the unknown inferShape func. - continue + if item[-1] == "unknown": + # TODO(wilber): handle the unknown inferShape func. + return "" + + for ir_dtype, origin_dtype in zip(ir_dtypes, origin_dtypes): + kernel_func = gen_kernel_func(item[3], ctx_name, origin_dtype) + ir_name = ir_ctx_name + '.' + item[0].lower( + ) + '.' + ir_dtype + '.' + item[2].lower() + if ir_name in attr_data.keys() and attr_data[ir_name] is not None: + attr_names = ', '.join( + ["\"" + a + "\"" for a in attr_data[ir_name]]) + res += f""" +registry->AddKernelWithAttrs("{ir_name}",""" + + res += f""" + std::bind(&KernelLauncherFunc, + KernelLauncher(), + std::placeholders::_1), + {{{attr_names}}}); +""" - for ir_dtype, origin_dtype in zip(ir_dtypes, origin_dtypes): - kernel_func = gen_kernel_func(update_item[3], ctx_name, - origin_dtype) - ir_name = 'phi_cpu.' + update_item[0].lower( - ) + '.' + ir_dtype + '.' + update_item[2].lower() + else: res += f""" - registry->AddKernel("{ir_name}",""" +registry->AddKernel("{ir_name}",""" res += f""" std::bind(&KernelLauncherFunc