未验证 提交 3ce97f18 编写于 作者: L littletomatodonkey 提交者: GitHub

fix predict rec (#2065)

* fix predict rec

* fix lopp
上级 d8719969
...@@ -248,9 +248,11 @@ class TextRecognizer(object): ...@@ -248,9 +248,11 @@ class TextRecognizer(object):
def main(args): def main(args):
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
text_recognizer = TextRecognizer(args) text_recognizer = TextRecognizer(args)
total_run_time = 0.0
total_images_num = 0
valid_image_file_list = [] valid_image_file_list = []
img_list = [] img_list = []
for image_file in image_file_list: for idx, image_file in enumerate(image_file_list):
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
if not flag: if not flag:
img = cv2.imread(image_file) img = cv2.imread(image_file)
...@@ -259,22 +261,29 @@ def main(args): ...@@ -259,22 +261,29 @@ def main(args):
continue continue
valid_image_file_list.append(image_file) valid_image_file_list.append(image_file)
img_list.append(img) img_list.append(img)
try: if len(img_list) >= args.rec_batch_num or idx == len(
rec_res, predict_time = text_recognizer(img_list) image_file_list) - 1:
except: try:
logger.info(traceback.format_exc()) rec_res, predict_time = text_recognizer(img_list)
logger.info( total_run_time += predict_time
"ERROR!!!! \n" except:
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" logger.info(traceback.format_exc())
"If your model has tps module: " logger.info(
"TPS does not support variable shape.\n" "ERROR!!!! \n"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
exit() "If your model has tps module: "
for ino in range(len(img_list)): "TPS does not support variable shape.\n"
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
rec_res[ino])) )
exit()
for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[
ino], rec_res[ino]))
total_images_num += len(valid_image_file_list)
valid_image_file_list = []
img_list = []
logger.info("Total predict time for {} images, cost: {:.3f}".format( logger.info("Total predict time for {} images, cost: {:.3f}".format(
len(img_list), predict_time)) total_images_num, total_run_time))
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册