From a66f856ed3f3db65628cbe648a8ecabaee434fcd Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Thu, 26 Aug 2021 03:30:06 +0000 Subject: [PATCH] opt code --- tests/compare_results.py | 63 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 57 insertions(+), 6 deletions(-) diff --git a/tests/compare_results.py b/tests/compare_results.py index f5135ac7..ecedb710 100644 --- a/tests/compare_results.py +++ b/tests/compare_results.py @@ -3,6 +3,7 @@ import os import subprocess import json import argparse +import glob def init_args(): @@ -12,6 +13,7 @@ def init_args(): parser.add_argument("--rtol", type=float, default=1e-3) parser.add_argument("--gt_file", type=str, default="") parser.add_argument("--log_file", type=str, default="") + parser.add_argument("--precision", type=str, default="fp32") return parser @@ -63,6 +65,34 @@ def load_gt_from_file(gt_file): return parser_gt +def load_gt_from_txts(gt_file): + gt_list = glob.glob(gt_file) + gt_collection = {} + for gt_f in gt_list: + gt_dict = load_gt_from_file(gt_f) + basename = os.path.basename(gt_f) + if "fp32" in basename: + gt_collection["fp32"] = [gt_dict, gt_f] + elif "fp16" in basename: + gt_collection["fp16"] = [gt_dict, gt_f] + elif "int8" in basename: + gt_collection["int8"] = [gt_dict, gt_f] + else: + continue + return gt_collection + + +def collect_predict_from_logs(log_path, key_list): + log_list = glob.glob(log_path) + pred_collection = {} + for log_f in log_list: + pred_dict = parser_results_from_log_by_name(log_f, key_list) + key = os.path.basename(log_f) + pred_collection[key] = pred_dict + + return pred_collection + + def testing_assert_allclose(dict_x, dict_y, atol=1e-7, rtol=1e-7): for k in dict_x: np.testing.assert_allclose( @@ -71,12 +101,33 @@ def testing_assert_allclose(dict_x, dict_y, atol=1e-7, rtol=1e-7): if __name__ == "__main__": # Usage: - # python3.7 tests/compare_results.py --gt_file=./det_results_gpu_fp32.txt --log_file=./test_log.log + # python3.7 tests/compare_results.py --gt_file=./tests/results/*.txt --log_file=./tests/output/infer_*.log args = parse_args() - gt_dict = load_gt_from_file(args.gt_file) - key_list = list(gt_dict.keys()) - - pred_dict = parser_results_from_log_by_name(args.log_file, key_list) - testing_assert_allclose(gt_dict, pred_dict, atol=args.atol, rtol=args.rtol) + gt_collection = load_gt_from_txts(args.gt_file) + key_list = gt_collection["fp32"][0].keys() + + pred_collection = collect_predict_from_logs(args.log_file, key_list) + for filename in pred_collection.keys(): + if "fp32" in filename: + gt_dict, gt_filename = gt_collection["fp32"] + elif "fp16" in filename: + gt_dict, gt_filename = gt_collection["fp16"] + elif "int8" in filename: + gt_dict, gt_filename = gt_collection["int8"] + else: + continue + pred_dict = pred_collection[filename] + + try: + testing_assert_allclose( + gt_dict, pred_dict, atol=args.atol, rtol=args.rtol) + print( + "Assert allclose passed! The results of {} and {} are consistent!". + format(filename, gt_filename)) + except Exception as E: + print(E) + raise ValueError( + "The results of {} and the results of {} are inconsistent!". + format(filename, gt_filename)) -- GitLab