collective.py 69.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import os
17 18
import pickle
import io
19 20
import datetime
import time
21
from ..fluid.layer_helper import LayerHelper
22
from ..fluid.framework import Variable
23
from ..fluid.framework import in_dygraph_mode
24
from ..fluid.framework import OpProtoHolder
J
Jiabin Yang 已提交
25
from ..fluid.framework import _non_static_mode
26
from ..fluid.framework import _in_legacy_dygraph
27
from ..fluid.framework import convert_np_dtype_to_dtype_
J
Jiangxinz 已提交
28
from ..fluid.framework import _varbase_creator
29 30 31 32
from ..fluid.data_feeder import convert_dtype
from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.data_feeder import check_type
from ..fluid.data_feeder import check_dtype
33 34
from ..fluid.layers.tensor import fill_constant
from ..fluid.layers import utils
B
Baibaifan 已提交
35
from ..fluid.dygraph import layers
36 37 38 39
from ..fluid.dygraph.parallel import prepare_context
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
40
from paddle import _C_ops, _legacy_C_ops
J
Jiangxinz 已提交
41
import paddle.fluid.dygraph_utils as dygraph_utils
42
import contextlib
W
wuhuachaocoding 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55
from .fleet.layers.mpu.mp_ops import split
from .fleet.layers.mpu.mp_ops import _c_identity
from .fleet.layers.mpu.mp_ops import _c_concat
from .fleet.layers.mpu.mp_ops import _c_split
from .fleet.layers.mpu.mp_ops import _mp_allreduce
from .fleet.layers.mpu.mp_ops import _c_lookup_table
from .fleet.layers.mpu.mp_ops import _Linear
from .fleet.layers.mpu.mp_ops import _set_var_distributed
from .fleet.layers.mpu.mp_ops import _c_softmax_with_cross_entropy
from .fleet.layers.mpu.mp_ops import _linear
from .fleet.layers.mpu.mp_ops import _parallel_linear
from .fleet.layers.mpu.mp_ops import _parallel_embedding
from .communication.comm_utils import ReduceOp
56

57
__all__ = []
58 59


K
kuizhiqing 已提交
60 61 62 63
class Group():
    """
    The abstract representation of group.
    """
64

65
    def __init__(self, rank, rank_num, id=0, ranks=[], pg=None, name=None):
66 67
        self.rank = rank
        self.nranks = rank_num
K
kuizhiqing 已提交
68 69
        self.id = id
        self.ranks = ranks
70 71
        self.pg = pg
        self.name = name
K
kuizhiqing 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85

    def is_member(self):
        if self.rank < 0:
            return False
        if self.nranks < 2:
            return False
        return True

    def get_group_rank(self, rank):
        if self.is_member() and rank in self.ranks:
            return self.ranks.index(rank)
        else:
            return -1

86 87 88 89
    @property
    def process_group(self):
        return self.pg

L
LiYuRio 已提交
90 91 92 93
    @property
    def world_size(self):
        return self.nranks if self.rank >= 0 else -1

94 95 96 97
    def __repr__(self):
        debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format(
            self.rank, self.nranks, self.id)
        debug_str += ", ".join(map(str, self.ranks))
98 99
        debug_str += "; name: "
        debug_str += self.name if self.name else "None"
100 101
        return debug_str

K
kuizhiqing 已提交
102 103 104 105 106 107 108 109 110 111 112 113 114 115

_global_env = None


def _get_global_env():
    global _global_env
    if not _global_env:
        _global_env = paddle.distributed.ParallelEnv()
    return _global_env


# group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group]
_group_map = {}
116
_global_env_gid = 0
K
kuizhiqing 已提交
117

118 119 120 121
# group map by name : the map of all groups from their names
# Dict[name, Group]
_group_map_by_name = {}

122 123 124 125
# backend map by group : the map of all backend from their groups
# Dict[group, backend]
_group_map_backend = {}

126 127 128
# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"

129
_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter', 'xccl']
130 131
_default_store = None  # the default tcp store
_default_backend = None
132 133
_default_timeout = datetime.timedelta(seconds=1800)
_start_ring_id = 0
134

K
kuizhiqing 已提交
135

L
lilong12 已提交
136 137 138 139 140 141 142 143 144 145
def _set_default_backend(backend):
    global _default_backend
    _default_backend = backend


def _set_default_store(store):
    global _default_store
    _default_store = store


K
kuizhiqing 已提交
146 147
def _get_group_map():
    global _group_map
148
    if _global_env_gid not in _group_map:
K
kuizhiqing 已提交
149
        genv = _get_global_env()
150 151 152
        _group_map[_global_env_gid] = Group(genv.rank,
                                            genv.world_size,
                                            ranks=list(range(genv.world_size)))
K
kuizhiqing 已提交
153 154 155 156
    return _group_map


def _get_global_group():
157
    return _get_group_map()[_global_env_gid]
K
kuizhiqing 已提交
158 159


160 161 162 163 164 165
def _get_group_map_by_name():
    global _group_map_by_name
    return _group_map_by_name


def _get_default_group():
L
lilong12 已提交
166
    global _group_map_by_name
167 168
    assert is_initialized(), ("Call paddle.distributed.init_parallel_env first "
                              "to initialize the distributed environment.")
169 170 171
    return _get_group_map_by_name()[_default_group_name]


L
lilong12 已提交
172 173 174 175 176 177 178 179 180 181 182 183
def _set_group_map(gid, group):
    global _group_map
    assert gid not in _group_map
    _group_map[gid] = group


def _set_group_map_by_name(name, group):
    global _group_map_by_name
    assert name not in _group_map_by_name
    _group_map_by_name[name] = group


184 185 186 187 188 189
def _set_group_map_backend(group, backend):
    global _group_map_backend
    assert group not in _group_map_backend
    _group_map_backend[group] = backend


K
kuizhiqing 已提交
190
def _new_ring_id():
191 192 193 194 195 196 197
    # NOTE(liyurui): For compatible reason, auto parallel and eager mode relay on previous syntax.
    if in_dygraph_mode():
        global _start_ring_id
        _start_ring_id += 1
        return _start_ring_id + max(_get_global_env().nrings, 9)
    else:
        return len(_get_group_map()) + max(_get_global_env().nrings, 9)
K
kuizhiqing 已提交
198 199


200 201 202 203 204 205 206 207 208 209 210 211 212
def _get_reduce_op(reduce_op, func_name):
    if reduce_op == ReduceOp.SUM:
        return core.ReduceOp.SUM
    elif reduce_op == ReduceOp.MAX:
        return core.ReduceOp.MAX
    elif reduce_op == ReduceOp.MIN:
        return core.ReduceOp.MIN
    elif reduce_op == ReduceOp.PROD:
        return core.ReduceOp.PRODUCT
    else:
        raise ValueError("Unknown reduce_op type for {}.".format(func_name))


K
kuizhiqing 已提交
213 214 215 216 217 218
def get_group(id=0):
    """

    Get group instance by group id.

    Args:
K
kuizhiqing 已提交
219
        id (int): the group id. Default value is 0.
K
kuizhiqing 已提交
220 221 222 223 224 225 226 227 228 229 230 231 232 233

    Returns:
        Group: the group instance.

    Examples:
        .. code-block:: python

            ...
            gid = paddle.distributed.new_group([2,4,6])
            paddle.distributed.get_group(gid.id)

    """

    gm = _get_group_map()
J
Jiangxinz 已提交
234
    return gm[id] if id in gm else None
K
kuizhiqing 已提交
235 236


237 238 239 240 241 242
def _new_process_group_impl(backend,
                            store,
                            rank,
                            world_size,
                            group_name,
                            pg_options,
L
lilong12 已提交
243 244 245
                            group_id=0,
                            src_rank=None,
                            dst_rank=None):
246
    pg = None
247
    genv = _get_global_env()
L
lilong12 已提交
248 249 250 251
    if backend != 'heter':
        assert src_rank is None and dst_rank is None, (
            "src_rank and dst_rank "
            "can only be set for heter backend.")
L
lilong12 已提交
252
    assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
253
    if backend == "gloo":
254 255
        place = core.CPUPlace()
        pg = core.ProcessGroupGloo(store, rank, world_size, place, group_id)
256
    elif backend == "nccl":
257 258
        place = core.CUDAPlace(genv.device_id)
        pg = core.ProcessGroupNCCL(store, rank, world_size, place, group_id)
259
    elif backend == "hccl":
260 261
        place = core.NPUPlace(genv.device_id)
        pg = core.ProcessGroupHCCL(store, rank, world_size, place, group_id)
262 263 264
    elif backend == "xccl":
        place = core.CustomPlace(genv.device_type, genv.device_id)
        pg = core.ProcessGroupCustom(store, rank, world_size, place, group_id)
265
    elif backend == "heter":
266 267 268 269 270
        place = None
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(genv.device_id)
        elif core.is_compiled_with_npu():
            place = core.NPUPlace(genv.device_id)
271 272 273 274 275 276 277 278 279 280 281 282 283
        cluster_id = int(os.getenv("CLUSTER_ID", "-1"))
        assert cluster_id >= 0, "please set the CLUSTER_ID variable."
        cluster_size = os.getenv("CLUSTER_SIZE", None)
        assert cluster_size, "please set the CLUSTER_SIZE variable."
        cluster_size = cluster_size.split(",")
        cluster_size = [int(s) for s in cluster_size]
        switch_ep = os.getenv("CLUSTER_SWITCH", None)
        assert switch_ep, "please set the CLUSTER_SWITCH variable."
        cluster_size_cumsum = np.cumsum(cluster_size)
        cluster_offset = 0 if cluster_id == 0 else cluster_size_cumsum[
            cluster_id - 1]
        global_rank = cluster_offset + rank
        global_world_size = cluster_size_cumsum[-1]
284
        global_rank, global_world_size = _get_global_config(backend, rank)
285 286 287 288 289 290 291 292 293 294 295 296 297
        pg = core.ProcessGroupHeter(store,
                                    rank=global_rank,
                                    world_size=global_world_size,
                                    place=place,
                                    gid=group_id,
                                    local_rank=rank,
                                    local_size=world_size,
                                    gloo_rank=cluster_id,
                                    gloo_size=len(cluster_size),
                                    with_switch=True,
                                    switch_endpoint=switch_ep,
                                    src_rank=src_rank,
                                    dst_rank=dst_rank)
298 299 300 301

    return pg


S
ShenLiang 已提交
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
def barrier(group=None):
    """

    Barrier among all participators in the group.

    Args:
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            paddle.distributed.barrier()
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
326
    if in_dygraph_mode():
327 328 329 330 331
        group = _get_default_group() if group is None else group
        task = group.process_group.barrier()
        task.wait()
        return

S
ShenLiang 已提交
332 333 334
    ring_id = 0 if group is None else group.id

    temp = fill_constant([1], dtype="int32", value="1")
J
Jiabin Yang 已提交
335
    if _non_static_mode():
336
        return _legacy_C_ops.barrier(temp, temp, 'ring_id', ring_id)
W
wanghuancoder 已提交
337 338 339

    op_type = 'barrier'

S
ShenLiang 已提交
340 341 342
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'group' for barrier must be int.")
    helper = LayerHelper(op_type, **locals())
343 344 345 346
    helper.append_op(type=op_type,
                     inputs={'X': [temp]},
                     outputs={'Out': [temp]},
                     attrs={'ring_id': ring_id})
S
ShenLiang 已提交
347 348


L
lilong12 已提交
349 350 351 352 353 354 355
# _custom_gid provides a way for users to
# set the group id, which is usually useful
# to be compatible with the static mode.
_custom_gid = None


def _set_custom_gid(gid):
356
    global _custom_gid
L
lilong12 已提交
357 358 359
    _custom_gid = gid


360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396
def _barrier_by_tcp_store(group_name, store, timeout):
    global_rank = paddle.distributed.get_rank()
    global_world_size = paddle.distributed.get_world_size()

    if global_world_size < 2:
        return

    barrier_prefix = "Barrier/" + group_name + "/"
    is_master = (global_rank == 0)

    def _check_keys_ready(wait_keys):
        start_time = time.time()
        while len(wait_keys) > 0:
            time.sleep(0.1)
            elapse_time = time.time() - start_time
            if datetime.timedelta(seconds=elapse_time) > timeout:
                raise RuntimeError(
                    "Timeout while initializing process group {}."
                    "Keys {} are not ready sinck rank {} is waiting them."
                    "Two reason may cause this error:\n 1. The create process group api should be called by all ranks.\n"
                    " 2. Try to increase the waiting time.\n".format(
                        group_name, wait_keys, global_rank))
            wait_keys = list(
                filter(lambda key: int(store.get(key)) != 1, wait_keys))

    # all the workers set their exiting key and exit
    # the master will wait for all workers' exiting key, ensure to exit in the end
    if is_master:
        wait_keys = [
            barrier_prefix + str(rank) for rank in range(1, global_world_size)
        ]
        _check_keys_ready(wait_keys)
    else:
        store.add(barrier_prefix + str(global_rank), 1)


def new_group(ranks=None, backend=None, timeout=_default_timeout):
K
kuizhiqing 已提交
397 398
    """

K
kuizhiqing 已提交
399
    Creates a new distributed communication group.
K
kuizhiqing 已提交
400 401

    Args:
K
kuizhiqing 已提交
402
        ranks (list): The global ranks of group members.
K
kuizhiqing 已提交
403
        backend (str): The backend used to create group, only nccl is supported now.
404
        timeout (datetime.timedelta, optional): The waiting timeout for store relevant options, default is 30 minutes.
K
kuizhiqing 已提交
405 406

    Returns:
K
kuizhiqing 已提交
407
        Group: The group instance.
K
kuizhiqing 已提交
408 409 410 411 412 413 414

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
415 416 417
            tindata = paddle.randn(shape=[2, 3])
            gp = paddle.distributed.new_group([2,4,6])
            paddle.distributed.all_reduce(tindata, group=gp, use_calc_stream=False)
K
kuizhiqing 已提交
418 419

    """
420
    global _custom_gid
421
    global _group_map
L
lilong12 已提交
422
    if in_dygraph_mode():
423
        global _default_group_name
L
lilong12 已提交
424
        gid = _custom_gid if _custom_gid else _new_ring_id()
425
        group_name = _default_group_name + str(gid)
L
lilong12 已提交
426
        if backend != 'heter' and (ranks is None or len(ranks) > 1):
427 428 429 430 431 432 433 434 435
            global_group = _get_default_group()
            global_rank = global_group.rank
            global_ranks = global_group.ranks
            backend = _default_backend if backend is None else backend
            if ranks is None:
                ranks = global_ranks
            assert len(ranks) <= len(global_ranks), (
                "Size of new group must be less than or "
                "equal to that of the default global group.")
436 437
        size = len(ranks)
        ranks = sorted(ranks)
L
lilong12 已提交
438 439 440 441
        if backend == 'heter' or (size > 1 and global_rank in ranks):
            rank = 0 if backend == 'heter' else ranks.index(global_rank)
            src_rank = ranks[0] if backend == 'heter' else None
            dst_rank = ranks[1] if backend == 'heter' else None
442 443 444 445 446 447 448 449 450
            pg = _new_process_group_impl(backend,
                                         _default_store,
                                         rank,
                                         size,
                                         group_name,
                                         pg_options=None,
                                         group_id=gid,
                                         src_rank=src_rank,
                                         dst_rank=dst_rank)
451 452 453 454 455 456
        else:
            rank = -1
            pg = None
        group = Group(rank, size, id=gid, ranks=ranks, pg=pg, name=group_name)
        _group_map_by_name[group_name] = group
        _group_map[gid] = group
457
        _group_map_backend[group] = backend
458

459
        # TODO(shenliang03): This is a temporary solution to solve the problem of
460
        # hang caused by tcp
461
        paddle.distributed.barrier(group=group)
462 463 464 465 466
        # NOTE(liyurui): All processors should hang and wait using tcp store, in case master exit before sub-group is created.
        if backend != 'heter':
            _barrier_by_tcp_store(group_name, _default_store, timeout)
        else:
            print("Warning: store barrier is not supported for heter backend.")
467
        return group
K
kuizhiqing 已提交
468 469 470 471 472 473 474 475 476 477 478 479 480 481

    if not backend:
        backend = 'nccl'
    assert backend == 'nccl', ("backend other than nccl is not supported yet")

    genv = _get_global_env()
    global_rank = genv.rank

    ring_id = _new_ring_id()

    if global_rank not in ranks:
        gp = Group(-1, -1, ring_id, ranks)
        _group_map[ring_id] = gp
    else:
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501
        ranks = sorted(ranks)
        group_rank = ranks.index(global_rank)
        group_size = len(ranks)
        gp = Group(group_rank, group_size, ring_id, ranks)
        _group_map[ring_id] = gp

        if group_size >= 2:
            strategy = core.ParallelStrategy()
            strategy.nranks = group_size
            strategy.local_rank = group_rank
            strategy.trainer_endpoints = [
                genv.trainer_endpoints[i] for i in ranks
            ]
            strategy.current_endpoint = genv.current_endpoint
            strategy.nrings = 1

            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(genv.device_id)
                core.NCCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
502 503 504 505
            elif core.is_compiled_with_npu():
                place = core.NPUPlace(genv.device_id)
                core.HCCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
506 507 508 509
            elif core.is_compiled_with_mlu():
                place = core.MLUPlace(genv.device_id)
                core.CNCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
510 511 512 513
            elif core.is_compiled_with_xpu():
                place = core.XPUPlace(genv.device_id)
                core.BKCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
514 515 516 517 518
            else:
                assert False, ("no cuda device found")
        else:
            return gp

519
    # TODO(shenliang03): This is a temporary solution to solve the problem of
520
    # hang caused by cross-creation of new_group
521
    tmp = paddle.to_tensor(
J
Jiabin Yang 已提交
522
        [1], dtype="int32") if _non_static_mode() else fill_constant(
523
            [0], dtype="int32", value="1")
524 525
    paddle.distributed.all_reduce(tmp, use_calc_stream=True)
    paddle.distributed.wait(tmp)
K
kuizhiqing 已提交
526 527
    return gp

528

529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558
def is_initialized():
    """

    Check whether the distributed environment has been initialized

    Returns (bool): `True` if distributed environment has been initialized, otherwise `False`.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle

            print(paddle.distributed.is_initialized())
            # False

            paddle.distributed.init_parallel_env()
            print(paddle.distributed.is_initialized())
            # True

    """
    global _group_map_by_name
    return _default_group_name in _group_map_by_name


def destroy_process_group(group=None):
    """
    Destroy a given group for communication

    Args:
559 560
        group (ProcessGroup, optional): The group to be destroyed. All of process groups, including
                                        the default group, will be destroyed and the distributed
561
                                        environment will be deinitialized.
562

563 564 565 566 567 568 569
    Returns : None

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
570
            import paddle.distributed as dist
571

572 573
            dist.init_parallel_env()
            group = dist.new_group([0, 1])
574

575 576
            dist.destroy_process_group(group)
            print(dist.is_initialized())
577
            # True
578 579
            dist.destroy_process_group()
            print(dist.is_initialized())
580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598
            # False

    """
    global _group_map
    global _group_map_by_name

    pg = _get_default_group() if group is None else group
    assert _group_map.get(pg.id, None) is not None, "Invalid group."

    if group is None:
        _group_map.clear()
        _group_map_by_name.clear()
        _group_map_backend.clear()
    else:
        del _group_map[pg.id]
        del _group_map_by_name[pg.name]
        del _group_map_backend[pg]


K
kuizhiqing 已提交
599 600 601 602 603 604 605 606
def wait(tensor, group=None, use_calc_stream=True):
    """

    wait to sync stream for group.

    Args:
        tensor (Tensor): The Tensor used before sync.
        group (Group): The Group instance to perform sync.
K
kuizhiqing 已提交
607 608
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
K
kuizhiqing 已提交
609 610 611 612 613 614 615 616 617 618

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
619
            tindata = paddle.randn(shape=[2, 3])
K
kuizhiqing 已提交
620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637
            paddle.distributed.all_reduce(tindata, use_calc_stream=True)
            paddle.distributed.wait(tindata)

    """

    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id

    if use_calc_stream:
        _sync_calc_stream(tensor)
    else:
        _sync_comm_stream(tensor, ring_id)


def _sync_calc_stream(tensor):

J
Jiabin Yang 已提交
638
    if _non_static_mode():
639
        return _legacy_C_ops.c_sync_calc_stream(tensor, tensor)
K
kuizhiqing 已提交
640 641 642 643 644 645 646

    op_type = 'c_sync_calc_stream'

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
647 648
        outputs={'Out': [tensor]},
    )
649

650

K
kuizhiqing 已提交
651
def _sync_comm_stream(tensor, ring_id=0):
652

J
Jiabin Yang 已提交
653
    if _non_static_mode():
654 655
        return _legacy_C_ops.c_sync_comm_stream([tensor], [tensor], 'ring_id',
                                                ring_id)
656

K
kuizhiqing 已提交
657
    op_type = 'c_sync_comm_stream'
658

K
kuizhiqing 已提交
659 660 661 662 663
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
664 665
        attrs={'ring_id': ring_id},
    )
K
kuizhiqing 已提交
666 667 668


def broadcast(tensor, src, group=None, use_calc_stream=True):
669 670 671
    """

    Broadcast a tensor from the source to all others.
672 673
    As shown below, one process is started with a GPU and GPU0 owns data 0. Through broadcast operator,
    data 0 will be sent to all GPUs from GPU0.
674 675 676 677 678

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/broadcast.png
        :width: 800
        :alt: broadcast
        :align: center
679 680

    Args:
681 682
        tensor (Tensor): The Tensor to send if current rank is the source, or the Tensor to receive otherwise. Its data type
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
683
        src (int): The source rank.
K
kuizhiqing 已提交
684
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
685 686
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
687 688 689 690 691 692 693

    Returns:
        None.

    Examples:
        .. code-block:: python

694
            # required: distributed
695
            import paddle
696
            import paddle.distributed as dist
697

698 699 700
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
701
            else:
702 703 704 705
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.broadcast(data, src=1)
            print(data)
            # [[1, 2, 3], [1, 2, 3]] (2 GPUs)
706
    """
K
kuizhiqing 已提交
707 708 709 710 711 712 713

    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

L
lilong12 已提交
714
    if in_dygraph_mode():
715 716 717 718 719 720 721 722 723 724 725
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
        assert gsrc >= 0, ("src rank out of group, need global rank")
        task = group.process_group.broadcast(tensor, gsrc)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

    ring_id = ring_id = 0 if group is None else group.id
K
kuizhiqing 已提交
726
    gsrc = src if group is None else group.get_group_rank(src)
K
kuizhiqing 已提交
727
    assert gsrc >= 0, ("src rank out of group, need global rank")
K
kuizhiqing 已提交
728

J
Jiabin Yang 已提交
729
    if _non_static_mode():
730 731 732
        return _legacy_C_ops.c_broadcast(tensor, tensor, 'root', gsrc,
                                         'use_calc_stream', use_calc_stream,
                                         'ring_id', ring_id)
733 734

    op_type = 'c_broadcast'
735 736 737 738
    check_variable_and_dtype(tensor, 'tensor', [
        'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
        'bool'
    ], 'broadcast')
739 740

    helper = LayerHelper(op_type, **locals())
741 742 743 744 745 746 747 748
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'root': gsrc,
                         'use_calc_stream': use_calc_stream,
                         'ring_id': ring_id,
                     })
749 750


K
kuizhiqing 已提交
751
def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
752 753 754
    """

    Reduce a tensor over all ranks so that all get the result.
755
    As shown below, one process is started with a GPU and the data of this process is represented
756
    by its group rank. The reduce operator is sum. Through all_reduce operator,
757 758 759 760 761 762
    each GPU will have the sum of the data from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allreduce.png
        :width: 800
        :alt: all_reduce
        :align: center
763 764 765

    Args:
        tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
766 767
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
K
kuizhiqing 已提交
768
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
769 770
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
771 772 773 774 775 776 777

    Returns:
        None.

    Examples:
        .. code-block:: python

778
            # required: distributed
779
            import paddle
780
            import paddle.distributed as dist
781

782 783
            dist.init_parallel_env()
            if dist.get_rank() == 0:
784
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
785
            else:
786
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
787 788 789
            dist.all_reduce(data)
            print(data)
            # [[5, 7, 9], [5, 7, 9]] (2 GPUs)
790
    """
K
kuizhiqing 已提交
791 792 793
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
794
    if in_dygraph_mode():
795
        op_type = _get_reduce_op(op, "all_reduce")
796 797 798 799 800 801 802 803
        group = _get_default_group() if group is None else group
        task = group.process_group.allreduce(tensor, op_type)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

K
kuizhiqing 已提交
804
    ring_id = 0 if group is None else group.id
J
Jiabin Yang 已提交
805
    if _non_static_mode():
806
        if op == ReduceOp.SUM:
807 808 809
            return _legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
                                                  use_calc_stream, 'ring_id',
                                                  ring_id)
810
        elif op == ReduceOp.MAX:
811 812 813
            return _legacy_C_ops.c_allreduce_max_(tensor, 'use_calc_stream',
                                                  use_calc_stream, 'ring_id',
                                                  ring_id)
814
        elif op == ReduceOp.MIN:
815 816 817
            return _legacy_C_ops.c_allreduce_min_(tensor, 'use_calc_stream',
                                                  use_calc_stream, 'ring_id',
                                                  ring_id)
818
        elif op == ReduceOp.PROD:
819 820 821
            return _legacy_C_ops.c_allreduce_prod_(tensor, 'use_calc_stream',
                                                   use_calc_stream, 'ring_id',
                                                   ring_id)
822 823 824
        else:
            raise ValueError("Unknown parameter: {}.".format(op))

825 826 827 828
    check_variable_and_dtype(tensor, 'tensor', [
        'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
        'bool'
    ], 'all_reduce')
829 830 831 832 833 834 835 836
    if op == ReduceOp.SUM:
        op_type = 'c_allreduce_sum'
    elif op == ReduceOp.MAX:
        op_type = 'c_allreduce_max'
    elif op == ReduceOp.MIN:
        op_type = 'c_allreduce_min'
    elif op == ReduceOp.PROD:
        op_type = 'c_allreduce_prod'
K
kuizhiqing 已提交
837 838
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'ring_id' for all_reduce should be int.")
839
    helper = LayerHelper(op_type, **locals())
840 841 842 843 844 845 846
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': use_calc_stream
                     })
847 848


K
kuizhiqing 已提交
849
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
850 851
    """

852 853
    Reduce a tensor to the destination from all others. As shown below, one process is started with a GPU and the data of this process is represented
    by its group rank. The destination of the reduce operator is GPU0 and the process is sum. Through reduce operator,
854 855 856 857 858 859
    the GPU0 will owns the sum of all data from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/reduce.png
        :width: 800
        :alt: reduce
        :align: center
860 861 862

    Args:
        tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
863
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
864
        dst (int): The destination rank id.
865
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
K
kuizhiqing 已提交
866
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
867 868
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
869 870 871 872 873 874 875

    Returns:
        None.

    Examples:
        .. code-block:: python

876
            # required: distributed
877
            import paddle
878
            import paddle.distributed as dist
879

880 881 882
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
883
            else:
884 885 886 887 888
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.reduce(data, dst=0)
            print(data)
            # [[5, 7, 9], [5, 7, 9]] (2 GPUs, out for rank 0)
            # [[1, 2, 3], [1, 2, 3]] (2 GPUs, out for rank 1)
889
    """
K
kuizhiqing 已提交
890 891 892
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
893
    if in_dygraph_mode():
894
        op_type = _get_reduce_op(op, "reduce")
895 896 897 898 899 900 901 902 903
        group = _get_default_group() if group is None else group
        gdst = group.get_group_rank(dst)
        assert gdst >= 0, ("dst rank out of group, need global rank")
        task = group.process_group.reduce(tensor, gdst, op_type)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task
K
kuizhiqing 已提交
904 905 906

    ring_id = 0 if group is None else group.id
    gdst = dst if group is None else group.get_group_rank(dst)
K
kuizhiqing 已提交
907
    assert gdst >= 0, ("dst rank out of group, need global rank")
K
kuizhiqing 已提交
908

J
Jiabin Yang 已提交
909
    if _non_static_mode():
910
        if op == ReduceOp.SUM:
911 912 913
            return _legacy_C_ops.c_reduce_sum(tensor, tensor, 'use_calc_stream',
                                              use_calc_stream, 'ring_id',
                                              ring_id, 'root_id', gdst)
914
        elif op == ReduceOp.MAX:
915 916 917
            return _legacy_C_ops.c_reduce_max(tensor, tensor, 'use_calc_stream',
                                              use_calc_stream, 'ring_id',
                                              ring_id, 'root_id', gdst)
918
        elif op == ReduceOp.MIN:
919 920 921
            return _legacy_C_ops.c_reduce_min(tensor, tensor, 'use_calc_stream',
                                              use_calc_stream, 'ring_id',
                                              ring_id, 'root_id', gdst)
922
        elif op == ReduceOp.PROD:
923 924 925 926
            return _legacy_C_ops.c_reduce_prod(tensor, tensor,
                                               'use_calc_stream',
                                               use_calc_stream, 'ring_id',
                                               ring_id, 'root_id', gdst)
927 928 929 930
        else:
            raise ValueError("Unknown parameter: {}.".format(op))

    op_type = 'c_reduce'
931 932 933 934
    check_variable_and_dtype(tensor, 'tensor', [
        'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
        'bool'
    ], 'reduce')
935 936 937 938 939 940 941 942 943 944 945

    if op == ReduceOp.SUM:
        op_type = 'c_reduce_sum'
    elif op == ReduceOp.MAX:
        op_type = 'c_reduce_max'
    elif op == ReduceOp.MIN:
        op_type = 'c_reduce_min'
    elif op == ReduceOp.PROD:
        op_type = 'c_reduce_prod'

    helper = LayerHelper(op_type, **locals())
946 947 948 949 950 951 952 953
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': use_calc_stream,
                         'root_id': gdst,
                     })
954 955


K
kuizhiqing 已提交
956
def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
957 958
    """

959
    Gather tensors from all participators and all get the result. As shown
960 961
    below, one process is started with a GPU and the data of this process is represented
    by its group rank. Through the all_gather operator, each GPU will have data
962 963 964 965 966 967
    from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png
        :width: 800
        :alt: all_gather
        :align: center
968 969 970

    Args:
        tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
971
            should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
972
        tensor (Tensor): The Tensor to send. Its data type
973
            should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
K
kuizhiqing 已提交
974
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
975 976
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
977 978 979 980 981 982 983

    Returns:
        None.

    Examples:
        .. code-block:: python

984
            # required: distributed
985
            import paddle
986
            import paddle.distributed as dist
987

988
            dist.init_parallel_env()
989
            tensor_list = []
990 991
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
992
            else:
993 994 995 996
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.all_gather(tensor_list, data)
            print(tensor_list)
            # [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
997
    """
K
kuizhiqing 已提交
998 999 1000
    if group is not None and not group.is_member():
        return

1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011
    def convert_to_complex(list_of_tensor):
        list_of_complex = []
        for tensor in list_of_tensor:
            list_of_complex.append(paddle.as_complex(tensor))
        return list_of_complex

    is_input_complex = (tensor.dtype == paddle.complex64
                        or tensor.dtype == paddle.complex128)
    if is_input_complex:
        tensor = paddle.as_real(tensor)

L
lilong12 已提交
1012
    if in_dygraph_mode():
1013
        group = _get_default_group() if group is None else group
1014 1015 1016 1017 1018 1019
        if len(tensor_list) == 0:
            tensor_shape = list(tensor.shape)
            tensor_shape[0] *= group.nranks
            out = paddle.empty(tensor_shape, tensor.dtype)
        else:
            out = paddle.concat(tensor_list, axis=0)
1020 1021 1022
        task = group.process_group.all_gather(tensor, out)
        task.wait()
        tensor_list.clear()
1023 1024 1025 1026 1027
        list_of_tensor = paddle.split(out, group.nranks, 0)
        if is_input_complex:
            tensor_list.extend(convert_to_complex(list_of_tensor))
        else:
            tensor_list.extend(list_of_tensor)
1028 1029
        return

K
kuizhiqing 已提交
1030 1031 1032
    ring_id = 0 if group is None else group.id
    nranks = _get_global_group().nranks if group is None else group.nranks

J
Jiabin Yang 已提交
1033
    if _non_static_mode():
1034 1035 1036
        out = _legacy_C_ops.c_allgather(tensor, 'use_calc_stream',
                                        use_calc_stream, 'ring_id', ring_id,
                                        'nranks', nranks)
1037
    else:
1038 1039 1040
        op_type = 'c_allgather'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
1041 1042 1043 1044
        if not isinstance(tensor_list, list):
            raise ValueError("The type of 'tensor_list' for all_gather "
                             "should be list.")
        for elem in tensor_list:
1045 1046 1047 1048 1049 1050 1051 1052
            check_variable_and_dtype(elem, 'tensor_list', [
                'float16', 'float32', 'float64', 'int32', 'int64', 'bool',
                'int8', 'uint8', 'complex64', 'complex128'
            ], 'all_gather')
        check_variable_and_dtype(tensor, 'tensor', [
            'float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'int8',
            'uint8', 'complex64', 'complex128'
        ], 'all_gather')
1053 1054 1055 1056 1057 1058 1059 1060
        helper.append_op(type=op_type,
                         inputs={'X': [tensor]},
                         outputs={'Out': [out]},
                         attrs={
                             'ring_id': ring_id,
                             'use_calc_stream': use_calc_stream,
                             'nranks': nranks
                         })
1061

1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074
    list_of_tensor = paddle.split(out, nranks, 0)
    if is_input_complex:
        tensor_list.extend(convert_to_complex(list_of_tensor))
    else:
        tensor_list.extend(list_of_tensor)


def _convert_object_to_tensor(obj):
    _pickler = pickle.Pickler
    f = io.BytesIO()
    _pickler(f).dump(obj)
    data = np.frombuffer(f.getvalue(), dtype=np.uint8)
    tensor = paddle.to_tensor(data)
1075
    return tensor, tensor.numel()
1076 1077


1078
def _convert_tensor_to_object(tensor, len_of_tensor):
1079
    _unpickler = pickle.Unpickler
1080
    return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()
1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107


def all_gather_object(object_list, obj, group=None):
    """

    Gather picklable objects from all participators and all get the result. Similiar to all_gather(), but python object can be passed in.

    Args:
        object_list (list): A list of output object. The datatype of every element in the list is same as the input obj.
        obj (Any): The picklable object to send.
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Warning:
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            object_list = []
1108
            if dist.get_rank() == 0:
1109 1110 1111
                obj = {"foo": [1, 2, 3]}
            else:
                obj = {"bar": [4, 5, 6]}
1112 1113 1114
            dist.all_gather_object(object_list, obj)
            print(object_list)
            # [{'foo': [1, 2, 3]}, {'bar': [4, 5, 6]}] (2 GPUs)
1115 1116 1117 1118
    """
    assert in_dygraph_mode(
    ), "all_gather_object doesn't support static graph mode."

1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131
    tensor, len_of_tensor = _convert_object_to_tensor(obj)

    # gather len_of_tensor from all ranks
    list_len_of_tensor = []
    all_gather(list_len_of_tensor, len_of_tensor, group)
    # get the max length from list
    max_len_of_tensor = int(max(list_len_of_tensor).item())
    # resize the input tensor to max length avoid hang in all gather
    # Note(liyurui): Maybe we should support various length all_gather?
    # Now this operation is efficient for we don't support resize in python.
    numpy_data = tensor.numpy()
    numpy_data = np.resize(numpy_data, [max_len_of_tensor])
    input_tensor = paddle.to_tensor(numpy_data)
1132 1133

    tensor_list = []
1134 1135 1136 1137
    all_gather(tensor_list, input_tensor, group)
    for i, tensor in enumerate(tensor_list):
        object_list.append(
            _convert_tensor_to_object(tensor, list_len_of_tensor[i]))
1138 1139


K
kuizhiqing 已提交
1140
def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
1141 1142
    """

1143
    Scatter a tensor to all participators. As shown below, one process is started with a GPU and the source of the scatter
1144 1145 1146 1147 1148 1149
    is GPU0. Through scatter operator, the data in GPU0 will be sent to all GPUs averagely.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/scatter.png
        :width: 800
        :alt: scatter
        :align: center
1150 1151 1152

    Args:
        tensor (Tensor): The output Tensor. Its data type
1153
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
1154
        tensor_list (list|tuple): A list/tuple of Tensors to scatter. Every element in the list must be a Tensor whose data type
1155
            should be float16, float32, float64, int32, int64, int8, uint8 or bool. Default value is None.
K
kuizhiqing 已提交
1156
        src (int): The source rank id. Default value is 0.
K
kuizhiqing 已提交
1157
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
1158 1159
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
1160 1161 1162 1163 1164 1165 1166

    Returns:
        None.

    Examples:
        .. code-block:: python

1167
            # required: distributed
1168
            import paddle
1169
            import paddle.distributed as dist
1170

1171 1172 1173 1174 1175
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([7, 8, 9])
                data2 = paddle.to_tensor([10, 11, 12])
                dist.scatter(data1, src=1)
1176
            else:
1177 1178 1179 1180 1181 1182
                data1 = paddle.to_tensor([1, 2, 3])
                data2 = paddle.to_tensor([4, 5, 6])
                dist.scatter(data1, tensor_list=[data1, data2], src=1)
            print(data1, data2)
            # [1, 2, 3] [10, 11, 12] (2 GPUs, out for rank 0)
            # [4, 5, 6] [4, 5, 6] (2 GPUs, out for rank 1)
1183
    """
K
kuizhiqing 已提交
1184 1185 1186 1187 1188 1189
    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

L
lilong12 已提交
1190
    if in_dygraph_mode():
1191 1192 1193 1194 1195 1196 1197 1198 1199
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
        rank = group.rank
        nranks = group.nranks
    else:
        ring_id = 0 if group is None else group.id
        gsrc = src if group is None else group.get_group_rank(src)
        rank = _get_global_group().rank if group is None else group.rank
        nranks = _get_global_group().nranks if group is None else group.nranks
K
kuizhiqing 已提交
1200
    assert gsrc >= 0, ("src rank out of group, need global rank")
K
kuizhiqing 已提交
1201 1202

    if rank != gsrc:
1203 1204 1205 1206
        tensor_list = []
        for _ in range(nranks):
            tensor_list.append(tensor)
    temp = paddle.concat(tensor_list, axis=0)
L
lilong12 已提交
1207
    if in_dygraph_mode():
1208 1209 1210 1211 1212 1213 1214
        task = group.process_group.scatter(temp, tensor, gsrc)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

L
lilong12 已提交
1215
    if _non_static_mode():
1216 1217 1218
        return _legacy_C_ops.c_scatter(temp, tensor, 'use_calc_stream',
                                       use_calc_stream, 'ring_id', ring_id,
                                       'nranks', nranks, 'root', gsrc)
W
wanghuancoder 已提交
1219
    op_type = 'c_scatter'
1220 1221 1222 1223
    check_variable_and_dtype(tensor, 'tensor', [
        'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
        'bool'
    ], 'scatter')
1224
    helper = LayerHelper(op_type, **locals())
1225 1226 1227 1228 1229 1230 1231 1232 1233
    helper.append_op(type=op_type,
                     inputs={'X': [temp]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'root': gsrc,
                         'use_calc_stream': use_calc_stream,
                         'nranks': nranks,
                     })
1234 1235


L
lilong12 已提交
1236 1237
def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
    """
1238 1239 1240 1241 1242 1243 1244 1245 1246 1247
    Scatter tensors in in_tensor_list to all participators averagely and gather the result tensors in out_tensor_list.
    As shown below, the in_tensor_list in GPU0 includes 0_0 and 0_1, and GPU1 includes 1_0 and 1_1.
    Through alltoall operator, the 0_0 in GPU0 will be sent to GPU0 and 0_1 to GPU1, 1_0 in GPU1 sent to GPU0 and 1_1 to GPU1.
    Finally the out_tensor_list in GPU0 includes 0_0 and 1_0, and GPU1 includes 0_1 and 1_1.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/alltoall.png
        :width: 800
        :alt: alltoall
        :align: center

L
lilong12 已提交
1248 1249
    Args:
        in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type
1250
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
1251
        out_tensor_list (list): A list of output Tensors. The data type of its elements should be the same as the
L
lilong12 已提交
1252 1253
            data type of the input Tensors.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
1254
        use_calc_stream (bool, optional): Whether to use calculation stream (True) or communication stream. Default: True.
1255

L
lilong12 已提交
1256 1257
    Returns:
        None.
1258

L
lilong12 已提交
1259 1260
    Examples:
        .. code-block:: python
1261

L
lilong12 已提交
1262 1263
            # required: distributed
            import paddle
1264 1265 1266
            import paddle.distributed as dist

            dist.init_parallel_env()
L
lilong12 已提交
1267
            out_tensor_list = []
1268 1269 1270
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
                data2 = paddle.to_tensor([[7, 8, 9], [10, 11, 12]])
L
lilong12 已提交
1271
            else:
1272 1273 1274 1275 1276 1277
                data1 = paddle.to_tensor([[13, 14, 15], [16, 17, 18]])
                data2 = paddle.to_tensor([[19, 20, 21], [22, 23, 24]])
            dist.alltoall([data1, data2], out_tensor_list)
            print(out_tensor_list)
            # [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]] (2 GPUs, out for rank 0)
            # [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]] (2 GPUs, out for rank 1)
L
lilong12 已提交
1278 1279 1280 1281
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
1282
    if in_dygraph_mode():
1283
        group = _get_default_group() if group is None else group
1284 1285
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
1286 1287 1288
    else:
        ring_id = 0 if group is None else group.id

L
lilong12 已提交
1289
    temp = paddle.concat(in_tensor_list, axis=0)
李季 已提交
1290
    nranks = len(in_tensor_list)
L
lilong12 已提交
1291
    if in_dygraph_mode():
1292 1293 1294 1295 1296 1297
        if len(out_tensor_list) == 0:
            tensor_shape = list(in_tensor_list[0].shape)
            tensor_shape[0] *= nranks
            out = paddle.empty(tensor_shape, in_tensor_list[0].dtype)
        else:
            out = paddle.concat(out_tensor_list, axis=0)
1298 1299 1300 1301 1302 1303
        task = group.process_group.alltoall(temp, out)
        task.wait()
        out_tensor_list.clear()
        out_tensor_list.extend(paddle.split(out, nranks, 0))
        return

J
Jiabin Yang 已提交
1304
    if _non_static_mode():
1305 1306
        out = _legacy_C_ops.alltoall(temp, 'use_calc_stream', use_calc_stream,
                                     'ring_id', ring_id)
L
lilong12 已提交
1307
    else:
W
wanghuancoder 已提交
1308 1309 1310 1311 1312
        op_type = 'alltoall'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(
            dtype=in_tensor_list[0].dtype)

L
lilong12 已提交
1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326
        if not isinstance(in_tensor_list, list):
            raise ValueError("The type of 'in_tensor_list' for all_to_all "
                             "should be list.")
        for elem in in_tensor_list:
            check_variable_and_dtype(
                elem, 'in_tensor_list',
                ['float16', 'float32', 'float64', 'int32', 'int64'],
                'all_to_all')
        if not isinstance(out_tensor_list, list):
            raise ValueError("The type of 'out_tensor_list' for all_to_all "
                             "should be list.")
        if len(out_tensor_list) != 0:
            raise ValueError("The 'out_tensor_list' for all_to_all "
                             "must be an empty list.")
1327 1328 1329 1330 1331 1332 1333
        helper.append_op(type=op_type,
                         inputs={'X': [temp]},
                         outputs={'Out': [out]},
                         attrs={
                             'ring_id': ring_id,
                             'use_calc_stream': use_calc_stream,
                         })
L
lilong12 已提交
1334 1335 1336
    out_tensor_list.extend(paddle.split(out, nranks, 0))


1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349
def alltoall_single(in_tensor,
                    out_tensor,
                    in_split_sizes=None,
                    out_split_sizes=None,
                    group=None,
                    use_calc_stream=True):
    """
    Scatter a single input tensor to all participators and gather the received tensors in out_tensor.

    .. note::
        ``alltoall_single`` is only supported in eager mode.

    Args:
1350
        in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
1351
        out_tensor (Tensor): Output Tensor. The data type should be the same as the data type of the input Tensor.
1352
        in_split_sizes (list[int], optional): Split sizes of ``in_tensor`` for dim[0]. If not given, dim[0] of ``in_tensor``
1353
            must be divisible by group size and ``in_tensor`` will be scattered averagely to all participators. Default: None.
1354
        out_split_sizes (list[int], optional): Split sizes of ``out_tensor`` for dim[0]. If not given, dim[0] of ``out_tensor``
1355 1356 1357
            must be divisible by group size and ``out_tensor`` will be gathered averagely from all participators. Default: None.
        group (Group, optional): The group instance return by ``new_group`` or None for global default group. Default: None.
        use_calc_stream (bool, optional): Whether to use calculation stream (True) or communication stream. Default: True.
1358

1359 1360
    Returns:
        None, if ``use_calc_stream`` is set to ``True``; ``Task`` of ``group``, if ``use_calc_stream`` is set to ``False``.
1361

1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372
    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            size = dist.get_world_size()

1373 1374 1375 1376
            # case 1 (2 GPUs)
            data = paddle.arange(2, dtype='int64') + rank * 2
            # data for rank 0: [0, 1]
            # data for rank 1: [2, 3]
1377
            output = paddle.empty([2], dtype='int64')
1378 1379
            dist.alltoall_single(data, output)
            print(output)
1380 1381 1382
            # output for rank 0: [0, 2]
            # output for rank 1: [1, 3]

1383
            # case 2 (2 GPUs)
1384
            in_split_sizes = [i + 1 for i in range(size)]
1385 1386
            # in_split_sizes for rank 0: [1, 2]
            # in_split_sizes for rank 1: [1, 2]
1387
            out_split_sizes = [rank + 1 for i in range(size)]
1388 1389 1390 1391 1392
            # out_split_sizes for rank 0: [1, 1]
            # out_split_sizes for rank 1: [2, 2]
            data = paddle.ones([sum(in_split_sizes), size], dtype='float32') * rank
            # data for rank 0: [[0., 0.], [0., 0.], [0., 0.]]
            # data for rank 1: [[1., 1.], [1., 1.], [1., 1.]]
1393 1394
            output = paddle.empty([(rank + 1) * size, size], dtype='float32')
            group = dist.new_group([0, 1])
1395
            task = dist.alltoall_single(data,
1396 1397 1398 1399 1400 1401
                                        output,
                                        in_split_sizes,
                                        out_split_sizes,
                                        use_calc_stream=False,
                                        group=group)
            task.wait()
1402
            print(output)
1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413
            # output for rank 0: [[0., 0.], [1., 1.]]
            # output for rank 1: [[0., 0.], [0., 0.], [1., 1.], [1., 1.]]

    """
    if group is not None and not group.is_member():
        return

    assert in_dygraph_mode(), "Only suppport alltoall_single in eager mode."
    # _check_single_tensor

    group = _get_default_group() if group is None else group
1414 1415 1416
    backend = _group_map_backend[group]
    assert backend != 'gloo', ("backend gloo is not supported yet")

1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428
    in_split_sizes = [] if in_split_sizes is None else in_split_sizes
    out_split_sizes = [] if out_split_sizes is None else out_split_sizes

    task = group.process_group.alltoall_single(in_tensor, out_tensor,
                                               in_split_sizes, out_split_sizes)
    if use_calc_stream:
        task.wait()
        return
    else:
        return task


S
ShenLiang 已提交
1429 1430 1431 1432
def _get_group_rank(global_rank, group=None):
    return global_rank if group is None else group.get_group_rank(global_rank)


L
lilong12 已提交
1433 1434 1435 1436 1437 1438
def send(tensor, dst=0, group=None, use_calc_stream=True):
    """
    Send a tensor to the receiver.

    Args:
        tensor (Tensor): The Tensor to send. Its data type
1439
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
L
lilong12 已提交
1440
        dst (int): The destination rank id.
L
lilong12 已提交
1441 1442
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
        use_calc_stream (bool, optional): Whether to use calculate stream or communication stream. Default: True.
1443

L
lilong12 已提交
1444 1445 1446 1447 1448
    Returns:
        None.

    Examples:
        .. code-block:: python
1449

L
lilong12 已提交
1450
            # required: distributed
L
lilong12 已提交
1451
            import paddle
1452
            import paddle.distributed as dist
1453

1454 1455
            dist.init_parallel_env()
            if dist.get_rank() == 0:
L
lilong12 已提交
1456
                data = paddle.to_tensor([7, 8, 9])
1457
                dist.send(data, dst=1)
L
lilong12 已提交
1458
            else:
1459 1460 1461 1462
                data = paddle.to_tensor([1, 2, 3])
                dist.recv(data, src=0)
            print(data)
            # [7, 8, 9] (2 GPUs)
L
lilong12 已提交
1463 1464 1465
    """
    if group is not None and not group.is_member():
        return
S
ShenLiang 已提交
1466
    dst = _get_group_rank(dst, group)
L
lilong12 已提交
1467
    if in_dygraph_mode():
1468
        group = _get_default_group() if group is None else group
1469 1470
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
S
ShenLiang 已提交
1471
        task = group.process_group.send(tensor, dst)
1472 1473 1474 1475 1476 1477
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

L
lilong12 已提交
1478 1479
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
1480
    if _non_static_mode():
1481 1482
        return _legacy_C_ops.send_v2(tensor, 'use_calc_stream', use_calc_stream,
                                     'ring_id', ring_id, 'peer', dst)
W
wanghuancoder 已提交
1483
    op_type = 'send_v2'
L
lilong12 已提交
1484 1485 1486 1487 1488
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'send')

    helper = LayerHelper(op_type, **locals())
1489 1490 1491 1492 1493 1494 1495
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'peer': dst,
                         'use_calc_stream': use_calc_stream,
                     })
L
lilong12 已提交
1496 1497 1498 1499 1500 1501 1502 1503


def recv(tensor, src=0, group=None, use_calc_stream=True):
    """
    Receive a tensor to the sender.

    Args:
        tensor (Tensor): The Tensor to receive. Its data type
1504
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
L
lilong12 已提交
1505
        src (int): The source rank id.
L
lilong12 已提交
1506 1507
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
        use_calc_stream (bool, optional): Whether to use calculate stream or communication stream. Default: True.
1508

L
lilong12 已提交
1509 1510 1511 1512 1513
    Returns:
        None.

    Examples:
        .. code-block:: python
1514

L
lilong12 已提交
1515
            # required: distributed
L
lilong12 已提交
1516
            import paddle
1517
            import paddle.distributed as dist
1518

1519 1520
            dist.init_parallel_env()
            if dist.get_rank() == 0:
L
lilong12 已提交
1521
                data = paddle.to_tensor([7, 8, 9])
1522
                dist.send(data, dst=1)
L
lilong12 已提交
1523
            else:
1524 1525 1526 1527
                data = paddle.to_tensor([1, 2, 3])
                dist.recv(data, src=0)
            print(data)
            # [7, 8, 9] (2 GPUs)
L
lilong12 已提交
1528 1529 1530
    """
    if group is not None and not group.is_member():
        return
1531

S
ShenLiang 已提交
1532
    src = _get_group_rank(src, group)
L
lilong12 已提交
1533
    if in_dygraph_mode():
1534
        group = _get_default_group() if group is None else group
1535 1536
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
S
ShenLiang 已提交
1537
        task = group.process_group.recv(tensor, src)
1538 1539 1540 1541 1542 1543
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

L
lilong12 已提交
1544 1545
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
1546
    if _non_static_mode():
1547 1548 1549
        return _legacy_C_ops.recv_v2(tensor, 'use_calc_stream', use_calc_stream,
                                     'ring_id', ring_id, 'peer', src, 'dtype',
                                     tensor.dtype, 'out_shape', tensor.shape)
W
wanghuancoder 已提交
1550
    op_type = 'recv_v2'
L
lilong12 已提交
1551 1552 1553 1554
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'recv')
    helper = LayerHelper(op_type, **locals())
1555 1556 1557 1558 1559 1560 1561 1562 1563
    helper.append_op(type=op_type,
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'peer': src,
                         'out_shape': tensor.shape,
                         'dtype': tensor.dtype,
                         'use_calc_stream': use_calc_stream,
                     })
1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585


def _check_single_tensor(tensor, tensor_name):
    if not isinstance(tensor, (core.eager.Tensor, paddle.Tensor)):
        raise RuntimeError("Invalid function argument. Expected parameter {}"
                           "to be of type paddle.Tensor, but it's {}".format(
                               tensor_name, type(tensor)))


def _check_tensor_list(tensor_list, tensor_name):
    if not isinstance(tensor_list, list) or \
        not all(isinstance(t, (core.eager.Tensor, paddle.Tensor)) for t in tensor_list):
        raise RuntimeError("Invalid function argument. Expected parameter {}"
                           "to be of type paddle.Tensor".format(tensor_name))


def isend(tensor, dst, group=None):
    """
    Sends a tensor asynchronously

    Args:
        tensor (Tensor): The Tensor to send. Its data type
1586
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
1587 1588
        dst (int): The destination rank.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
1589

1590 1591 1592
    Returns:
        A distributed task object.

1593
    Warning:
1594 1595 1596 1597 1598 1599 1600 1601 1602 1603
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
1604
            if dist.get_rank() == 0:
1605
                data = paddle.to_tensor([7, 8, 9])
1606
                task = dist.isend(data, dst=1)
1607 1608
            else:
                data = paddle.to_tensor([1, 2, 3])
1609
                task = dist.irecv(data, src=0)
1610 1611
            task.wait()
            print(data)
1612
            # [7, 8, 9] (2 GPUs)
1613 1614 1615 1616 1617 1618 1619 1620

    """
    _check_single_tensor(tensor, "tensor")
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
1621 1622
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
1623 1624 1625 1626
        group_dst_rank = group.get_group_rank(dst)
        assert group_dst_rank >= 0, ("dst rank out of group, need global rank")
        return group.process_group.send(tensor, group_dst_rank)
    else:
1627
        raise RuntimeError("Only support eager dygraph mode.")
1628 1629 1630 1631 1632 1633 1634 1635


def irecv(tensor, src=None, group=None):
    """
    Receive a tensor to the sender.

    Args:
        tensor (Tensor): The Tensor to receive. Its data type
1636
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
1637 1638 1639 1640
        src (int): The source rank id.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.

    Returns:
1641
        A distributed task object.
1642

1643
    Warning:
1644 1645 1646 1647 1648 1649 1650 1651 1652 1653
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
1654
            if dist.get_rank() == 0:
1655
                data = paddle.to_tensor([7, 8, 9])
1656
                task = dist.isend(data, dst=1)
1657 1658
            else:
                data = paddle.to_tensor([1, 2, 3])
1659
                task = dist.irecv(data, src=0)
1660 1661
            task.wait()
            print(data)
1662
            # [7, 8, 9] (2 GPUs)
1663 1664 1665 1666 1667 1668 1669
    """
    _check_single_tensor(tensor, "tensor")
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
1670 1671
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
1672 1673 1674 1675
        group_src_rank = group.get_group_rank(src)
        assert group_src_rank >= 0, ("src rank out of group, need global rank")
        return group.process_group.recv(tensor, group_src_rank)
    else:
1676
        raise RuntimeError("Only support eager dygraph mode.")
1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691


class P2POp(object):
    """
    A class that makes point-to-point operations for "batch_isend_irecv".

    This class creates the type of P2P operation, communication buffer, peer rank,
    Group. Instances of this class will be passed to
    ``paddle.distributed.batch_isend_irecv`` for point-to-point communication.

    Args:
        op (callable): A function to send data to or receive data from a peer process.
            The type of ``op`` is either ``paddle.distributed.isend`` or ``paddle.distributed.irecv``.
        tensor (Tensor): Tensor to send or receive.
        peer (int): The destination or source rank.
1692
        group (Group, optional): The group instance return by new_group or None for global
1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740
            default group. Default: None.

    """

    def __init__(self, op, tensor, peer, group=None):
        if op not in [isend, irecv]:
            raise RuntimeError("Invalid ``op`` function. Expected ``op`` "
                               "to be of type ``paddle.distributed.isend`` or "
                               "``paddle.distributed.irecv``.")
        _check_single_tensor(tensor, "tensor")

        self.op = op
        self.tensor = tensor
        self.peer = peer
        self.group = _get_default_group() if group is None else group


@contextlib.contextmanager
def _with_batch_p2p_guard(backend):
    if backend == "nccl":
        core.ProcessGroupNCCL.group_start()
    try:
        yield
    finally:
        if backend == "nccl":
            core.ProcessGroupNCCL.group_end()


def _check_p2p_op_list(p2p_op_list):
    """
    Helper to check that the ``p2p_op_list`` is a list of P2POp instances and
    all ops use the same backend.
    """
    if not isinstance(p2p_op_list, list) or not all(
            isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list):
        raise RuntimeError("Invalid ``p2p_op_list``. Each op is expected to "
                           "to be of type ``paddle.distributed.P2POp``.")

    backend = _group_map_backend[p2p_op_list[0].group]
    if not all(backend == _group_map_backend[p2p_op.group]
               for p2p_op in p2p_op_list):
        raise RuntimeError("All groups need to use the same backend.")


def batch_isend_irecv(p2p_op_list):
    """
    Send or Receive a batch of tensors asynchronously and return a list of requests.

1741
    Process each of the point-to-point operations in ``p2p_op_list`` and return the
1742 1743 1744 1745 1746 1747 1748 1749 1750 1751
    corresponding tasks. NCCL are currently supported.

    Args:
        p2p_op_list: A list of point-to-point operations(type of each operator is
            ``paddle.distributed.P2POp``). The order of the isend/irecv in the list
            matters and it needs to match with corresponding isend/irecv on the
            remote end.

    Returns:
        A list of distributed tasks returned by calling the corresponding
1752
        op in the op_list.
1753

1754
    Warning:
1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed

            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            world_size = dist.get_world_size()

            send_t = paddle.arange(2) + rank
            # paddle.tensor([0, 1])  # Rank-0
            # paddle.tensor([1, 2])  # Rank-1

            recv_t = paddle.empty(shape=[2], dtype=send_t.dtype)

            send_op = dist.P2POp(dist.isend, send_t, (rank + 1) % world_size)
            recv_op = dist.P2POp(dist.irecv, recv_t, (rank - 1 + world_size) % world_size)

            tasks = dist.batch_isend_irecv([send_op, recv_op])

            for task in tasks:
                task.wait()
1782

1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818
            print(recv_t)
            # paddle.tensor([1, 2])     # Rank-0
            # paddle.tensor([0, 1])     # Rank-1
    """
    _check_p2p_op_list(p2p_op_list)
    group = p2p_op_list[0].group
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
        backend = _group_map_backend[group]
        tasks = []
        with _with_batch_p2p_guard(backend):
            for p2p_op in p2p_op_list:
                op = p2p_op.op
                tensor = p2p_op.tensor
                peer = p2p_op.peer
                comm_group = p2p_op.group
                task = op(tensor, peer, comm_group)
                if task is not None:
                    tasks.append(task)
        return tasks
    else:
        raise RuntimeError("Don't support static graph mode currently.")


def reduce_scatter(tensor,
                   tensor_list,
                   op=ReduceOp.SUM,
                   group=None,
                   use_calc_stream=True):
    """
    Reduces, then scatters a list of tensors to all processes in a group

    Args:
1819 1820 1821
        tensor (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
        tensor_list (list[Tensor]): List of tensors to reduce and scatter. Every element in the list must be a Tensor whose data type
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
1822
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
1823
        group (Group, optional): The group instance return by new_group or None for global
1824 1825 1826 1827 1828 1829
            default group. Default: None.
        use_calc_stream (bool, optional): Whether this op should be an async op.

    Returns:
        Async task handle, if use_calc_stream is set to False.
        None, if use_calc_stream or if not part of the group.
1830 1831

    Warning:
1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842
        This API only supports the dygraph mode.


    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
1843 1844 1845
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([0, 1])
                data2 = paddle.to_tensor([2, 3])
1846
            else:
1847 1848 1849 1850 1851 1852
                data1 = paddle.to_tensor([4, 5])
                data2 = paddle.to_tensor([6, 7])
            dist.reduce_scatter(data1, [data1, data2])
            print(data1)
            # [4, 6] (2 GPUs, out for rank 0)
            # [8, 10] (2 GPUs, out for rank 1)
1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863

    """
    _check_single_tensor(tensor, "tensor")
    _check_tensor_list(tensor_list, "tensor_list")

    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        op_type = _get_reduce_op(op, "reduce_scatter")
        group = _get_default_group() if group is None else group
1864 1865
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886

        temp = paddle.concat(tensor_list, axis=0)
        task = group.process_group._reduce_scatter_base(tensor, temp, op_type)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task
    else:
        raise RuntimeError("Don't support static graph mode currently.")


def _reduce_scatter_base(output,
                         input,
                         op=ReduceOp.SUM,
                         group=None,
                         use_calc_stream=True):
    """
    Reduces, then scatters a flattened tensor to all processes in a group.

    Args:
1887
        output (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
1888
        input (Tensor): Input tensor that is of size output tensor size times world size. Its data type
1889
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
1890
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.
        use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
    Returns:
        Async task handle, if use_calc_stream is set to False.
        None, if use_calc_stream or if not part of the group.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
1908 1909 1910 1911 1912
            data = paddle.arange(4) + rank
            # [0, 1, 2, 3] (2 GPUs, for rank 0)
            # [1, 2, 3, 4] (2 GPUs, for rank 1)
            output = paddle.empty(shape=[2], dtype=data.dtype)
            dist.collective._reduce_scatter_base(output, data)
1913
            print(output)
1914 1915
            # [1, 3] (2 GPUs, out for rank 0)
            # [5, 7] (2 GPUs, out for rank 1)
1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934

    """
    _check_single_tensor(output, "output")
    _check_single_tensor(input, "input")

    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        op_type = _get_reduce_op(op, "_reduce_scatter_base")
        group = _get_default_group() if group is None else group
        task = group.process_group._reduce_scatter_base(output, input, op_type)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task
    else:
        raise RuntimeError("Don't support static graph mode currently.")