collective.py 91.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import os
17
from datetime import timedelta
18
from ..fluid.layer_helper import LayerHelper
19
from ..fluid.framework import Variable
20
from ..fluid.framework import in_dygraph_mode
21
from ..fluid.framework import OpProtoHolder
J
Jiabin Yang 已提交
22
from ..fluid.framework import _non_static_mode
23
from ..fluid.framework import _in_legacy_dygraph
24
from ..fluid.framework import convert_np_dtype_to_dtype_
J
Jiangxinz 已提交
25
from ..fluid.framework import _varbase_creator
26 27 28 29
from ..fluid.data_feeder import convert_dtype
from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.data_feeder import check_type
from ..fluid.data_feeder import check_dtype
30 31
from ..fluid.layers.tensor import fill_constant
from ..fluid.layers import utils
B
Baibaifan 已提交
32
from ..fluid.dygraph import layers
33 34 35 36
from ..fluid.dygraph.parallel import prepare_context
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
W
wanghuancoder 已提交
37
from paddle import _C_ops
J
Jiangxinz 已提交
38
import paddle.fluid.dygraph_utils as dygraph_utils
39
import contextlib
40

41
__all__ = []
42 43 44


class ReduceOp:
L
lilong12 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
    """
    Specify the type of operation used for element-wise reductions.
    It should be one of the following values:

        ReduceOp.SUM

        ReduceOp.MAX

        ReduceOp.MIN

        ReduceOp.PROD

    Examples:
        .. code-block:: python

            import numpy as np
            import paddle
            from paddle.distributed import ReduceOp
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data = np.array([[4, 5, 6], [4, 5, 6]])
            else:
                np_data = np.array([[1, 2, 3], [1, 2, 3]])
            data = paddle.to_tensor(np_data)
            paddle.distributed.all_reduce(data, op=ReduceOp.SUM)
            out = data.numpy()
            # [[5, 7, 9], [5, 7, 9]]
    """
76 77 78 79
    SUM = 0
    MAX = 1
    MIN = 2
    PROD = 3
80
    AVG = 4
81 82


K
kuizhiqing 已提交
83 84 85 86
class Group():
    """
    The abstract representation of group.
    """
87

88
    def __init__(self, rank, rank_num, id=0, ranks=[], pg=None, name=None):
89 90
        self.rank = rank
        self.nranks = rank_num
K
kuizhiqing 已提交
91 92
        self.id = id
        self.ranks = ranks
93 94
        self.pg = pg
        self.name = name
K
kuizhiqing 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108

    def is_member(self):
        if self.rank < 0:
            return False
        if self.nranks < 2:
            return False
        return True

    def get_group_rank(self, rank):
        if self.is_member() and rank in self.ranks:
            return self.ranks.index(rank)
        else:
            return -1

109 110 111 112
    @property
    def process_group(self):
        return self.pg

113 114 115 116
    def __repr__(self):
        debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format(
            self.rank, self.nranks, self.id)
        debug_str += ", ".join(map(str, self.ranks))
117 118
        debug_str += "; name: "
        debug_str += self.name if self.name else "None"
119 120
        return debug_str

K
kuizhiqing 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135

_global_env = None


def _get_global_env():
    global _global_env
    if not _global_env:
        _global_env = paddle.distributed.ParallelEnv()
    return _global_env


# group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group]
_group_map = {}

136 137 138 139
# group map by name : the map of all groups from their names
# Dict[name, Group]
_group_map_by_name = {}

140 141 142 143
# backend map by group : the map of all backend from their groups
# Dict[group, backend]
_group_map_backend = {}

144 145 146
# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"

147
_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter']
148 149 150
_default_store = None  # the default tcp store
_default_backend = None

K
kuizhiqing 已提交
151

L
lilong12 已提交
152 153 154 155 156 157 158 159 160 161
def _set_default_backend(backend):
    global _default_backend
    _default_backend = backend


def _set_default_store(store):
    global _default_store
    _default_store = store


K
kuizhiqing 已提交
162 163 164 165
def _get_group_map():
    global _group_map
    if not _group_map:
        genv = _get_global_env()
166 167 168
        _group_map[0] = Group(genv.rank,
                              genv.world_size,
                              ranks=list(range(genv.world_size)))
K
kuizhiqing 已提交
169 170 171 172 173 174 175
    return _group_map


def _get_global_group():
    return _get_group_map()[0]


176 177 178 179 180 181
def _get_group_map_by_name():
    global _group_map_by_name
    return _group_map_by_name


def _get_default_group():
L
lilong12 已提交
182
    global _group_map_by_name
183 184
    assert is_initialized(), ("Call paddle.distributed.init_parallel_env first "
                              "to initialize the distributed environment.")
185 186 187
    return _get_group_map_by_name()[_default_group_name]


L
lilong12 已提交
188 189 190 191 192 193 194 195 196 197 198 199
def _set_group_map(gid, group):
    global _group_map
    assert gid not in _group_map
    _group_map[gid] = group


def _set_group_map_by_name(name, group):
    global _group_map_by_name
    assert name not in _group_map_by_name
    _group_map_by_name[name] = group


200 201 202 203 204 205
def _set_group_map_backend(group, backend):
    global _group_map_backend
    assert group not in _group_map_backend
    _group_map_backend[group] = backend


K
kuizhiqing 已提交
206 207 208 209
def _new_ring_id():
    return len(_get_group_map()) + max(_get_global_env().nrings, 9)


210 211 212 213 214 215 216 217 218 219 220 221 222
def _get_reduce_op(reduce_op, func_name):
    if reduce_op == ReduceOp.SUM:
        return core.ReduceOp.SUM
    elif reduce_op == ReduceOp.MAX:
        return core.ReduceOp.MAX
    elif reduce_op == ReduceOp.MIN:
        return core.ReduceOp.MIN
    elif reduce_op == ReduceOp.PROD:
        return core.ReduceOp.PRODUCT
    else:
        raise ValueError("Unknown reduce_op type for {}.".format(func_name))


K
kuizhiqing 已提交
223 224 225 226 227 228
def get_group(id=0):
    """

    Get group instance by group id.

    Args:
K
kuizhiqing 已提交
229
        id (int): the group id. Default value is 0.
K
kuizhiqing 已提交
230 231 232 233 234 235 236 237 238 239 240 241 242 243

    Returns:
        Group: the group instance.

    Examples:
        .. code-block:: python

            ...
            gid = paddle.distributed.new_group([2,4,6])
            paddle.distributed.get_group(gid.id)

    """

    gm = _get_group_map()
J
Jiangxinz 已提交
244
    return gm[id] if id in gm else None
K
kuizhiqing 已提交
245 246


247 248 249 250 251 252
def _new_process_group_impl(backend,
                            store,
                            rank,
                            world_size,
                            group_name,
                            pg_options,
L
lilong12 已提交
253 254 255
                            group_id=0,
                            src_rank=None,
                            dst_rank=None):
256
    pg = None
257
    genv = _get_global_env()
L
lilong12 已提交
258 259 260 261
    if backend != 'heter':
        assert src_rank is None and dst_rank is None, (
            "src_rank and dst_rank "
            "can only be set for heter backend.")
L
lilong12 已提交
262
    assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
263
    if backend == "gloo":
264 265
        place = core.CPUPlace()
        pg = core.ProcessGroupGloo(store, rank, world_size, place, group_id)
266
    elif backend == "nccl":
267 268
        place = core.CUDAPlace(genv.device_id)
        pg = core.ProcessGroupNCCL(store, rank, world_size, place, group_id)
269
    elif backend == "hccl":
270 271
        place = core.NPUPlace(genv.device_id)
        pg = core.ProcessGroupHCCL(store, rank, world_size, place, group_id)
272
    elif backend == "heter":
273 274 275 276 277
        place = None
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(genv.device_id)
        elif core.is_compiled_with_npu():
            place = core.NPUPlace(genv.device_id)
278 279 280 281 282 283 284 285 286 287 288 289 290
        cluster_id = int(os.getenv("CLUSTER_ID", "-1"))
        assert cluster_id >= 0, "please set the CLUSTER_ID variable."
        cluster_size = os.getenv("CLUSTER_SIZE", None)
        assert cluster_size, "please set the CLUSTER_SIZE variable."
        cluster_size = cluster_size.split(",")
        cluster_size = [int(s) for s in cluster_size]
        switch_ep = os.getenv("CLUSTER_SWITCH", None)
        assert switch_ep, "please set the CLUSTER_SWITCH variable."
        cluster_size_cumsum = np.cumsum(cluster_size)
        cluster_offset = 0 if cluster_id == 0 else cluster_size_cumsum[
            cluster_id - 1]
        global_rank = cluster_offset + rank
        global_world_size = cluster_size_cumsum[-1]
291 292 293 294 295 296 297 298 299 300 301 302 303
        pg = core.ProcessGroupHeter(store,
                                    rank=global_rank,
                                    world_size=global_world_size,
                                    place=place,
                                    gid=group_id,
                                    local_rank=rank,
                                    local_size=world_size,
                                    gloo_rank=cluster_id,
                                    gloo_size=len(cluster_size),
                                    with_switch=True,
                                    switch_endpoint=switch_ep,
                                    src_rank=src_rank,
                                    dst_rank=dst_rank)
304 305 306 307

    return pg


S
ShenLiang 已提交
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
def barrier(group=None):
    """

    Barrier among all participators in the group.

    Args:
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            paddle.distributed.barrier()
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
332
    if in_dygraph_mode():
333 334 335 336 337
        group = _get_default_group() if group is None else group
        task = group.process_group.barrier()
        task.wait()
        return

S
ShenLiang 已提交
338 339 340
    ring_id = 0 if group is None else group.id

    temp = fill_constant([1], dtype="int32", value="1")
J
Jiabin Yang 已提交
341
    if _non_static_mode():
W
wanghuancoder 已提交
342
        return _C_ops.barrier(temp, temp, 'ring_id', ring_id)
W
wanghuancoder 已提交
343 344 345

    op_type = 'barrier'

S
ShenLiang 已提交
346 347 348
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'group' for barrier must be int.")
    helper = LayerHelper(op_type, **locals())
349 350 351 352
    helper.append_op(type=op_type,
                     inputs={'X': [temp]},
                     outputs={'Out': [temp]},
                     attrs={'ring_id': ring_id})
S
ShenLiang 已提交
353 354


L
lilong12 已提交
355 356 357 358 359 360 361
# _custom_gid provides a way for users to
# set the group id, which is usually useful
# to be compatible with the static mode.
_custom_gid = None


def _set_custom_gid(gid):
362
    global _custom_gid
L
lilong12 已提交
363 364 365
    _custom_gid = gid


K
kuizhiqing 已提交
366 367 368
def new_group(ranks=None, backend=None):
    """

K
kuizhiqing 已提交
369
    Creates a new distributed communication group.
K
kuizhiqing 已提交
370 371

    Args:
K
kuizhiqing 已提交
372
        ranks (list): The global ranks of group members.
K
kuizhiqing 已提交
373 374 375
        backend (str): The backend used to create group, only nccl is supported now.

    Returns:
K
kuizhiqing 已提交
376
        Group: The group instance.
K
kuizhiqing 已提交
377 378 379 380 381 382 383

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
384 385 386
            tindata = paddle.randn(shape=[2, 3])
            gp = paddle.distributed.new_group([2,4,6])
            paddle.distributed.all_reduce(tindata, group=gp, use_calc_stream=False)
K
kuizhiqing 已提交
387 388

    """
389
    global _custom_gid
390
    global _group_map
L
lilong12 已提交
391
    if in_dygraph_mode():
392
        global _default_group_name
L
lilong12 已提交
393
        gid = _custom_gid if _custom_gid else _new_ring_id()
394
        group_name = _default_group_name + str(gid)
L
lilong12 已提交
395
        if backend != 'heter' and (ranks is None or len(ranks) > 1):
396 397 398 399 400 401 402 403 404
            global_group = _get_default_group()
            global_rank = global_group.rank
            global_ranks = global_group.ranks
            backend = _default_backend if backend is None else backend
            if ranks is None:
                ranks = global_ranks
            assert len(ranks) <= len(global_ranks), (
                "Size of new group must be less than or "
                "equal to that of the default global group.")
405 406
        size = len(ranks)
        ranks = sorted(ranks)
L
lilong12 已提交
407 408 409 410
        if backend == 'heter' or (size > 1 and global_rank in ranks):
            rank = 0 if backend == 'heter' else ranks.index(global_rank)
            src_rank = ranks[0] if backend == 'heter' else None
            dst_rank = ranks[1] if backend == 'heter' else None
411 412 413 414 415 416 417 418 419
            pg = _new_process_group_impl(backend,
                                         _default_store,
                                         rank,
                                         size,
                                         group_name,
                                         pg_options=None,
                                         group_id=gid,
                                         src_rank=src_rank,
                                         dst_rank=dst_rank)
420 421 422 423 424 425
        else:
            rank = -1
            pg = None
        group = Group(rank, size, id=gid, ranks=ranks, pg=pg, name=group_name)
        _group_map_by_name[group_name] = group
        _group_map[gid] = group
426
        _group_map_backend[group] = backend
427

428
        # TODO(shenliang03): This is a temporary solution to solve the problem of
429
        # hang caused by tcp
430
        paddle.distributed.barrier(group=group)
431
        return group
K
kuizhiqing 已提交
432 433 434 435 436 437 438 439 440 441 442 443 444 445

    if not backend:
        backend = 'nccl'
    assert backend == 'nccl', ("backend other than nccl is not supported yet")

    genv = _get_global_env()
    global_rank = genv.rank

    ring_id = _new_ring_id()

    if global_rank not in ranks:
        gp = Group(-1, -1, ring_id, ranks)
        _group_map[ring_id] = gp
    else:
446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465
        ranks = sorted(ranks)
        group_rank = ranks.index(global_rank)
        group_size = len(ranks)
        gp = Group(group_rank, group_size, ring_id, ranks)
        _group_map[ring_id] = gp

        if group_size >= 2:
            strategy = core.ParallelStrategy()
            strategy.nranks = group_size
            strategy.local_rank = group_rank
            strategy.trainer_endpoints = [
                genv.trainer_endpoints[i] for i in ranks
            ]
            strategy.current_endpoint = genv.current_endpoint
            strategy.nrings = 1

            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(genv.device_id)
                core.NCCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
466 467 468 469
            elif core.is_compiled_with_npu():
                place = core.NPUPlace(genv.device_id)
                core.HCCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
470 471 472 473
            elif core.is_compiled_with_mlu():
                place = core.MLUPlace(genv.device_id)
                core.CNCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
474 475 476 477 478
            else:
                assert False, ("no cuda device found")
        else:
            return gp

479
    # TODO(shenliang03): This is a temporary solution to solve the problem of
480
    # hang caused by cross-creation of new_group
481
    tmp = paddle.to_tensor(
J
Jiabin Yang 已提交
482
        [1], dtype="int32") if _non_static_mode() else fill_constant(
483
            [0], dtype="int32", value="1")
484 485
    paddle.distributed.all_reduce(tmp, use_calc_stream=True)
    paddle.distributed.wait(tmp)
K
kuizhiqing 已提交
486 487
    return gp

488

489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557
def is_initialized():
    """

    Check whether the distributed environment has been initialized

    Returns (bool): `True` if distributed environment has been initialized, otherwise `False`.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle

            print(paddle.distributed.is_initialized())
            # False

            paddle.distributed.init_parallel_env()
            print(paddle.distributed.is_initialized())
            # True

    """
    global _group_map_by_name
    return _default_group_name in _group_map_by_name


def destroy_process_group(group=None):
    """
    Destroy a given group for communication

    Args:
        group (ProcessGroup, optional): The group to be destroyed. All of process groups, including 
                                        the default group, will be destroyed and the distributed 
                                        environment will be deinitialized.
    
    Returns : None

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle

            paddle.distributed.init_parallel_env()
            group = paddle.distributed.new_group([0, 1])

            paddle.distributed.destroy_process_group(group)
            print(paddle.distributed.is_initialized())
            # True
            paddle.distributed.destroy_process_group()
            print(paddle.distributed.is_initialized())
            # False

    """
    global _group_map
    global _group_map_by_name

    pg = _get_default_group() if group is None else group
    assert _group_map.get(pg.id, None) is not None, "Invalid group."

    if group is None:
        _group_map.clear()
        _group_map_by_name.clear()
        _group_map_backend.clear()
    else:
        del _group_map[pg.id]
        del _group_map_by_name[pg.name]
        del _group_map_backend[pg]


K
kuizhiqing 已提交
558 559 560 561 562 563 564 565
def wait(tensor, group=None, use_calc_stream=True):
    """

    wait to sync stream for group.

    Args:
        tensor (Tensor): The Tensor used before sync.
        group (Group): The Group instance to perform sync.
K
kuizhiqing 已提交
566 567
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
K
kuizhiqing 已提交
568 569 570 571 572 573 574 575 576 577

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
578
            tindata = paddle.randn(shape=[2, 3])
K
kuizhiqing 已提交
579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596
            paddle.distributed.all_reduce(tindata, use_calc_stream=True)
            paddle.distributed.wait(tindata)

    """

    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id

    if use_calc_stream:
        _sync_calc_stream(tensor)
    else:
        _sync_comm_stream(tensor, ring_id)


def _sync_calc_stream(tensor):

J
Jiabin Yang 已提交
597
    if _non_static_mode():
W
wanghuancoder 已提交
598
        return _C_ops.c_sync_calc_stream(tensor, tensor)
K
kuizhiqing 已提交
599 600 601 602 603 604 605

    op_type = 'c_sync_calc_stream'

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
606 607
        outputs={'Out': [tensor]},
    )
608

609

K
kuizhiqing 已提交
610
def _sync_comm_stream(tensor, ring_id=0):
611

J
Jiabin Yang 已提交
612
    if _non_static_mode():
W
wanghuancoder 已提交
613
        return _C_ops.c_sync_comm_stream([tensor], [tensor], 'ring_id', ring_id)
614

K
kuizhiqing 已提交
615
    op_type = 'c_sync_comm_stream'
616

K
kuizhiqing 已提交
617 618 619 620 621
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
622 623
        attrs={'ring_id': ring_id},
    )
K
kuizhiqing 已提交
624 625 626


def broadcast(tensor, src, group=None, use_calc_stream=True):
627 628 629
    """

    Broadcast a tensor from the source to all others.
630 631 632 633 634 635 636
    As shown below, 4 GPUs each start 4 processes and GPU0 owns data 0. Through broadcast operator,
    the data 0 will be sent to all GPUs from GPU0.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/broadcast.png
        :width: 800
        :alt: broadcast
        :align: center
637 638 639 640 641

    Args:
        tensor (Tensor): The Tensor to send if current rank is the source, or the tensor to receive otherwise. Its data type
            should be float16, float32, float64, int32 or int64.
        src (int): The source rank.
K
kuizhiqing 已提交
642
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
643 644
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
645 646 647 648 649 650 651

    Returns:
        None.

    Examples:
        .. code-block:: python

652
            # required: distributed
653 654 655 656 657 658 659 660 661 662 663 664 665 666
            import numpy as np
            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data = np.array([[4, 5, 6], [4, 5, 6]])
            else:
                np_data = np.array([[1, 2, 3], [1, 2, 3]])
            data = paddle.to_tensor(np_data)
            paddle.distributed.broadcast(data, 1)
            out = data.numpy()
            # [[1, 2, 3], [1, 2, 3]]
667
    """
K
kuizhiqing 已提交
668 669 670 671 672 673 674

    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

L
lilong12 已提交
675
    if in_dygraph_mode():
676 677 678 679 680 681 682 683 684 685 686
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
        assert gsrc >= 0, ("src rank out of group, need global rank")
        task = group.process_group.broadcast(tensor, gsrc)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

    ring_id = ring_id = 0 if group is None else group.id
K
kuizhiqing 已提交
687
    gsrc = src if group is None else group.get_group_rank(src)
K
kuizhiqing 已提交
688
    assert gsrc >= 0, ("src rank out of group, need global rank")
K
kuizhiqing 已提交
689

J
Jiabin Yang 已提交
690
    if _non_static_mode():
W
wanghuancoder 已提交
691 692 693
        return _C_ops.c_broadcast(tensor, tensor, 'root', gsrc,
                                  'use_calc_stream', use_calc_stream, 'ring_id',
                                  ring_id)
694 695 696 697 698 699 700

    op_type = 'c_broadcast'
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'broadcast')

    helper = LayerHelper(op_type, **locals())
701 702 703 704 705 706 707 708
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'root': gsrc,
                         'use_calc_stream': use_calc_stream,
                         'ring_id': ring_id,
                     })
709 710


K
kuizhiqing 已提交
711
def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
712 713 714
    """

    Reduce a tensor over all ranks so that all get the result.
715 716 717 718 719 720 721 722
    As shown below, 4 GPUs each start 4 processes and the data on each GPU is represnted
    by the GPU number. The reduce operator is sum. Through all_reduce operator, 
    each GPU will have the sum of the data from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allreduce.png
        :width: 800
        :alt: all_reduce
        :align: center
723 724 725 726

    Args:
        tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
K
kuizhiqing 已提交
727
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
K
kuizhiqing 已提交
728
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
729 730
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
731 732 733 734 735 736 737

    Returns:
        None.

    Examples:
        .. code-block:: python

738
            # required: distributed
739 740 741 742 743 744 745 746 747 748 749 750 751 752 753
            import numpy as np
            import paddle
            from paddle.distributed import ReduceOp
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data = np.array([[4, 5, 6], [4, 5, 6]])
            else:
                np_data = np.array([[1, 2, 3], [1, 2, 3]])
            data = paddle.to_tensor(np_data)
            paddle.distributed.all_reduce(data)
            out = data.numpy()
            # [[5, 7, 9], [5, 7, 9]]
754
    """
K
kuizhiqing 已提交
755 756 757
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
758
    if in_dygraph_mode():
759
        op_type = _get_reduce_op(op, "all_reduce")
760 761 762 763 764 765 766 767
        group = _get_default_group() if group is None else group
        task = group.process_group.allreduce(tensor, op_type)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

K
kuizhiqing 已提交
768
    ring_id = 0 if group is None else group.id
J
Jiabin Yang 已提交
769
    if _non_static_mode():
770
        if op == ReduceOp.SUM:
W
wanghuancoder 已提交
771 772
            return _C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
                                           use_calc_stream, 'ring_id', ring_id)
773
        elif op == ReduceOp.MAX:
W
wanghuancoder 已提交
774 775
            return _C_ops.c_allreduce_max_(tensor, 'use_calc_stream',
                                           use_calc_stream, 'ring_id', ring_id)
776
        elif op == ReduceOp.MIN:
W
wanghuancoder 已提交
777 778
            return _C_ops.c_allreduce_min_(tensor, 'use_calc_stream',
                                           use_calc_stream, 'ring_id', ring_id)
779
        elif op == ReduceOp.PROD:
W
wanghuancoder 已提交
780 781
            return _C_ops.c_allreduce_prod_(tensor, 'use_calc_stream',
                                            use_calc_stream, 'ring_id', ring_id)
782 783 784 785 786 787 788 789 790 791 792 793 794 795
        else:
            raise ValueError("Unknown parameter: {}.".format(op))

    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'all_reduce')
    if op == ReduceOp.SUM:
        op_type = 'c_allreduce_sum'
    elif op == ReduceOp.MAX:
        op_type = 'c_allreduce_max'
    elif op == ReduceOp.MIN:
        op_type = 'c_allreduce_min'
    elif op == ReduceOp.PROD:
        op_type = 'c_allreduce_prod'
K
kuizhiqing 已提交
796 797
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'ring_id' for all_reduce should be int.")
798
    helper = LayerHelper(op_type, **locals())
799 800 801 802 803 804 805
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': use_calc_stream
                     })
806 807


K
kuizhiqing 已提交
808
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
809 810
    """

811 812 813 814 815 816 817 818
    Reduce a tensor to the destination from all others. As shown below, 4 GPUs each start 4 processes and the data on each GPU is respresnted
    by the GPU number. The destination of the reduce operator is GPU0 and the process is sum. Through reduce operator,
    the GPU0 will owns the sum of all data from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/reduce.png
        :width: 800
        :alt: reduce
        :align: center
819 820 821 822 823

    Args:
        tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
            should be float16, float32, float64, int32 or int64.
        dst (int): The destination rank id.
K
kuizhiqing 已提交
824
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
K
kuizhiqing 已提交
825
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
826 827
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
828 829 830 831 832 833 834

    Returns:
        None.

    Examples:
        .. code-block:: python

835
            # required: distributed
836 837 838 839 840 841 842 843 844 845 846 847 848 849
            import numpy as np
            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data = np.array([[4, 5, 6], [4, 5, 6]])
            else:
                np_data = np.array([[1, 2, 3], [1, 2, 3]])
            data = paddle.to_tensor(np_data)
            paddle.distributed.reduce(data, 0)
            out = data.numpy()
            # [[5, 7, 9], [5, 7, 9]]
850
    """
K
kuizhiqing 已提交
851 852 853
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
854
    if in_dygraph_mode():
855
        op_type = _get_reduce_op(op, "reduce")
856 857 858 859 860 861 862 863 864
        group = _get_default_group() if group is None else group
        gdst = group.get_group_rank(dst)
        assert gdst >= 0, ("dst rank out of group, need global rank")
        task = group.process_group.reduce(tensor, gdst, op_type)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task
K
kuizhiqing 已提交
865 866 867

    ring_id = 0 if group is None else group.id
    gdst = dst if group is None else group.get_group_rank(dst)
K
kuizhiqing 已提交
868
    assert gdst >= 0, ("dst rank out of group, need global rank")
K
kuizhiqing 已提交
869

J
Jiabin Yang 已提交
870
    if _non_static_mode():
871
        if op == ReduceOp.SUM:
W
wanghuancoder 已提交
872 873 874
            return _C_ops.c_reduce_sum(tensor, tensor, 'use_calc_stream',
                                       use_calc_stream, 'ring_id', ring_id,
                                       'root_id', gdst)
875
        elif op == ReduceOp.MAX:
W
wanghuancoder 已提交
876 877 878
            return _C_ops.c_reduce_max(tensor, tensor, 'use_calc_stream',
                                       use_calc_stream, 'ring_id', ring_id,
                                       'root_id', gdst)
879
        elif op == ReduceOp.MIN:
W
wanghuancoder 已提交
880 881 882
            return _C_ops.c_reduce_min(tensor, tensor, 'use_calc_stream',
                                       use_calc_stream, 'ring_id', ring_id,
                                       'root_id', gdst)
883
        elif op == ReduceOp.PROD:
W
wanghuancoder 已提交
884 885 886
            return _C_ops.c_reduce_prod(tensor, tensor, 'use_calc_stream',
                                        use_calc_stream, 'ring_id', ring_id,
                                        'root_id', gdst)
887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904
        else:
            raise ValueError("Unknown parameter: {}.".format(op))

    op_type = 'c_reduce'
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'all_reduce')

    if op == ReduceOp.SUM:
        op_type = 'c_reduce_sum'
    elif op == ReduceOp.MAX:
        op_type = 'c_reduce_max'
    elif op == ReduceOp.MIN:
        op_type = 'c_reduce_min'
    elif op == ReduceOp.PROD:
        op_type = 'c_reduce_prod'

    helper = LayerHelper(op_type, **locals())
905 906 907 908 909 910 911 912
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': use_calc_stream,
                         'root_id': gdst,
                     })
913 914


K
kuizhiqing 已提交
915
def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
916 917
    """

918 919 920 921 922 923 924 925 926
    Gather tensors from all participators and all get the result. As shown
    below, 4 GPUs each start 4 processes and the data on each GPU is represnted
    by the GPU number. Through the all_gather operator, each GPU will have data
    from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png
        :width: 800
        :alt: all_gather
        :align: center
927 928 929 930 931 932

    Args:
        tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
            should be float16, float32, float64, int32 or int64.
        tensor (Tensor): The Tensor to send. Its data type
            should be float16, float32, float64, int32 or int64.
K
kuizhiqing 已提交
933
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
934 935
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
936 937 938 939 940 941 942

    Returns:
        None.

    Examples:
        .. code-block:: python

943
            # required: distributed
944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962
            import numpy as np
            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            tensor_list = []
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data1 = np.array([[4, 5, 6], [4, 5, 6]])
                np_data2 = np.array([[4, 5, 6], [4, 5, 6]])
                data1 = paddle.to_tensor(np_data1)
                data2 = paddle.to_tensor(np_data2)
                paddle.distributed.all_gather(tensor_list, data1)
            else:
                np_data1 = np.array([[1, 2, 3], [1, 2, 3]])
                np_data2 = np.array([[1, 2, 3], [1, 2, 3]])
                data1 = paddle.to_tensor(np_data1)
                data2 = paddle.to_tensor(np_data2)
                paddle.distributed.all_gather(tensor_list, data2)
963
    """
K
kuizhiqing 已提交
964 965 966
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
967
    if in_dygraph_mode():
968
        group = _get_default_group() if group is None else group
969 970 971 972 973 974
        if len(tensor_list) == 0:
            tensor_shape = list(tensor.shape)
            tensor_shape[0] *= group.nranks
            out = paddle.empty(tensor_shape, tensor.dtype)
        else:
            out = paddle.concat(tensor_list, axis=0)
975 976 977 978 979 980
        task = group.process_group.all_gather(tensor, out)
        task.wait()
        tensor_list.clear()
        tensor_list.extend(paddle.split(out, group.nranks, 0))
        return

K
kuizhiqing 已提交
981 982 983
    ring_id = 0 if group is None else group.id
    nranks = _get_global_group().nranks if group is None else group.nranks

J
Jiabin Yang 已提交
984
    if _non_static_mode():
985 986
        out = _C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream,
                                 'ring_id', ring_id, 'nranks', nranks)
987
    else:
988 989 990
        op_type = 'c_allgather'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
991 992 993 994 995 996 997 998 999 1000 1001
        if not isinstance(tensor_list, list):
            raise ValueError("The type of 'tensor_list' for all_gather "
                             "should be list.")
        for elem in tensor_list:
            check_variable_and_dtype(
                elem, 'tensor_list',
                ['float16', 'float32', 'float64', 'int32', 'int64'],
                'all_gather')
        check_variable_and_dtype(
            tensor, 'tensor',
            ['float16', 'float32', 'float64', 'int32', 'int64'], 'all_gather')
1002 1003 1004 1005 1006 1007 1008 1009
        helper.append_op(type=op_type,
                         inputs={'X': [tensor]},
                         outputs={'Out': [out]},
                         attrs={
                             'ring_id': ring_id,
                             'use_calc_stream': use_calc_stream,
                             'nranks': nranks
                         })
1010

K
kuizhiqing 已提交
1011
    tensor_list.extend(paddle.split(out, nranks, 0))
1012 1013


K
kuizhiqing 已提交
1014
def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
1015 1016
    """

1017 1018 1019 1020 1021 1022 1023
    Scatter a tensor to all participators. As shown below, 4 GPUs each start 4 processes and the source of the scatter
    is GPU0. Through scatter operator, the data in GPU0 will be sent to all GPUs averagely.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/scatter.png
        :width: 800
        :alt: scatter
        :align: center
1024 1025 1026 1027

    Args:
        tensor (Tensor): The output Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
1028
        tensor_list (list|tuple): A list/tuple of Tensors to scatter. Every element in the list must be a Tensor whose data type
K
kuizhiqing 已提交
1029 1030
            should be float16, float32, float64, int32 or int64. Default value is None.
        src (int): The source rank id. Default value is 0.
K
kuizhiqing 已提交
1031
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
1032 1033
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
1034 1035 1036 1037 1038 1039 1040

    Returns:
        None.

    Examples:
        .. code-block:: python

1041
            # required: distributed
1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060
            import numpy as np
            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data1 = np.array([7, 8, 9])
                np_data2 = np.array([10, 11, 12])
            else:
                np_data1 = np.array([1, 2, 3])
                np_data2 = np.array([4, 5, 6])
            data1 = paddle.to_tensor(np_data1)
            data2 = paddle.to_tensor(np_data2)
            if paddle.distributed.ParallelEnv().local_rank == 0:
                paddle.distributed.scatter(data1, src=1)
            else:
                paddle.distributed.scatter(data1, tensor_list=[data1, data2], src=1)
            out = data1.numpy()
1061
    """
K
kuizhiqing 已提交
1062 1063 1064 1065 1066 1067
    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

L
lilong12 已提交
1068
    if in_dygraph_mode():
1069 1070 1071 1072 1073 1074 1075 1076 1077
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
        rank = group.rank
        nranks = group.nranks
    else:
        ring_id = 0 if group is None else group.id
        gsrc = src if group is None else group.get_group_rank(src)
        rank = _get_global_group().rank if group is None else group.rank
        nranks = _get_global_group().nranks if group is None else group.nranks
K
kuizhiqing 已提交
1078
    assert gsrc >= 0, ("src rank out of group, need global rank")
K
kuizhiqing 已提交
1079 1080

    if rank != gsrc:
1081 1082 1083 1084
        tensor_list = []
        for _ in range(nranks):
            tensor_list.append(tensor)
    temp = paddle.concat(tensor_list, axis=0)
L
lilong12 已提交
1085
    if in_dygraph_mode():
1086 1087 1088 1089 1090 1091 1092
        task = group.process_group.scatter(temp, tensor, gsrc)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

L
lilong12 已提交
1093
    if _non_static_mode():
W
wanghuancoder 已提交
1094 1095 1096
        return _C_ops.c_scatter(temp, tensor, 'use_calc_stream',
                                use_calc_stream, 'ring_id', ring_id, 'nranks',
                                nranks, 'root', gsrc)
W
wanghuancoder 已提交
1097
    op_type = 'c_scatter'
1098 1099 1100 1101
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'scatter')
    helper = LayerHelper(op_type, **locals())
1102 1103 1104 1105 1106 1107 1108 1109 1110
    helper.append_op(type=op_type,
                     inputs={'X': [temp]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'root': gsrc,
                         'use_calc_stream': use_calc_stream,
                         'nranks': nranks,
                     })
1111 1112


1113
def _c_identity(tensor, group=None):
L
lilong12 已提交
1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124
    """
    Return a copy of the tensor, mainly used with model parallel.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
        group (int): The id of the process group to work on.

    Returns:
        Tensor.
    """
1125 1126 1127 1128
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
1129
    if _non_static_mode():
W
wanghuancoder 已提交
1130 1131
        return _C_ops.c_identity(tensor, 'use_calc_stream', True, 'ring_id',
                                 ring_id, 'use_model_parallel', True)
L
lilong12 已提交
1132 1133 1134
    op_type = 'c_identity'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
1135

L
lilong12 已提交
1136 1137 1138
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        '_c_identity')
1139

1140 1141 1142 1143 1144 1145 1146 1147
    helper.append_op(type=op_type,
                     inputs={'X': tensor},
                     outputs={'Out': out},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': True,
                         'use_model_parallel': True,
                     })
L
lilong12 已提交
1148 1149 1150
    return out


1151
def _c_concat(tensor, group=None):
1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164
    """
    Return allgather of the tensor, mainly used with model parallel.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
        group (int): The id of the process group to work on.

    Returns:
        Tensor.
    """
    if group is not None and not group.is_member():
        return
1165 1166
    group = _get_default_group() if group is None else group
    ring_id = group.id
1167

1168
    global_rank = _get_global_env().rank
1169 1170
    rank = group.rank
    nranks = group.nranks
1171

J
Jiabin Yang 已提交
1172
    if _non_static_mode():
W
wanghuancoder 已提交
1173 1174 1175
        return _C_ops.c_concat(tensor, 'ring_id', ring_id, 'use_calc_stream',
                               True, 'rank', rank, 'nranks', nranks,
                               'use_model_parallel', True)
1176 1177 1178 1179 1180 1181 1182 1183 1184

    op_type = 'c_concat'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)

    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        '_c_concat')

1185 1186 1187 1188 1189 1190 1191 1192 1193 1194
    helper.append_op(type=op_type,
                     inputs={'X': tensor},
                     outputs={'Out': out},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': True,
                         'use_model_parallel': True,
                         'nranks': nranks,
                         'rank': rank
                     })
1195 1196 1197
    return out


1198
def _c_split(tensor, group=None):
L
lilong12 已提交
1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210
    """
    Split tensor evenly among all members, mainly used with model parallel.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
        rank (int): The rank of the current process.
        group (int): The id of the process group to work on.

    Returns:
        Tensor.
    """
1211 1212 1213 1214
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

1215 1216 1217 1218
    global_rank = _get_global_env().rank
    rank = global_rank if group is None else group.get_group_rank(global_rank)
    nranks = _get_global_env().world_size if group is None else group.nranks

J
Jiabin Yang 已提交
1219
    if _non_static_mode():
W
wanghuancoder 已提交
1220 1221 1222
        return _C_ops.c_split(tensor, 'use_calc_stream', True, 'ring_id',
                              ring_id, 'rank', rank, 'nranks', nranks,
                              'use_model_parallel', True)
1223

L
lilong12 已提交
1224 1225 1226
    op_type = 'c_split'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
1227

L
lilong12 已提交
1228 1229 1230
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        '_c_split')
1231

1232 1233 1234 1235 1236 1237 1238 1239 1240 1241
    helper.append_op(type=op_type,
                     inputs={'X': tensor},
                     outputs={'Out': out},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': True,
                         'rank': rank,
                         'nranks': nranks,
                         'use_model_parallel': True,
                     })
L
lilong12 已提交
1242 1243 1244
    return out


1245 1246 1247 1248 1249
def _mp_allreduce(tensor,
                  op=ReduceOp.SUM,
                  group=None,
                  use_calc_stream=True,
                  use_model_parallel=True):
1250
    """[it is same as allreduce above, but it supports model parallel. And it support inplace startegy]
1251 1252 1253 1254 1255
    """
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

1256 1257 1258
    if in_dygraph_mode():
        assert op == ReduceOp.SUM, "Unknown parameter: {}.".format(op)

1259
        from paddle.autograd import PyLayer
1260

1261
        class mp_allreduce_eager(PyLayer):
1262

1263 1264 1265 1266
            @staticmethod
            def forward(ctx, tensor, use_calc_stream, ring_id,
                        use_model_parallel):
                ctx.ring_id = ring_id
1267 1268 1269 1270
                return _C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
                                               use_calc_stream, 'ring_id',
                                               ring_id, "use_model_parallel",
                                               use_model_parallel)
1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281

            @staticmethod
            def backward(ctx, dy):
                return _C_ops.c_identity(dy, 'use_calc_stream', True, 'ring_id',
                                         ctx.ring_id, 'use_model_parallel',
                                         True)

        return mp_allreduce_eager.apply(tensor, use_calc_stream, ring_id,
                                        use_model_parallel)

    elif _in_legacy_dygraph():
1282
        if op == ReduceOp.SUM:
1283 1284 1285 1286
            return _C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
                                           use_calc_stream, 'ring_id', ring_id,
                                           "use_model_parallel",
                                           use_model_parallel)
1287 1288
        else:
            raise ValueError("Unknown parameter: {}.".format(op))
1289 1290 1291 1292 1293 1294 1295 1296 1297

    op_type = 'c_allreduce_sum'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)

    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        op_type)

1298 1299 1300 1301 1302 1303 1304 1305
    helper.append_op(type=op_type,
                     inputs={'X': tensor},
                     outputs={'Out': out},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': use_calc_stream,
                         'use_model_parallel': use_model_parallel,
                     })
1306
    return out
1307 1308


1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322
def _c_lookup_table(table, index, start_index=0, name=None):
    """
    Lookup table according to index.

    Args:
        table (Tensor): The input Tensor. Its data type
            should be float16, float32, float64.
        index (Tensor): The index to lookup table.
        start_index (int): The initial index for table range.
        name (string): The name of the api

    Returns:
        Tensor.
    """
J
Jiabin Yang 已提交
1323
    if _non_static_mode():
W
wanghuancoder 已提交
1324
        return _C_ops.c_embedding(table, index, "start_index", start_index)
1325

1326 1327 1328 1329 1330
    op_type = 'c_embedding'
    helper = LayerHelper(op_type, **locals())
    dtype = helper.input_dtype(input_param_name='table')
    check_variable_and_dtype(index, 'input', ['int32', 'int64'], op_type)
    tmp = helper.create_variable_for_type_inference(dtype)
1331 1332 1333 1334 1335 1336 1337
    helper.append_op(type='c_embedding',
                     inputs={
                         'Ids': index,
                         'W': table
                     },
                     outputs={'Out': tmp},
                     attrs={"start_index": start_index})
1338 1339
    return tmp

1340

B
Baibaifan 已提交
1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355
class _Linear(layers.Layer):
    """
    Linear
    """

    def __init__(self,
                 in_features,
                 out_features,
                 weight_attr=None,
                 bias_attr=None,
                 name=None):
        super(_Linear, self).__init__()
        self._dtype = self._helper.get_default_dtype()
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
1356 1357 1358 1359 1360 1361 1362 1363
        self.weight = self.create_parameter(shape=[in_features, out_features],
                                            attr=self._weight_attr,
                                            dtype=self._dtype,
                                            is_bias=False)
        self.bias = self.create_parameter(shape=[out_features],
                                          attr=self._bias_attr,
                                          dtype=self._dtype,
                                          is_bias=True)
B
Baibaifan 已提交
1364 1365 1366
        self.name = name

    def forward(self, input):
1367 1368 1369 1370
        out = _linear(x=input,
                      weight=self.weight,
                      bias=self.bias,
                      name=self.name)
B
Baibaifan 已提交
1371 1372 1373 1374 1375 1376 1377 1378
        return out

    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'in_features={}, out_features={}, dtype={}{}'.format(
            self.weight.shape[0], self.weight.shape[1], self._dtype, name_str)


1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398
def _c_softmax_with_cross_entropy(logits,
                                  label,
                                  group=None,
                                  return_softmax=False):
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id
    global_rank = _get_global_env().rank
    rank = global_rank if group is None else group.get_group_rank(global_rank)
    nranks = _get_global_env().world_size if group is None else group.nranks

    input_dims = len(list(logits.shape))
    label_dims = len(list(label.shape))
    if input_dims - 1 != label_dims and input_dims != label_dims:
        raise ValueError(
            'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
             (got nput_dims{}, label_dims{})'.format(input_dims, label_dims))
    if input_dims - 1 == label_dims:
        label = paddle.unsqueeze(label, axis=-1)

J
Jiabin Yang 已提交
1399
    if _non_static_mode():
W
wanghuancoder 已提交
1400
        softmax, loss = _C_ops.c_softmax_with_cross_entropy(
1401 1402 1403 1404 1405 1406
            logits, label, 'ring_id', ring_id, 'rank', rank, 'nranks', nranks)
        if not return_softmax:
            return loss
        else:
            return loss, softmax

W
WangXi 已提交
1407 1408 1409 1410 1411 1412 1413 1414
    attrs = {
        'ring_id': ring_id,
        'rank': rank,
        'nranks': nranks,
    }
    helper = LayerHelper('c_softmax_with_cross_entropy', **locals())
    softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
    loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
1415 1416 1417 1418 1419 1420 1421 1422 1423 1424
    helper.append_op(type='c_softmax_with_cross_entropy',
                     inputs={
                         'Logits': logits,
                         'Label': label
                     },
                     outputs={
                         'Softmax': softmax,
                         'Loss': loss
                     },
                     attrs=attrs)
W
WangXi 已提交
1425 1426 1427 1428 1429 1430

    if return_softmax:
        return loss, softmax

    return loss

1431

B
Baibaifan 已提交
1432 1433 1434 1435
def _linear(x, weight, bias=None, name=None):
    """
    Fuction Linear
    """
J
Jiabin Yang 已提交
1436
    if _non_static_mode():
B
Baibaifan 已提交
1437
        pre_bias = _varbase_creator(dtype=x.dtype)
W
wanghuancoder 已提交
1438 1439
        _C_ops.matmul(x, weight, pre_bias, 'transpose_X', False, 'transpose_Y',
                      False, "alpha", 1)
1440 1441 1442
        return dygraph_utils._append_bias_in_dygraph(pre_bias,
                                                     bias,
                                                     axis=len(x.shape) - 1)
B
Baibaifan 已提交
1443 1444 1445
    else:
        helper = LayerHelper('linear', **locals())
        dtype = x.dtype
B
Baibaifan 已提交
1446 1447
        assert len(
            x.shape) < 4, "X latitude is not supported greater than 3 now."
B
Baibaifan 已提交
1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459

        check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
                                 'linear')
        check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear')

        inputs = {'X': [x], 'Y': [weight]}
        attrs = {
            'transpose_X': False,
            'transpose_Y': False,
            'alpha': 1,
        }
        tmp = helper.create_variable_for_type_inference(dtype)
1460 1461 1462 1463
        helper.append_op(type='matmul_v2',
                         inputs=inputs,
                         outputs={'Out': tmp},
                         attrs=attrs)
B
Baibaifan 已提交
1464 1465
        if bias is not None:
            res = helper.create_variable_for_type_inference(dtype)
1466 1467 1468 1469 1470 1471 1472
            helper.append_op(type='elementwise_add',
                             inputs={
                                 'X': [tmp],
                                 'Y': [bias]
                             },
                             outputs={'Out': [res]},
                             attrs={'axis': len(x.shape) - 1})
B
Baibaifan 已提交
1473 1474 1475 1476 1477
        else:
            res = tmp
        return res


1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490
def _set_var_distributed(var):
    if var is None:
        return

    var.is_distributed = True

    # NOTE: use current_block and find_var_recursive to support while_loop
    startup_block = paddle.static.default_startup_program().current_block()
    main_block = paddle.static.default_main_program().current_block()
    startup_block._find_var_recursive(var.name).is_distributed = True
    main_block._find_var_recursive(var.name).is_distributed = True


L
lilong12 已提交
1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501
def _parallel_linear(x,
                     num_rows,
                     num_cols,
                     axis,
                     param_attr,
                     bias_attr,
                     gather_out,
                     inner_rank,
                     nranks,
                     split_tensor,
                     name,
1502
                     group=None):
1503 1504
    """
    Parallel Linear
1505 1506 1507

    axis the dimension of the parameter of linear layer. 
    axis = 0: the row dimension
1508
    axis = 1: the col dimension
1509
    
1510
    """
1511 1512 1513 1514
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

L
lilong12 已提交
1515 1516
    if axis == 0:
        if split_tensor:
1517
            x = _c_split(x, group=group)
1518
    else:
L
lilong12 已提交
1519 1520
        x = _c_identity(x, group=group)

1521 1522 1523 1524 1525
    linear = paddle.nn.Linear(num_rows,
                              num_cols,
                              weight_attr=param_attr,
                              bias_attr=bias_attr,
                              name=name)
1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537

    # NOTE: npu linear function use matmul_v2 but linear use matmul
    linear_function = _linear if core.is_compiled_with_npu()\
        else paddle.nn.functional.linear
    linear_out = linear_function(
        x,
        linear.weight,
        # NOTE(wangxi): row split, bias need add after allreduce
        None if axis == 0 else linear.bias,
        linear.name)

    _set_var_distributed(linear.weight)
1538 1539 1540 1541
    # set is_distributed for splited bias
    # if a linear layer is splited by row, each rank would hold a complete bias and they should be the same in each rank.
    # if a linear layer is splited by col, the bias would also be split into each rank as its weight
    if axis == 1 and linear._bias_attr != False:
1542
        _set_var_distributed(linear.bias)
L
lilong12 已提交
1543 1544 1545 1546 1547

    if not gather_out: return linear_out

    out_shape = list(linear_out.shape)
    out_shape[0] *= 1 if axis == 0 else nranks
1548
    main_block = paddle.static.default_main_program().current_block()
L
lilong12 已提交
1549 1550 1551 1552 1553 1554 1555 1556 1557
    out = main_block.create_var(
        shape=out_shape,
        dtype=linear_out.dtype,
        type=linear_out.type,
        lod_level=linear_out.lod_level,
        persistable=False,
        is_data=False,
        need_check_feed=linear_out.desc.need_check_feed())
    if axis == 0:
1558 1559 1560 1561 1562 1563 1564 1565
        main_block.append_op(type='c_allreduce_sum',
                             inputs={'X': linear_out},
                             outputs={'Out': out},
                             attrs={
                                 'ring_id': ring_id,
                                 'use_calc_stream': True,
                                 'use_model_parallel': True
                             })
1566 1567
        if linear.bias is not None:
            out = out + linear.bias
L
lilong12 已提交
1568
    else:
1569 1570 1571 1572 1573 1574 1575 1576 1577 1578
        main_block.append_op(type='c_concat',
                             inputs={'X': linear_out},
                             outputs={'Out': out},
                             attrs={
                                 'rank': inner_rank,
                                 'ring_id': ring_id,
                                 'nranks': nranks,
                                 'use_calc_stream': True,
                                 'use_model_parallel': True
                             })
L
lilong12 已提交
1579
    return out
1580 1581


L
lilong12 已提交
1582 1583 1584 1585 1586 1587 1588
def _parallel_embedding(x,
                        per_part_embeddings,
                        origin_size,
                        param_attr,
                        inner_rank,
                        num_partitions,
                        name,
1589
                        group=None):
1590 1591 1592
    """
    Parallel Embedding
    """
1593 1594 1595 1596
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

1597 1598 1599 1600 1601 1602 1603 1604 1605
    helper = LayerHelper("_parallel_embedding", **locals())

    per_part_size = per_part_embeddings
    rank = inner_rank

    vocab_start_index = rank * per_part_size
    dtype = helper.get_default_dtype()
    size = [per_part_size, origin_size[1]]

1606 1607 1608 1609
    weight = helper.create_parameter(attr=param_attr,
                                     shape=size,
                                     dtype=dtype,
                                     is_bias=False)
1610 1611

    if num_partitions == 1:
1612 1613 1614 1615 1616
        return paddle.nn.functional.embedding(x,
                                              weight=weight,
                                              padding_idx=None,
                                              sparse=False,
                                              name=name)
1617

1618 1619
    startup_block = paddle.static.default_startup_program().global_block()
    main_block = paddle.static.default_main_program().global_block()
1620 1621 1622 1623 1624
    startup_block.vars[weight.name].is_distributed = True
    main_block.vars[weight.name].is_distributed = True

    output_parallel = paddle.distributed.collective._c_lookup_table(
        weight, x, start_index=vocab_start_index, name=name)
1625 1626 1627 1628
    out = paddle.distributed.collective._mp_allreduce(output_parallel,
                                                      group=group,
                                                      use_calc_stream=True,
                                                      use_model_parallel=True)
L
lilong12 已提交
1629
    return out
1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652


def split(x,
          size,
          operation,
          axis=0,
          num_partitions=1,
          gather_out=True,
          weight_attr=None,
          bias_attr=None,
          name=None):
    """

    Split the weight of the specified operation into multiple devices
    and do the computation in parallel.

    Now the following three cases are supported.

    Case 1: Parallel Embedding
        The weight of the embedding operation is a NxM matrix with N rows and M columns.
        With parallel embedding, the weight is split into num_partitions partitions, each
        of which is a matrix with (N/num_partitions + 1) rows and M column where the last
        row as the padding idx.
K
kuizhiqing 已提交
1653

1654 1655 1656 1657 1658 1659 1660 1661 1662
        Suppose we split the NxM weight into two partitons on device_0 and device_1
        respectively. Then, one each device, the final weight has (N/2 + 1) rows with the
        index range from 0 to N/2. On device_0, all values in the input within [0, N/2 -1]
        keep unchanged and all other values are changed to N/2 which is the padding index and
        are mapped to all zeros after embedding. In the same way, on device_1, the value V in the
        input within [N/2, N-1] will be changed to (V - N/2), and all other values are changed
        to N/2 and are mapped to all zeros after embedding. Finally, the results on the two
        devices are sum-reduced.

1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677
        The Embedding put on single card is as shown below:

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_single.png
            :width: 800
            :height: 350
            :alt: single_embedding
            :align: center

        Parallel Embedding is shown as below:

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_split.png
            :width: 800
            :alt: split_embedding
            :align: center

1678 1679 1680 1681 1682
    Case 2: Row Parallel Linear
        The weight of the linear operation is a NxM matrix with N rows and M columns.
        With row parallel linear, the weight is split into num_partitions partitions, each
        of which is a matrix with N/num_partitions rows and M column.

1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700
        The linear layer put on single card is shown as below, the input variable is represented by X,
        the weight matrix is represented by W and the output vaiable is O. The linear layer on single card is 
        simple matrix multiplication operation, O = X * W.

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_single.png
            :width: 800
            :alt: single_linear
            :align: center

        Row Parallel Linear is shown as below. As the name suggests, Row Parallel Linear splits the weight matrix W into
        [[W_row1], [W_row2]] along the row. And accordingly the input is splitted along the column into [X_col1, X_col2] and multiply their
        respective weight matrices. Finally apply AllReduce on the output from each card to get the final output.

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_row.png
            :width: 800
            :alt: split_row
            :align: center

1701 1702 1703 1704 1705
    Case 3: Column Parallel Linear
        The weight of the linear operation is a NxM matrix with N rows and M columns.
        With column parallel linear, the weight is split into num_paratitions partitions, each
        of which is a matrix with N rows and M/num_partitions column.

1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722
        The linear layer put on single card has been illustrated on case 2 and Column Parallel Linear
        is shown as below. The Column Parallel Linear splits the weight matrix W into [W_col1, W_col2] along the column and 
        these splitted matrices respectively multiply the input. Finally apply AllGather on the output from each card to get the final output. 

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col.png
            :width: 800
            :alt: split_col
            :align: center
    
    As observed, the column parallel linear and row parallel linear can be combined to skip one ALLGATHER communication
    operator. Furthermore the Attention and MLP can be combined to imporve the performance as shown below.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col_row.png
            :width: 800
            :alt: split_col_row
            :align: center

1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742
    Args:
        x (Tensor): Input tensor. It's data type should be float16, float32, float64, int32 or int64.
        size (list|tuple): A list or tuple with two elements indicating the shape of the weight.
        operation (str): The name of the operation. The supported operations are 'linear' and 'embedding'.
        axis (int, Optional): Indicate along which axis to split the weight. Default: 0.
        num_partitions (int, Optional): How many parts the weight is partitioned. Default: 1.
        gather_out (bool, Optional): Whether to gather the output after computation. By default, the output
            on each partitions will be gathered after computation. Default: True.
        weight_attr (ParamAttr, Optional): The parameter attribute for the learnable
            weights(Parameter) of the specified operation. Default: None.
        bias_attr (ParamAttr, Optional): The parameter attribute for the bias
            of the specified operation. Default: None.
        name (str, Optional): The default value is None. Normally there is no need for user to set this
            property. Default: None. For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor.

    Examples:
        .. code-block:: python
1743

1744
            # required: distributed
1745
            import paddle
1746
            import paddle.distributed.fleet as fleet
1747

1748
            paddle.enable_static()
1749
            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
1750
            fleet.init(is_collective=True)
1751
            data = paddle.randint(0, 8, shape=[10,4])
1752
            emb_out = paddle.distributed.split(
1753 1754 1755 1756
                data,
                (8, 8),
                operation="embedding",
                num_partitions=2)
1757

1758
    """
1759 1760 1761 1762
    assert isinstance(
        size,
        (list, tuple)), ("The type of size for "
                         "paddle.distributed.split must be list or tuple.")
1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774
    assert len(size) == 2, ("Number of elements in size of "
                            "paddle.distributed.split must be two.")
    assert isinstance(operation, str), ("The type of operation for "
                                        "paddle.distributed.split must be str.")
    supported_operations = [
        'linear',
        'embedding',
    ]
    assert operation in supported_operations, (
        "The operation for "
        "paddle.distributed.split must be one of {}.".format(
            supported_operations))
J
Jiabin Yang 已提交
1775
    if _non_static_mode():
L
lilong12 已提交
1776 1777 1778 1779
        raise ValueError(
            "paddle.distributed.split cannot be used in dynamic "
            "graph mode, plese use ParallelEmbedding, ParallelRowLinear, "
            "ParallelColumnLinear instead.")
1780
    else:
1781
        from .fleet import fleet
1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792
        assert fleet._role_maker, ("To use paddle.distributed.split, "
                                   "you must call fleet.init() firstly.")
        rank = fleet.worker_index()
        nranks = fleet.worker_num()

    # rank within a model parallel group
    inner_rank = rank % num_partitions

    if operation == "embedding":
        assert axis == 0, ("We only support to split the weight of embedding "
                           "along the first axis now.")
1793 1794 1795
        assert size[0] % num_partitions == 0, \
            "The length of the vocabulary must be divisible by num_partitions " \
            "but received vocabulary={} num_partitions={}".format(size[0], num_partitions)
1796

1797
        per_part_size = size[0] // num_partitions
1798 1799 1800 1801 1802 1803 1804 1805
        emb_out = _parallel_embedding(x,
                                      per_part_size,
                                      size,
                                      weight_attr,
                                      inner_rank,
                                      num_partitions,
                                      name,
                                      group=None)
B
Baibaifan 已提交
1806
        return emb_out
1807
    else:
L
lilong12 已提交
1808
        should_split = False
1809 1810 1811
        if axis == 0:
            assert size[0] % num_partitions == 0, (
                "Number of rows of the weight for linear ({}) must be"
1812 1813
                " divisible by num_partitions ({})".format(
                    size[0], num_partitions))
1814 1815
            per_part_size = size[0] // num_partitions
            linear_size = (per_part_size, size[1])
L
lilong12 已提交
1816
            if x.shape[-1] == size[0]: should_split = True
1817 1818 1819 1820

        elif axis == 1:
            assert size[1] % num_partitions == 0, (
                "Number of column of the weight for linear ({}) must be"
1821 1822
                " divisible by num_partitions ({})".format(
                    size[1], num_partitions))
1823 1824 1825 1826 1827 1828
            per_part_size = size[1] // num_partitions
            linear_size = (size[0], per_part_size)
        else:
            raise ValueError("The value of axis must be 0 or 1, but the value "
                             "given is {}.".format(axis))

1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840
        linear_out = _parallel_linear(x,
                                      linear_size[0],
                                      linear_size[1],
                                      axis,
                                      weight_attr,
                                      bias_attr,
                                      gather_out,
                                      inner_rank,
                                      num_partitions,
                                      should_split,
                                      name=name,
                                      group=None)
1841
        return linear_out
L
lilong12 已提交
1842 1843


L
lilong12 已提交
1844 1845
def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
    """
1846 1847 1848 1849 1850 1851 1852 1853 1854 1855
    Scatter tensors in in_tensor_list to all participators averagely and gather the result tensors in out_tensor_list.
    As shown below, the in_tensor_list in GPU0 includes 0_0 and 0_1, and GPU1 includes 1_0 and 1_1.
    Through alltoall operator, the 0_0 in GPU0 will be sent to GPU0 and 0_1 to GPU1, 1_0 in GPU1 sent to GPU0 and 1_1 to GPU1.
    Finally the out_tensor_list in GPU0 includes 0_0 and 1_0, and GPU1 includes 0_1 and 1_1.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/alltoall.png
        :width: 800
        :alt: alltoall
        :align: center

L
lilong12 已提交
1856 1857 1858
    Args:
        in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type
            should be float16, float32, float64, int32 or int64.
1859
        out_tensor_list (list): A list of output Tensors. The data type of its elements should be the same as the
L
lilong12 已提交
1860 1861
            data type of the input Tensors.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
1862
        use_calc_stream (bool, optional): Whether to use calculation stream (True) or communication stream. Default: True.
1863
    
L
lilong12 已提交
1864 1865
    Returns:
        None.
1866
    
L
lilong12 已提交
1867 1868
    Examples:
        .. code-block:: python
1869

L
lilong12 已提交
1870 1871 1872 1873
            # required: distributed
            import numpy as np
            import paddle
            from paddle.distributed import init_parallel_env
1874
            
L
lilong12 已提交
1875 1876 1877 1878 1879 1880 1881 1882 1883 1884
            init_parallel_env()
            out_tensor_list = []
            if paddle.distributed.ParallelEnv().rank == 0:
                np_data1 = np.array([[1, 2, 3], [4, 5, 6]])
                np_data2 = np.array([[7, 8, 9], [10, 11, 12]])
            else:
                np_data1 = np.array([[13, 14, 15], [16, 17, 18]])
                np_data2 = np.array([[19, 20, 21], [22, 23, 24]])
            data1 = paddle.to_tensor(np_data1)
            data2 = paddle.to_tensor(np_data2)
李季 已提交
1885
            paddle.distributed.alltoall([data1, data2], out_tensor_list)
L
lilong12 已提交
1886 1887 1888 1889 1890 1891
            # out for rank 0: [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]]
            # out for rank 1: [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]]
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
1892
    if in_dygraph_mode():
1893 1894 1895 1896
        group = _get_default_group() if group is None else group
    else:
        ring_id = 0 if group is None else group.id

L
lilong12 已提交
1897
    temp = paddle.concat(in_tensor_list, axis=0)
李季 已提交
1898
    nranks = len(in_tensor_list)
L
lilong12 已提交
1899
    if in_dygraph_mode():
1900 1901 1902 1903 1904 1905
        if len(out_tensor_list) == 0:
            tensor_shape = list(in_tensor_list[0].shape)
            tensor_shape[0] *= nranks
            out = paddle.empty(tensor_shape, in_tensor_list[0].dtype)
        else:
            out = paddle.concat(out_tensor_list, axis=0)
1906 1907 1908 1909 1910 1911
        task = group.process_group.alltoall(temp, out)
        task.wait()
        out_tensor_list.clear()
        out_tensor_list.extend(paddle.split(out, nranks, 0))
        return

J
Jiabin Yang 已提交
1912
    if _non_static_mode():
李季 已提交
1913 1914
        out = _C_ops.alltoall(temp, 'use_calc_stream', use_calc_stream,
                              'ring_id', ring_id)
L
lilong12 已提交
1915
    else:
W
wanghuancoder 已提交
1916 1917 1918 1919 1920
        op_type = 'alltoall'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(
            dtype=in_tensor_list[0].dtype)

L
lilong12 已提交
1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934
        if not isinstance(in_tensor_list, list):
            raise ValueError("The type of 'in_tensor_list' for all_to_all "
                             "should be list.")
        for elem in in_tensor_list:
            check_variable_and_dtype(
                elem, 'in_tensor_list',
                ['float16', 'float32', 'float64', 'int32', 'int64'],
                'all_to_all')
        if not isinstance(out_tensor_list, list):
            raise ValueError("The type of 'out_tensor_list' for all_to_all "
                             "should be list.")
        if len(out_tensor_list) != 0:
            raise ValueError("The 'out_tensor_list' for all_to_all "
                             "must be an empty list.")
1935 1936 1937 1938 1939 1940 1941
        helper.append_op(type=op_type,
                         inputs={'X': [temp]},
                         outputs={'Out': [out]},
                         attrs={
                             'ring_id': ring_id,
                             'use_calc_stream': use_calc_stream,
                         })
L
lilong12 已提交
1942 1943 1944
    out_tensor_list.extend(paddle.split(out, nranks, 0))


1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032
def alltoall_single(in_tensor,
                    out_tensor,
                    in_split_sizes=None,
                    out_split_sizes=None,
                    group=None,
                    use_calc_stream=True):
    """
    Scatter a single input tensor to all participators and gather the received tensors in out_tensor.

    .. note::
        ``alltoall_single`` is only supported in eager mode.

    Args:
        in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32 or int64.
        out_tensor (Tensor): Output Tensor. The data type should be the same as the data type of the input Tensor.
        in_split_sizes (list[int], optional): Split sizes of ``in_tensor`` for dim[0]. If not given, dim[0] of ``in_tensor`` 
            must be divisible by group size and ``in_tensor`` will be scattered averagely to all participators. Default: None.
        out_split_sizes (list[int], optional): Split sizes of ``out_tensor`` for dim[0]. If not given, dim[0] of ``out_tensor`` 
            must be divisible by group size and ``out_tensor`` will be gathered averagely from all participators. Default: None.
        group (Group, optional): The group instance return by ``new_group`` or None for global default group. Default: None.
        use_calc_stream (bool, optional): Whether to use calculation stream (True) or communication stream. Default: True.
    
    Returns:
        None, if ``use_calc_stream`` is set to ``True``; ``Task`` of ``group``, if ``use_calc_stream`` is set to ``False``.
    
    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            size = dist.get_world_size()

            # case 1
            input = paddle.arange(2, dtype='int64') + rank * 2
            # input for rank 0: [0, 1]
            # input for rank 1: [2, 3]
            
            output = paddle.empty([2], dtype='int64')
            dist.alltoall_single(input, output)
            # output for rank 0: [0, 2]
            # output for rank 1: [1, 3]

            # case 2
            in_split_sizes = [i + 1 for i in range(size)]
            # in_split_sizes for rank 0: [1, 2] and for rank 1: [1, 2]
            out_split_sizes = [rank + 1 for i in range(size)]
            # out_split_sizes for rank 0: [1, 1] and for rank 1: [2, 2]

            input = paddle.ones([sum(in_split_sizes), size], dtype='float32') * rank
            # input for rank 0: [[0., 0.], [0., 0.], [0., 0.]]
            # input for rank 1: [[1., 1.], [1., 1.], [1., 1.]]
            output = paddle.empty([(rank + 1) * size, size], dtype='float32')

            group = dist.new_group([0, 1])
            task = dist.alltoall_single(input,
                                        output,
                                        in_split_sizes,
                                        out_split_sizes,
                                        use_calc_stream=False,
                                        group=group)
            task.wait()
            # output for rank 0: [[0., 0.], [1., 1.]]
            # output for rank 1: [[0., 0.], [0., 0.], [1., 1.], [1., 1.]]

    """
    if group is not None and not group.is_member():
        return

    assert in_dygraph_mode(), "Only suppport alltoall_single in eager mode."
    # _check_single_tensor

    group = _get_default_group() if group is None else group
    in_split_sizes = [] if in_split_sizes is None else in_split_sizes
    out_split_sizes = [] if out_split_sizes is None else out_split_sizes

    task = group.process_group.alltoall_single(in_tensor, out_tensor,
                                               in_split_sizes, out_split_sizes)
    if use_calc_stream:
        task.wait()
        return
    else:
        return task


L
lilong12 已提交
2033 2034 2035 2036 2037 2038 2039 2040
def send(tensor, dst=0, group=None, use_calc_stream=True):
    """
    Send a tensor to the receiver.

    Args:
        tensor (Tensor): The Tensor to send. Its data type
            should be float16, float32, float64, int32 or int64.
        dst (int): The destination rank id.
L
lilong12 已提交
2041 2042
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
        use_calc_stream (bool, optional): Whether to use calculate stream or communication stream. Default: True.
2043
    
L
lilong12 已提交
2044 2045 2046 2047 2048
    Returns:
        None.

    Examples:
        .. code-block:: python
2049

L
lilong12 已提交
2050
            # required: distributed
L
lilong12 已提交
2051
            import paddle
L
lilong12 已提交
2052
            from paddle.distributed import init_parallel_env
2053

L
lilong12 已提交
2054 2055 2056 2057 2058 2059 2060 2061
            init_parallel_env()
            if paddle.distributed.ParallelEnv().rank == 0:
                data = paddle.to_tensor([7, 8, 9])
                paddle.distributed.send(data, dst=1)
            else:
                data = paddle.to_tensor([1,2,3])
                paddle.distributed.recv(data, src=0)
            out = data.numpy()
L
lilong12 已提交
2062 2063 2064
    """
    if group is not None and not group.is_member():
        return
2065

L
lilong12 已提交
2066
    if in_dygraph_mode():
2067
        group = _get_default_group() if group is None else group
2068 2069
        group_dst_rank = group.get_group_rank(dst)
        task = group.process_group.send(tensor, group_dst_rank)
2070 2071 2072 2073 2074 2075
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

L
lilong12 已提交
2076 2077
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
2078
    if _non_static_mode():
W
wanghuancoder 已提交
2079 2080
        return _C_ops.send_v2(tensor, 'use_calc_stream', use_calc_stream,
                              'ring_id', ring_id, 'peer', dst)
W
wanghuancoder 已提交
2081
    op_type = 'send_v2'
L
lilong12 已提交
2082 2083 2084 2085 2086
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'send')

    helper = LayerHelper(op_type, **locals())
2087 2088 2089 2090 2091 2092 2093
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'peer': dst,
                         'use_calc_stream': use_calc_stream,
                     })
L
lilong12 已提交
2094 2095 2096 2097 2098 2099 2100 2101 2102 2103


def recv(tensor, src=0, group=None, use_calc_stream=True):
    """
    Receive a tensor to the sender.

    Args:
        tensor (Tensor): The Tensor to receive. Its data type
            should be float16, float32, float64, int32 or int64.
        src (int): The source rank id.
L
lilong12 已提交
2104 2105
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
        use_calc_stream (bool, optional): Whether to use calculate stream or communication stream. Default: True.
2106
    
L
lilong12 已提交
2107 2108 2109 2110 2111
    Returns:
        None.

    Examples:
        .. code-block:: python
2112

L
lilong12 已提交
2113
            # required: distributed
L
lilong12 已提交
2114
            import paddle
L
lilong12 已提交
2115
            from paddle.distributed import init_parallel_env
2116

L
lilong12 已提交
2117 2118 2119 2120 2121 2122 2123 2124
            init_parallel_env()
            if paddle.distributed.ParallelEnv().rank == 0:
                data = paddle.to_tensor([7, 8, 9])
                paddle.distributed.send(data, dst=1)
            else:
                data = paddle.to_tensor([1,2,3])
                paddle.distributed.recv(data, src=0)
            out = data.numpy()
L
lilong12 已提交
2125 2126 2127
    """
    if group is not None and not group.is_member():
        return
2128

L
lilong12 已提交
2129
    if in_dygraph_mode():
2130
        group = _get_default_group() if group is None else group
2131 2132
        group_src_rank = group.get_group_rank(src)
        task = group.process_group.recv(tensor, group_src_rank)
2133 2134 2135 2136 2137 2138
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

L
lilong12 已提交
2139 2140
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
2141
    if _non_static_mode():
W
wanghuancoder 已提交
2142 2143 2144
        return _C_ops.recv_v2(tensor, 'use_calc_stream', use_calc_stream,
                              'ring_id', ring_id, 'peer', src, 'dtype',
                              tensor.dtype, 'out_shape', tensor.shape)
W
wanghuancoder 已提交
2145
    op_type = 'recv_v2'
L
lilong12 已提交
2146 2147 2148 2149
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'recv')
    helper = LayerHelper(op_type, **locals())
2150 2151 2152 2153 2154 2155 2156 2157 2158
    helper.append_op(type=op_type,
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'peer': src,
                         'out_shape': tensor.shape,
                         'dtype': tensor.dtype,
                         'use_calc_stream': use_calc_stream,
                     })
2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 2440 2441 2442 2443 2444 2445 2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462 2463 2464 2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508 2509 2510 2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532 2533 2534 2535 2536 2537 2538 2539 2540 2541 2542 2543 2544 2545


def _check_single_tensor(tensor, tensor_name):
    if not isinstance(tensor, (core.eager.Tensor, paddle.Tensor)):
        raise RuntimeError("Invalid function argument. Expected parameter {}"
                           "to be of type paddle.Tensor, but it's {}".format(
                               tensor_name, type(tensor)))


def _check_tensor_list(tensor_list, tensor_name):
    if not isinstance(tensor_list, list) or \
        not all(isinstance(t, (core.eager.Tensor, paddle.Tensor)) for t in tensor_list):
        raise RuntimeError("Invalid function argument. Expected parameter {}"
                           "to be of type paddle.Tensor".format(tensor_name))


def isend(tensor, dst, group=None):
    """
    Sends a tensor asynchronously

    Args:
        tensor (Tensor): The Tensor to send. Its data type
            should be float16, float32, float64, int32 or int64.
        dst (int): The destination rank.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
    
    Returns:
        A distributed task object.

    Warning:    
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            world_size = dist.get_world_size()

            if rank == 0:
                data = paddle.to_tensor([7, 8, 9])
                task = paddle.distributed.isend(data, dst=1)
            else:
                data = paddle.to_tensor([1, 2, 3])
                task = paddle.distributed.irecv(data, src=0)

            task.wait()

            print(data)
            # paddle.tensor([7, 8, 9])     # Rank-0
            # paddle.tensor([7, 8, 9])     # Rank-1

    """
    _check_single_tensor(tensor, "tensor")
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
        group_dst_rank = group.get_group_rank(dst)
        assert group_dst_rank >= 0, ("dst rank out of group, need global rank")
        return group.process_group.send(tensor, group_dst_rank)
    else:
        raise RuntimeError("Don't support static graph mode currently.")


def irecv(tensor, src=None, group=None):
    """
    Receive a tensor to the sender.

    Args:
        tensor (Tensor): The Tensor to receive. Its data type
            should be float16, float32, float64, int32 or int64.
        src (int): The source rank id.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.

    Returns:
         A distributed task object.

    Warning:    
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            world_size = dist.get_world_size()

            if rank == 0:
                data = paddle.to_tensor([7, 8, 9])
                task = paddle.distributed.isend(data, dst=1)
            else:
                data = paddle.to_tensor([1, 2, 3])
                task = paddle.distributed.irecv(data, src=0)

            task.wait()

            print(data)
            # paddle.tensor([7, 8, 9])     # Rank-0
            # paddle.tensor([7, 8, 9])     # Rank-1
    """
    _check_single_tensor(tensor, "tensor")
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
        group_src_rank = group.get_group_rank(src)
        assert group_src_rank >= 0, ("src rank out of group, need global rank")
        return group.process_group.recv(tensor, group_src_rank)
    else:
        raise RuntimeError("Don't support static graph mode currently.")


class P2POp(object):
    """
    A class that makes point-to-point operations for "batch_isend_irecv".

    This class creates the type of P2P operation, communication buffer, peer rank,
    Group. Instances of this class will be passed to
    ``paddle.distributed.batch_isend_irecv`` for point-to-point communication.

    Args:
        op (callable): A function to send data to or receive data from a peer process.
            The type of ``op`` is either ``paddle.distributed.isend`` or ``paddle.distributed.irecv``.
        tensor (Tensor): Tensor to send or receive.
        peer (int): The destination or source rank.
        group (Group, optional): The group instance return by new_group or None for global 
            default group. Default: None.

    """

    def __init__(self, op, tensor, peer, group=None):
        if op not in [isend, irecv]:
            raise RuntimeError("Invalid ``op`` function. Expected ``op`` "
                               "to be of type ``paddle.distributed.isend`` or "
                               "``paddle.distributed.irecv``.")
        _check_single_tensor(tensor, "tensor")

        self.op = op
        self.tensor = tensor
        self.peer = peer
        self.group = _get_default_group() if group is None else group


@contextlib.contextmanager
def _with_batch_p2p_guard(backend):
    if backend == "nccl":
        core.ProcessGroupNCCL.group_start()
    try:
        yield
    finally:
        if backend == "nccl":
            core.ProcessGroupNCCL.group_end()


def _check_p2p_op_list(p2p_op_list):
    """
    Helper to check that the ``p2p_op_list`` is a list of P2POp instances and
    all ops use the same backend.
    """
    if not isinstance(p2p_op_list, list) or not all(
            isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list):
        raise RuntimeError("Invalid ``p2p_op_list``. Each op is expected to "
                           "to be of type ``paddle.distributed.P2POp``.")

    backend = _group_map_backend[p2p_op_list[0].group]
    if not all(backend == _group_map_backend[p2p_op.group]
               for p2p_op in p2p_op_list):
        raise RuntimeError("All groups need to use the same backend.")


def batch_isend_irecv(p2p_op_list):
    """
    Send or Receive a batch of tensors asynchronously and return a list of requests.

    Process each of the point-to-point operations in ``p2p_op_list`` and return the 
    corresponding tasks. NCCL are currently supported.

    Args:
        p2p_op_list: A list of point-to-point operations(type of each operator is
            ``paddle.distributed.P2POp``). The order of the isend/irecv in the list
            matters and it needs to match with corresponding isend/irecv on the
            remote end.

    Returns:
        A list of distributed tasks returned by calling the corresponding
        op in the op_list. 

    Warning:    
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed

            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            world_size = dist.get_world_size()

            send_t = paddle.arange(2) + rank
            # paddle.tensor([0, 1])  # Rank-0
            # paddle.tensor([1, 2])  # Rank-1

            recv_t = paddle.empty(shape=[2], dtype=send_t.dtype)

            send_op = dist.P2POp(dist.isend, send_t, (rank + 1) % world_size)
            recv_op = dist.P2POp(dist.irecv, recv_t, (rank - 1 + world_size) % world_size)

            tasks = dist.batch_isend_irecv([send_op, recv_op])

            for task in tasks:
                task.wait()
            
            print(recv_t)
            # paddle.tensor([1, 2])     # Rank-0
            # paddle.tensor([0, 1])     # Rank-1
    """
    _check_p2p_op_list(p2p_op_list)
    group = p2p_op_list[0].group
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
        backend = _group_map_backend[group]
        tasks = []
        with _with_batch_p2p_guard(backend):
            for p2p_op in p2p_op_list:
                op = p2p_op.op
                tensor = p2p_op.tensor
                peer = p2p_op.peer
                comm_group = p2p_op.group
                task = op(tensor, peer, comm_group)
                if task is not None:
                    tasks.append(task)
        return tasks
    else:
        raise RuntimeError("Don't support static graph mode currently.")


def reduce_scatter(tensor,
                   tensor_list,
                   op=ReduceOp.SUM,
                   group=None,
                   use_calc_stream=True):
    """
    Reduces, then scatters a list of tensors to all processes in a group

    Args:
        tensor (Tensor): Output tensor.
        tensor_list (list[Tensor]): List of tensors to reduce and scatter.
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
        group (Group, optional): The group instance return by new_group or None for global 
            default group. Default: None.
        use_calc_stream (bool, optional): Whether this op should be an async op.

    Returns:
        Async task handle, if use_calc_stream is set to False.
        None, if use_calc_stream or if not part of the group.
    
    Warning:    
        This API only supports the dygraph mode.


    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            world_size = dist.get_world_size()

            if rank == 0:
                t1 = paddle.to_tensor([0, 1])
                t2 = paddle.to_tensor([2, 3])
            else:
                t1 = paddle.to_tensor([4, 5])
                t2 = paddle.to_tensor([6, 7])

            tensor_list = [t1, t2]

            output = paddle.empty(shape=[2], dtype=tensor_list[0].dtype)
            dist.reduce_scatter(output, tensor_list)

            print(output)
            # [4, 6]     # Rank-0
            # [8, 10]     # Rank-1

    """
    _check_single_tensor(tensor, "tensor")
    _check_tensor_list(tensor_list, "tensor_list")

    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        op_type = _get_reduce_op(op, "reduce_scatter")
        group = _get_default_group() if group is None else group

        temp = paddle.concat(tensor_list, axis=0)
        task = group.process_group._reduce_scatter_base(tensor, temp, op_type)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task
    else:
        raise RuntimeError("Don't support static graph mode currently.")


def _reduce_scatter_base(output,
                         input,
                         op=ReduceOp.SUM,
                         group=None,
                         use_calc_stream=True):
    """
    Reduces, then scatters a flattened tensor to all processes in a group.

    Args:
        output (Tensor): Output tensor.
        input (Tensor): Input tensor that is of size output tensor size times world size
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.
        use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
    Returns:
        Async task handle, if use_calc_stream is set to False.
        None, if use_calc_stream or if not part of the group.

    Examples:
        .. code-block:: python

            # required: distributed

            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            world_size = dist.get_world_size()

            input = paddle.arange(4) + rank
            # [0, 1, 2, 3]  # Rank-0
            # [1, 2, 3, 4]  # Rank-1

            output = paddle.empty(shape=[2], dtype=input.dtype)
            paddle.distributed.collective._reduce_scatter_base(output, input)
            print(output)
            # [1, 3]     # Rank-0
            # [5, 7]     # Rank-1

    """
    _check_single_tensor(output, "output")
    _check_single_tensor(input, "input")

    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        op_type = _get_reduce_op(op, "_reduce_scatter_base")
        group = _get_default_group() if group is None else group
        task = group.process_group._reduce_scatter_base(output, input, op_type)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task
    else:
        raise RuntimeError("Don't support static graph mode currently.")