diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index d59f2f41c57ec5d0d4bb5e9bf6ca29fd9c592b0e..795a5e2a164192967de5a43ef8592e96fd899e9d 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -34,6 +34,10 @@ def classification_eval(engine, epoch_id=0): metric_key = None tic = time.time() + accum_samples = 0 + total_samples = len( + engine.eval_dataloader. + dataset) if not engine.use_dali else engine.eval_dataloader.size max_iter = len(engine.eval_dataloader) - 1 if platform.system( ) == "Windows" else len(engine.eval_dataloader) for iter_id, batch in enumerate(engine.eval_dataloader): @@ -61,15 +65,31 @@ def classification_eval(engine, epoch_id=0): if key not in output_info: output_info[key] = AverageMeter(key, '7.5f') output_info[key].update(loss_dict[key].numpy()[0], batch_size) + + # just for DistributedBatchSampler issue: repeat sampling + current_samples = batch_size * paddle.distributed.get_world_size() + accum_samples += current_samples + # calc metric if engine.eval_metric_func is not None: - metric_dict = engine.eval_metric_func(out, batch[1]) if paddle.distributed.get_world_size() > 1: - for key in metric_dict: - paddle.distributed.all_reduce( - metric_dict[key], op=paddle.distributed.ReduceOp.SUM) - metric_dict[key] = metric_dict[ - key] / paddle.distributed.get_world_size() + pred_list = [] + label_list = [] + if isinstance(out, dict): + out = out["logits"] + paddle.distributed.all_gather(pred_list, out) + paddle.distributed.all_gather(label_list, batch[1]) + pred = paddle.concat(pred_list, 0) + labels = paddle.concat(label_list, 0) + if accum_samples > total_samples: + pred = pred[:total_samples + current_samples - + accum_samples] + labels = labels[:total_samples + current_samples - + accum_samples] + current_samples = total_samples + current_samples - accum_samples + metric_dict = engine.eval_metric_func(pred, labels) + else: + metric_dict = engine.eval_metric_func(out, batch[1]) for key in metric_dict: if metric_key is None: metric_key = key @@ -77,7 +97,7 @@ def classification_eval(engine, epoch_id=0): output_info[key] = AverageMeter(key, '7.5f') output_info[key].update(metric_dict[key].numpy()[0], - batch_size) + current_samples) time_info["batch_cost"].update(time.time() - tic)