From e61f40ef4155a5ea40463a521d2dcd127413542a Mon Sep 17 00:00:00 2001 From: andyjpaddle Date: Thu, 29 Sep 2022 07:31:45 +0000 Subject: [PATCH] support infer pdf file --- doc/doc_ch/inference_args.md | 1 + doc/doc_en/inference_args_en.md | 1 + tools/infer/predict_det.py | 79 +++++++++++++++-------- tools/infer/predict_system.py | 111 +++++++++++++++++++------------- tools/infer/utility.py | 8 ++- 5 files changed, 128 insertions(+), 72 deletions(-) diff --git a/doc/doc_ch/inference_args.md b/doc/doc_ch/inference_args.md index 36efc6fb..24e7223e 100644 --- a/doc/doc_ch/inference_args.md +++ b/doc/doc_ch/inference_args.md @@ -7,6 +7,7 @@ | 参数名称 | 类型 | 默认值 | 含义 | | :--: | :--: | :--: | :--: | | image_dir | str | 无,必须显式指定 | 图像或者文件夹路径 | +| page_num | int | 0 | 当输入类型为pdf文件时有效,指定预测前面page_num页,默认预测所有页 | | vis_font_path | str | "./doc/fonts/simfang.ttf" | 用于可视化的字体路径 | | drop_score | float | 0.5 | 识别得分小于该值的结果会被丢弃,不会作为返回结果 | | use_pdserving | bool | False | 是否使用Paddle Serving进行预测 | diff --git a/doc/doc_en/inference_args_en.md b/doc/doc_en/inference_args_en.md index f2c99fc8..b28cd843 100644 --- a/doc/doc_en/inference_args_en.md +++ b/doc/doc_en/inference_args_en.md @@ -7,6 +7,7 @@ When using PaddleOCR for model inference, you can customize the modification par | parameters | type | default | implication | | :--: | :--: | :--: | :--: | | image_dir | str | None, must be specified explicitly | Image or folder path | +| page_num | int | 0 | Valid when the input type is pdf file, specify to predict the previous page_num pages, all pages are predicted by default | | vis_font_path | str | "./doc/fonts/simfang.ttf" | font path for visualization | | drop_score | float | 0.5 | Results with a recognition score less than this value will be discarded and will not be returned as results | | use_pdserving | bool | False | Whether to use Paddle Serving for prediction | diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 00fa2e9b..7e371886 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -282,44 +282,69 @@ if __name__ == "__main__": args = utility.parse_args() image_file_list = get_image_file_list(args.image_dir) text_detector = TextDetector(args) - count = 0 total_time = 0 - draw_img_save = "./inference_results" + draw_img_save_dir = args.draw_img_save_dir + os.makedirs(draw_img_save_dir, exist_ok=True) if args.warmup: img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) for i in range(2): res = text_detector(img) - if not os.path.exists(draw_img_save): - os.makedirs(draw_img_save) save_results = [] - for image_file in image_file_list: - img, flag, _ = check_and_read(image_file) - if not flag: + for idx, image_file in enumerate(image_file_list): + img, flag_gif, flag_pdf = check_and_read(image_file) + if not flag_gif and not flag_pdf: img = cv2.imread(image_file) - if img is None: - logger.info("error in loading image:{}".format(image_file)) - continue - st = time.time() - dt_boxes, _ = text_detector(img) - elapse = time.time() - st - if count > 0: + if not flag_pdf: + if img is None: + logger.debug("error in loading image:{}".format(image_file)) + continue + imgs = [img] + else: + page_num = args.page_num + if page_num > len(img) or page_num == 0: + page_num = len(img) + imgs = img[:page_num] + for index, img in enumerate(imgs): + st = time.time() + dt_boxes, _ = text_detector(img) + elapse = time.time() - st total_time += elapse - count += 1 - save_pred = os.path.basename(image_file) + "\t" + str( - json.dumps([x.tolist() for x in dt_boxes])) + "\n" - save_results.append(save_pred) - logger.info(save_pred) - logger.info("The predict time of {}: {}".format(image_file, elapse)) - src_im = utility.draw_text_det_res(dt_boxes, image_file) - img_name_pure = os.path.split(image_file)[-1] - img_path = os.path.join(draw_img_save, - "det_res_{}".format(img_name_pure)) - cv2.imwrite(img_path, src_im) - logger.info("The visualized image saved in {}".format(img_path)) + if len(imgs) > 1: + save_pred = os.path.basename(image_file) + '_' + str( + index) + "\t" + str( + json.dumps([x.tolist() for x in dt_boxes])) + "\n" + else: + save_pred = os.path.basename(image_file) + "\t" + str( + json.dumps([x.tolist() for x in dt_boxes])) + "\n" + save_results.append(save_pred) + logger.info(save_pred) + if len(imgs) > 1: + logger.info("{}_{} The predict time of {}: {}".format( + idx, index, image_file, elapse)) + else: + logger.info("{} The predict time of {}: {}".format( + idx, image_file, elapse)) + if flag_pdf: + src_im = utility.draw_text_det_res(dt_boxes, img, flag_pdf) + else: + src_im = utility.draw_text_det_res(dt_boxes, image_file, + flag_pdf) + if flag_gif: + save_file = image_file[:-3] + "png" + elif flag_pdf: + save_file = image_file.replace('.pdf', + '_' + str(index) + '.png') + else: + save_file = image_file + img_path = os.path.join( + draw_img_save_dir, + "det_res_{}".format(os.path.basename(save_file))) + cv2.imwrite(img_path, src_im) + logger.info("The visualized image saved in {}".format(img_path)) - with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f: + with open(os.path.join(draw_img_save_dir, "det_results.txt"), 'w') as f: f.writelines(save_results) f.close() if args.benchmark: diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index e0f2c41f..affd0d1b 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -159,50 +159,75 @@ def main(args): count = 0 for idx, image_file in enumerate(image_file_list): - img, flag, _ = check_and_read(image_file) - if not flag: + img, flag_gif, flag_pdf = check_and_read(image_file) + if not flag_gif and not flag_pdf: img = cv2.imread(image_file) - if img is None: - logger.debug("error in loading image:{}".format(image_file)) - continue - starttime = time.time() - dt_boxes, rec_res, time_dict = text_sys(img) - elapse = time.time() - starttime - total_time += elapse - - logger.debug( - str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse)) - for text, score in rec_res: - logger.debug("{}, {:.3f}".format(text, score)) - - res = [{ - "transcription": rec_res[idx][0], - "points": np.array(dt_boxes[idx]).astype(np.int32).tolist(), - } for idx in range(len(dt_boxes))] - save_pred = os.path.basename(image_file) + "\t" + json.dumps( - res, ensure_ascii=False) + "\n" - save_results.append(save_pred) - - if is_visualize: - image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) - boxes = dt_boxes - txts = [rec_res[i][0] for i in range(len(rec_res))] - scores = [rec_res[i][1] for i in range(len(rec_res))] - - draw_img = draw_ocr_box_txt( - image, - boxes, - txts, - scores, - drop_score=drop_score, - font_path=font_path) - if flag: - image_file = image_file[:-3] + "png" - cv2.imwrite( - os.path.join(draw_img_save_dir, os.path.basename(image_file)), - draw_img[:, :, ::-1]) - logger.debug("The visualized image saved in {}".format( - os.path.join(draw_img_save_dir, os.path.basename(image_file)))) + if not flag_pdf: + if img is None: + logger.debug("error in loading image:{}".format(image_file)) + continue + imgs = [img] + else: + page_num = args.page_num + if page_num > len(img) or page_num == 0: + page_num = len(img) + imgs = img[:page_num] + for index, img in enumerate(imgs): + starttime = time.time() + dt_boxes, rec_res, time_dict = text_sys(img) + elapse = time.time() - starttime + total_time += elapse + if len(imgs) > 1: + logger.debug( + str(idx) + '_' + str(index) + " Predict time of %s: %.3fs" + % (image_file, elapse)) + else: + logger.debug( + str(idx) + " Predict time of %s: %.3fs" % (image_file, + elapse)) + for text, score in rec_res: + logger.debug("{}, {:.3f}".format(text, score)) + + res = [{ + "transcription": rec_res[i][0], + "points": np.array(dt_boxes[i]).astype(np.int32).tolist(), + } for i in range(len(dt_boxes))] + if len(imgs) > 1: + save_pred = os.path.basename(image_file) + '_' + str( + index) + "\t" + json.dumps( + res, ensure_ascii=False) + "\n" + else: + save_pred = os.path.basename(image_file) + "\t" + json.dumps( + res, ensure_ascii=False) + "\n" + save_results.append(save_pred) + + if is_visualize: + image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + boxes = dt_boxes + txts = [rec_res[i][0] for i in range(len(rec_res))] + scores = [rec_res[i][1] for i in range(len(rec_res))] + + draw_img = draw_ocr_box_txt( + image, + boxes, + txts, + scores, + drop_score=drop_score, + font_path=font_path) + if flag_gif: + save_file = image_file[:-3] + "png" + elif flag_pdf: + save_file = image_file.replace('.pdf', + '_' + str(index) + '.png') + else: + save_file = image_file + cv2.imwrite( + os.path.join(draw_img_save_dir, + os.path.basename(save_file)), + draw_img[:, :, ::-1]) + logger.debug("The visualized image saved in {}".format( + os.path.join(draw_img_save_dir, os.path.basename( + save_file)))) logger.info("The predict total time is {}".format(time.time() - _st)) if args.benchmark: diff --git a/tools/infer/utility.py b/tools/infer/utility.py index b9c9490b..5de672be 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -45,6 +45,7 @@ def init_args(): # params for text detector parser.add_argument("--image_dir", type=str) + parser.add_argument("--page_num", type=int, default=0) parser.add_argument("--det_algorithm", type=str, default='DB') parser.add_argument("--det_model_dir", type=str) parser.add_argument("--det_limit_side_len", type=float, default=960) @@ -337,8 +338,11 @@ def draw_e2e_res(dt_boxes, strs, img_path): return src_im -def draw_text_det_res(dt_boxes, img_path): - src_im = cv2.imread(img_path) +def draw_text_det_res(dt_boxes, img_path, flag_pdf=False): + if not flag_pdf: + src_im = cv2.imread(img_path) + else: + src_im = img_path for box in dt_boxes: box = np.array(box).astype(np.int32).reshape(-1, 2) cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) -- GitLab