all_gather.py 6.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
16
import paddle.distributed as dist
17
from paddle import framework
18
from paddle.distributed.communication.group import _get_global_group
19
from paddle.fluid import data_feeder
20 21


22 23 24
def _all_gather_into_tensor_in_dygraph(
    out_tensor, in_tensor, group, sync_op, use_calc_stream
):
25
    group = _get_global_group() if group is None else group
26 27

    if use_calc_stream:
L
LiYuRio 已提交
28
        return group.process_group.all_gather_into_tensor_on_calc_stream(
29 30
            out_tensor,
            in_tensor,
31
        )
32

L
LiYuRio 已提交
33
    task = group.process_group.all_gather_into_tensor(
34
        out_tensor, in_tensor, sync_op
35
    )
36 37 38 39 40 41
    if sync_op:
        task.wait()

    return task


42 43 44
def _all_gather_in_dygraph(
    tensor_list, tensor, group, sync_op, use_calc_stream
):
45
    group = _get_global_group() if group is None else group
46 47 48 49 50

    if len(tensor_list) == 0:
        tensor_list += [paddle.empty_like(tensor) for _ in range(group.nranks)]

    if use_calc_stream:
L
LiYuRio 已提交
51 52 53
        return group.process_group.all_gather_on_calc_stream(
            tensor_list, tensor
        )
54

L
LiYuRio 已提交
55
    task = group.process_group.all_gather(tensor_list, tensor, sync_op)
56 57 58 59 60 61
    if sync_op:
        task.wait()

    return task


62 63
def _all_gather_in_static_mode(tensor_list, tensor, group, sync_op):
    op_type = 'c_allgather'
64
    helper = framework.LayerHelper(op_type, **locals())
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
    for elem in tensor_list:
        data_feeder.check_variable_and_dtype(
            elem,
            'tensor_list',
            [
                'float16',
                'float32',
                'float64',
                'int32',
                'int64',
                'bool',
                'int8',
                'uint8',
            ],
            'all_gather',
        )
    data_feeder.check_variable_and_dtype(
        tensor,
        'tensor',
        [
            'float16',
            'float32',
            'float64',
            'int32',
            'int64',
            'bool',
            'int8',
            'uint8',
        ],
        'all_gather',
    )

    ring_id = 0 if group is None else group.id
    nranks = dist.get_world_size()
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [out]},
        attrs={
            'ring_id': ring_id,
            'use_calc_stream': sync_op,
            'nranks': nranks,
        },
    )
    tensor_list.clear()
    tensor_list.extend(paddle.split(out, nranks, 0))


114 115 116 117 118 119 120
def all_gather(
    tensor_or_tensor_list,
    tensor,
    group=None,
    sync_op=True,
    use_calc_stream=False,
):
121 122 123 124 125 126 127 128
    """

    Gather tensors across devices to a correctly-sized tensor or a tensor list.

    Args:
        tensor_or_tensor_list (Union[Tensor, List[Tensor]]): The output. If it is a tensor, it should be correctly-sized. If it is a list, it
            should be empty or contain correctly-sized tensors.
        tensor (Tensor): The input tensor on each rank. The result will overwrite this tenor after communication. Support
129
            float16, float32, float64, int32, int64, int8, uint8 or bool as the input data type.
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
        group (Group, optional): Communicate in which group. If none is given, use the global group as default.
        sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default.
        use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This
            option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning.

    Returns:
        Return a task object.

    Warning:
        This API only supports the dygraph mode now.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            local_rank = dist.get_rank()
            tensor_list = []
            if local_rank == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
            else:
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            task = dist.stream.all_gather(tensor_list, data, sync_op=False)
            task.wait()
            print(tensor_list)
            # [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
    """
    if group is not None and not group.is_member():
        raise RuntimeError(
            "The group should not be None and all ranks which invoke this operation should be the member of this group."
        )

    if not sync_op and use_calc_stream:
        raise RuntimeError(
167 168
            "use_calc_stream can only be true in sync op behavior."
        )
169 170 171

    if framework.in_dygraph_mode():
        if paddle.is_tensor(tensor_or_tensor_list):
172 173 174
            return _all_gather_into_tensor_in_dygraph(
                tensor_or_tensor_list, tensor, group, sync_op, use_calc_stream
            )
175
        else:
176 177 178
            return _all_gather_in_dygraph(
                tensor_or_tensor_list, tensor, group, sync_op, use_calc_stream
            )
179
    else:
180 181 182
        assert (
            group is None
        ), "Group can not be used in static graph mode for now."
183 184
        if paddle.is_tensor(tensor_or_tensor_list):
            raise RuntimeError(
185
                "Only support passing a tensor list to `all_gather` in static graph mode now."
186 187 188 189 190
            )
        else:
            return _all_gather_in_static_mode(
                tensor_or_tensor_list, tensor, group, sync_op
            )