diff --git a/ppocr/losses/kie_sdmgr_loss.py b/ppocr/losses/kie_sdmgr_loss.py index 8f2173e49904926ebab2c450890c4fafe3f36b50..354fbded3daedcb09028f07e4ff16346b46a26c1 100644 --- a/ppocr/losses/kie_sdmgr_loss.py +++ b/ppocr/losses/kie_sdmgr_loss.py @@ -1,16 +1,4 @@ -# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Reference From: https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/kie/losses/sdmgr_loss.py from __future__ import absolute_import from __future__ import division diff --git a/ppocr/modeling/heads/kie_sdmgr_head.py b/ppocr/modeling/heads/kie_sdmgr_head.py index 46ac0ed8dcaccb7628ef87fbe851a2b6acd60d55..156aa9177acbc31c77a29ec76989c7d92cdb146f 100644 --- a/ppocr/modeling/heads/kie_sdmgr_head.py +++ b/ppocr/modeling/heads/kie_sdmgr_head.py @@ -1,17 +1,4 @@ -# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - +# reference from: https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/kie/heads/sdmgr_head.py from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 8d674809a5fe22e458fcb0c68419a7313e71d5f6..e9aff6d210114a1ebcb42409a7b9480f69ead664 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -27,6 +27,7 @@ import numpy as np import time import logging from PIL import Image +import json import tools.infer.utility as utility import tools.infer.predict_rec as predict_rec import tools.infer.predict_det as predict_det @@ -121,11 +122,31 @@ def sorted_boxes(dt_boxes): return _boxes +def save_results_to_txt(results, path): + if os.path.isdir(path): + if not os.path.exists(path): + os.makedirs(path) + with open(os.path.join(path, "results.txt"), 'w') as f: + f.writelines(results) + f.close() + logger.info("The results will be saved in {}".format( + os.path.join(path, "results.txt"))) + else: + draw_img_save = os.path.dirname(path) + if not os.path.exists(draw_img_save): + os.makedirs(draw_img_save) + + with open(path, 'w') as f: + f.writelines(results) + f.close() + logger.info("The results will be saved in {}".format(path)) + + def main(args): image_file_list = get_image_file_list(args.image_dir) image_file_list = image_file_list[args.process_id::args.total_process_num] text_sys = TextSystem(args) - is_visualize = True + is_visualize = args.is_visualize font_path = args.vis_font_path drop_score = args.drop_score @@ -139,6 +160,7 @@ def main(args): cpu_mem, gpu_mem, gpu_util = 0, 0, 0 _st = time.time() count = 0 + save_res = [] for idx, image_file in enumerate(image_file_list): img, flag = check_and_read_gif(image_file) @@ -152,6 +174,21 @@ def main(args): elapse = time.time() - starttime total_time += elapse + # save results + preds = [] + dt_num = len(dt_boxes) + for dno in range(dt_num): + text, score = rec_res[dno] + if score >= drop_score: + preds.append({ + "transcription": text, + "points": np.array(dt_boxes[dno]).tolist() + }) + text_str = "%s, %.3f" % (text, score) + save_res.append(image_file + '\t' + json.dumps( + preds, ensure_ascii=False) + '\n') + + # print predicted results logger.debug( str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse)) for text, score in rec_res: @@ -180,6 +217,9 @@ def main(args): logger.debug("The visualized image saved in {}".format( os.path.join(draw_img_save_dir, os.path.basename(image_file)))) + # The predicted results will be saved in os.path.join(os.draw_img_save_dir, "results.txt") + save_results_to_txt(save_res, args.draw_img_save_dir) + logger.info("The predict total time is {}".format(time.time() - _st)) if args.benchmark: text_sys.text_detector.autolog.report() diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 33ed62125c0b59b5f23b72b5b8f6ecb3b0835cf3..7b7b81e3cd22c12561b30e6705eded1c92ec7761 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -114,6 +114,7 @@ def init_args(): # parser.add_argument( "--draw_img_save_dir", type=str, default="./inference_results") + parser.add_argument("--is_visualize", type=str2bool, default=True) parser.add_argument("--save_crop_res", type=str2bool, default=False) parser.add_argument("--crop_res_save_dir", type=str, default="./output")