diff --git a/ppcls/metric/__init__.py b/ppcls/metric/__init__.py index d566a962d1fc2769f3e6ade61c80d32348adc08a..696e8f8580f089f53dc9808fd479535f25188ea5 100644 --- a/ppcls/metric/__init__.py +++ b/ppcls/metric/__init__.py @@ -16,8 +16,7 @@ from paddle import nn import copy from collections import OrderedDict -from .metrics import TopkAcc, mAP, mINP, Recallk - +from .metrics import TopkAcc, mAP, mINP, Recallk, RetriMetric class CombinedMetrics(nn.Layer): def __init__(self, config_list): @@ -25,12 +24,20 @@ class CombinedMetrics(nn.Layer): self.metric_func_list = [] assert isinstance(config_list, list), ( 'operator config should be a list') + + self.retri_config = dict() # retrieval metrics config for config in config_list: assert isinstance(config, dict) and len(config) == 1, "yaml format error" metric_name = list(config)[0] + if metric_name in ["Recallk", "mAP", "mINP"]: + self.retri_config[metric_name] = config[metric_name] + continue metric_params = config[metric_name] self.metric_func_list.append(eval(metric_name)(**metric_params)) + + if self.retri_config: + self.metric_func_list.append(RetriMetric(self.retri_config)) def __call__(self, *args, **kwargs): metric_dict = OrderedDict()