From 80209e021848e3a0a8db9af84e86423c94b84c2e Mon Sep 17 00:00:00 2001 From: dongshuilong Date: Wed, 20 Oct 2021 11:22:37 +0000 Subject: [PATCH] fix clas distributed eval bug --- ppcls/engine/evaluation/classification.py | 28 ++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index d59f2f41..d35b6a1e 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -34,6 +34,10 @@ def classification_eval(engine, epoch_id=0): metric_key = None tic = time.time() + accum_samples = 0 + total_samples = len( + engine.eval_dataloader. + dataset) if not engine.use_dali else engine.eval_dataloader.size max_iter = len(engine.eval_dataloader) - 1 if platform.system( ) == "Windows" else len(engine.eval_dataloader) for iter_id, batch in enumerate(engine.eval_dataloader): @@ -61,15 +65,37 @@ def classification_eval(engine, epoch_id=0): if key not in output_info: output_info[key] = AverageMeter(key, '7.5f') output_info[key].update(loss_dict[key].numpy()[0], batch_size) + + # just for DistributedBatchSampler issue: repeat sampling + current_samples = batch_size * paddle.distributed.get_world_size() + accum_samples += current_samples + # calc metric if engine.eval_metric_func is not None: - metric_dict = engine.eval_metric_func(out, batch[1]) if paddle.distributed.get_world_size() > 1: + pred_list = [] + label_list = [] + if isinstance(out, dict): + out = out["logits"] + paddle.distributed.all_gather(pred_list, out) + paddle.distributed.all_gather(label_list, batch[1]) + pred = paddle.concat(pred_list, 0) + labels = paddle.concat(label_list, 0) + if accum_samples > total_samples: + pred = pred[:total_samples + current_samples - + accum_samples] + labels = labels[:total_samples + current_samples - + accum_samples] + current_samples = total_samples + current_samples - accum_samples + metric_dict = engine.eval_metric_func(pred, labels) + for key in metric_dict: paddle.distributed.all_reduce( metric_dict[key], op=paddle.distributed.ReduceOp.SUM) metric_dict[key] = metric_dict[ key] / paddle.distributed.get_world_size() + else: + metric_dict = engine.eval_metric_func(out, batch[1]) for key in metric_dict: if metric_key is None: metric_key = key -- GitLab