未验证 提交 8803f6bb 编写于 作者: S Shang Zhizhou 提交者: GitHub

add print pten kernel tool (#39371)

* test=document_fix;add print pten kernel tool

* test=document_fix

* test=document_fix

* test=document_fix

* test=document_fix

* add print_pten_kernels tool

* add print_pten_kernels tool

* fix windows complie

* notest,test=rocm_ci

* add merge tool

* add comments
上级 f38c2e5c
......@@ -17,6 +17,8 @@ add_subdirectory(kernels)
add_subdirectory(infermeta)
# pten operator definitions
add_subdirectory(ops)
# pten tools
add_subdirectory(tools)
# pten tests
add_subdirectory(tests)
......
......@@ -95,25 +95,81 @@ const Kernel& KernelFactory::SelectKernelOrThrowError(
KernelKey(backend, layout, dtype));
}
// print kernel info with json format:
// {
// "(CPU, Undefined(AnyLayout), complex64)": {
// "input": ["CPU, NCHW, complex64", "CPU, NCHW, complex64"],
// "output": ["CPU, NCHW, complex64"],
// "attribute": ["i"]
// }
std::ostream& operator<<(std::ostream& os, const Kernel& kernel) {
os << "InputNum(" << kernel.args_def().input_defs().size() << "): [";
// input
os << "{\"input\":[";
bool need_comma = false;
for (auto& in_def : kernel.args_def().input_defs()) {
os << "<" << in_def.backend << ", " << in_def.layout << ", " << in_def.dtype
<< ">";
if (need_comma) os << ",";
os << "\"" << in_def.backend << ", " << in_def.layout << ", "
<< in_def.dtype << "\"";
need_comma = true;
}
os << "]), AttributeNum(" << kernel.args_def().attribute_defs().size()
<< "), OutputNum(" << kernel.args_def().output_defs().size() << ")";
os << "],";
// output
os << "\"output\":[";
need_comma = false;
for (auto& out_def : kernel.args_def().output_defs()) {
if (need_comma) os << ",";
os << "\"" << out_def.backend << ", " << out_def.layout << ", "
<< out_def.dtype << "\"";
need_comma = true;
}
os << "],";
// attr
os << "\"attribute\":[";
need_comma = false;
for (auto& arg_def : kernel.args_def().attribute_defs()) {
if (need_comma) os << ",";
os << "\"" << arg_def.type_index.name() << "\"";
need_comma = true;
}
os << "]}";
return os;
}
// print all kernels info with json format:
// {
// "kernel_name1":
// [
// {
// "(CPU, Undefined(AnyLayout), complex64)": {
// "input": ["CPU, NCHW, complex64", "CPU, NCHW, complex64"],
// "output": ["CPU, NCHW, complex64"],
// "attribute": ["i"]
// },
// ...
// ],
// "kernel_name2": []
// ...
// }
std::ostream& operator<<(std::ostream& os, KernelFactory& kernel_factory) {
os << "{";
bool need_comma_kernels = false;
for (const auto& op_kernel_pair : kernel_factory.kernels()) {
os << "- kernel name: " << op_kernel_pair.first << "\n";
if (need_comma_kernels) os << ",";
os << "\"" << op_kernel_pair.first << "\":[";
bool need_comma_per_kernel = false;
for (const auto& kernel_pair : op_kernel_pair.second) {
os << "\t- kernel key: " << kernel_pair.first << " | "
<< "kernel: " << kernel_pair.second << "\n";
if (need_comma_per_kernel) os << ",";
os << "{\"" << kernel_pair.first << "\":" << kernel_pair.second << "}";
need_comma_per_kernel = true;
}
os << "]";
need_comma_kernels = true;
}
os << "}";
return os;
}
......
add_executable(print_pten_kernels print_pten_kernels.cc)
target_link_libraries(print_pten_kernels pten pten_api_utils)
if(WIN32)
target_link_libraries(print_pten_kernels shlwapi.lib)
endif()
if(WITH_ROCM)
target_link_libraries(print_pten_kernels ${ROCM_HIPRTC_LIB})
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>
#include "paddle/pten/core/kernel_factory.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/declarations.h"
int main(int argc, char** argv) {
std::cout << pten::KernelFactory::Instance() << std::endl;
return 0;
}
#!/usr/bin/env bash
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#=================================================
# Utils
#=================================================
set -e
EXIT_CODE=0;
tmp_dir=`mktemp -d`
PADDLE_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}")/../../" && pwd )"
unset GREP_OPTIONS && find ${PADDLE_ROOT}/paddle/pten/kernels -name "*.c*" | xargs sed -e '/PT_REGISTER_\(GENERAL_\)\?KERNEL(/,/)/!d' | awk 'BEGIN { RS="{" }{ gsub(/\n /,""); print $0 }' | grep PT_REGISTER | awk -F ",|\(" '{gsub(/ /,"");print $2, $3, $4, $5}' | sort -u | awk '{gsub(/pten::/,"");print $0}' | grep -v "_grad"
#!/bin/python
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import yaml
def parse_args():
parser = argparse.ArgumentParser("gather pten kernel and infermate info")
parser.add_argument(
"--paddle_root_path",
type=str,
required=True,
help="root path of paddle src[WORK_PATH/Paddle] .")
parser.add_argument(
"--kernel_info_file",
type=str,
required=True,
help="kernel info file generated by get_pten_kernel_function.sh .")
args = parser.parse_args()
return args
def get_api_yaml_info(file_path):
f = open(file_path + "/python/paddle/utils/code_gen/api.yaml", "r")
cont = f.read()
return yaml.load(cont, Loader=yaml.FullLoader)
def get_kernel_info(file_path):
f = open(file_path, "r")
cont = f.readlines()
return [l.strip() for l in cont]
def merge(infer_meta_data, kernel_data):
meta_map = {}
for api in infer_meta_data:
if not api.has_key("kernel") or not api.has_key("infer_meta"):
continue
meta_map[api["kernel"]["func"]] = api["infer_meta"]["func"]
full_kernel_data = []
for l in kernel_data:
key = l.split()[0]
if meta_map.has_key(key):
full_kernel_data.append((l + " " + meta_map[key]).split())
else:
full_kernel_data.append((l + " unknown").split())
return full_kernel_data
if __name__ == "__main__":
args = parse_args()
infer_meta_data = get_api_yaml_info(args.paddle_root_path)
kernel_data = get_kernel_info(args.kernel_info_file)
out = merge(infer_meta_data, kernel_data)
print(json.dumps(out))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册