未验证 提交 119d6663 编写于 作者: M MissPenguin 提交者: GitHub

Merge pull request #5774 from LDOUBLEV/release/2.4

add end2end
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. # Reference From: https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/kie/losses/sdmgr_loss.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. # reference from: https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/kie/heads/sdmgr_head.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
...@@ -27,6 +27,7 @@ import numpy as np ...@@ -27,6 +27,7 @@ import numpy as np
import time import time
import logging import logging
from PIL import Image from PIL import Image
import json
import tools.infer.utility as utility import tools.infer.utility as utility
import tools.infer.predict_rec as predict_rec import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det import tools.infer.predict_det as predict_det
...@@ -121,11 +122,31 @@ def sorted_boxes(dt_boxes): ...@@ -121,11 +122,31 @@ def sorted_boxes(dt_boxes):
return _boxes return _boxes
def save_results_to_txt(results, path):
if os.path.isdir(path):
if not os.path.exists(path):
os.makedirs(path)
with open(os.path.join(path, "results.txt"), 'w') as f:
f.writelines(results)
f.close()
logger.info("The results will be saved in {}".format(
os.path.join(path, "results.txt")))
else:
draw_img_save = os.path.dirname(path)
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
with open(path, 'w') as f:
f.writelines(results)
f.close()
logger.info("The results will be saved in {}".format(path))
def main(args): def main(args):
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list[args.process_id::args.total_process_num] image_file_list = image_file_list[args.process_id::args.total_process_num]
text_sys = TextSystem(args) text_sys = TextSystem(args)
is_visualize = True is_visualize = args.is_visualize
font_path = args.vis_font_path font_path = args.vis_font_path
drop_score = args.drop_score drop_score = args.drop_score
...@@ -139,6 +160,7 @@ def main(args): ...@@ -139,6 +160,7 @@ def main(args):
cpu_mem, gpu_mem, gpu_util = 0, 0, 0 cpu_mem, gpu_mem, gpu_util = 0, 0, 0
_st = time.time() _st = time.time()
count = 0 count = 0
save_res = []
for idx, image_file in enumerate(image_file_list): for idx, image_file in enumerate(image_file_list):
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
...@@ -152,6 +174,21 @@ def main(args): ...@@ -152,6 +174,21 @@ def main(args):
elapse = time.time() - starttime elapse = time.time() - starttime
total_time += elapse total_time += elapse
# save results
preds = []
dt_num = len(dt_boxes)
for dno in range(dt_num):
text, score = rec_res[dno]
if score >= drop_score:
preds.append({
"transcription": text,
"points": np.array(dt_boxes[dno]).tolist()
})
text_str = "%s, %.3f" % (text, score)
save_res.append(image_file + '\t' + json.dumps(
preds, ensure_ascii=False) + '\n')
# print predicted results
logger.debug( logger.debug(
str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse)) str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse))
for text, score in rec_res: for text, score in rec_res:
...@@ -180,6 +217,9 @@ def main(args): ...@@ -180,6 +217,9 @@ def main(args):
logger.debug("The visualized image saved in {}".format( logger.debug("The visualized image saved in {}".format(
os.path.join(draw_img_save_dir, os.path.basename(image_file)))) os.path.join(draw_img_save_dir, os.path.basename(image_file))))
# The predicted results will be saved in os.path.join(os.draw_img_save_dir, "results.txt")
save_results_to_txt(save_res, args.draw_img_save_dir)
logger.info("The predict total time is {}".format(time.time() - _st)) logger.info("The predict total time is {}".format(time.time() - _st))
if args.benchmark: if args.benchmark:
text_sys.text_detector.autolog.report() text_sys.text_detector.autolog.report()
......
...@@ -114,6 +114,7 @@ def init_args(): ...@@ -114,6 +114,7 @@ def init_args():
# #
parser.add_argument( parser.add_argument(
"--draw_img_save_dir", type=str, default="./inference_results") "--draw_img_save_dir", type=str, default="./inference_results")
parser.add_argument("--is_visualize", type=str2bool, default=True)
parser.add_argument("--save_crop_res", type=str2bool, default=False) parser.add_argument("--save_crop_res", type=str2bool, default=False)
parser.add_argument("--crop_res_save_dir", type=str, default="./output") parser.add_argument("--crop_res_save_dir", type=str, default="./output")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册