diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 02fc688735fd89ed061646b26f4cd1f25c80af83..7928ecb0e7da824fa7437819abb524c1ca73b923 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -40,7 +40,7 @@ class TopkAcc(AvgMetrics): def reset(self): self.avg_meters = { - "top{}".format(k): AverageMeter("top{}".format(k)) + f"top{k}": AverageMeter(f"top{k}") for k in self.topk } @@ -55,11 +55,14 @@ class TopkAcc(AvgMetrics): if output_dims < k: msg = f"The output dims({output_dims}) is less than k({k}), and the argument {k} of Topk has been removed." logger.warning(msg) - self.topk.pop(idx) + self.avg_meters.pop(f"top{k}") continue - metric_dict["top{}".format(k)] = paddle.metric.accuracy( - x, label, k=k) - self.avg_meters["top{}".format(k)].update(metric_dict["top{}".format(k)], x.shape[0]) + metric_dict[f"top{k}"] = paddle.metric.accuracy(x, label, k=k) + self.avg_meters[f"top{k}"].update(metric_dict[f"top{k}"], + x.shape[0]) + + self.topk = filter(lambda k: k <= output_dims, self.topk) + return metric_dict