From 278f6d8050ee32b242b7fbdac4b417cfa19c1565 Mon Sep 17 00:00:00 2001 From: dongshuilong Date: Tue, 26 Oct 2021 11:56:30 +0000 Subject: [PATCH] fig goooglenet distributed eval bug --- ppcls/engine/evaluation/classification.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index 795a5e2a..53e7fe6b 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -73,14 +73,24 @@ def classification_eval(engine, epoch_id=0): # calc metric if engine.eval_metric_func is not None: if paddle.distributed.get_world_size() > 1: - pred_list = [] label_list = [] - if isinstance(out, dict): - out = out["logits"] - paddle.distributed.all_gather(pred_list, out) paddle.distributed.all_gather(label_list, batch[1]) - pred = paddle.concat(pred_list, 0) labels = paddle.concat(label_list, 0) + + if isinstance(out, dict): + out = out["logits"] + if isinstance(out, list): + pred = [] + for x in out: + pred_list = [] + paddle.distributed.all_gather(pred_list, x) + pred_x = paddle.concat(pred_list, 0) + pred.append(pred_x) + else: + pred_list = [] + paddle.distributed.all_gather(pred_list, out) + pred = paddle.concat(pred_list, 0) + if accum_samples > total_samples: pred = pred[:total_samples + current_samples - accum_samples] -- GitLab