提交 aa48cda3 编写于 作者: T tink2123

save for tensorrt

上级 5fb3c419
...@@ -62,8 +62,8 @@ class TextRecognizer(object): ...@@ -62,8 +62,8 @@ class TextRecognizer(object):
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape imgC, imgH, imgW = self.rec_image_shape
assert imgC == img.shape[2] assert imgC == img.shape[2]
if self.character_type == "ch": #if self.character_type == "ch":
imgW = int((32 * max_wh_ratio)) #imgW = int((32 * max_wh_ratio))
h, w = img.shape[:2] h, w = img.shape[:2]
ratio = w / float(h) ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW: if math.ceil(imgH * ratio) > imgW:
...@@ -314,17 +314,12 @@ def main(args): ...@@ -314,17 +314,12 @@ def main(args):
valid_image_file_list.append(image_file) valid_image_file_list.append(image_file)
img_list.append(img) img_list.append(img)
try: rec_res, predict_time = text_recognizer(img_list)
rec_res, predict_time = text_recognizer(img_list) """
except Exception as e: except Exception as e:
print(e) print(e)
logger.info(
"ERROR!!!! \n"
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
"If your model has tps module: "
"TPS does not support variable shape.\n"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit() exit()
"""
for ino in range(len(img_list)): for ino in range(len(img_list)):
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
print("Total predict time for %d images:%.3f" % print("Total predict time for %d images:%.3f" %
......
...@@ -123,50 +123,68 @@ def main(args): ...@@ -123,50 +123,68 @@ def main(args):
text_sys = TextSystem(args) text_sys = TextSystem(args)
is_visualize = True is_visualize = True
tackle_img_num = 0 tackle_img_num = 0
for image_file in image_file_list: if not args.enable_benchmark:
img, flag = check_and_read_gif(image_file) for image_file in image_file_list:
if not flag: img, flag = check_and_read_gif(image_file)
img = cv2.imread(image_file) if not flag:
if img is None: img = cv2.imread(image_file)
logger.info("error in loading image:{}".format(image_file)) if img is None:
continue logger.info("error in loading image:{}".format(image_file))
starttime = time.time() continue
tackle_img_num += 1 starttime = time.time()
if not args.use_gpu and args.enable_mkldnn and tackle_img_num % 30 == 0: tackle_img_num += 1
text_sys = TextSystem(args) if not args.use_gpu and args.enable_mkldnn and tackle_img_num % 30 == 0:
dt_boxes, rec_res = text_sys(img) text_sys = TextSystem(args)
elapse = time.time() - starttime dt_boxes, rec_res = text_sys(img)
print("Predict time of %s: %.3fs" % (image_file, elapse)) elapse = time.time() - starttime
print("Predict time of %s: %.3fs" % (image_file, elapse))
drop_score = 0.5
dt_num = len(dt_boxes) drop_score = 0.5
for dno in range(dt_num): dt_num = len(dt_boxes)
text, score = rec_res[dno] for dno in range(dt_num):
if score >= drop_score: text, score = rec_res[dno]
text_str = "%s, %.3f" % (text, score) if score >= drop_score:
print(text_str) text_str = "%s, %.3f" % (text, score)
print(text_str)
if is_visualize:
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) if is_visualize:
boxes = dt_boxes image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
txts = [rec_res[i][0] for i in range(len(rec_res))] boxes = dt_boxes
scores = [rec_res[i][1] for i in range(len(rec_res))] txts = [rec_res[i][0] for i in range(len(rec_res))]
scores = [rec_res[i][1] for i in range(len(rec_res))]
draw_img = draw_ocr(
image, draw_img = draw_ocr(
boxes, image,
txts, boxes,
scores, txts,
drop_score=drop_score) scores,
draw_img_save = "./inference_results/" drop_score=drop_score)
if not os.path.exists(draw_img_save): draw_img_save = "./inference_results/"
os.makedirs(draw_img_save) if not os.path.exists(draw_img_save):
cv2.imwrite( os.makedirs(draw_img_save)
os.path.join(draw_img_save, os.path.basename(image_file)), cv2.imwrite(
draw_img[:, :, ::-1]) os.path.join(draw_img_save, os.path.basename(image_file)),
print("The visualized image saved in {}".format( draw_img[:, :, ::-1])
os.path.join(draw_img_save, os.path.basename(image_file)))) print("The visualized image saved in {}".format(
os.path.join(draw_img_save, os.path.basename(image_file))))
else:
test_num = 10
test_time = 0.0
for i in range(0, test_num + 10):
#inputs = np.random.rand(640, 640, 3).astype(np.float32)
#print(image_file_list)
image_file = image_file_list[0]
inputs = cv2.imread(image_file)
inputs = cv2.resize(inputs, (int(640), int(640)))
start_time = time.time()
dt_boxes,rec_res = text_sys(inputs)
if i >= 10:
test_time += time.time() - start_time
time.sleep(0.01)
fp_message = "FP16" if args.use_fp16 else "FP32"
trt_msg = "using tensorrt" if args.use_tensorrt else "not using tensorrt"
print("model\t{0}\t{1}\tbatch size: {2}\ttime(ms): {3}".format(
trt_msg, fp_message, args.max_batch_size, 1000 *
test_time / test_num))
if __name__ == "__main__": if __name__ == "__main__":
main(utility.parse_args()) main(utility.parse_args())
...@@ -36,7 +36,9 @@ def parse_args(): ...@@ -36,7 +36,9 @@ def parse_args():
parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000) parser.add_argument("--gpu_mem", type=int, default=8000)
parser.add_argument("--use_fp16", type=str2bool, default=False)
parser.add_argument("--max_batch_size", type=int, default=10)
parser.add_argument("--enable_benchmark", type=str2bool, default=True)
#params for text detector #params for text detector
parser.add_argument("--image_dir", type=str) parser.add_argument("--image_dir", type=str)
parser.add_argument("--det_algorithm", type=str, default='DB') parser.add_argument("--det_algorithm", type=str, default='DB')
...@@ -112,6 +114,12 @@ def create_predictor(args, mode): ...@@ -112,6 +114,12 @@ def create_predictor(args, mode):
else: else:
config.switch_use_feed_fetch_ops(True) config.switch_use_feed_fetch_ops(True)
if args.use_tensorrt:
config.enable_tensorrt_engine(
precision_mode=AnalysisConfig.Precision.Half
if args.use_fp16 else AnalysisConfig.Precision.Float32,
max_batch_size=args.max_batch_size)
predictor = create_paddle_predictor(config) predictor = create_paddle_predictor(config)
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
for name in input_names: for name in input_names:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册