未验证 提交 53ed9239 编写于 作者: C cuicheng01 提交者: GitHub

Merge pull request #1320 from RainFrost1/develop

fix clas distributed eval bug
......@@ -34,6 +34,10 @@ def classification_eval(engine, epoch_id=0):
metric_key = None
tic = time.time()
accum_samples = 0
total_samples = len(
engine.eval_dataloader.
dataset) if not engine.use_dali else engine.eval_dataloader.size
max_iter = len(engine.eval_dataloader) - 1 if platform.system(
) == "Windows" else len(engine.eval_dataloader)
for iter_id, batch in enumerate(engine.eval_dataloader):
......@@ -61,15 +65,31 @@ def classification_eval(engine, epoch_id=0):
if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0], batch_size)
# just for DistributedBatchSampler issue: repeat sampling
current_samples = batch_size * paddle.distributed.get_world_size()
accum_samples += current_samples
# calc metric
if engine.eval_metric_func is not None:
metric_dict = engine.eval_metric_func(out, batch[1])
if paddle.distributed.get_world_size() > 1:
for key in metric_dict:
paddle.distributed.all_reduce(
metric_dict[key], op=paddle.distributed.ReduceOp.SUM)
metric_dict[key] = metric_dict[
key] / paddle.distributed.get_world_size()
pred_list = []
label_list = []
if isinstance(out, dict):
out = out["logits"]
paddle.distributed.all_gather(pred_list, out)
paddle.distributed.all_gather(label_list, batch[1])
pred = paddle.concat(pred_list, 0)
labels = paddle.concat(label_list, 0)
if accum_samples > total_samples:
pred = pred[:total_samples + current_samples -
accum_samples]
labels = labels[:total_samples + current_samples -
accum_samples]
current_samples = total_samples + current_samples - accum_samples
metric_dict = engine.eval_metric_func(pred, labels)
else:
metric_dict = engine.eval_metric_func(out, batch[1])
for key in metric_dict:
if metric_key is None:
metric_key = key
......@@ -77,7 +97,7 @@ def classification_eval(engine, epoch_id=0):
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(metric_dict[key].numpy()[0],
batch_size)
current_samples)
time_info["batch_cost"].update(time.time() - tic)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册