diff --git a/ppcls/data/postprocess/topk.py b/ppcls/data/postprocess/topk.py index df02719471300ea8e2b7c1db286d104adabe116f..76772f568eef157c4bb5e3485ea9ec5bc41f9d20 100644 --- a/ppcls/data/postprocess/topk.py +++ b/ppcls/data/postprocess/topk.py @@ -21,9 +21,9 @@ import paddle.nn.functional as F class Topk(object): def __init__(self, topk=1, class_id_map_file=None, delimiter=None): assert isinstance(topk, (int, )) - self.class_id_map = self.parse_class_id_map(class_id_map_file) self.topk = topk self.delimiter = delimiter if delimiter is not None else " " + self.class_id_map = self.parse_class_id_map(class_id_map_file) def parse_class_id_map(self, class_id_map_file): if class_id_map_file is None: diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 3e39daef8e442aeb8c6ea1a6d3c8d96fb140547d..34cb4c4da7f05a2594e2ee7a923abd4137a92f03 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -177,7 +177,7 @@ class Engine(object): self.eval_metric_func = None elif self.eval_mode == "retrieval": if "Metric" in self.config and "Eval" in self.config["Metric"]: - metric_config = metric_config["Metric"]["Eval"] + metric_config = self.config["Metric"]["Eval"] else: metric_config = [{"name": "Recallk", "topk": (1, 5)}] self.eval_metric_func = build_metrics(metric_config)