diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 265ab592f7d93af8ad7c766412842d5921d47711..48b1e025bdc31844bd318f83264047d91b63b40b 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -202,6 +202,12 @@ if __name__ == "__main__": count = 0 total_time = 0 draw_img_save = "./inference_results" + # warmup 10 times + if args.warmup: + img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) + for i in range(10): + res = text_detector(img) + if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) for image_file in image_file_list: diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 24388026b8f395427c93e285ed550446e3aa9b9c..c7808e2e8a001db06a0e1efae119965068e45dce 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -254,6 +254,12 @@ def main(args): total_images_num = 0 valid_image_file_list = [] img_list = [] + # warmup 10 times + if args.warmup: + img = np.random.uniform(0, 255, [32, 320, 3]).astype(np.uint8) + for i in range(10): + res = text_recognizer([img]) + for idx, image_file in enumerate(image_file_list): img, flag = check_and_read_gif(image_file) if not flag: diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 235a075bd431f258ea0bd10e4017a4212c3f3f1b..d9433ffb5c06c55815365057d708e3619a40998a 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -147,6 +147,12 @@ def main(args): is_visualize = True font_path = args.vis_font_path drop_score = args.drop_score + # warm up 10 times + if args.warmup: + img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) + for i in range(10): + res = text_sys(img) + for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: diff --git a/tools/infer/utility.py b/tools/infer/utility.py index a558f490f941ab0dd940329ff7c82c49b6eb31e7..38cd6d765b8f83276d2414c81d441cc452ec7a41 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -105,6 +105,7 @@ def init_args(): parser.add_argument("--enable_mkldnn", type=str2bool, default=False) parser.add_argument("--cpu_threads", type=int, default=10) parser.add_argument("--use_pdserving", type=str2bool, default=False) + parser.add_argument("--warmup", type=str2bool, default=True) parser.add_argument("--use_mp", type=str2bool, default=False) parser.add_argument("--total_process_num", type=int, default=1)