提交 aa48cda3 编写于 作者: T tink2123

save for tensorrt

上级 5fb3c419
...@@ -62,8 +62,8 @@ class TextRecognizer(object): ...@@ -62,8 +62,8 @@ class TextRecognizer(object):
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape imgC, imgH, imgW = self.rec_image_shape
assert imgC == img.shape[2] assert imgC == img.shape[2]
if self.character_type == "ch": #if self.character_type == "ch":
imgW = int((32 * max_wh_ratio)) #imgW = int((32 * max_wh_ratio))
h, w = img.shape[:2] h, w = img.shape[:2]
ratio = w / float(h) ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW: if math.ceil(imgH * ratio) > imgW:
...@@ -314,17 +314,12 @@ def main(args): ...@@ -314,17 +314,12 @@ def main(args):
valid_image_file_list.append(image_file) valid_image_file_list.append(image_file)
img_list.append(img) img_list.append(img)
try:
rec_res, predict_time = text_recognizer(img_list) rec_res, predict_time = text_recognizer(img_list)
"""
except Exception as e: except Exception as e:
print(e) print(e)
logger.info(
"ERROR!!!! \n"
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
"If your model has tps module: "
"TPS does not support variable shape.\n"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit() exit()
"""
for ino in range(len(img_list)): for ino in range(len(img_list)):
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
print("Total predict time for %d images:%.3f" % print("Total predict time for %d images:%.3f" %
......
...@@ -123,6 +123,7 @@ def main(args): ...@@ -123,6 +123,7 @@ def main(args):
text_sys = TextSystem(args) text_sys = TextSystem(args)
is_visualize = True is_visualize = True
tackle_img_num = 0 tackle_img_num = 0
if not args.enable_benchmark:
for image_file in image_file_list: for image_file in image_file_list:
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
if not flag: if not flag:
...@@ -166,7 +167,24 @@ def main(args): ...@@ -166,7 +167,24 @@ def main(args):
draw_img[:, :, ::-1]) draw_img[:, :, ::-1])
print("The visualized image saved in {}".format( print("The visualized image saved in {}".format(
os.path.join(draw_img_save, os.path.basename(image_file)))) os.path.join(draw_img_save, os.path.basename(image_file))))
else:
test_num = 10
test_time = 0.0
for i in range(0, test_num + 10):
#inputs = np.random.rand(640, 640, 3).astype(np.float32)
#print(image_file_list)
image_file = image_file_list[0]
inputs = cv2.imread(image_file)
inputs = cv2.resize(inputs, (int(640), int(640)))
start_time = time.time()
dt_boxes,rec_res = text_sys(inputs)
if i >= 10:
test_time += time.time() - start_time
time.sleep(0.01)
fp_message = "FP16" if args.use_fp16 else "FP32"
trt_msg = "using tensorrt" if args.use_tensorrt else "not using tensorrt"
print("model\t{0}\t{1}\tbatch size: {2}\ttime(ms): {3}".format(
trt_msg, fp_message, args.max_batch_size, 1000 *
test_time / test_num))
if __name__ == "__main__": if __name__ == "__main__":
main(utility.parse_args()) main(utility.parse_args())
...@@ -36,7 +36,9 @@ def parse_args(): ...@@ -36,7 +36,9 @@ def parse_args():
parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000) parser.add_argument("--gpu_mem", type=int, default=8000)
parser.add_argument("--use_fp16", type=str2bool, default=False)
parser.add_argument("--max_batch_size", type=int, default=10)
parser.add_argument("--enable_benchmark", type=str2bool, default=True)
#params for text detector #params for text detector
parser.add_argument("--image_dir", type=str) parser.add_argument("--image_dir", type=str)
parser.add_argument("--det_algorithm", type=str, default='DB') parser.add_argument("--det_algorithm", type=str, default='DB')
...@@ -112,6 +114,12 @@ def create_predictor(args, mode): ...@@ -112,6 +114,12 @@ def create_predictor(args, mode):
else: else:
config.switch_use_feed_fetch_ops(True) config.switch_use_feed_fetch_ops(True)
if args.use_tensorrt:
config.enable_tensorrt_engine(
precision_mode=AnalysisConfig.Precision.Half
if args.use_fp16 else AnalysisConfig.Precision.Float32,
max_batch_size=args.max_batch_size)
predictor = create_paddle_predictor(config) predictor = create_paddle_predictor(config)
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
for name in input_names: for name in input_names:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册