From a43f8539eaa3f26479ecc08fef1a887c42f749c1 Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Mon, 6 Jun 2022 06:23:25 +0000 Subject: [PATCH] fix: warn and fix when topk parameter setting is wrong --- ppcls/metric/metrics.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 02fc6887..7928ecb0 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -40,7 +40,7 @@ class TopkAcc(AvgMetrics): def reset(self): self.avg_meters = { - "top{}".format(k): AverageMeter("top{}".format(k)) + f"top{k}": AverageMeter(f"top{k}") for k in self.topk } @@ -55,11 +55,14 @@ class TopkAcc(AvgMetrics): if output_dims < k: msg = f"The output dims({output_dims}) is less than k({k}), and the argument {k} of Topk has been removed." logger.warning(msg) - self.topk.pop(idx) + self.avg_meters.pop(f"top{k}") continue - metric_dict["top{}".format(k)] = paddle.metric.accuracy( - x, label, k=k) - self.avg_meters["top{}".format(k)].update(metric_dict["top{}".format(k)], x.shape[0]) + metric_dict[f"top{k}"] = paddle.metric.accuracy(x, label, k=k) + self.avg_meters[f"top{k}"].update(metric_dict[f"top{k}"], + x.shape[0]) + + self.topk = filter(lambda k: k <= output_dims, self.topk) + return metric_dict -- GitLab