未验证 提交 119d6663 编写于 作者: M MissPenguin 提交者: GitHub

Merge pull request #5774 from LDOUBLEV/release/2.4

add end2end
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Reference From: https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/kie/losses/sdmgr_loss.py
from __future__ import absolute_import
from __future__ import division
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# reference from: https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/kie/heads/sdmgr_head.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......
......@@ -27,6 +27,7 @@ import numpy as np
import time
import logging
from PIL import Image
import json
import tools.infer.utility as utility
import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det
......@@ -121,11 +122,31 @@ def sorted_boxes(dt_boxes):
return _boxes
def save_results_to_txt(results, path):
if os.path.isdir(path):
if not os.path.exists(path):
os.makedirs(path)
with open(os.path.join(path, "results.txt"), 'w') as f:
f.writelines(results)
f.close()
logger.info("The results will be saved in {}".format(
os.path.join(path, "results.txt")))
else:
draw_img_save = os.path.dirname(path)
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
with open(path, 'w') as f:
f.writelines(results)
f.close()
logger.info("The results will be saved in {}".format(path))
def main(args):
image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list[args.process_id::args.total_process_num]
text_sys = TextSystem(args)
is_visualize = True
is_visualize = args.is_visualize
font_path = args.vis_font_path
drop_score = args.drop_score
......@@ -139,6 +160,7 @@ def main(args):
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
_st = time.time()
count = 0
save_res = []
for idx, image_file in enumerate(image_file_list):
img, flag = check_and_read_gif(image_file)
......@@ -152,6 +174,21 @@ def main(args):
elapse = time.time() - starttime
total_time += elapse
# save results
preds = []
dt_num = len(dt_boxes)
for dno in range(dt_num):
text, score = rec_res[dno]
if score >= drop_score:
preds.append({
"transcription": text,
"points": np.array(dt_boxes[dno]).tolist()
})
text_str = "%s, %.3f" % (text, score)
save_res.append(image_file + '\t' + json.dumps(
preds, ensure_ascii=False) + '\n')
# print predicted results
logger.debug(
str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse))
for text, score in rec_res:
......@@ -180,6 +217,9 @@ def main(args):
logger.debug("The visualized image saved in {}".format(
os.path.join(draw_img_save_dir, os.path.basename(image_file))))
# The predicted results will be saved in os.path.join(os.draw_img_save_dir, "results.txt")
save_results_to_txt(save_res, args.draw_img_save_dir)
logger.info("The predict total time is {}".format(time.time() - _st))
if args.benchmark:
text_sys.text_detector.autolog.report()
......
......@@ -114,6 +114,7 @@ def init_args():
#
parser.add_argument(
"--draw_img_save_dir", type=str, default="./inference_results")
parser.add_argument("--is_visualize", type=str2bool, default=True)
parser.add_argument("--save_crop_res", type=str2bool, default=False)
parser.add_argument("--crop_res_save_dir", type=str, default="./output")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册