提交 df3e75dd 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: warn when topk parameter setting is wrong

上级 4003cdb7
...@@ -152,39 +152,34 @@ class Engine(object): ...@@ -152,39 +152,34 @@ class Engine(object):
self.eval_loss_func = None self.eval_loss_func = None
# build metric # build metric
if self.mode == 'train': if self.mode == 'train' and "Metric" in self.config and "Train" in self.config[
metric_config = self.config.get("Metric") "Metric"]:
if metric_config is not None: metric_config = self.config["Metric"]["Train"]
metric_config = metric_config.get("Train") if hasattr(self.train_dataloader, "collate_fn"
if metric_config is not None: ) and self.train_dataloader.collate_fn is not None:
if hasattr( for m_idx, m in enumerate(metric_config):
self.train_dataloader, "collate_fn" if "TopkAcc" in m:
) and self.train_dataloader.collate_fn is not None: msg = f"'TopkAcc' metric can not be used when setting 'batch_transform_ops' in config. The 'TopkAcc' metric has been removed."
for m_idx, m in enumerate(metric_config): logger.warning(msg)
if "TopkAcc" in m: break
msg = f"'TopkAcc' metric can not be used when setting 'batch_transform_ops' in config. The 'TopkAcc' metric has been removed." metric_config.pop(m_idx)
logger.warning(msg) self.train_metric_func = build_metrics(metric_config)
break
metric_config.pop(m_idx)
self.train_metric_func = build_metrics(metric_config)
else:
self.train_metric_func = None
else: else:
self.train_metric_func = None self.train_metric_func = None
if self.mode == "eval" or (self.mode == "train" and if self.mode == "eval" or (self.mode == "train" and
self.config["Global"]["eval_during_train"]): self.config["Global"]["eval_during_train"]):
metric_config = self.config.get("Metric")
if self.eval_mode == "classification": if self.eval_mode == "classification":
if metric_config is not None: if "Metric" in self.config and "Eval" in self.config["Metric"]:
metric_config = metric_config.get("Eval") self.eval_metric_func = build_metrics(self.config["Metric"]
if metric_config is not None: ["Eval"])
self.eval_metric_func = build_metrics(metric_config) else:
self.eval_metric_func = None
elif self.eval_mode == "retrieval": elif self.eval_mode == "retrieval":
if metric_config is None: if "Metric" in self.config and "Eval" in self.config["Metric"]:
metric_config = [{"name": "Recallk", "topk": (1, 5)}] metric_config = metric_config["Metric"]["Eval"]
else: else:
metric_config = metric_config["Eval"] metric_config = [{"name": "Recallk", "topk": (1, 5)}]
self.eval_metric_func = build_metrics(metric_config) self.eval_metric_func = build_metrics(metric_config)
else: else:
self.eval_metric_func = None self.eval_metric_func = None
......
...@@ -34,7 +34,6 @@ def classification_eval(engine, epoch_id=0): ...@@ -34,7 +34,6 @@ def classification_eval(engine, epoch_id=0):
} }
print_batch_step = engine.config["Global"]["print_batch_step"] print_batch_step = engine.config["Global"]["print_batch_step"]
metric_key = None
tic = time.time() tic = time.time()
accum_samples = 0 accum_samples = 0
total_samples = len( total_samples = len(
......
...@@ -26,6 +26,7 @@ from easydict import EasyDict ...@@ -26,6 +26,7 @@ from easydict import EasyDict
from ppcls.metric.avg_metrics import AvgMetrics from ppcls.metric.avg_metrics import AvgMetrics
from ppcls.utils.misc import AverageMeter, AttrMeter from ppcls.utils.misc import AverageMeter, AttrMeter
from ppcls.utils import logger
class TopkAcc(AvgMetrics): class TopkAcc(AvgMetrics):
...@@ -47,8 +48,15 @@ class TopkAcc(AvgMetrics): ...@@ -47,8 +48,15 @@ class TopkAcc(AvgMetrics):
if isinstance(x, dict): if isinstance(x, dict):
x = x["logits"] x = x["logits"]
output_dims = x.shape[-1]
metric_dict = dict() metric_dict = dict()
for k in self.topk: for idx, k in enumerate(self.topk):
if output_dims < k:
msg = f"The output dims({output_dims}) is less than k({k}), and the argument {k} of Topk has been removed."
logger.warning(msg)
self.topk.pop(idx)
continue
metric_dict["top{}".format(k)] = paddle.metric.accuracy( metric_dict["top{}".format(k)] = paddle.metric.accuracy(
x, label, k=k) x, label, k=k)
self.avg_meters["top{}".format(k)].update(metric_dict["top{}".format(k)], x.shape[0]) self.avg_meters["top{}".format(k)].update(metric_dict["top{}".format(k)], x.shape[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册