diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index 795a5e2a164192967de5a43ef8592e96fd899e9d..53e7fe6be4e51c038ec45e795ee98b7de9db95cf 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -73,14 +73,24 @@ def classification_eval(engine, epoch_id=0): # calc metric if engine.eval_metric_func is not None: if paddle.distributed.get_world_size() > 1: - pred_list = [] label_list = [] - if isinstance(out, dict): - out = out["logits"] - paddle.distributed.all_gather(pred_list, out) paddle.distributed.all_gather(label_list, batch[1]) - pred = paddle.concat(pred_list, 0) labels = paddle.concat(label_list, 0) + + if isinstance(out, dict): + out = out["logits"] + if isinstance(out, list): + pred = [] + for x in out: + pred_list = [] + paddle.distributed.all_gather(pred_list, x) + pred_x = paddle.concat(pred_list, 0) + pred.append(pred_x) + else: + pred_list = [] + paddle.distributed.all_gather(pred_list, out) + pred = paddle.concat(pred_list, 0) + if accum_samples > total_samples: pred = pred[:total_samples + current_samples - accum_samples]