diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 7fe8674ef1296dc0916ee3be8e9331f58519eddd..08cc1570e0b57a298b06ea60d3bea265ffa37ed0 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -863,6 +863,58 @@ PYBIND11_MODULE(libpaddle, m) { lib[string]: the libarary, could be 'phi', 'fluid' and 'all'. )DOC"); + m.def( + "_get_registered_phi_kernels", + [](const std::string &kernel_registered_type) { + std::unordered_map> + all_kernels_info; + auto phi_kernels = phi::KernelFactory::Instance().kernels(); + for (auto &kernel_pair : phi_kernels) { + auto kernel_name = kernel_pair.first; + std::vector kernel_keys; + for (auto &info_pair : kernel_pair.second) { + bool get_function_kernel = + kernel_registered_type == "function" && + info_pair.second.GetKernelRegisteredType() == + phi::KernelRegisteredType::FUNCTION; + bool get_structure_kernel = + kernel_registered_type == "structure" && + info_pair.second.GetKernelRegisteredType() == + phi::KernelRegisteredType::STRUCTURE; + if (kernel_registered_type == "all" || get_function_kernel || + get_structure_kernel) { + std::ostringstream stream; + stream << info_pair.first; + std::string kernel_key_str = stream.str(); + if (all_kernels_info.count(kernel_name)) { + bool kernel_exist = + std::find(all_kernels_info[kernel_name].begin(), + all_kernels_info[kernel_name].end(), + kernel_key_str) != + all_kernels_info[kernel_name].end(); + if (!kernel_exist) { + all_kernels_info[kernel_name].emplace_back(kernel_key_str); + } + } else { + kernel_keys.emplace_back(kernel_key_str); + } + } + } + if (!kernel_keys.empty()) { + all_kernels_info.emplace(kernel_name, kernel_keys); + } + } + + return all_kernels_info; + }, + py::arg("kernel_registered_type") = "function", + R"DOC( + Return the registered kernels in phi. + + Args: + kernel_registered_type[string]: the libarary, could be 'function', 'structure', and 'all'. + )DOC"); + // NOTE(Aganlengzi): KernelFactory static instance is initialized BEFORE // plugins are loaded for custom kernels, but de-initialized AFTER they are // unloaded. We need manually clear symbols(may contain plugins' symbols) diff --git a/python/paddle/fluid/core.py b/python/paddle/fluid/core.py index c94d92f17f8723d74a76e530a07bed48809a87c1..c3a50f7767aaaef42ce74d9c92aed9e66d97af77 100644 --- a/python/paddle/fluid/core.py +++ b/python/paddle/fluid/core.py @@ -286,6 +286,7 @@ try: from .libpaddle import _cleanup, _Scope from .libpaddle import _get_use_default_grad_op_desc_maker_ops from .libpaddle import _get_all_register_op_kernels + from .libpaddle import _get_registered_phi_kernels from .libpaddle import _is_program_version_supported from .libpaddle import _set_eager_deletion_mode from .libpaddle import _get_eager_deletion_vars diff --git a/python/paddle/fluid/tests/unittests/test_registered_phi_kernels.py b/python/paddle/fluid/tests/unittests/test_registered_phi_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..4ebc598e4683811f4b2b98bdeda8743bf6ca8159 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_registered_phi_kernels.py @@ -0,0 +1,108 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib +import re +import unittest + +import yaml + +import paddle.fluid.core as core + + +def parse_kernels_name(op_item): + result = [] + if 'kernel' in op_item: + kernel_config = op_item['kernel'] + kernel_funcs = re.compile(r'([a-zA-Z0-9_]+)\s*({[^}]+})?').findall( + kernel_config['func'] + ) + for func_item in kernel_funcs: + result.append(func_item[0]) + + return result + + +def get_all_kernels(op_list, all_registerd_kernels): + kernels = [] + for op in op_list: + op_kernels = parse_kernels_name(op) + for op_kernel in op_kernels: + if op_kernel not in kernels and op_kernel in all_registerd_kernels: + kernels.append(op_kernel) + if op_kernel not in all_registerd_kernels: + print("********** wrong kernel: ", op_kernel) + return kernels + + +def remove_forward_kernels(bw_kernels, forward_kernels): + new_bw_kernels = [] + for bw_kernel in bw_kernels: + if bw_kernel not in forward_kernels: + new_bw_kernels.append(bw_kernel) + return new_bw_kernels + + +class TestRegisteredPhiKernels(unittest.TestCase): + """TestRegisteredPhiKernels.""" + + def setUp(self): + self.forward_ops = [] + self.backward_ops = [] + + root_path = pathlib.Path(__file__).parents[6] + + ops_yaml_path = [ + 'paddle/phi/api/yaml/ops.yaml', + 'paddle/phi/api/yaml/legacy_ops.yaml', + ] + bw_ops_yaml_path = [ + 'paddle/phi/api/yaml/backward.yaml', + 'paddle/phi/api/yaml/legacy_backward.yaml', + ] + + for each_ops_yaml in ops_yaml_path: + with open(root_path.joinpath(each_ops_yaml), 'r') as f: + op_list = yaml.load(f, Loader=yaml.FullLoader) + if op_list: + self.forward_ops.extend(op_list) + + for each_ops_yaml in bw_ops_yaml_path: + with open(root_path.joinpath(each_ops_yaml), 'r') as f: + api_list = yaml.load(f, Loader=yaml.FullLoader) + if api_list: + self.backward_ops.extend(api_list) + + def test_registered_phi_kernels(self): + phi_function_kernel_infos = core._get_registered_phi_kernels("function") + registered_kernel_list = [ + name for name in phi_function_kernel_infos.keys() + ] + forward_kernels = get_all_kernels( + self.forward_ops, registered_kernel_list + ) + backward_kernels = remove_forward_kernels( + get_all_kernels(self.backward_ops, registered_kernel_list), + forward_kernels, + ) + + for kernel_name in forward_kernels: + self.assertIn(kernel_name, registered_kernel_list) + + for kernel_name in backward_kernels: + self.assertIn(kernel_name, registered_kernel_list) + + +if __name__ == '__main__': + unittest.main()