未验证 提交 3ce97f18 编写于 作者: L littletomatodonkey 提交者: GitHub

fix predict rec (#2065)

* fix predict rec

* fix lopp
上级 d8719969
...@@ -248,9 +248,11 @@ class TextRecognizer(object): ...@@ -248,9 +248,11 @@ class TextRecognizer(object):
def main(args): def main(args):
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
text_recognizer = TextRecognizer(args) text_recognizer = TextRecognizer(args)
total_run_time = 0.0
total_images_num = 0
valid_image_file_list = [] valid_image_file_list = []
img_list = [] img_list = []
for image_file in image_file_list: for idx, image_file in enumerate(image_file_list):
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
if not flag: if not flag:
img = cv2.imread(image_file) img = cv2.imread(image_file)
...@@ -259,8 +261,11 @@ def main(args): ...@@ -259,8 +261,11 @@ def main(args):
continue continue
valid_image_file_list.append(image_file) valid_image_file_list.append(image_file)
img_list.append(img) img_list.append(img)
if len(img_list) >= args.rec_batch_num or idx == len(
image_file_list) - 1:
try: try:
rec_res, predict_time = text_recognizer(img_list) rec_res, predict_time = text_recognizer(img_list)
total_run_time += predict_time
except: except:
logger.info(traceback.format_exc()) logger.info(traceback.format_exc())
logger.info( logger.info(
...@@ -268,13 +273,17 @@ def main(args): ...@@ -268,13 +273,17 @@ def main(args):
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
"If your model has tps module: " "If your model has tps module: "
"TPS does not support variable shape.\n" "TPS does not support variable shape.\n"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
)
exit() exit()
for ino in range(len(img_list)): for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], logger.info("Predicts of {}:{}".format(valid_image_file_list[
rec_res[ino])) ino], rec_res[ino]))
total_images_num += len(valid_image_file_list)
valid_image_file_list = []
img_list = []
logger.info("Total predict time for {} images, cost: {:.3f}".format( logger.info("Total predict time for {} images, cost: {:.3f}".format(
len(img_list), predict_time)) total_images_num, total_run_time))
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册