batch_isend_irecv.py 6.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
16

17
import paddle.distributed as dist
18
from paddle import framework
19 20 21 22 23 24
from paddle.distributed.communication.group import (
    _get_global_group,
    _warn_cur_rank_not_in_group,
)


25
class P2POp:
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
    """
    A class that makes point-to-point operations for "batch_isend_irecv".

    This class creates the type of P2P operation, communication buffer, peer rank,
    Group. Instances of this class will be passed to
    ``paddle.distributed.batch_isend_irecv`` for point-to-point communication.

    Args:
        op (callable): A function to send data to or receive data from a peer process.
            The type of ``op`` is either ``paddle.distributed.isend`` or ``paddle.distributed.irecv``.
        tensor (Tensor): Tensor to send or receive.
        peer (int): The destination or source rank.
        group (Group, optional): The group instance return by new_group or None for global
            default group. Default: None.

    Examples:
        .. code-block:: python

44
            >>> # doctest: +REQUIRES(env: DISTRIBUTED)
45

46 47
            >>> import paddle
            >>> import paddle.distributed as dist
48

49 50 51
            >>> dist.init_parallel_env()
            >>> rank = dist.get_rank()
            >>> world_size = dist.get_world_size()
52

53 54 55
            >>> send_t = paddle.arange(2) + rank
            >>> # paddle.tensor([0, 1])  # Rank-0
            >>> # paddle.tensor([1, 2])  # Rank-1
56

57
            >>> recv_t = paddle.empty(shape=[2], dtype=send_t.dtype)
58

59 60
            >>> send_op = dist.P2POp(dist.isend, send_t, (rank + 1) % world_size)
            >>> recv_op = dist.P2POp(dist.irecv, recv_t, (rank - 1 + world_size) % world_size)
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80

    """

    def __init__(self, op, tensor, peer, group=None):
        if op not in [dist.isend, dist.irecv]:
            raise RuntimeError(
                "Invalid ``op`` function. Expected ``op`` "
                "to be of type ``paddle.distributed.isend`` or "
                "``paddle.distributed.irecv``."
            )

        self.op = op
        self.tensor = tensor
        self.peer = peer
        self.group = _get_global_group() if group is None else group


@contextlib.contextmanager
def _with_batch_p2p_guard(backend):
    if backend == "NCCL":
81
        framework.core.ProcessGroupNCCL.group_start()
82 83 84 85
    try:
        yield
    finally:
        if backend == "NCCL":
86
            framework.core.ProcessGroupNCCL.group_end()
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129


def _check_p2p_op_list(p2p_op_list):
    """
    Helper to check that the ``p2p_op_list`` is a list of P2POp instances and
    all ops use the same backend.
    """
    if not isinstance(p2p_op_list, list) or not all(
        isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list
    ):
        raise RuntimeError(
            "Invalid ``p2p_op_list``. Each op is expected to "
            "to be of type ``paddle.distributed.P2POp``."
        )

    backend = p2p_op_list[0].group.backend
    if not all(backend == p2p_op.group.backend for p2p_op in p2p_op_list):
        raise RuntimeError("All groups need to use the same backend.")


def batch_isend_irecv(p2p_op_list):
    """
    Send or Receive a batch of tensors asynchronously and return a list of requests.

    Process each of the point-to-point operations in ``p2p_op_list`` and return the
    corresponding tasks. NCCL are currently supported.

    Args:
        p2p_op_list (List[P2POp]): A list of point-to-point operations(type of each operator is
            ``paddle.distributed.P2POp``). The order of the isend/irecv in the list
            matters and it needs to match with corresponding isend/irecv on the
            remote end.

    Returns:
        A list of distributed tasks returned by calling the corresponding
        op in the op_list.

    Warning:
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

130
            >>> # doctest: +REQUIRES(env: DISTRIBUTED)
131

132 133
            >>> import paddle
            >>> import paddle.distributed as dist
134

135 136 137
            >>> dist.init_parallel_env()
            >>> rank = dist.get_rank()
            >>> world_size = dist.get_world_size()
138

139 140 141
            >>> send_t = paddle.arange(2) + rank
            >>> # paddle.tensor([0, 1])  # Rank-0
            >>> # paddle.tensor([1, 2])  # Rank-1
142

143
            >>> recv_t = paddle.empty(shape=[2], dtype=send_t.dtype)
144

145 146
            >>> send_op = dist.P2POp(dist.isend, send_t, (rank + 1) % world_size)
            >>> recv_op = dist.P2POp(dist.irecv, recv_t, (rank - 1 + world_size) % world_size)
147

148
            >>> tasks = dist.batch_isend_irecv([send_op, recv_op])
149

150 151
            >>> for task in tasks:
            ...     task.wait()
152

153 154 155
            >>> print(recv_t)
            >>> # paddle.tensor([1, 2])     # Rank-0
            >>> # paddle.tensor([0, 1])     # Rank-1
156 157 158 159 160 161
    """
    _check_p2p_op_list(p2p_op_list)
    group = p2p_op_list[0].group
    if _warn_cur_rank_not_in_group(group):
        return

162
    if framework.in_dynamic_mode():
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
        group = _get_global_group() if group is None else group
        backend = group.backend
        tasks = []
        with _with_batch_p2p_guard(backend):
            for p2p_op in p2p_op_list:
                op = p2p_op.op
                tensor = p2p_op.tensor
                peer = p2p_op.peer
                comm_group = p2p_op.group
                task = op(tensor, peer, comm_group)
                if task is not None:
                    tasks.append(task)
        return tasks
    else:
        raise RuntimeError("Don't support static graph mode currently.")