提交 df3e75dd 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: warn when topk parameter setting is wrong

上级 4003cdb7
......@@ -152,39 +152,34 @@ class Engine(object):
self.eval_loss_func = None
# build metric
if self.mode == 'train':
metric_config = self.config.get("Metric")
if metric_config is not None:
metric_config = metric_config.get("Train")
if metric_config is not None:
if hasattr(
self.train_dataloader, "collate_fn"
) and self.train_dataloader.collate_fn is not None:
for m_idx, m in enumerate(metric_config):
if "TopkAcc" in m:
msg = f"'TopkAcc' metric can not be used when setting 'batch_transform_ops' in config. The 'TopkAcc' metric has been removed."
logger.warning(msg)
break
metric_config.pop(m_idx)
self.train_metric_func = build_metrics(metric_config)
else:
self.train_metric_func = None
if self.mode == 'train' and "Metric" in self.config and "Train" in self.config[
"Metric"]:
metric_config = self.config["Metric"]["Train"]
if hasattr(self.train_dataloader, "collate_fn"
) and self.train_dataloader.collate_fn is not None:
for m_idx, m in enumerate(metric_config):
if "TopkAcc" in m:
msg = f"'TopkAcc' metric can not be used when setting 'batch_transform_ops' in config. The 'TopkAcc' metric has been removed."
logger.warning(msg)
break
metric_config.pop(m_idx)
self.train_metric_func = build_metrics(metric_config)
else:
self.train_metric_func = None
if self.mode == "eval" or (self.mode == "train" and
self.config["Global"]["eval_during_train"]):
metric_config = self.config.get("Metric")
if self.eval_mode == "classification":
if metric_config is not None:
metric_config = metric_config.get("Eval")
if metric_config is not None:
self.eval_metric_func = build_metrics(metric_config)
if "Metric" in self.config and "Eval" in self.config["Metric"]:
self.eval_metric_func = build_metrics(self.config["Metric"]
["Eval"])
else:
self.eval_metric_func = None
elif self.eval_mode == "retrieval":
if metric_config is None:
metric_config = [{"name": "Recallk", "topk": (1, 5)}]
if "Metric" in self.config and "Eval" in self.config["Metric"]:
metric_config = metric_config["Metric"]["Eval"]
else:
metric_config = metric_config["Eval"]
metric_config = [{"name": "Recallk", "topk": (1, 5)}]
self.eval_metric_func = build_metrics(metric_config)
else:
self.eval_metric_func = None
......
......@@ -34,7 +34,6 @@ def classification_eval(engine, epoch_id=0):
}
print_batch_step = engine.config["Global"]["print_batch_step"]
metric_key = None
tic = time.time()
accum_samples = 0
total_samples = len(
......
......@@ -26,6 +26,7 @@ from easydict import EasyDict
from ppcls.metric.avg_metrics import AvgMetrics
from ppcls.utils.misc import AverageMeter, AttrMeter
from ppcls.utils import logger
class TopkAcc(AvgMetrics):
......@@ -47,8 +48,15 @@ class TopkAcc(AvgMetrics):
if isinstance(x, dict):
x = x["logits"]
output_dims = x.shape[-1]
metric_dict = dict()
for k in self.topk:
for idx, k in enumerate(self.topk):
if output_dims < k:
msg = f"The output dims({output_dims}) is less than k({k}), and the argument {k} of Topk has been removed."
logger.warning(msg)
self.topk.pop(idx)
continue
metric_dict["top{}".format(k)] = paddle.metric.accuracy(
x, label, k=k)
self.avg_meters["top{}".format(k)].update(metric_dict["top{}".format(k)], x.shape[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册