提交 a43f8539 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: warn and fix when topk parameter setting is wrong

上级 80358efb
...@@ -40,7 +40,7 @@ class TopkAcc(AvgMetrics): ...@@ -40,7 +40,7 @@ class TopkAcc(AvgMetrics):
def reset(self): def reset(self):
self.avg_meters = { self.avg_meters = {
"top{}".format(k): AverageMeter("top{}".format(k)) f"top{k}": AverageMeter(f"top{k}")
for k in self.topk for k in self.topk
} }
...@@ -55,11 +55,14 @@ class TopkAcc(AvgMetrics): ...@@ -55,11 +55,14 @@ class TopkAcc(AvgMetrics):
if output_dims < k: if output_dims < k:
msg = f"The output dims({output_dims}) is less than k({k}), and the argument {k} of Topk has been removed." msg = f"The output dims({output_dims}) is less than k({k}), and the argument {k} of Topk has been removed."
logger.warning(msg) logger.warning(msg)
self.topk.pop(idx) self.avg_meters.pop(f"top{k}")
continue continue
metric_dict["top{}".format(k)] = paddle.metric.accuracy( metric_dict[f"top{k}"] = paddle.metric.accuracy(x, label, k=k)
x, label, k=k) self.avg_meters[f"top{k}"].update(metric_dict[f"top{k}"],
self.avg_meters["top{}".format(k)].update(metric_dict["top{}".format(k)], x.shape[0]) x.shape[0])
self.topk = filter(lambda k: k <= output_dims, self.topk)
return metric_dict return metric_dict
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册