From 9874f5023bdfa8833f8d1965f11e0dc4842a168e Mon Sep 17 00:00:00 2001 From: tink2123 Date: Tue, 8 Sep 2020 16:53:19 +0800 Subject: [PATCH] support tensorrt --- tools/infer/predict_rec.py | 17 +++++++++++------ tools/infer/predict_system.py | 18 +++++++----------- tools/infer/utility.py | 4 ++-- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index b7232885..9e59c824 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -42,6 +42,7 @@ class TextRecognizer(object): self.rec_algorithm = args.rec_algorithm self.text_len = args.max_text_length self.use_zero_copy_run = args.use_zero_copy_run + self.benchmark = args.enable_benchmark char_ops_params = { "character_type": args.rec_char_type, "character_dict_path": args.rec_char_dict_path, @@ -62,8 +63,8 @@ class TextRecognizer(object): def resize_norm_img(self, img, max_wh_ratio): imgC, imgH, imgW = self.rec_image_shape assert imgC == img.shape[2] - #if self.character_type == "ch": - #imgW = int((32 * max_wh_ratio)) + if self.character_type == "ch" and not self.benchmark: + imgW = int((32 * max_wh_ratio)) h, w = img.shape[:2] ratio = w / float(h) if math.ceil(imgH * ratio) > imgW: @@ -313,13 +314,17 @@ def main(args): continue valid_image_file_list.append(image_file) img_list.append(img) - - rec_res, predict_time = text_recognizer(img_list) - """ + try: + rec_res, predict_time = text_recognizer(img_list) except Exception as e: print(e) + logger.info( + "ERROR!!!! \n" + "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" + "If your model has tps module: " + "TPS does not support variable shape.\n" + "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") exit() - """ for ino in range(len(img_list)): print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) print("Total predict time for %d images:%.3f" % diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 6ca8eade..53f5f9e3 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -154,11 +154,7 @@ def main(args): scores = [rec_res[i][1] for i in range(len(rec_res))] draw_img = draw_ocr( - image, - boxes, - txts, - scores, - drop_score=drop_score) + image, boxes, txts, scores, drop_score=drop_score) draw_img_save = "./inference_results/" if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) @@ -171,20 +167,20 @@ def main(args): test_num = 10 test_time = 0.0 for i in range(0, test_num + 10): - #inputs = np.random.rand(640, 640, 3).astype(np.float32) - #print(image_file_list) image_file = image_file_list[0] inputs = cv2.imread(image_file) inputs = cv2.resize(inputs, (int(640), int(640))) start_time = time.time() - dt_boxes,rec_res = text_sys(inputs) + dt_boxes, rec_res = text_sys(inputs) if i >= 10: test_time += time.time() - start_time time.sleep(0.01) fp_message = "FP16" if args.use_fp16 else "FP32" trt_msg = "using tensorrt" if args.use_tensorrt else "not using tensorrt" - print("model\t{0}\t{1}\tbatch size: {2}\ttime(ms): {3}".format( - trt_msg, fp_message, args.max_batch_size, 1000 * - test_time / test_num)) + print("Benchmark\t{0}\t{1}\tbatch size: {2}\ttime(ms): {3}".format( + trt_msg, fp_message, args.max_batch_size, 1000 * test_time / + test_num)) + + if __name__ == "__main__": main(utility.parse_args()) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index f416e39a..3ea1052b 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -37,8 +37,8 @@ def parse_args(): parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--gpu_mem", type=int, default=8000) parser.add_argument("--use_fp16", type=str2bool, default=False) - parser.add_argument("--max_batch_size", type=int, default=10) - parser.add_argument("--enable_benchmark", type=str2bool, default=True) + parser.add_argument("--max_batch_size", type=int, default=1) + parser.add_argument("--enable_benchmark", type=str2bool, default=False) #params for text detector parser.add_argument("--image_dir", type=str) parser.add_argument("--det_algorithm", type=str, default='DB') -- GitLab