未验证 提交 f3b2b2f4 编写于 作者: Z zhangyubo0722 提交者: GitHub

[uapi]Save predict result (#2926)

* sava predict result
上级 ae96c979
......@@ -38,7 +38,7 @@ from ppcls.utils.amp import AutoCast, build_scaler
from ppcls.utils.ema import ExponentialMovingAverage
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
from ppcls.utils.save_load import init_model
from ppcls.utils import save_load
from ppcls.utils import save_load, save_predict_result
from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
......@@ -477,6 +477,9 @@ class Engine(object):
results.extend(result)
batch_data.clear()
image_file_list.clear()
save_path = self.config["Infer"].get("save_dir", None)
if save_path:
save_predict_result(save_path, results)
return results
def export(self):
......
......@@ -26,3 +26,4 @@ from .metrics import multi_hot_encode
from .metrics import precision_recall_fscore
from .misc import AverageMeter
from .save_load import init_model, save_model
from .save_result import save_predict_result
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import json
from . import logger
def save_predict_result(save_path, result):
if os.path.splitext(save_path)[-1] == '':
if save_path[-1] == "/":
save_path = save_path[:-1]
save_path = save_path + '.json'
elif os.path.splitext(save_path)[-1] == '.json':
save_path = save_path
else:
logger.warning(
f"{save_path} is invalid input path, only files in json format are supported."
)
if os.path.exists(save_path):
logger.warning(
f"The file {save_path} will be overwritten."
)
with open(save_path, 'w', encoding='utf-8') as f:
json.dump(result, f)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册