From b23e72b17c6ede0dc52b4abe5198350dfee74fcc Mon Sep 17 00:00:00 2001 From: weishengyu Date: Fri, 4 Jun 2021 22:27:32 +0800 Subject: [PATCH] add combined_metrics --- ppcls/engine/trainer.py | 7 ++++- ppcls/metric/__init__.py | 55 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/ppcls/engine/trainer.py b/ppcls/engine/trainer.py index fe53ecfe..a5a941ec 100644 --- a/ppcls/engine/trainer.py +++ b/ppcls/engine/trainer.py @@ -31,7 +31,7 @@ from ppcls.utils import logger from ppcls.data import build_dataloader from ppcls.arch import build_model from ppcls.loss import build_loss -from ppcls.arch.loss_metrics import build_metrics +from ppcls.metric import build_metrics from ppcls.optimizer import build_optimizer from ppcls.utils.save_load import load_dygraph_pretrain from ppcls.utils.save_load import init_model @@ -379,6 +379,11 @@ class Trainer(object): query_img_id, gallery_img_id) else: metric_dict = {metric_key: 0.} + metric_msg = ", ".join([ + "{}: {:.5f}".format(key, metric_dict[key].avg) + for key in metric_dict + ]) + logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg)) return metric_dict[metric_key] diff --git a/ppcls/metric/__init__.py b/ppcls/metric/__init__.py index e69de29b..bcb13839 100644 --- a/ppcls/metric/__init__.py +++ b/ppcls/metric/__init__.py @@ -0,0 +1,55 @@ +#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from paddle import nn +import copy +from collections import OrderedDict + +from .metrics import Topk, mAP, mINP, Recallk + + +class CombinedMetrics(nn.Layer): + def __init__(self, config_list): + super().__init__() + self.metric_func_list = [] + assert isinstance(config_list, list), ( + 'operator config should be a list') + for config in config_list: + print(config) + assert isinstance(config, + dict) and len(config) == 1, "yaml format error" + metric_name = list(config)[0] + metric_params = config[metric_name] + self.metric_func_list.append(eval(metric_name)(**metric_params)) + + def __call__(self, + similarities_matrix, + query_img_id, + gallery_img_id, + x=None, + label=None): + metric_dict = OrderedDict() + for idx, metric_func in enumerate(self.metric_func_list): + if x is None: + metric_dict.update(metric_func(x, label)) + else: + metric_dict.update( + metric_func(similarities_matrix, query_img_id, + gallery_img_id)) + return metric_dict + + +def build_metrics(config): + metrics_list = CombinedMetrics(copy.deepcopy(config)) + return metrics_list -- GitLab