提交 e61f40ef 编写于 作者: A andyjpaddle

support infer pdf file

上级 2f312ae0
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
| 参数名称 | 类型 | 默认值 | 含义 | | 参数名称 | 类型 | 默认值 | 含义 |
| :--: | :--: | :--: | :--: | | :--: | :--: | :--: | :--: |
| image_dir | str | 无,必须显式指定 | 图像或者文件夹路径 | | image_dir | str | 无,必须显式指定 | 图像或者文件夹路径 |
| page_num | int | 0 | 当输入类型为pdf文件时有效,指定预测前面page_num页,默认预测所有页 |
| vis_font_path | str | "./doc/fonts/simfang.ttf" | 用于可视化的字体路径 | | vis_font_path | str | "./doc/fonts/simfang.ttf" | 用于可视化的字体路径 |
| drop_score | float | 0.5 | 识别得分小于该值的结果会被丢弃,不会作为返回结果 | | drop_score | float | 0.5 | 识别得分小于该值的结果会被丢弃,不会作为返回结果 |
| use_pdserving | bool | False | 是否使用Paddle Serving进行预测 | | use_pdserving | bool | False | 是否使用Paddle Serving进行预测 |
......
...@@ -7,6 +7,7 @@ When using PaddleOCR for model inference, you can customize the modification par ...@@ -7,6 +7,7 @@ When using PaddleOCR for model inference, you can customize the modification par
| parameters | type | default | implication | | parameters | type | default | implication |
| :--: | :--: | :--: | :--: | | :--: | :--: | :--: | :--: |
| image_dir | str | None, must be specified explicitly | Image or folder path | | image_dir | str | None, must be specified explicitly | Image or folder path |
| page_num | int | 0 | Valid when the input type is pdf file, specify to predict the previous page_num pages, all pages are predicted by default |
| vis_font_path | str | "./doc/fonts/simfang.ttf" | font path for visualization | | vis_font_path | str | "./doc/fonts/simfang.ttf" | font path for visualization |
| drop_score | float | 0.5 | Results with a recognition score less than this value will be discarded and will not be returned as results | | drop_score | float | 0.5 | Results with a recognition score less than this value will be discarded and will not be returned as results |
| use_pdserving | bool | False | Whether to use Paddle Serving for prediction | | use_pdserving | bool | False | Whether to use Paddle Serving for prediction |
......
...@@ -282,44 +282,69 @@ if __name__ == "__main__": ...@@ -282,44 +282,69 @@ if __name__ == "__main__":
args = utility.parse_args() args = utility.parse_args()
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
text_detector = TextDetector(args) text_detector = TextDetector(args)
count = 0
total_time = 0 total_time = 0
draw_img_save = "./inference_results" draw_img_save_dir = args.draw_img_save_dir
os.makedirs(draw_img_save_dir, exist_ok=True)
if args.warmup: if args.warmup:
img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
for i in range(2): for i in range(2):
res = text_detector(img) res = text_detector(img)
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
save_results = [] save_results = []
for image_file in image_file_list: for idx, image_file in enumerate(image_file_list):
img, flag, _ = check_and_read(image_file) img, flag_gif, flag_pdf = check_and_read(image_file)
if not flag: if not flag_gif and not flag_pdf:
img = cv2.imread(image_file) img = cv2.imread(image_file)
if not flag_pdf:
if img is None: if img is None:
logger.info("error in loading image:{}".format(image_file)) logger.debug("error in loading image:{}".format(image_file))
continue continue
imgs = [img]
else:
page_num = args.page_num
if page_num > len(img) or page_num == 0:
page_num = len(img)
imgs = img[:page_num]
for index, img in enumerate(imgs):
st = time.time() st = time.time()
dt_boxes, _ = text_detector(img) dt_boxes, _ = text_detector(img)
elapse = time.time() - st elapse = time.time() - st
if count > 0:
total_time += elapse total_time += elapse
count += 1 if len(imgs) > 1:
save_pred = os.path.basename(image_file) + '_' + str(
index) + "\t" + str(
json.dumps([x.tolist() for x in dt_boxes])) + "\n"
else:
save_pred = os.path.basename(image_file) + "\t" + str( save_pred = os.path.basename(image_file) + "\t" + str(
json.dumps([x.tolist() for x in dt_boxes])) + "\n" json.dumps([x.tolist() for x in dt_boxes])) + "\n"
save_results.append(save_pred) save_results.append(save_pred)
logger.info(save_pred) logger.info(save_pred)
logger.info("The predict time of {}: {}".format(image_file, elapse)) if len(imgs) > 1:
src_im = utility.draw_text_det_res(dt_boxes, image_file) logger.info("{}_{} The predict time of {}: {}".format(
img_name_pure = os.path.split(image_file)[-1] idx, index, image_file, elapse))
img_path = os.path.join(draw_img_save, else:
"det_res_{}".format(img_name_pure)) logger.info("{} The predict time of {}: {}".format(
idx, image_file, elapse))
if flag_pdf:
src_im = utility.draw_text_det_res(dt_boxes, img, flag_pdf)
else:
src_im = utility.draw_text_det_res(dt_boxes, image_file,
flag_pdf)
if flag_gif:
save_file = image_file[:-3] + "png"
elif flag_pdf:
save_file = image_file.replace('.pdf',
'_' + str(index) + '.png')
else:
save_file = image_file
img_path = os.path.join(
draw_img_save_dir,
"det_res_{}".format(os.path.basename(save_file)))
cv2.imwrite(img_path, src_im) cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path)) logger.info("The visualized image saved in {}".format(img_path))
with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f: with open(os.path.join(draw_img_save_dir, "det_results.txt"), 'w') as f:
f.writelines(save_results) f.writelines(save_results)
f.close() f.close()
if args.benchmark: if args.benchmark:
......
...@@ -159,26 +159,44 @@ def main(args): ...@@ -159,26 +159,44 @@ def main(args):
count = 0 count = 0
for idx, image_file in enumerate(image_file_list): for idx, image_file in enumerate(image_file_list):
img, flag, _ = check_and_read(image_file) img, flag_gif, flag_pdf = check_and_read(image_file)
if not flag: if not flag_gif and not flag_pdf:
img = cv2.imread(image_file) img = cv2.imread(image_file)
if not flag_pdf:
if img is None: if img is None:
logger.debug("error in loading image:{}".format(image_file)) logger.debug("error in loading image:{}".format(image_file))
continue continue
imgs = [img]
else:
page_num = args.page_num
if page_num > len(img) or page_num == 0:
page_num = len(img)
imgs = img[:page_num]
for index, img in enumerate(imgs):
starttime = time.time() starttime = time.time()
dt_boxes, rec_res, time_dict = text_sys(img) dt_boxes, rec_res, time_dict = text_sys(img)
elapse = time.time() - starttime elapse = time.time() - starttime
total_time += elapse total_time += elapse
if len(imgs) > 1:
logger.debug(
str(idx) + '_' + str(index) + " Predict time of %s: %.3fs"
% (image_file, elapse))
else:
logger.debug( logger.debug(
str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse)) str(idx) + " Predict time of %s: %.3fs" % (image_file,
elapse))
for text, score in rec_res: for text, score in rec_res:
logger.debug("{}, {:.3f}".format(text, score)) logger.debug("{}, {:.3f}".format(text, score))
res = [{ res = [{
"transcription": rec_res[idx][0], "transcription": rec_res[i][0],
"points": np.array(dt_boxes[idx]).astype(np.int32).tolist(), "points": np.array(dt_boxes[i]).astype(np.int32).tolist(),
} for idx in range(len(dt_boxes))] } for i in range(len(dt_boxes))]
if len(imgs) > 1:
save_pred = os.path.basename(image_file) + '_' + str(
index) + "\t" + json.dumps(
res, ensure_ascii=False) + "\n"
else:
save_pred = os.path.basename(image_file) + "\t" + json.dumps( save_pred = os.path.basename(image_file) + "\t" + json.dumps(
res, ensure_ascii=False) + "\n" res, ensure_ascii=False) + "\n"
save_results.append(save_pred) save_results.append(save_pred)
...@@ -196,13 +214,20 @@ def main(args): ...@@ -196,13 +214,20 @@ def main(args):
scores, scores,
drop_score=drop_score, drop_score=drop_score,
font_path=font_path) font_path=font_path)
if flag: if flag_gif:
image_file = image_file[:-3] + "png" save_file = image_file[:-3] + "png"
elif flag_pdf:
save_file = image_file.replace('.pdf',
'_' + str(index) + '.png')
else:
save_file = image_file
cv2.imwrite( cv2.imwrite(
os.path.join(draw_img_save_dir, os.path.basename(image_file)), os.path.join(draw_img_save_dir,
os.path.basename(save_file)),
draw_img[:, :, ::-1]) draw_img[:, :, ::-1])
logger.debug("The visualized image saved in {}".format( logger.debug("The visualized image saved in {}".format(
os.path.join(draw_img_save_dir, os.path.basename(image_file)))) os.path.join(draw_img_save_dir, os.path.basename(
save_file))))
logger.info("The predict total time is {}".format(time.time() - _st)) logger.info("The predict total time is {}".format(time.time() - _st))
if args.benchmark: if args.benchmark:
......
...@@ -45,6 +45,7 @@ def init_args(): ...@@ -45,6 +45,7 @@ def init_args():
# params for text detector # params for text detector
parser.add_argument("--image_dir", type=str) parser.add_argument("--image_dir", type=str)
parser.add_argument("--page_num", type=int, default=0)
parser.add_argument("--det_algorithm", type=str, default='DB') parser.add_argument("--det_algorithm", type=str, default='DB')
parser.add_argument("--det_model_dir", type=str) parser.add_argument("--det_model_dir", type=str)
parser.add_argument("--det_limit_side_len", type=float, default=960) parser.add_argument("--det_limit_side_len", type=float, default=960)
...@@ -337,8 +338,11 @@ def draw_e2e_res(dt_boxes, strs, img_path): ...@@ -337,8 +338,11 @@ def draw_e2e_res(dt_boxes, strs, img_path):
return src_im return src_im
def draw_text_det_res(dt_boxes, img_path): def draw_text_det_res(dt_boxes, img_path, flag_pdf=False):
if not flag_pdf:
src_im = cv2.imread(img_path) src_im = cv2.imread(img_path)
else:
src_im = img_path
for box in dt_boxes: for box in dt_boxes:
box = np.array(box).astype(np.int32).reshape(-1, 2) box = np.array(box).astype(np.int32).reshape(-1, 2)
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册