提交 b3fcc986 编写于 作者: G gaotingquan 提交者: Tingquan Gao

to be compatible with training and evaluation

上级 f3b2b2f4
...@@ -38,6 +38,7 @@ class TopkAcc(AvgMetrics): ...@@ -38,6 +38,7 @@ class TopkAcc(AvgMetrics):
topk = [topk] topk = [topk]
self.topk = topk self.topk = topk
self.reset() self.reset()
self.warned = False
def reset(self): def reset(self):
self.avg_meters = { self.avg_meters = {
...@@ -54,17 +55,15 @@ class TopkAcc(AvgMetrics): ...@@ -54,17 +55,15 @@ class TopkAcc(AvgMetrics):
metric_dict = dict() metric_dict = dict()
for idx, k in enumerate(self.topk): for idx, k in enumerate(self.topk):
if output_dims < k: if output_dims < k:
msg = f"The output dims({output_dims}) is less than k({k}), so the Top-{k} metric is meaningless." if not self.warned:
logger.warning(msg) msg = f"The output dims({output_dims}) is less than k({k}), so the Top-{k} metric is meaningless."
logger.warning(msg)
self.warned = True
metric_dict[f"top{k}"] = 1 metric_dict[f"top{k}"] = 1
self.avg_meters.pop(f"top{k}") else:
continue metric_dict[f"top{k}"] = paddle.metric.accuracy(x, label, k=k)
metric_dict[f"top{k}"] = paddle.metric.accuracy(x, label, k=k)
self.avg_meters[f"top{k}"].update(metric_dict[f"top{k}"], self.avg_meters[f"top{k}"].update(metric_dict[f"top{k}"],
x.shape[0]) x.shape[0])
self.topk = list(filter(lambda k: k <= output_dims, self.topk))
return metric_dict return metric_dict
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册