未验证 提交 474c918b 编写于 作者: G gaotingquan

fix: fix bug of batch_size statistics error

上级 c46189ba
...@@ -126,7 +126,8 @@ def classification_eval(engine, epoch_id=0): ...@@ -126,7 +126,8 @@ def classification_eval(engine, epoch_id=0):
for key in loss_dict: for key in loss_dict:
if key not in output_info: if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f') output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0], batch_size) output_info[key].update(loss_dict[key].numpy()[0],
current_samples)
# calc metric # calc metric
if engine.eval_metric_func is not None: if engine.eval_metric_func is not None:
metric_dict = engine.eval_metric_func(preds, labels) metric_dict = engine.eval_metric_func(preds, labels)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册