diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 3a8ebdef1ed099ade59818c240fe43200c07cc38..3f93ceb7ababd97781966ceff46b88be62d58e33 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -38,7 +38,7 @@ from ppcls.utils.amp import AutoCast, build_scaler from ppcls.utils.ema import ExponentialMovingAverage from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url from ppcls.utils.save_load import init_model -from ppcls.utils import save_load +from ppcls.utils import save_load, save_predict_result from ppcls.data.utils.get_image_list import get_image_list from ppcls.data.postprocess import build_postprocess @@ -477,6 +477,9 @@ class Engine(object): results.extend(result) batch_data.clear() image_file_list.clear() + save_path = self.config["Infer"].get("save_dir", None) + if save_path: + save_predict_result(save_path, results) return results def export(self): diff --git a/ppcls/utils/__init__.py b/ppcls/utils/__init__.py index f9307ffd27a9ab0c1f4bab04ca6b21b9f21098e4..4701384e7f2511e1a11270fde951f3547a68cae6 100644 --- a/ppcls/utils/__init__.py +++ b/ppcls/utils/__init__.py @@ -26,3 +26,4 @@ from .metrics import multi_hot_encode from .metrics import precision_recall_fscore from .misc import AverageMeter from .save_load import init_model, save_model +from .save_result import save_predict_result diff --git a/ppcls/utils/save_result.py b/ppcls/utils/save_result.py new file mode 100644 index 0000000000000000000000000000000000000000..863113ce876595b856bd85b816dce4b828745cde --- /dev/null +++ b/ppcls/utils/save_result.py @@ -0,0 +1,35 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json + +from . import logger + +def save_predict_result(save_path, result): + if os.path.splitext(save_path)[-1] == '': + if save_path[-1] == "/": + save_path = save_path[:-1] + save_path = save_path + '.json' + elif os.path.splitext(save_path)[-1] == '.json': + save_path = save_path + else: + logger.warning( + f"{save_path} is invalid input path, only files in json format are supported." + ) + if os.path.exists(save_path): + logger.warning( + f"The file {save_path} will be overwritten." + ) + with open(save_path, 'w', encoding='utf-8') as f: + json.dump(result, f)