diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index 5b305b0a0cb8f3c94561fd338631a2b3a4278687..647a821714f88ed849da71db21b4d98121134eb6 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -81,8 +81,9 @@ def classification_eval(engine, epoch_id=0): # gather Tensor when distributed if paddle.distributed.get_world_size() > 1: label_list = [] - - paddle.distributed.all_gather(label_list, batch[1]) + label = batch[1].cuda() if engine.config["Global"][ + "device"] == "gpu" else batch[1] + paddle.distributed.all_gather(label_list, label) labels = paddle.concat(label_list, 0) if isinstance(out, list):