diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 8e8a20c2a784c72cf06212da85f353dd89072c57..3e39daef8e442aeb8c6ea1a6d3c8d96fb140547d 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -152,39 +152,34 @@ class Engine(object): self.eval_loss_func = None # build metric - if self.mode == 'train': - metric_config = self.config.get("Metric") - if metric_config is not None: - metric_config = metric_config.get("Train") - if metric_config is not None: - if hasattr( - self.train_dataloader, "collate_fn" - ) and self.train_dataloader.collate_fn is not None: - for m_idx, m in enumerate(metric_config): - if "TopkAcc" in m: - msg = f"'TopkAcc' metric can not be used when setting 'batch_transform_ops' in config. The 'TopkAcc' metric has been removed." - logger.warning(msg) - break - metric_config.pop(m_idx) - self.train_metric_func = build_metrics(metric_config) - else: - self.train_metric_func = None + if self.mode == 'train' and "Metric" in self.config and "Train" in self.config[ + "Metric"]: + metric_config = self.config["Metric"]["Train"] + if hasattr(self.train_dataloader, "collate_fn" + ) and self.train_dataloader.collate_fn is not None: + for m_idx, m in enumerate(metric_config): + if "TopkAcc" in m: + msg = f"'TopkAcc' metric can not be used when setting 'batch_transform_ops' in config. The 'TopkAcc' metric has been removed." + logger.warning(msg) + break + metric_config.pop(m_idx) + self.train_metric_func = build_metrics(metric_config) else: self.train_metric_func = None if self.mode == "eval" or (self.mode == "train" and self.config["Global"]["eval_during_train"]): - metric_config = self.config.get("Metric") if self.eval_mode == "classification": - if metric_config is not None: - metric_config = metric_config.get("Eval") - if metric_config is not None: - self.eval_metric_func = build_metrics(metric_config) + if "Metric" in self.config and "Eval" in self.config["Metric"]: + self.eval_metric_func = build_metrics(self.config["Metric"] + ["Eval"]) + else: + self.eval_metric_func = None elif self.eval_mode == "retrieval": - if metric_config is None: - metric_config = [{"name": "Recallk", "topk": (1, 5)}] + if "Metric" in self.config and "Eval" in self.config["Metric"]: + metric_config = metric_config["Metric"]["Eval"] else: - metric_config = metric_config["Eval"] + metric_config = [{"name": "Recallk", "topk": (1, 5)}] self.eval_metric_func = build_metrics(metric_config) else: self.eval_metric_func = None diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index 1f9b55fc33ff6b49e9e7f7bd3e9bcebdbf3e0093..61aa92b4f8be528b2a26b03b8af167608457308c 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -34,7 +34,6 @@ def classification_eval(engine, epoch_id=0): } print_batch_step = engine.config["Global"]["print_batch_step"] - metric_key = None tic = time.time() accum_samples = 0 total_samples = len( diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 2161ca86ae51c1c1aa551dd08c1924adc3d9c59b..02fc688735fd89ed061646b26f4cd1f25c80af83 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -26,6 +26,7 @@ from easydict import EasyDict from ppcls.metric.avg_metrics import AvgMetrics from ppcls.utils.misc import AverageMeter, AttrMeter +from ppcls.utils import logger class TopkAcc(AvgMetrics): @@ -47,8 +48,15 @@ class TopkAcc(AvgMetrics): if isinstance(x, dict): x = x["logits"] + output_dims = x.shape[-1] + metric_dict = dict() - for k in self.topk: + for idx, k in enumerate(self.topk): + if output_dims < k: + msg = f"The output dims({output_dims}) is less than k({k}), and the argument {k} of Topk has been removed." + logger.warning(msg) + self.topk.pop(idx) + continue metric_dict["top{}".format(k)] = paddle.metric.accuracy( x, label, k=k) self.avg_meters["top{}".format(k)].update(metric_dict["top{}".format(k)], x.shape[0])