diff --git a/tools/get_op_list.sh b/tools/get_op_list.sh new file mode 100644 index 0000000000000000000000000000000000000000..2e4cad13582df3cda395306ecf9b2ee270c47f98 --- /dev/null +++ b/tools/get_op_list.sh @@ -0,0 +1,27 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +dir=$1 +UT_list=$(ctest -N | awk -F ': ' '{print $2}' | sed '/^$/d' | sed '$d') +for case in $UT_list; do + flag=$(echo $case|grep -oE '_op') + if [ -n "$flag" ];then + if [ -z "$UT_list_prec" ];then + UT_list_prec="$case" + else + UT_list_prec="${UT_list_prec}\n${case}" + fi + fi +done +echo -e "$UT_list_prec" > "$dir" diff --git a/tools/jetson_infer_op.py b/tools/jetson_infer_op.py new file mode 100644 index 0000000000000000000000000000000000000000..70b9eb82b6ed461d8e01d7de2a4edad156870bc0 --- /dev/null +++ b/tools/jetson_infer_op.py @@ -0,0 +1,246 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import os +import math +import argparse +from threading import Thread + +# some particular ops +black_list = [ + # special op + 'test_custom_relu_op_setup', + 'test_custom_relu_op_jit', + 'test_python_operator_overriding', + 'test_c_comm_init_all_op', + 'test_c_embedding_op', + # train op + 'test_imperative_optimizer', + 'test_imperative_optimizer_v2', + 'test_momentum_op', + 'test_sgd_op', + 'test_sgd_op_bf16', + 'test_warpctc_op', + # sync op + 'test_sync_batch_norm_op', + # case too large + 'test_reduce_op', + 'test_transpose_op' +] + +op_diff_list = [ + # diff<1E-7,it's right + 'test_elementwise_mul_op' +] + + +def parse_arguments(): + """ + :return: + """ + parser = argparse.ArgumentParser() + parser.add_argument( + '--shell_name', + type=str, + default='get_op_list.sh', + help='please input right name') + parser.add_argument( + '--op_list_file', + type=str, + default='list_op.txt', + help='please input right name') + return parser.parse_args() + + +def search_file(file_name, path, file_path): + """ + :param file_name:target + :param path: to search this path + :param file_path: result + :return: + """ + for item in os.listdir(path): + if os.path.isdir(os.path.join(path, item)): + search_file(file_name, os.path.join(path, item), file_path) + else: + if file_name in item: + file_path.append(os.path.join(path, file_name)) + + +def get_prefix(line, end_char='d'): + """ + :param line: string_demo + :param end_char: copy the prefix of string_demo until end_char + :return: prefix + """ + i = 0 + prefix = '' + while (line[i] != end_char): + prefix += line[i] + i += 1 + return prefix + + +def add_import_skip_return(file, pattern_import, pattern_skip, pattern_return): + """ + :param file: the file need to be changed + :param pattern_import: import skip + :param pattern_skip: @skip + :param pattern_return: add return + :return: + """ + pattern_1 = re.compile(pattern_import) + pattern_2 = re.compile(pattern_skip) + pattern_3 = re.compile(pattern_return) + file_data = "" + + # change file + with open(file, "r", encoding="utf-8") as f: + for line in f: + # add import skip_check_grad_ci + match_obj = pattern_1.search(line) + if match_obj is not None: + line = line[:-1] + ", skip_check_grad_ci\n" + print("### add import skip_check_grad_ci ####") + + # add @skip_check_grad_ci + match_obj = pattern_2.search(line) + if match_obj is not None: + file_data += "@skip_check_grad_ci(reason='jetson do n0t neeed this !')\n" + print("### add @skip_check_grad_ci ####") + + # delete test_grad_output + match_obj = pattern_3.search(line) + if match_obj is not None: + file_data += line + file_data += get_prefix(line) + file_data += " return\n" + print("### add return for function ####") + continue + file_data += line + # save file + with open(file, "w", encoding="utf-8") as f: + f.write(file_data) + + +def get_op_list(op_list_file='list_op.txt'): + """ + :param op_list_file: op list file + :return: list of op + """ + op_list = [] + with open(op_list_file, "r", encoding="utf-8") as f: + for line in f: + if line in black_list: + continue + # delete /n + op_list.append(line[:-1]) + return op_list + + +def set_diff_value(file, atol="1e-5", inplace_atol="1e-7"): + """ + :param file: refer to op_test.py + :param atol: refer to op_test.py + :param inplace_atol: + :return: + """ + os.system("sed -i 's/self.check_output(/self\.check_output\(atol=" + atol + + ",inplace_atol=" + inplace_atol + ",/g\' " + file) + + +def change_op_file(start=0, end=0, op_list_file='list_op.txt', path='.'): + """ + :param start: + :param end: + :param op_list_file: op_list + :param path: just the file in this path + :return: + """ + test_op_list = get_op_list(op_list_file) + file_path = [] + for id in range(start, end): + item = test_op_list[id] + print(id, ":", item) + search_file(item + '.py', os.path.abspath(path), file_path) + if len(file_path) == 0: + print("'", item, "' is not a python file!") + continue + file_with_path = file_path[0] + # pattern + pattern_import = ".*import OpTest.*" + pattern_skip = "^class .*\(OpTest\):$" + pattern_return = "def test.*grad.*\):$" + # change file + add_import_skip_return(file_with_path, pattern_import, pattern_skip, + pattern_return) + # op_diff + if item in op_diff_list: + set_diff_value(file_with_path) + + file_path.clear() + + +def run_multi_thread(list_file, thread_num=4): + """ + :param list_file: + :param thread_num: + :return: + """ + length = len(get_op_list(list_file)) + thread_list = [] + start = 0 + end = 0 + for item in range(thread_num): + # set argument + start = math.floor(item / thread_num * length) + end = math.floor((item + 1) / thread_num * length) + print("thread num-", item, ":", start, end) + thread = Thread(target=change_op_file, args=(start, end)) + thread_list.append(thread) + # start a thread + thread.start() + + # wait all thread + for item in thread_list: + item.join() + + print("------change successfully!-------") + + +def transform_list_to_str(list_op): + """ + :param list_op: + :return: + """ + res = "" + for item in list_op: + tmp = "^" + item + "$|" + res += tmp + return res[:-1] + + +if __name__ == '__main__': + arg = parse_arguments() + print("------start get op list!------") + os.system("bash " + arg.shell_name + " " + arg.op_list_file) + print("------start change op file!------") + run_multi_thread(arg.op_list_file) + old_list = get_op_list(arg.op_list_file) + new_list = filter(lambda x: x not in black_list, old_list) + op_test = transform_list_to_str(new_list) + print("------start run op test!------") + os.system("ctest -R \"(" + op_test + ")\" --output-on-failure") + print("------do well!------")