diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index b3d9d4907ba35f7cfade795b6d3897c525d41e6d..b24e57dd973bc0216f2875232bcec6e36ab47e29 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -248,9 +248,11 @@ class TextRecognizer(object): def main(args): image_file_list = get_image_file_list(args.image_dir) text_recognizer = TextRecognizer(args) + total_run_time = 0.0 + total_images_num = 0 valid_image_file_list = [] img_list = [] - for image_file in image_file_list: + for idx, image_file in enumerate(image_file_list): img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) @@ -259,22 +261,29 @@ def main(args): continue valid_image_file_list.append(image_file) img_list.append(img) - try: - rec_res, predict_time = text_recognizer(img_list) - except: - logger.info(traceback.format_exc()) - logger.info( - "ERROR!!!! \n" - "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" - "If your model has tps module: " - "TPS does not support variable shape.\n" - "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") - exit() - for ino in range(len(img_list)): - logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], - rec_res[ino])) + if len(img_list) >= args.rec_batch_num or idx == len( + image_file_list) - 1: + try: + rec_res, predict_time = text_recognizer(img_list) + total_run_time += predict_time + except: + logger.info(traceback.format_exc()) + logger.info( + "ERROR!!!! \n" + "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" + "If your model has tps module: " + "TPS does not support variable shape.\n" + "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' " + ) + exit() + for ino in range(len(img_list)): + logger.info("Predicts of {}:{}".format(valid_image_file_list[ + ino], rec_res[ino])) + total_images_num += len(valid_image_file_list) + valid_image_file_list = [] + img_list = [] logger.info("Total predict time for {} images, cost: {:.3f}".format( - len(img_list), predict_time)) + total_images_num, total_run_time)) if __name__ == "__main__":