未验证 提交 0f987097 编写于 作者: F Felix 提交者: GitHub

Update __init__.py

上级 3e1e25af
......@@ -16,8 +16,7 @@ from paddle import nn
import copy
from collections import OrderedDict
from .metrics import TopkAcc, mAP, mINP, Recallk
from .metrics import TopkAcc, mAP, mINP, Recallk, RetriMetric
class CombinedMetrics(nn.Layer):
def __init__(self, config_list):
......@@ -25,12 +24,20 @@ class CombinedMetrics(nn.Layer):
self.metric_func_list = []
assert isinstance(config_list, list), (
'operator config should be a list')
self.retri_config = dict() # retrieval metrics config
for config in config_list:
assert isinstance(config,
dict) and len(config) == 1, "yaml format error"
metric_name = list(config)[0]
if metric_name in ["Recallk", "mAP", "mINP"]:
self.retri_config[metric_name] = config[metric_name]
continue
metric_params = config[metric_name]
self.metric_func_list.append(eval(metric_name)(**metric_params))
if self.retri_config:
self.metric_func_list.append(RetriMetric(self.retri_config))
def __call__(self, *args, **kwargs):
metric_dict = OrderedDict()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册