提交 b3fcc986 编写于 作者: G gaotingquan 提交者: Tingquan Gao

to be compatible with training and evaluation

上级 f3b2b2f4
......@@ -38,6 +38,7 @@ class TopkAcc(AvgMetrics):
topk = [topk]
self.topk = topk
self.reset()
self.warned = False
def reset(self):
self.avg_meters = {
......@@ -54,17 +55,15 @@ class TopkAcc(AvgMetrics):
metric_dict = dict()
for idx, k in enumerate(self.topk):
if output_dims < k:
if not self.warned:
msg = f"The output dims({output_dims}) is less than k({k}), so the Top-{k} metric is meaningless."
logger.warning(msg)
self.warned = True
metric_dict[f"top{k}"] = 1
self.avg_meters.pop(f"top{k}")
continue
else:
metric_dict[f"top{k}"] = paddle.metric.accuracy(x, label, k=k)
self.avg_meters[f"top{k}"].update(metric_dict[f"top{k}"],
x.shape[0])
self.topk = list(filter(lambda k: k <= output_dims, self.topk))
return metric_dict
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册