未验证 提交 89cc44e7 编写于 作者: A andyj 提交者: GitHub

Merge pull request #7775 from andyjpaddle/support_pdf_infer

Support infer pdf file
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
| 参数名称 | 类型 | 默认值 | 含义 | | 参数名称 | 类型 | 默认值 | 含义 |
| :--: | :--: | :--: | :--: | | :--: | :--: | :--: | :--: |
| image_dir | str | 无,必须显式指定 | 图像或者文件夹路径 | | image_dir | str | 无,必须显式指定 | 图像或者文件夹路径 |
| page_num | int | 0 | 当输入类型为pdf文件时有效,指定预测前面page_num页,默认预测所有页 |
| vis_font_path | str | "./doc/fonts/simfang.ttf" | 用于可视化的字体路径 | | vis_font_path | str | "./doc/fonts/simfang.ttf" | 用于可视化的字体路径 |
| drop_score | float | 0.5 | 识别得分小于该值的结果会被丢弃,不会作为返回结果 | | drop_score | float | 0.5 | 识别得分小于该值的结果会被丢弃,不会作为返回结果 |
| use_pdserving | bool | False | 是否使用Paddle Serving进行预测 | | use_pdserving | bool | False | 是否使用Paddle Serving进行预测 |
......
...@@ -7,6 +7,7 @@ When using PaddleOCR for model inference, you can customize the modification par ...@@ -7,6 +7,7 @@ When using PaddleOCR for model inference, you can customize the modification par
| parameters | type | default | implication | | parameters | type | default | implication |
| :--: | :--: | :--: | :--: | | :--: | :--: | :--: | :--: |
| image_dir | str | None, must be specified explicitly | Image or folder path | | image_dir | str | None, must be specified explicitly | Image or folder path |
| page_num | int | 0 | Valid when the input type is pdf file, specify to predict the previous page_num pages, all pages are predicted by default |
| vis_font_path | str | "./doc/fonts/simfang.ttf" | font path for visualization | | vis_font_path | str | "./doc/fonts/simfang.ttf" | font path for visualization |
| drop_score | float | 0.5 | Results with a recognition score less than this value will be discarded and will not be returned as results | | drop_score | float | 0.5 | Results with a recognition score less than this value will be discarded and will not be returned as results |
| use_pdserving | bool | False | Whether to use Paddle Serving for prediction | | use_pdserving | bool | False | Whether to use Paddle Serving for prediction |
......
...@@ -15,3 +15,4 @@ premailer ...@@ -15,3 +15,4 @@ premailer
openpyxl openpyxl
attrdict attrdict
Polygon3 Polygon3
PyMuPDF==1.18.7
...@@ -282,44 +282,67 @@ if __name__ == "__main__": ...@@ -282,44 +282,67 @@ if __name__ == "__main__":
args = utility.parse_args() args = utility.parse_args()
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
text_detector = TextDetector(args) text_detector = TextDetector(args)
count = 0
total_time = 0 total_time = 0
draw_img_save = "./inference_results" draw_img_save_dir = args.draw_img_save_dir
os.makedirs(draw_img_save_dir, exist_ok=True)
if args.warmup: if args.warmup:
img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
for i in range(2): for i in range(2):
res = text_detector(img) res = text_detector(img)
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
save_results = [] save_results = []
for image_file in image_file_list: for idx, image_file in enumerate(image_file_list):
img, flag, _ = check_and_read(image_file) img, flag_gif, flag_pdf = check_and_read(image_file)
if not flag: if not flag_gif and not flag_pdf:
img = cv2.imread(image_file) img = cv2.imread(image_file)
if img is None: if not flag_pdf:
logger.info("error in loading image:{}".format(image_file)) if img is None:
continue logger.debug("error in loading image:{}".format(image_file))
st = time.time() continue
dt_boxes, _ = text_detector(img) imgs = [img]
elapse = time.time() - st else:
if count > 0: page_num = args.page_num
if page_num > len(img) or page_num == 0:
page_num = len(img)
imgs = img[:page_num]
for index, img in enumerate(imgs):
st = time.time()
dt_boxes, _ = text_detector(img)
elapse = time.time() - st
total_time += elapse total_time += elapse
count += 1 if len(imgs) > 1:
save_pred = os.path.basename(image_file) + "\t" + str( save_pred = os.path.basename(image_file) + '_' + str(
json.dumps([x.tolist() for x in dt_boxes])) + "\n" index) + "\t" + str(
save_results.append(save_pred) json.dumps([x.tolist() for x in dt_boxes])) + "\n"
logger.info(save_pred) else:
logger.info("The predict time of {}: {}".format(image_file, elapse)) save_pred = os.path.basename(image_file) + "\t" + str(
src_im = utility.draw_text_det_res(dt_boxes, image_file) json.dumps([x.tolist() for x in dt_boxes])) + "\n"
img_name_pure = os.path.split(image_file)[-1] save_results.append(save_pred)
img_path = os.path.join(draw_img_save, logger.info(save_pred)
"det_res_{}".format(img_name_pure)) if len(imgs) > 1:
cv2.imwrite(img_path, src_im) logger.info("{}_{} The predict time of {}: {}".format(
logger.info("The visualized image saved in {}".format(img_path)) idx, index, image_file, elapse))
else:
logger.info("{} The predict time of {}: {}".format(
idx, image_file, elapse))
src_im = utility.draw_text_det_res(dt_boxes, img)
if flag_gif:
save_file = image_file[:-3] + "png"
elif flag_pdf:
save_file = image_file.replace('.pdf',
'_' + str(index) + '.png')
else:
save_file = image_file
img_path = os.path.join(
draw_img_save_dir,
"det_res_{}".format(os.path.basename(save_file)))
cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path))
with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f: with open(os.path.join(draw_img_save_dir, "det_results.txt"), 'w') as f:
f.writelines(save_results) f.writelines(save_results)
f.close() f.close()
if args.benchmark: if args.benchmark:
......
...@@ -159,50 +159,75 @@ def main(args): ...@@ -159,50 +159,75 @@ def main(args):
count = 0 count = 0
for idx, image_file in enumerate(image_file_list): for idx, image_file in enumerate(image_file_list):
img, flag, _ = check_and_read(image_file) img, flag_gif, flag_pdf = check_and_read(image_file)
if not flag: if not flag_gif and not flag_pdf:
img = cv2.imread(image_file) img = cv2.imread(image_file)
if img is None: if not flag_pdf:
logger.debug("error in loading image:{}".format(image_file)) if img is None:
continue logger.debug("error in loading image:{}".format(image_file))
starttime = time.time() continue
dt_boxes, rec_res, time_dict = text_sys(img) imgs = [img]
elapse = time.time() - starttime else:
total_time += elapse page_num = args.page_num
if page_num > len(img) or page_num == 0:
logger.debug( page_num = len(img)
str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse)) imgs = img[:page_num]
for text, score in rec_res: for index, img in enumerate(imgs):
logger.debug("{}, {:.3f}".format(text, score)) starttime = time.time()
dt_boxes, rec_res, time_dict = text_sys(img)
res = [{ elapse = time.time() - starttime
"transcription": rec_res[idx][0], total_time += elapse
"points": np.array(dt_boxes[idx]).astype(np.int32).tolist(), if len(imgs) > 1:
} for idx in range(len(dt_boxes))] logger.debug(
save_pred = os.path.basename(image_file) + "\t" + json.dumps( str(idx) + '_' + str(index) + " Predict time of %s: %.3fs"
res, ensure_ascii=False) + "\n" % (image_file, elapse))
save_results.append(save_pred) else:
logger.debug(
if is_visualize: str(idx) + " Predict time of %s: %.3fs" % (image_file,
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) elapse))
boxes = dt_boxes for text, score in rec_res:
txts = [rec_res[i][0] for i in range(len(rec_res))] logger.debug("{}, {:.3f}".format(text, score))
scores = [rec_res[i][1] for i in range(len(rec_res))]
res = [{
draw_img = draw_ocr_box_txt( "transcription": rec_res[i][0],
image, "points": np.array(dt_boxes[i]).astype(np.int32).tolist(),
boxes, } for i in range(len(dt_boxes))]
txts, if len(imgs) > 1:
scores, save_pred = os.path.basename(image_file) + '_' + str(
drop_score=drop_score, index) + "\t" + json.dumps(
font_path=font_path) res, ensure_ascii=False) + "\n"
if flag: else:
image_file = image_file[:-3] + "png" save_pred = os.path.basename(image_file) + "\t" + json.dumps(
cv2.imwrite( res, ensure_ascii=False) + "\n"
os.path.join(draw_img_save_dir, os.path.basename(image_file)), save_results.append(save_pred)
draw_img[:, :, ::-1])
logger.debug("The visualized image saved in {}".format( if is_visualize:
os.path.join(draw_img_save_dir, os.path.basename(image_file)))) image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
boxes = dt_boxes
txts = [rec_res[i][0] for i in range(len(rec_res))]
scores = [rec_res[i][1] for i in range(len(rec_res))]
draw_img = draw_ocr_box_txt(
image,
boxes,
txts,
scores,
drop_score=drop_score,
font_path=font_path)
if flag_gif:
save_file = image_file[:-3] + "png"
elif flag_pdf:
save_file = image_file.replace('.pdf',
'_' + str(index) + '.png')
else:
save_file = image_file
cv2.imwrite(
os.path.join(draw_img_save_dir,
os.path.basename(save_file)),
draw_img[:, :, ::-1])
logger.debug("The visualized image saved in {}".format(
os.path.join(draw_img_save_dir, os.path.basename(
save_file))))
logger.info("The predict total time is {}".format(time.time() - _st)) logger.info("The predict total time is {}".format(time.time() - _st))
if args.benchmark: if args.benchmark:
......
...@@ -45,6 +45,7 @@ def init_args(): ...@@ -45,6 +45,7 @@ def init_args():
# params for text detector # params for text detector
parser.add_argument("--image_dir", type=str) parser.add_argument("--image_dir", type=str)
parser.add_argument("--page_num", type=int, default=0)
parser.add_argument("--det_algorithm", type=str, default='DB') parser.add_argument("--det_algorithm", type=str, default='DB')
parser.add_argument("--det_model_dir", type=str) parser.add_argument("--det_model_dir", type=str)
parser.add_argument("--det_limit_side_len", type=float, default=960) parser.add_argument("--det_limit_side_len", type=float, default=960)
...@@ -337,12 +338,11 @@ def draw_e2e_res(dt_boxes, strs, img_path): ...@@ -337,12 +338,11 @@ def draw_e2e_res(dt_boxes, strs, img_path):
return src_im return src_im
def draw_text_det_res(dt_boxes, img_path): def draw_text_det_res(dt_boxes, img):
src_im = cv2.imread(img_path)
for box in dt_boxes: for box in dt_boxes:
box = np.array(box).astype(np.int32).reshape(-1, 2) box = np.array(box).astype(np.int32).reshape(-1, 2)
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
return src_im return img
def resize_img(img, input_size=600): def resize_img(img, input_size=600):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册