From f3b2b2f4ad9ecfc372d50605c69d3bcc16bb46ca Mon Sep 17 00:00:00 2001 From: zhangyubo0722 <94225063+zhangyubo0722@users.noreply.github.com> Date: Tue, 29 Aug 2023 14:32:07 +0800 Subject: [PATCH] [uapi]Save predict result (#2926) * sava predict result --- ppcls/engine/engine.py | 5 ++++- ppcls/utils/__init__.py | 1 + ppcls/utils/save_result.py | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 ppcls/utils/save_result.py diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 3a8ebdef..3f93ceb7 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -38,7 +38,7 @@ from ppcls.utils.amp import AutoCast, build_scaler from ppcls.utils.ema import ExponentialMovingAverage from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url from ppcls.utils.save_load import init_model -from ppcls.utils import save_load +from ppcls.utils import save_load, save_predict_result from ppcls.data.utils.get_image_list import get_image_list from ppcls.data.postprocess import build_postprocess @@ -477,6 +477,9 @@ class Engine(object): results.extend(result) batch_data.clear() image_file_list.clear() + save_path = self.config["Infer"].get("save_dir", None) + if save_path: + save_predict_result(save_path, results) return results def export(self): diff --git a/ppcls/utils/__init__.py b/ppcls/utils/__init__.py index f9307ffd..4701384e 100644 --- a/ppcls/utils/__init__.py +++ b/ppcls/utils/__init__.py @@ -26,3 +26,4 @@ from .metrics import multi_hot_encode from .metrics import precision_recall_fscore from .misc import AverageMeter from .save_load import init_model, save_model +from .save_result import save_predict_result diff --git a/ppcls/utils/save_result.py b/ppcls/utils/save_result.py new file mode 100644 index 00000000..863113ce --- /dev/null +++ b/ppcls/utils/save_result.py @@ -0,0 +1,35 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json + +from . import logger + +def save_predict_result(save_path, result): + if os.path.splitext(save_path)[-1] == '': + if save_path[-1] == "/": + save_path = save_path[:-1] + save_path = save_path + '.json' + elif os.path.splitext(save_path)[-1] == '.json': + save_path = save_path + else: + logger.warning( + f"{save_path} is invalid input path, only files in json format are supported." + ) + if os.path.exists(save_path): + logger.warning( + f"The file {save_path} will be overwritten." + ) + with open(save_path, 'w', encoding='utf-8') as f: + json.dump(result, f) -- GitLab