diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 38dd9d2ff8ea9543620e66b934d3db103cbf6567..c5e25903c17378bb992ebdd4fc7ead012528dd46 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -207,10 +207,12 @@ if __name__ == "__main__": total_time = 0 draw_img_save = "./inference_results" - # warmup 10 times - fake_img = np.random.uniform(-1, 1, [640, 640, 3]).astype(np.float32) - for i in range(10): - dt_boxes, _ = text_detector(fake_img) + if args.warmup: + img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) + for i in range(10): + res = text_detector(img) + + cpu_mem, gpu_mem, gpu_util = 0, 0, 0 if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 2eeb39b2a0bff15241ea7762b4981e4daaada096..0d847046530c02c9b0591bb4b379fd7ddeac1263 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -257,13 +257,15 @@ def main(args): text_recognizer = TextRecognizer(args) valid_image_file_list = [] img_list = [] - cpu_mem, gpu_mem, gpu_util = 0, 0, 0 - count = 0 # warmup 10 times - fake_img = np.random.uniform(-1, 1, [1, 32, 320, 3]).astype(np.float32) - for i in range(10): - dt_boxes, _ = text_recognizer(fake_img) + if args.warmup: + img = np.random.uniform(0, 255, [32, 320, 3]).astype(np.uint8) + for i in range(10): + res = text_recognizer([img]) + + cpu_mem, gpu_mem, gpu_util = 0, 0, 0 + count = 0 for image_file in image_file_list: img, flag = check_and_read_gif(image_file) diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 133b9a9df243cc25536f9a0b49120be8c07bdd5d..c008f9679684e2433859cd104261aeff56b410a2 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -152,11 +152,19 @@ def main(args): is_visualize = True font_path = args.vis_font_path drop_score = args.drop_score + + # warm up 10 times + if args.warmup: + img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) + for i in range(10): + res = text_sys(img) + total_time = 0 cpu_mem, gpu_mem, gpu_util = 0, 0, 0 _st = time.time() count = 0 for idx, image_file in enumerate(image_file_list): + img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index d87666e605a72cd8f35c6169d0c651c4bfc64a03..d491d6013869da5cc5e7cc7975a3324a460182a2 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -105,6 +105,7 @@ def init_args(): parser.add_argument("--enable_mkldnn", type=str2bool, default=False) parser.add_argument("--cpu_threads", type=int, default=10) parser.add_argument("--use_pdserving", type=str2bool, default=False) + parser.add_argument("--warmup", type=str2bool, default=True) # multi-process parser.add_argument("--use_mp", type=str2bool, default=False)