未验证 提交 0f8c304a 编写于 作者: Z zyfncg 提交者: GitHub

Add inferface of get registered phi kernels (#50814)

* add inferface of get registered phi kernels

* change KernelType to KernelKey

* add test

* refactor code
上级 d2a0577a
......@@ -863,6 +863,58 @@ PYBIND11_MODULE(libpaddle, m) {
lib[string]: the libarary, could be 'phi', 'fluid' and 'all'.
)DOC");
m.def(
"_get_registered_phi_kernels",
[](const std::string &kernel_registered_type) {
std::unordered_map<std::string, std::vector<std::string>>
all_kernels_info;
auto phi_kernels = phi::KernelFactory::Instance().kernels();
for (auto &kernel_pair : phi_kernels) {
auto kernel_name = kernel_pair.first;
std::vector<std::string> kernel_keys;
for (auto &info_pair : kernel_pair.second) {
bool get_function_kernel =
kernel_registered_type == "function" &&
info_pair.second.GetKernelRegisteredType() ==
phi::KernelRegisteredType::FUNCTION;
bool get_structure_kernel =
kernel_registered_type == "structure" &&
info_pair.second.GetKernelRegisteredType() ==
phi::KernelRegisteredType::STRUCTURE;
if (kernel_registered_type == "all" || get_function_kernel ||
get_structure_kernel) {
std::ostringstream stream;
stream << info_pair.first;
std::string kernel_key_str = stream.str();
if (all_kernels_info.count(kernel_name)) {
bool kernel_exist =
std::find(all_kernels_info[kernel_name].begin(),
all_kernels_info[kernel_name].end(),
kernel_key_str) !=
all_kernels_info[kernel_name].end();
if (!kernel_exist) {
all_kernels_info[kernel_name].emplace_back(kernel_key_str);
}
} else {
kernel_keys.emplace_back(kernel_key_str);
}
}
}
if (!kernel_keys.empty()) {
all_kernels_info.emplace(kernel_name, kernel_keys);
}
}
return all_kernels_info;
},
py::arg("kernel_registered_type") = "function",
R"DOC(
Return the registered kernels in phi.
Args:
kernel_registered_type[string]: the libarary, could be 'function', 'structure', and 'all'.
)DOC");
// NOTE(Aganlengzi): KernelFactory static instance is initialized BEFORE
// plugins are loaded for custom kernels, but de-initialized AFTER they are
// unloaded. We need manually clear symbols(may contain plugins' symbols)
......
......@@ -286,6 +286,7 @@ try:
from .libpaddle import _cleanup, _Scope
from .libpaddle import _get_use_default_grad_op_desc_maker_ops
from .libpaddle import _get_all_register_op_kernels
from .libpaddle import _get_registered_phi_kernels
from .libpaddle import _is_program_version_supported
from .libpaddle import _set_eager_deletion_mode
from .libpaddle import _get_eager_deletion_vars
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pathlib
import re
import unittest
import yaml
import paddle.fluid.core as core
def parse_kernels_name(op_item):
result = []
if 'kernel' in op_item:
kernel_config = op_item['kernel']
kernel_funcs = re.compile(r'([a-zA-Z0-9_]+)\s*({[^}]+})?').findall(
kernel_config['func']
)
for func_item in kernel_funcs:
result.append(func_item[0])
return result
def get_all_kernels(op_list, all_registerd_kernels):
kernels = []
for op in op_list:
op_kernels = parse_kernels_name(op)
for op_kernel in op_kernels:
if op_kernel not in kernels and op_kernel in all_registerd_kernels:
kernels.append(op_kernel)
if op_kernel not in all_registerd_kernels:
print("********** wrong kernel: ", op_kernel)
return kernels
def remove_forward_kernels(bw_kernels, forward_kernels):
new_bw_kernels = []
for bw_kernel in bw_kernels:
if bw_kernel not in forward_kernels:
new_bw_kernels.append(bw_kernel)
return new_bw_kernels
class TestRegisteredPhiKernels(unittest.TestCase):
"""TestRegisteredPhiKernels."""
def setUp(self):
self.forward_ops = []
self.backward_ops = []
root_path = pathlib.Path(__file__).parents[6]
ops_yaml_path = [
'paddle/phi/api/yaml/ops.yaml',
'paddle/phi/api/yaml/legacy_ops.yaml',
]
bw_ops_yaml_path = [
'paddle/phi/api/yaml/backward.yaml',
'paddle/phi/api/yaml/legacy_backward.yaml',
]
for each_ops_yaml in ops_yaml_path:
with open(root_path.joinpath(each_ops_yaml), 'r') as f:
op_list = yaml.load(f, Loader=yaml.FullLoader)
if op_list:
self.forward_ops.extend(op_list)
for each_ops_yaml in bw_ops_yaml_path:
with open(root_path.joinpath(each_ops_yaml), 'r') as f:
api_list = yaml.load(f, Loader=yaml.FullLoader)
if api_list:
self.backward_ops.extend(api_list)
def test_registered_phi_kernels(self):
phi_function_kernel_infos = core._get_registered_phi_kernels("function")
registered_kernel_list = [
name for name in phi_function_kernel_infos.keys()
]
forward_kernels = get_all_kernels(
self.forward_ops, registered_kernel_list
)
backward_kernels = remove_forward_kernels(
get_all_kernels(self.backward_ops, registered_kernel_list),
forward_kernels,
)
for kernel_name in forward_kernels:
self.assertIn(kernel_name, registered_kernel_list)
for kernel_name in backward_kernels:
self.assertIn(kernel_name, registered_kernel_list)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册