未验证 提交 2323caec 编写于 作者: B Bin Lu 提交者: GitHub

Update __init__.py

上级 65780c29
...@@ -16,39 +16,31 @@ from paddle import nn ...@@ -16,39 +16,31 @@ from paddle import nn
import copy import copy
from collections import OrderedDict from collections import OrderedDict
from .metrics import TopkAcc, mAP, mINP, Recallk, RetriMetric from .metrics import TopkAcc, mAP, mINP, Recallk
from .metrics import DistillationTopkAcc from .metrics import DistillationTopkAcc
class CombinedMetrics(nn.Layer): class CombinedMetrics(nn.Layer):
def __init__(self, config_list): def __init__(self, config_list):
super().__init__() super().__init__()
self.metric_func_list = [] self.metric_func_list = []
assert isinstance(config_list, list), ( assert isinstance(config_list, list), (
'operator config should be a list') 'operator config should be a list')
self.retri_config = dict() # retrieval metrics config
for config in config_list: for config in config_list:
assert isinstance(config, assert isinstance(config,
dict) and len(config) == 1, "yaml format error" dict) and len(config) == 1, "yaml format error"
metric_name = list(config)[0] metric_name = list(config)[0]
if metric_name in ["Recallk", "mAP", "mINP"]:
self.retri_config[metric_name] = config[metric_name]
continue
metric_params = config[metric_name] metric_params = config[metric_name]
if metric_params is not None:
self.metric_func_list.append(eval(metric_name)(**metric_params)) self.metric_func_list.append(eval(metric_name)(**metric_params))
else:
if self.retri_config: self.metric_func_list.append(eval(metric_name)())
self.metric_func_list.append(RetriMetric(self.retri_config))
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
metric_dict = OrderedDict() metric_dict = OrderedDict()
for idx, metric_func in enumerate(self.metric_func_list): for idx, metric_func in enumerate(self.metric_func_list):
metric_dict.update(metric_func(*args, **kwargs)) metric_dict.update(metric_func(*args, **kwargs))
return metric_dict return metric_dict
def build_metrics(config): def build_metrics(config):
metrics_list = CombinedMetrics(copy.deepcopy(config)) metrics_list = CombinedMetrics(copy.deepcopy(config))
return metrics_list return metrics_list
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册