diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 20282e228870e21dcc983051fa4af95eeb9b1340..723cff84b954317a5fe7d46e8e298977bf199b62 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -38,6 +38,7 @@ class TopkAcc(AvgMetrics): topk = [topk] self.topk = topk self.reset() + self.warned = False def reset(self): self.avg_meters = { @@ -54,17 +55,15 @@ class TopkAcc(AvgMetrics): metric_dict = dict() for idx, k in enumerate(self.topk): if output_dims < k: - msg = f"The output dims({output_dims}) is less than k({k}), so the Top-{k} metric is meaningless." - logger.warning(msg) + if not self.warned: + msg = f"The output dims({output_dims}) is less than k({k}), so the Top-{k} metric is meaningless." + logger.warning(msg) + self.warned = True metric_dict[f"top{k}"] = 1 - self.avg_meters.pop(f"top{k}") - continue - metric_dict[f"top{k}"] = paddle.metric.accuracy(x, label, k=k) + else: + metric_dict[f"top{k}"] = paddle.metric.accuracy(x, label, k=k) self.avg_meters[f"top{k}"].update(metric_dict[f"top{k}"], x.shape[0]) - - self.topk = list(filter(lambda k: k <= output_dims, self.topk)) - return metric_dict