From 2323caec99c86df9af73fd4e34164c5e35b28ae3 Mon Sep 17 00:00:00 2001 From: Bin Lu Date: Fri, 11 Jun 2021 10:38:17 +0800 Subject: [PATCH] Update __init__.py --- ppcls/metric/__init__.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/ppcls/metric/__init__.py b/ppcls/metric/__init__.py index 95f86e4a..36161b69 100644 --- a/ppcls/metric/__init__.py +++ b/ppcls/metric/__init__.py @@ -16,39 +16,31 @@ from paddle import nn import copy from collections import OrderedDict -from .metrics import TopkAcc, mAP, mINP, Recallk, RetriMetric +from .metrics import TopkAcc, mAP, mINP, Recallk from .metrics import DistillationTopkAcc - class CombinedMetrics(nn.Layer): def __init__(self, config_list): super().__init__() self.metric_func_list = [] assert isinstance(config_list, list), ( 'operator config should be a list') - - self.retri_config = dict() # retrieval metrics config for config in config_list: assert isinstance(config, dict) and len(config) == 1, "yaml format error" metric_name = list(config)[0] - if metric_name in ["Recallk", "mAP", "mINP"]: - self.retri_config[metric_name] = config[metric_name] - continue metric_params = config[metric_name] - self.metric_func_list.append(eval(metric_name)(**metric_params)) - - if self.retri_config: - self.metric_func_list.append(RetriMetric(self.retri_config)) + if metric_params is not None: + self.metric_func_list.append(eval(metric_name)(**metric_params)) + else: + self.metric_func_list.append(eval(metric_name)()) def __call__(self, *args, **kwargs): metric_dict = OrderedDict() for idx, metric_func in enumerate(self.metric_func_list): metric_dict.update(metric_func(*args, **kwargs)) - return metric_dict - def build_metrics(config): metrics_list = CombinedMetrics(copy.deepcopy(config)) return metrics_list -- GitLab