From b3fcc98610e1be70b89c569b0f3ff5d511c61765 Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Tue, 29 Aug 2023 03:38:42 +0000 Subject: [PATCH] to be compatible with training and evaluation --- ppcls/metric/metrics.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 20282e22..723cff84 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -38,6 +38,7 @@ class TopkAcc(AvgMetrics): topk = [topk] self.topk = topk self.reset() + self.warned = False def reset(self): self.avg_meters = { @@ -54,17 +55,15 @@ class TopkAcc(AvgMetrics): metric_dict = dict() for idx, k in enumerate(self.topk): if output_dims < k: - msg = f"The output dims({output_dims}) is less than k({k}), so the Top-{k} metric is meaningless." - logger.warning(msg) + if not self.warned: + msg = f"The output dims({output_dims}) is less than k({k}), so the Top-{k} metric is meaningless." + logger.warning(msg) + self.warned = True metric_dict[f"top{k}"] = 1 - self.avg_meters.pop(f"top{k}") - continue - metric_dict[f"top{k}"] = paddle.metric.accuracy(x, label, k=k) + else: + metric_dict[f"top{k}"] = paddle.metric.accuracy(x, label, k=k) self.avg_meters[f"top{k}"].update(metric_dict[f"top{k}"], x.shape[0]) - - self.topk = list(filter(lambda k: k <= output_dims, self.topk)) - return metric_dict -- GitLab