diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index fbad470cb3f1334fb5e0ad559b3300c4b439bde9..d2bed171aa27ae87c8bd0720e67ef2ec23534e02 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -849,7 +849,9 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True): if in_dygraph_mode(): group = _get_default_group() if group is None else group - out = paddle.concat(tensor_list) + tensor_shape = list(tensor.shape) + tensor_shape[0] *= group.nranks + out = paddle.empty(tensor_shape, tensor.dtype) task = group.process_group.all_gather(tensor, out) task.wait() tensor_list.clear()