未验证 提交 7e0bada3 编写于 作者: C Chen Weihang 提交者: GitHub

[Ci] Add kernel dtype check py script (#45329)

* add kernel dtype check py

* add kernel dtype check py, test=document_fix
上级 db937b5a
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Print all registered kernels of a python module in alphabet order.
Usage:
python check_op_kernel_same_dtypes.py > all_kernels.txt
python check_op_kernel_same_dtypes.py OP_KERNEL_DTYPE_DEV.spec OP_KERNEL_DTYPE_PR.spec > is_valid
"""
from __future__ import print_function
import sys
import re
import collections
import paddle
def get_all_kernels():
all_kernels_info = paddle.framework.core._get_all_register_op_kernels()
print(all_kernels_info)
# [u'{data_type[float]; data_layout[Undefined(AnyLayout)]; place[Place(gpu:0)]; library_type[PLAIN]}']
op_kernel_types = collections.defaultdict(list)
for op_type, op_infos in all_kernels_info.items():
# is_grad_op = op_type.endswith("_grad")
# if is_grad_op: continue
pattern = re.compile(r'data_type\[([^\]]+)\]')
for op_info in op_infos:
infos = pattern.findall(op_info)
if infos is None or len(infos) == 0: continue
register_type = infos[0].split(":")[-1]
op_kernel_types[op_type].append(register_type.lower())
for (op_type, op_kernels) in sorted(op_kernel_types.items(),
key=lambda x: x[0]):
print(op_type, " ".join(sorted(op_kernels)))
def read_file(file_path):
with open(file_path, 'r') as f:
content = f.read()
content = content.splitlines()
return content
def print_diff(op_type, op_kernel_dtype_set, grad_op_kernel_dtype_set):
if len(op_kernel_dtype_set) > len(grad_op_kernel_dtype_set):
lack_dtypes = list(op_kernel_dtype_set - grad_op_kernel_dtype_set)
print("{} supports [{}] now, but its grad op kernel not supported.".
format(op_type, " ".join(lack_dtypes)))
else:
lack_dtypes = list(grad_op_kernel_dtype_set - op_kernel_dtype_set)
print("{} supports [{}] now, but its forward op kernel not supported.".
format(op_type + "_grad", " ".join(lack_dtypes)))
def check_change_or_add_op_kernel_dtypes_valid():
origin = read_file(sys.argv[1])
new = read_file(sys.argv[2])
origin_all_kernel_dtype_dict = dict()
for op_msg in origin:
op_info = op_msg.split()
origin_all_kernel_dtype_dict[op_info[0]] = set(op_info[1:])
new_all_kernel_dtype_dict = dict()
for op_msg in new:
op_info = op_msg.split()
new_all_kernel_dtype_dict[op_info[0]] = set(op_info[1:])
added_or_changed_op_info = dict()
for op_type, dtype_set in new_all_kernel_dtype_dict.items():
if op_type in origin_all_kernel_dtype_dict:
origin_dtype_set = origin_all_kernel_dtype_dict[op_type]
# op kernel changed
if origin_dtype_set != dtype_set:
added_or_changed_op_info[op_type] = dtype_set
else:
# do nothing
continue
else:
# op kernel added
added_or_changed_op_info[op_type] = dtype_set
for op_type, dtype_set in added_or_changed_op_info.items():
# if changed forward op
if not op_type.endswith("_grad"):
# only support grad op
grad_op_type = op_type + "_grad"
if grad_op_type in new_all_kernel_dtype_dict:
grad_op_kernel_dtype_set = set(
new_all_kernel_dtype_dict[grad_op_type])
if dtype_set != grad_op_kernel_dtype_set:
print_diff(op_type, dtype_set, grad_op_kernel_dtype_set)
# if changed grad op
else:
forward_op_type = op_type.rstrip("_grad")
if forward_op_type in new_all_kernel_dtype_dict:
op_kernel_dtype_set = set(
new_all_kernel_dtype_dict[forward_op_type])
if op_kernel_dtype_set != dtype_set:
print_diff(forward_op_type, op_kernel_dtype_set, dtype_set)
if len(sys.argv) == 1:
get_all_kernels()
elif len(sys.argv) == 3:
check_change_or_add_op_kernel_dtypes_valid()
else:
print("Usage:\n" \
"\tpython check_op_kernel_same_dtypes.py > all_kernels.txt\n" \
"\tpython check_op_kernel_same_dtypes.py OP_KERNEL_DTYPE_DEV.spec OP_KERNEL_DTYPE_PR.spec > diff")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册