From 54382ce49753f46cb706937aa734df88ecf78852 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 4 Dec 2019 10:30:46 +0800 Subject: [PATCH] Add get_all_kernels api of registered data_type in pybind.cc (#21499) * add _get_all_register_op_kernels api test=develop * refine usage of check_op_register_type test=develop * add import in core test=develop --- paddle/fluid/pybind/pybind.cc | 16 +++++++ python/paddle/fluid/core.py | 2 + tools/check_op_register_type.py | 85 +++++++++++++++++++++++++++++++++ 3 files changed, 103 insertions(+) create mode 100644 tools/check_op_register_type.py diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index f01ecfd6ca4..0abaf2d5f05 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -383,6 +383,22 @@ PYBIND11_MODULE(core_noavx, m) { m.def("_get_use_default_grad_op_desc_maker_ops", [] { return OpInfoMap::Instance().GetUseDefaultGradOpDescMakerOps(); }); + m.def("_get_all_register_op_kernels", [] { + auto &all_kernels = paddle::framework::OperatorWithKernel::AllOpKernels(); + std::unordered_map> all_kernels_info; + for (auto &kernel_pair : all_kernels) { + auto op_type = kernel_pair.first; + std::vector kernel_types; + for (auto &info_pair : kernel_pair.second) { + paddle::framework::OpKernelType kernel_type = info_pair.first; + kernel_types.push_back( + paddle::framework::KernelTypeToString(kernel_type)); + } + all_kernels_info.emplace(op_type, kernel_types); + } + return all_kernels_info; + }); + // NOTE(zjl): ctest would load environment variables at the beginning even // though we have not `import paddle.fluid as fluid`. So we add this API // to enable eager deletion mode in unittest. diff --git a/python/paddle/fluid/core.py b/python/paddle/fluid/core.py index 930feeee2bb..1bc69f74fde 100644 --- a/python/paddle/fluid/core.py +++ b/python/paddle/fluid/core.py @@ -170,6 +170,7 @@ if avx_supported(): from .core_avx import _append_python_callable_object_and_return_id from .core_avx import _cleanup, _Scope from .core_avx import _get_use_default_grad_op_desc_maker_ops + from .core_avx import _get_all_register_op_kernels from .core_avx import _is_program_version_supported from .core_avx import _set_eager_deletion_mode from .core_avx import _set_fuse_parameter_group_size @@ -205,6 +206,7 @@ if load_noavx: from .core_noavx import _append_python_callable_object_and_return_id from .core_noavx import _cleanup, _Scope from .core_noavx import _get_use_default_grad_op_desc_maker_ops + from .core_noavx import _get_all_register_op_kernels from .core_noavx import _is_program_version_supported from .core_noavx import _set_eager_deletion_mode from .core_noavx import _set_fuse_parameter_group_size diff --git a/tools/check_op_register_type.py b/tools/check_op_register_type.py new file mode 100644 index 00000000000..8f45838bed9 --- /dev/null +++ b/tools/check_op_register_type.py @@ -0,0 +1,85 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Print all registered kernels of a python module in alphabet order. + +Usage: + python check_op_register_type.py > all_kernels.txt + python check_op_register_type.py OP_TYPE_DEV.spec OP_TYPE_PR.spec > is_valid +""" +from __future__ import print_function +import sys +import re +import difflib +import collections +import paddle.fluid as fluid + + +def get_all_kernels(): + all_kernels_info = fluid.core._get_all_register_op_kernels() + # [u'data_type[double]:data_layout[ANY_LAYOUT]:place[CPUPlace]:library_type[PLAIN]' + op_kernel_types = collections.defaultdict(list) + for op_type, op_infos in all_kernels_info.items(): + is_grad_op = op_type.endswith("_grad") + if is_grad_op: continue + + pattern = re.compile(r'data_type\[([^\]]+)\]') + for op_info in op_infos: + infos = pattern.findall(op_info) + if infos is None or len(infos) == 0: continue + + register_type = infos[0].split(":")[-1] + op_kernel_types[op_type].append(register_type.lower()) + + for (op_type, op_kernels) in sorted( + op_kernel_types.items(), key=lambda x: x[0]): + print(op_type, " ".join(sorted(op_kernels))) + + +def read_file(file_path): + with open(file_path, 'r') as f: + content = f.read() + content = content.splitlines() + return content + + +INTS = set(['int', 'int64_t']) +FLOATS = set(['float', 'double']) + + +def check_add_op_valid(): + origin = read_file(sys.argv[1]) + new = read_file(sys.argv[2]) + + differ = difflib.Differ() + result = differ.compare(origin, new) + + for each_diff in result: + if each_diff[0] in ['+'] and len(each_diff) > 2: # if change or add op + op_info = each_diff[1:].split() + if len(op_info) < 2: continue + register_types = set(op_info[1:]) + if len(FLOATS - register_types) == 1 or \ + len(INTS - register_types) == 1: + print(each_diff) + + +if len(sys.argv) == 1: + get_all_kernels() +elif len(sys.argv) == 3: + check_add_op_valid() +else: + print("Usage:\n" \ + "\tpython check_op_register_type.py > all_kernels.txt\n" \ + "\tpython check_op_register_type.py OP_TYPE_DEV.spec OP_TYPE_PR.spec > diff") -- GitLab