未验证 提交 0f987097 编写于 作者: F Felix 提交者: GitHub

Update __init__.py

上级 3e1e25af
...@@ -16,8 +16,7 @@ from paddle import nn ...@@ -16,8 +16,7 @@ from paddle import nn
import copy import copy
from collections import OrderedDict from collections import OrderedDict
from .metrics import TopkAcc, mAP, mINP, Recallk from .metrics import TopkAcc, mAP, mINP, Recallk, RetriMetric
class CombinedMetrics(nn.Layer): class CombinedMetrics(nn.Layer):
def __init__(self, config_list): def __init__(self, config_list):
...@@ -25,12 +24,20 @@ class CombinedMetrics(nn.Layer): ...@@ -25,12 +24,20 @@ class CombinedMetrics(nn.Layer):
self.metric_func_list = [] self.metric_func_list = []
assert isinstance(config_list, list), ( assert isinstance(config_list, list), (
'operator config should be a list') 'operator config should be a list')
self.retri_config = dict() # retrieval metrics config
for config in config_list: for config in config_list:
assert isinstance(config, assert isinstance(config,
dict) and len(config) == 1, "yaml format error" dict) and len(config) == 1, "yaml format error"
metric_name = list(config)[0] metric_name = list(config)[0]
if metric_name in ["Recallk", "mAP", "mINP"]:
self.retri_config[metric_name] = config[metric_name]
continue
metric_params = config[metric_name] metric_params = config[metric_name]
self.metric_func_list.append(eval(metric_name)(**metric_params)) self.metric_func_list.append(eval(metric_name)(**metric_params))
if self.retri_config:
self.metric_func_list.append(RetriMetric(self.retri_config))
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
metric_dict = OrderedDict() metric_dict = OrderedDict()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册