gather.py 2.4 KB
Newer Older
zhenhailiu's avatar
zhenhailiu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from paddle import framework
from paddle.distributed.communication import stream


def gather(tensor, gather_list=None, dst=0, group=None, sync_op=True):
    """

    Gather tensors from all participators.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
        gather_list (list): A list of Tensors to hold the gathered tensors. Every element in the list must be a Tensor whose data type
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. Default value is None.
        dst (int): The dst rank id. Default value is 0.
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.

    Returns:
        Async work handle,which can be wait on, if async_op is set to True.
        None, if not async_op

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            gather_list = []
            if dist.get_rank() == 0:
                data = paddle.to_tensor([1, 2, 3])
                dist.gather(data, gather_list, dst=0)
            else:
                data = paddle.to_tensor([4, 5, 6])
                dist.gather(data1, gather_list, dst=0)
            print(gather_list)
            # [[1, 2, 3], [4, 5, 6]] (2 GPUs, out for rank 0)
            # [] (2 GPUs, out for rank 1)
    """
    assert (
        framework.in_dygraph_mode()
    ), "gather doesn't support static graph mode yet."
    return stream.gather(tensor, gather_list, dst, group, sync_op)