未验证 提交 a7d4ddc4 编写于 作者: S Shang Zhizhou 提交者: GitHub

update tools for infrt build (#39552)

上级 eb3c7d00
...@@ -20,9 +20,33 @@ ...@@ -20,9 +20,33 @@
set -e set -e
EXIT_CODE=0; #step 1:get kernel registered info
tmp_dir=`mktemp -d` kernel_register_info_file=`mktemp`
PADDLE_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}")/../../" && pwd )" PADDLE_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}")/../../" && pwd )"
unset GREP_OPTIONS && find ${PADDLE_ROOT}/paddle/pten/kernels -name "*.c*" \
| xargs sed -e '/PT_REGISTER_\(GENERAL_\)\?KERNEL(/,/)/!d' \
| awk 'BEGIN { RS="{" }{ gsub(/\n /,""); print $0 }' \
| grep PT_REGISTER \
| awk -F ",|\(" '{gsub(/ /,"");print $2, $3, $4, $5}' \
| sort -u | awk '{gsub(/pten::/,"");print $0}' \
| grep -v "_grad" > $kernel_register_info_file
#step 2:get simple general inferMeta function wrap info
temp_path=`mktemp -d`
python3 ${PADDLE_ROOT}/python/paddle/utils/code_gen/wrapped_infermeta_gen.py \
--api_yaml_path ${PADDLE_ROOT}/python/paddle/utils/code_gen/api.yaml \
--wrapped_infermeta_header_path ${temp_path}/generate.h \
--wrapped_infermeta_source_path ${temp_path}/generate.cc
grep PT_REGISTER_INFER_META_FN ${temp_path}/generate.cc \
| awk -F "\(|,|::|\)" '{print $2, $4}' > ${temp_path}/wrap_info.txt
unset GREP_OPTIONS && find ${PADDLE_ROOT}/paddle/pten/kernels -name "*.c*" | xargs sed -e '/PT_REGISTER_\(GENERAL_\)\?KERNEL(/,/)/!d' | awk 'BEGIN { RS="{" }{ gsub(/\n /,""); print $0 }' | grep PT_REGISTER | awk -F ",|\(" '{gsub(/ /,"");print $2, $3, $4, $5}' | sort -u | awk '{gsub(/pten::/,"");print $0}' | grep -v "_grad" #step 3: merge all infos
# @input1 => pten kernel infomation : kernel_name kernel_key(GPU/CPU, precision, layout)
# @input2 => information from api.yaml : kernel_name kernel_function_name inferMeta_function_name
# @input3 => information from wrapped_infermeta_gen : ensure the inferMeta function has
# same signature with kernel function
python3 ${PADDLE_ROOT}/tools/infrt/get_pten_kernel_info.py \
--paddle_root_path ${PADDLE_ROOT} \
--kernel_info_file $kernel_register_info_file \
--infermeta_wrap_file ${temp_path}/wrap_info.txt
...@@ -31,6 +31,11 @@ def parse_args(): ...@@ -31,6 +31,11 @@ def parse_args():
type=str, type=str,
required=True, required=True,
help="kernel info file generated by get_pten_kernel_function.sh .") help="kernel info file generated by get_pten_kernel_function.sh .")
parser.add_argument(
"--infermeta_wrap_file",
type=str,
required=True,
help="inferMeta wrap info file .")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -47,16 +52,23 @@ def get_kernel_info(file_path): ...@@ -47,16 +52,23 @@ def get_kernel_info(file_path):
return [l.strip() for l in cont] return [l.strip() for l in cont]
def merge(infer_meta_data, kernel_data): def merge(infer_meta_data, kernel_data, wrap_data):
meta_map = {} meta_map = {}
for api in infer_meta_data: for api in infer_meta_data:
if not api.has_key("kernel") or not api.has_key("infer_meta"): if "kernel" not in api or "infer_meta" not in api:
continue continue
meta_map[api["kernel"]["func"]] = api["infer_meta"]["func"] meta_map[api["kernel"]["func"]] = api["infer_meta"]["func"]
wrap_map = {}
for l in wrap_data:
wrap_map[l.split()[0]] = l.split()[1]
full_kernel_data = [] full_kernel_data = []
for l in kernel_data: for l in kernel_data:
key = l.split()[0] key = l.split()[0]
if meta_map.has_key(key): if key in meta_map:
if key in meta_map:
full_kernel_data.append((l + " " + wrap_map[key]).split())
else:
full_kernel_data.append((l + " " + meta_map[key]).split()) full_kernel_data.append((l + " " + meta_map[key]).split())
else: else:
full_kernel_data.append((l + " unknown").split()) full_kernel_data.append((l + " unknown").split())
...@@ -68,5 +80,6 @@ if __name__ == "__main__": ...@@ -68,5 +80,6 @@ if __name__ == "__main__":
args = parse_args() args = parse_args()
infer_meta_data = get_api_yaml_info(args.paddle_root_path) infer_meta_data = get_api_yaml_info(args.paddle_root_path)
kernel_data = get_kernel_info(args.kernel_info_file) kernel_data = get_kernel_info(args.kernel_info_file)
out = merge(infer_meta_data, kernel_data) info_meta_wrap_data = get_kernel_info(args.infermeta_wrap_file)
out = merge(infer_meta_data, kernel_data, info_meta_wrap_data)
print(json.dumps(out)) print(json.dumps(out))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册