未验证 提交 89cc44e7 编写于 作者: A andyj 提交者: GitHub

Merge pull request #7775 from andyjpaddle/support_pdf_infer

Support infer pdf file
......@@ -7,6 +7,7 @@
| 参数名称 | 类型 | 默认值 | 含义 |
| :--: | :--: | :--: | :--: |
| image_dir | str | 无,必须显式指定 | 图像或者文件夹路径 |
| page_num | int | 0 | 当输入类型为pdf文件时有效,指定预测前面page_num页,默认预测所有页 |
| vis_font_path | str | "./doc/fonts/simfang.ttf" | 用于可视化的字体路径 |
| drop_score | float | 0.5 | 识别得分小于该值的结果会被丢弃,不会作为返回结果 |
| use_pdserving | bool | False | 是否使用Paddle Serving进行预测 |
......
......@@ -7,6 +7,7 @@ When using PaddleOCR for model inference, you can customize the modification par
| parameters | type | default | implication |
| :--: | :--: | :--: | :--: |
| image_dir | str | None, must be specified explicitly | Image or folder path |
| page_num | int | 0 | Valid when the input type is pdf file, specify to predict the previous page_num pages, all pages are predicted by default |
| vis_font_path | str | "./doc/fonts/simfang.ttf" | font path for visualization |
| drop_score | float | 0.5 | Results with a recognition score less than this value will be discarded and will not be returned as results |
| use_pdserving | bool | False | Whether to use Paddle Serving for prediction |
......
......@@ -282,44 +282,67 @@ if __name__ == "__main__":
args = utility.parse_args()
image_file_list = get_image_file_list(args.image_dir)
text_detector = TextDetector(args)
count = 0
total_time = 0
draw_img_save = "./inference_results"
draw_img_save_dir = args.draw_img_save_dir
os.makedirs(draw_img_save_dir, exist_ok=True)
if args.warmup:
img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
for i in range(2):
res = text_detector(img)
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
save_results = []
for image_file in image_file_list:
img, flag, _ = check_and_read(image_file)
if not flag:
for idx, image_file in enumerate(image_file_list):
img, flag_gif, flag_pdf = check_and_read(image_file)
if not flag_gif and not flag_pdf:
img = cv2.imread(image_file)
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
st = time.time()
dt_boxes, _ = text_detector(img)
elapse = time.time() - st
if count > 0:
if not flag_pdf:
if img is None:
logger.debug("error in loading image:{}".format(image_file))
continue
imgs = [img]
else:
page_num = args.page_num
if page_num > len(img) or page_num == 0:
page_num = len(img)
imgs = img[:page_num]
for index, img in enumerate(imgs):
st = time.time()
dt_boxes, _ = text_detector(img)
elapse = time.time() - st
total_time += elapse
count += 1
save_pred = os.path.basename(image_file) + "\t" + str(
json.dumps([x.tolist() for x in dt_boxes])) + "\n"
save_results.append(save_pred)
logger.info(save_pred)
logger.info("The predict time of {}: {}".format(image_file, elapse))
src_im = utility.draw_text_det_res(dt_boxes, image_file)
img_name_pure = os.path.split(image_file)[-1]
img_path = os.path.join(draw_img_save,
"det_res_{}".format(img_name_pure))
cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path))
if len(imgs) > 1:
save_pred = os.path.basename(image_file) + '_' + str(
index) + "\t" + str(
json.dumps([x.tolist() for x in dt_boxes])) + "\n"
else:
save_pred = os.path.basename(image_file) + "\t" + str(
json.dumps([x.tolist() for x in dt_boxes])) + "\n"
save_results.append(save_pred)
logger.info(save_pred)
if len(imgs) > 1:
logger.info("{}_{} The predict time of {}: {}".format(
idx, index, image_file, elapse))
else:
logger.info("{} The predict time of {}: {}".format(
idx, image_file, elapse))
src_im = utility.draw_text_det_res(dt_boxes, img)
if flag_gif:
save_file = image_file[:-3] + "png"
elif flag_pdf:
save_file = image_file.replace('.pdf',
'_' + str(index) + '.png')
else:
save_file = image_file
img_path = os.path.join(
draw_img_save_dir,
"det_res_{}".format(os.path.basename(save_file)))
cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path))
with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f:
with open(os.path.join(draw_img_save_dir, "det_results.txt"), 'w') as f:
f.writelines(save_results)
f.close()
if args.benchmark:
......
......@@ -159,50 +159,75 @@ def main(args):
count = 0
for idx, image_file in enumerate(image_file_list):
img, flag, _ = check_and_read(image_file)
if not flag:
img, flag_gif, flag_pdf = check_and_read(image_file)
if not flag_gif and not flag_pdf:
img = cv2.imread(image_file)
if img is None:
logger.debug("error in loading image:{}".format(image_file))
continue
starttime = time.time()
dt_boxes, rec_res, time_dict = text_sys(img)
elapse = time.time() - starttime
total_time += elapse
logger.debug(
str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse))
for text, score in rec_res:
logger.debug("{}, {:.3f}".format(text, score))
res = [{
"transcription": rec_res[idx][0],
"points": np.array(dt_boxes[idx]).astype(np.int32).tolist(),
} for idx in range(len(dt_boxes))]
save_pred = os.path.basename(image_file) + "\t" + json.dumps(
res, ensure_ascii=False) + "\n"
save_results.append(save_pred)
if is_visualize:
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
boxes = dt_boxes
txts = [rec_res[i][0] for i in range(len(rec_res))]
scores = [rec_res[i][1] for i in range(len(rec_res))]
draw_img = draw_ocr_box_txt(
image,
boxes,
txts,
scores,
drop_score=drop_score,
font_path=font_path)
if flag:
image_file = image_file[:-3] + "png"
cv2.imwrite(
os.path.join(draw_img_save_dir, os.path.basename(image_file)),
draw_img[:, :, ::-1])
logger.debug("The visualized image saved in {}".format(
os.path.join(draw_img_save_dir, os.path.basename(image_file))))
if not flag_pdf:
if img is None:
logger.debug("error in loading image:{}".format(image_file))
continue
imgs = [img]
else:
page_num = args.page_num
if page_num > len(img) or page_num == 0:
page_num = len(img)
imgs = img[:page_num]
for index, img in enumerate(imgs):
starttime = time.time()
dt_boxes, rec_res, time_dict = text_sys(img)
elapse = time.time() - starttime
total_time += elapse
if len(imgs) > 1:
logger.debug(
str(idx) + '_' + str(index) + " Predict time of %s: %.3fs"
% (image_file, elapse))
else:
logger.debug(
str(idx) + " Predict time of %s: %.3fs" % (image_file,
elapse))
for text, score in rec_res:
logger.debug("{}, {:.3f}".format(text, score))
res = [{
"transcription": rec_res[i][0],
"points": np.array(dt_boxes[i]).astype(np.int32).tolist(),
} for i in range(len(dt_boxes))]
if len(imgs) > 1:
save_pred = os.path.basename(image_file) + '_' + str(
index) + "\t" + json.dumps(
res, ensure_ascii=False) + "\n"
else:
save_pred = os.path.basename(image_file) + "\t" + json.dumps(
res, ensure_ascii=False) + "\n"
save_results.append(save_pred)
if is_visualize:
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
boxes = dt_boxes
txts = [rec_res[i][0] for i in range(len(rec_res))]
scores = [rec_res[i][1] for i in range(len(rec_res))]
draw_img = draw_ocr_box_txt(
image,
boxes,
txts,
scores,
drop_score=drop_score,
font_path=font_path)
if flag_gif:
save_file = image_file[:-3] + "png"
elif flag_pdf:
save_file = image_file.replace('.pdf',
'_' + str(index) + '.png')
else:
save_file = image_file
cv2.imwrite(
os.path.join(draw_img_save_dir,
os.path.basename(save_file)),
draw_img[:, :, ::-1])
logger.debug("The visualized image saved in {}".format(
os.path.join(draw_img_save_dir, os.path.basename(
save_file))))
logger.info("The predict total time is {}".format(time.time() - _st))
if args.benchmark:
......
......@@ -45,6 +45,7 @@ def init_args():
# params for text detector
parser.add_argument("--image_dir", type=str)
parser.add_argument("--page_num", type=int, default=0)
parser.add_argument("--det_algorithm", type=str, default='DB')
parser.add_argument("--det_model_dir", type=str)
parser.add_argument("--det_limit_side_len", type=float, default=960)
......@@ -337,12 +338,11 @@ def draw_e2e_res(dt_boxes, strs, img_path):
return src_im
def draw_text_det_res(dt_boxes, img_path):
src_im = cv2.imread(img_path)
def draw_text_det_res(dt_boxes, img):
for box in dt_boxes:
box = np.array(box).astype(np.int32).reshape(-1, 2)
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
return src_im
cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
return img
def resize_img(img, input_size=600):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册