all_gather.py 6.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
16
import paddle.distributed as dist
17
import paddle.fluid.framework as framework
18 19 20
import paddle.fluid.data_feeder as data_feeder
import paddle.fluid.layer_helper as layer_helper
from paddle.distributed.communication.group import _get_global_group
21 22


23 24 25
def _all_gather_into_tensor_in_dygraph(
    out_tensor, in_tensor, group, sync_op, use_calc_stream
):
26
    group = _get_global_group() if group is None else group
27 28

    if use_calc_stream:
L
LiYuRio 已提交
29
        return group.process_group.all_gather_into_tensor_on_calc_stream(
30 31
            out_tensor,
            in_tensor,
32
        )
33

L
LiYuRio 已提交
34
    task = group.process_group.all_gather_into_tensor(
35
        out_tensor, in_tensor, sync_op
36
    )
37 38 39 40 41 42
    if sync_op:
        task.wait()

    return task


43 44 45
def _all_gather_in_dygraph(
    tensor_list, tensor, group, sync_op, use_calc_stream
):
46
    group = _get_global_group() if group is None else group
47 48 49 50 51

    if len(tensor_list) == 0:
        tensor_list += [paddle.empty_like(tensor) for _ in range(group.nranks)]

    if use_calc_stream:
L
LiYuRio 已提交
52 53 54
        return group.process_group.all_gather_on_calc_stream(
            tensor_list, tensor
        )
55

L
LiYuRio 已提交
56
    task = group.process_group.all_gather(tensor_list, tensor, sync_op)
57 58 59 60 61 62
    if sync_op:
        task.wait()

    return task


63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
def _all_gather_in_static_mode(tensor_list, tensor, group, sync_op):
    op_type = 'c_allgather'
    helper = layer_helper.LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
    for elem in tensor_list:
        data_feeder.check_variable_and_dtype(
            elem,
            'tensor_list',
            [
                'float16',
                'float32',
                'float64',
                'int32',
                'int64',
                'bool',
                'int8',
                'uint8',
            ],
            'all_gather',
        )
    data_feeder.check_variable_and_dtype(
        tensor,
        'tensor',
        [
            'float16',
            'float32',
            'float64',
            'int32',
            'int64',
            'bool',
            'int8',
            'uint8',
        ],
        'all_gather',
    )

    ring_id = 0 if group is None else group.id
    nranks = dist.get_world_size()
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [out]},
        attrs={
            'ring_id': ring_id,
            'use_calc_stream': sync_op,
            'nranks': nranks,
        },
    )
    tensor_list.clear()
    tensor_list.extend(paddle.split(out, nranks, 0))


115 116 117 118 119 120 121
def all_gather(
    tensor_or_tensor_list,
    tensor,
    group=None,
    sync_op=True,
    use_calc_stream=False,
):
122 123 124 125 126 127 128 129
    """

    Gather tensors across devices to a correctly-sized tensor or a tensor list.

    Args:
        tensor_or_tensor_list (Union[Tensor, List[Tensor]]): The output. If it is a tensor, it should be correctly-sized. If it is a list, it
            should be empty or contain correctly-sized tensors.
        tensor (Tensor): The input tensor on each rank. The result will overwrite this tenor after communication. Support
130
            float16, float32, float64, int32, int64, int8, uint8 or bool as the input data type.
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
        group (Group, optional): Communicate in which group. If none is given, use the global group as default.
        sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default.
        use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This
            option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning.

    Returns:
        Return a task object.

    Warning:
        This API only supports the dygraph mode now.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            local_rank = dist.get_rank()
            tensor_list = []
            if local_rank == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
            else:
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            task = dist.stream.all_gather(tensor_list, data, sync_op=False)
            task.wait()
            print(tensor_list)
            # [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
    """
    if group is not None and not group.is_member():
        raise RuntimeError(
            "The group should not be None and all ranks which invoke this operation should be the member of this group."
        )

    if not sync_op and use_calc_stream:
        raise RuntimeError(
168 169
            "use_calc_stream can only be true in sync op behavior."
        )
170 171 172

    if framework.in_dygraph_mode():
        if paddle.is_tensor(tensor_or_tensor_list):
173 174 175
            return _all_gather_into_tensor_in_dygraph(
                tensor_or_tensor_list, tensor, group, sync_op, use_calc_stream
            )
176
        else:
177 178 179
            return _all_gather_in_dygraph(
                tensor_or_tensor_list, tensor, group, sync_op, use_calc_stream
            )
180 181 182 183 184 185 186 187 188 189
    else:
        assert group is None, "Group can not be used in static mode for now."
        if paddle.is_tensor(tensor_or_tensor_list):
            raise RuntimeError(
                "Only support passing a tensor list to `all_gather` in static mode now."
            )
        else:
            return _all_gather_in_static_mode(
                tensor_or_tensor_list, tensor, group, sync_op
            )