提交 9874f502 编写于 作者: T tink2123

support tensorrt

上级 aa48cda3
...@@ -42,6 +42,7 @@ class TextRecognizer(object): ...@@ -42,6 +42,7 @@ class TextRecognizer(object):
self.rec_algorithm = args.rec_algorithm self.rec_algorithm = args.rec_algorithm
self.text_len = args.max_text_length self.text_len = args.max_text_length
self.use_zero_copy_run = args.use_zero_copy_run self.use_zero_copy_run = args.use_zero_copy_run
self.benchmark = args.enable_benchmark
char_ops_params = { char_ops_params = {
"character_type": args.rec_char_type, "character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path, "character_dict_path": args.rec_char_dict_path,
...@@ -62,8 +63,8 @@ class TextRecognizer(object): ...@@ -62,8 +63,8 @@ class TextRecognizer(object):
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape imgC, imgH, imgW = self.rec_image_shape
assert imgC == img.shape[2] assert imgC == img.shape[2]
#if self.character_type == "ch": if self.character_type == "ch" and not self.benchmark:
#imgW = int((32 * max_wh_ratio)) imgW = int((32 * max_wh_ratio))
h, w = img.shape[:2] h, w = img.shape[:2]
ratio = w / float(h) ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW: if math.ceil(imgH * ratio) > imgW:
...@@ -313,13 +314,17 @@ def main(args): ...@@ -313,13 +314,17 @@ def main(args):
continue continue
valid_image_file_list.append(image_file) valid_image_file_list.append(image_file)
img_list.append(img) img_list.append(img)
try:
rec_res, predict_time = text_recognizer(img_list) rec_res, predict_time = text_recognizer(img_list)
"""
except Exception as e: except Exception as e:
print(e) print(e)
logger.info(
"ERROR!!!! \n"
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
"If your model has tps module: "
"TPS does not support variable shape.\n"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit() exit()
"""
for ino in range(len(img_list)): for ino in range(len(img_list)):
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
print("Total predict time for %d images:%.3f" % print("Total predict time for %d images:%.3f" %
......
...@@ -154,11 +154,7 @@ def main(args): ...@@ -154,11 +154,7 @@ def main(args):
scores = [rec_res[i][1] for i in range(len(rec_res))] scores = [rec_res[i][1] for i in range(len(rec_res))]
draw_img = draw_ocr( draw_img = draw_ocr(
image, image, boxes, txts, scores, drop_score=drop_score)
boxes,
txts,
scores,
drop_score=drop_score)
draw_img_save = "./inference_results/" draw_img_save = "./inference_results/"
if not os.path.exists(draw_img_save): if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save) os.makedirs(draw_img_save)
...@@ -171,20 +167,20 @@ def main(args): ...@@ -171,20 +167,20 @@ def main(args):
test_num = 10 test_num = 10
test_time = 0.0 test_time = 0.0
for i in range(0, test_num + 10): for i in range(0, test_num + 10):
#inputs = np.random.rand(640, 640, 3).astype(np.float32)
#print(image_file_list)
image_file = image_file_list[0] image_file = image_file_list[0]
inputs = cv2.imread(image_file) inputs = cv2.imread(image_file)
inputs = cv2.resize(inputs, (int(640), int(640))) inputs = cv2.resize(inputs, (int(640), int(640)))
start_time = time.time() start_time = time.time()
dt_boxes,rec_res = text_sys(inputs) dt_boxes, rec_res = text_sys(inputs)
if i >= 10: if i >= 10:
test_time += time.time() - start_time test_time += time.time() - start_time
time.sleep(0.01) time.sleep(0.01)
fp_message = "FP16" if args.use_fp16 else "FP32" fp_message = "FP16" if args.use_fp16 else "FP32"
trt_msg = "using tensorrt" if args.use_tensorrt else "not using tensorrt" trt_msg = "using tensorrt" if args.use_tensorrt else "not using tensorrt"
print("model\t{0}\t{1}\tbatch size: {2}\ttime(ms): {3}".format( print("Benchmark\t{0}\t{1}\tbatch size: {2}\ttime(ms): {3}".format(
trt_msg, fp_message, args.max_batch_size, 1000 * trt_msg, fp_message, args.max_batch_size, 1000 * test_time /
test_time / test_num)) test_num))
if __name__ == "__main__": if __name__ == "__main__":
main(utility.parse_args()) main(utility.parse_args())
...@@ -37,8 +37,8 @@ def parse_args(): ...@@ -37,8 +37,8 @@ def parse_args():
parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000) parser.add_argument("--gpu_mem", type=int, default=8000)
parser.add_argument("--use_fp16", type=str2bool, default=False) parser.add_argument("--use_fp16", type=str2bool, default=False)
parser.add_argument("--max_batch_size", type=int, default=10) parser.add_argument("--max_batch_size", type=int, default=1)
parser.add_argument("--enable_benchmark", type=str2bool, default=True) parser.add_argument("--enable_benchmark", type=str2bool, default=False)
#params for text detector #params for text detector
parser.add_argument("--image_dir", type=str) parser.add_argument("--image_dir", type=str)
parser.add_argument("--det_algorithm", type=str, default='DB') parser.add_argument("--det_algorithm", type=str, default='DB')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册