diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 06273e9f9e5b42a9ecc829c435662e9aabcdd224..b7232885927acbebc2bf7ad203712089d5ef4837 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -62,8 +62,8 @@ class TextRecognizer(object): def resize_norm_img(self, img, max_wh_ratio): imgC, imgH, imgW = self.rec_image_shape assert imgC == img.shape[2] - if self.character_type == "ch": - imgW = int((32 * max_wh_ratio)) + #if self.character_type == "ch": + #imgW = int((32 * max_wh_ratio)) h, w = img.shape[:2] ratio = w / float(h) if math.ceil(imgH * ratio) > imgW: @@ -314,17 +314,12 @@ def main(args): valid_image_file_list.append(image_file) img_list.append(img) - try: - rec_res, predict_time = text_recognizer(img_list) + rec_res, predict_time = text_recognizer(img_list) + """ except Exception as e: print(e) - logger.info( - "ERROR!!!! \n" - "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" - "If your model has tps module: " - "TPS does not support variable shape.\n" - "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") exit() + """ for ino in range(len(img_list)): print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) print("Total predict time for %d images:%.3f" % diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 647a76b20496335cd059242890f86fffe1e3ac1a..6ca8eade9665e11f84443fce68925723ce1e5a00 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -123,50 +123,68 @@ def main(args): text_sys = TextSystem(args) is_visualize = True tackle_img_num = 0 - for image_file in image_file_list: - img, flag = check_and_read_gif(image_file) - if not flag: - img = cv2.imread(image_file) - if img is None: - logger.info("error in loading image:{}".format(image_file)) - continue - starttime = time.time() - tackle_img_num += 1 - if not args.use_gpu and args.enable_mkldnn and tackle_img_num % 30 == 0: - text_sys = TextSystem(args) - dt_boxes, rec_res = text_sys(img) - elapse = time.time() - starttime - print("Predict time of %s: %.3fs" % (image_file, elapse)) - - drop_score = 0.5 - dt_num = len(dt_boxes) - for dno in range(dt_num): - text, score = rec_res[dno] - if score >= drop_score: - text_str = "%s, %.3f" % (text, score) - print(text_str) - - if is_visualize: - image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) - boxes = dt_boxes - txts = [rec_res[i][0] for i in range(len(rec_res))] - scores = [rec_res[i][1] for i in range(len(rec_res))] - - draw_img = draw_ocr( - image, - boxes, - txts, - scores, - drop_score=drop_score) - draw_img_save = "./inference_results/" - if not os.path.exists(draw_img_save): - os.makedirs(draw_img_save) - cv2.imwrite( - os.path.join(draw_img_save, os.path.basename(image_file)), - draw_img[:, :, ::-1]) - print("The visualized image saved in {}".format( - os.path.join(draw_img_save, os.path.basename(image_file)))) - - + if not args.enable_benchmark: + for image_file in image_file_list: + img, flag = check_and_read_gif(image_file) + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + starttime = time.time() + tackle_img_num += 1 + if not args.use_gpu and args.enable_mkldnn and tackle_img_num % 30 == 0: + text_sys = TextSystem(args) + dt_boxes, rec_res = text_sys(img) + elapse = time.time() - starttime + print("Predict time of %s: %.3fs" % (image_file, elapse)) + + drop_score = 0.5 + dt_num = len(dt_boxes) + for dno in range(dt_num): + text, score = rec_res[dno] + if score >= drop_score: + text_str = "%s, %.3f" % (text, score) + print(text_str) + + if is_visualize: + image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + boxes = dt_boxes + txts = [rec_res[i][0] for i in range(len(rec_res))] + scores = [rec_res[i][1] for i in range(len(rec_res))] + + draw_img = draw_ocr( + image, + boxes, + txts, + scores, + drop_score=drop_score) + draw_img_save = "./inference_results/" + if not os.path.exists(draw_img_save): + os.makedirs(draw_img_save) + cv2.imwrite( + os.path.join(draw_img_save, os.path.basename(image_file)), + draw_img[:, :, ::-1]) + print("The visualized image saved in {}".format( + os.path.join(draw_img_save, os.path.basename(image_file)))) + else: + test_num = 10 + test_time = 0.0 + for i in range(0, test_num + 10): + #inputs = np.random.rand(640, 640, 3).astype(np.float32) + #print(image_file_list) + image_file = image_file_list[0] + inputs = cv2.imread(image_file) + inputs = cv2.resize(inputs, (int(640), int(640))) + start_time = time.time() + dt_boxes,rec_res = text_sys(inputs) + if i >= 10: + test_time += time.time() - start_time + time.sleep(0.01) + fp_message = "FP16" if args.use_fp16 else "FP32" + trt_msg = "using tensorrt" if args.use_tensorrt else "not using tensorrt" + print("model\t{0}\t{1}\tbatch size: {2}\ttime(ms): {3}".format( + trt_msg, fp_message, args.max_batch_size, 1000 * + test_time / test_num)) if __name__ == "__main__": main(utility.parse_args()) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 8d66f8671815c53d00476033f772cb58069606c8..f416e39a9834439b6ac695532aed51b9bbd3af57 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -36,7 +36,9 @@ def parse_args(): parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--gpu_mem", type=int, default=8000) - + parser.add_argument("--use_fp16", type=str2bool, default=False) + parser.add_argument("--max_batch_size", type=int, default=10) + parser.add_argument("--enable_benchmark", type=str2bool, default=True) #params for text detector parser.add_argument("--image_dir", type=str) parser.add_argument("--det_algorithm", type=str, default='DB') @@ -112,6 +114,12 @@ def create_predictor(args, mode): else: config.switch_use_feed_fetch_ops(True) + if args.use_tensorrt: + config.enable_tensorrt_engine( + precision_mode=AnalysisConfig.Precision.Half + if args.use_fp16 else AnalysisConfig.Precision.Float32, + max_batch_size=args.max_batch_size) + predictor = create_paddle_predictor(config) input_names = predictor.get_input_names() for name in input_names: