提交 fd6f1ad2 编写于 作者: D dongshuilong

fix clas distributed eval bug

上级 c93d638f
......@@ -88,12 +88,6 @@ def classification_eval(engine, epoch_id=0):
accum_samples]
current_samples = total_samples + current_samples - accum_samples
metric_dict = engine.eval_metric_func(pred, labels)
for key in metric_dict:
paddle.distributed.all_reduce(
metric_dict[key], op=paddle.distributed.ReduceOp.SUM)
metric_dict[key] = metric_dict[
key] / paddle.distributed.get_world_size()
else:
metric_dict = engine.eval_metric_func(out, batch[1])
for key in metric_dict:
......@@ -103,7 +97,7 @@ def classification_eval(engine, epoch_id=0):
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(metric_dict[key].numpy()[0],
batch_size)
current_samples)
time_info["batch_cost"].update(time.time() - tic)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册