diff --git a/ppcls/metric/__init__.py b/ppcls/metric/__init__.py index d566a962d1fc2769f3e6ade61c80d32348adc08a..696e8f8580f089f53dc9808fd479535f25188ea5 100644 --- a/ppcls/metric/__init__.py +++ b/ppcls/metric/__init__.py @@ -16,8 +16,7 @@ from paddle import nn import copy from collections import OrderedDict -from .metrics import TopkAcc, mAP, mINP, Recallk - +from .metrics import TopkAcc, mAP, mINP, Recallk, RetriMetric class CombinedMetrics(nn.Layer): def __init__(self, config_list): @@ -25,12 +24,20 @@ class CombinedMetrics(nn.Layer): self.metric_func_list = [] assert isinstance(config_list, list), ( 'operator config should be a list') + + self.retri_config = dict() # retrieval metrics config for config in config_list: assert isinstance(config, dict) and len(config) == 1, "yaml format error" metric_name = list(config)[0] + if metric_name in ["Recallk", "mAP", "mINP"]: + self.retri_config[metric_name] = config[metric_name] + continue metric_params = config[metric_name] self.metric_func_list.append(eval(metric_name)(**metric_params)) + + if self.retri_config: + self.metric_func_list.append(RetriMetric(self.retri_config)) def __call__(self, *args, **kwargs): metric_dict = OrderedDict() diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 05723c12303fdf616110a0cccffb883058f62674..d2e66bc54dc298ec329b45305684eb33f39da11b 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -84,6 +84,29 @@ class Recallk(nn.Layer): metric_dict["recall{}".format(k)] = all_cmc[k - 1] return metric_dict +# retrieval metrics +class RetriMetric(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + self.max_rank = 50 #max(self.topk) if max(self.topk) > 50 else 50 + + def forward(self, similarities_matrix, query_img_id, gallery_img_id): + metric_dict = dict() + all_cmc, all_AP, all_INP = get_metrics(similarities_matrix, query_img_id, + gallery_img_id, self.max_rank) + if "Recallk" in self.config.keys(): + topk = self.config['Recallk']['topk'] + for k in topk: + metric_dict["recall{}".format(k)] = all_cmc[k - 1] + if "mAP" in self.config.keys(): + mAP = np.mean(all_AP) + metric_dict["mAP"] = mAP + if "mINP" in self.config.keys(): + mINP = np.mean(all_INP) + metric_dict["mINP"] = mINP + return metric_dict + @lru_cache() def get_metrics(similarities_matrix, query_img_id, gallery_img_id,