未验证 提交 54382ce4 编写于 作者: A Aurelius84 提交者: GitHub

Add get_all_kernels api of registered data_type in pybind.cc (#21499)

* add _get_all_register_op_kernels api test=develop

* refine usage of check_op_register_type test=develop

* add import in core test=develop
上级 b4ad7bdf
......@@ -383,6 +383,22 @@ PYBIND11_MODULE(core_noavx, m) {
m.def("_get_use_default_grad_op_desc_maker_ops",
[] { return OpInfoMap::Instance().GetUseDefaultGradOpDescMakerOps(); });
m.def("_get_all_register_op_kernels", [] {
auto &all_kernels = paddle::framework::OperatorWithKernel::AllOpKernels();
std::unordered_map<std::string, std::vector<std::string>> all_kernels_info;
for (auto &kernel_pair : all_kernels) {
auto op_type = kernel_pair.first;
std::vector<std::string> kernel_types;
for (auto &info_pair : kernel_pair.second) {
paddle::framework::OpKernelType kernel_type = info_pair.first;
kernel_types.push_back(
paddle::framework::KernelTypeToString(kernel_type));
}
all_kernels_info.emplace(op_type, kernel_types);
}
return all_kernels_info;
});
// NOTE(zjl): ctest would load environment variables at the beginning even
// though we have not `import paddle.fluid as fluid`. So we add this API
// to enable eager deletion mode in unittest.
......
......@@ -170,6 +170,7 @@ if avx_supported():
from .core_avx import _append_python_callable_object_and_return_id
from .core_avx import _cleanup, _Scope
from .core_avx import _get_use_default_grad_op_desc_maker_ops
from .core_avx import _get_all_register_op_kernels
from .core_avx import _is_program_version_supported
from .core_avx import _set_eager_deletion_mode
from .core_avx import _set_fuse_parameter_group_size
......@@ -205,6 +206,7 @@ if load_noavx:
from .core_noavx import _append_python_callable_object_and_return_id
from .core_noavx import _cleanup, _Scope
from .core_noavx import _get_use_default_grad_op_desc_maker_ops
from .core_noavx import _get_all_register_op_kernels
from .core_noavx import _is_program_version_supported
from .core_noavx import _set_eager_deletion_mode
from .core_noavx import _set_fuse_parameter_group_size
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Print all registered kernels of a python module in alphabet order.
Usage:
python check_op_register_type.py > all_kernels.txt
python check_op_register_type.py OP_TYPE_DEV.spec OP_TYPE_PR.spec > is_valid
"""
from __future__ import print_function
import sys
import re
import difflib
import collections
import paddle.fluid as fluid
def get_all_kernels():
all_kernels_info = fluid.core._get_all_register_op_kernels()
# [u'data_type[double]:data_layout[ANY_LAYOUT]:place[CPUPlace]:library_type[PLAIN]'
op_kernel_types = collections.defaultdict(list)
for op_type, op_infos in all_kernels_info.items():
is_grad_op = op_type.endswith("_grad")
if is_grad_op: continue
pattern = re.compile(r'data_type\[([^\]]+)\]')
for op_info in op_infos:
infos = pattern.findall(op_info)
if infos is None or len(infos) == 0: continue
register_type = infos[0].split(":")[-1]
op_kernel_types[op_type].append(register_type.lower())
for (op_type, op_kernels) in sorted(
op_kernel_types.items(), key=lambda x: x[0]):
print(op_type, " ".join(sorted(op_kernels)))
def read_file(file_path):
with open(file_path, 'r') as f:
content = f.read()
content = content.splitlines()
return content
INTS = set(['int', 'int64_t'])
FLOATS = set(['float', 'double'])
def check_add_op_valid():
origin = read_file(sys.argv[1])
new = read_file(sys.argv[2])
differ = difflib.Differ()
result = differ.compare(origin, new)
for each_diff in result:
if each_diff[0] in ['+'] and len(each_diff) > 2: # if change or add op
op_info = each_diff[1:].split()
if len(op_info) < 2: continue
register_types = set(op_info[1:])
if len(FLOATS - register_types) == 1 or \
len(INTS - register_types) == 1:
print(each_diff)
if len(sys.argv) == 1:
get_all_kernels()
elif len(sys.argv) == 3:
check_add_op_valid()
else:
print("Usage:\n" \
"\tpython check_op_register_type.py > all_kernels.txt\n" \
"\tpython check_op_register_type.py OP_TYPE_DEV.spec OP_TYPE_PR.spec > diff")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册