collective.py 63.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import os
17 18
import pickle
import io
19 20
import datetime
import time
21
from ..fluid.layer_helper import LayerHelper
22
from ..fluid.framework import Variable
23
from ..fluid.framework import in_dygraph_mode
24
from ..fluid.framework import OpProtoHolder
J
Jiabin Yang 已提交
25
from ..fluid.framework import _non_static_mode
26
from ..fluid.framework import _in_legacy_dygraph
27
from ..fluid.framework import convert_np_dtype_to_dtype_
J
Jiangxinz 已提交
28
from ..fluid.framework import _varbase_creator
29 30 31 32
from ..fluid.data_feeder import convert_dtype
from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.data_feeder import check_type
from ..fluid.data_feeder import check_dtype
33 34
from ..fluid.layers.tensor import fill_constant
from ..fluid.layers import utils
B
Baibaifan 已提交
35
from ..fluid.dygraph import layers
36 37 38 39
from ..fluid.dygraph.parallel import prepare_context
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
40
from paddle import _C_ops, _legacy_C_ops
J
Jiangxinz 已提交
41
import paddle.fluid.dygraph_utils as dygraph_utils
42
import contextlib
W
wuhuachaocoding 已提交
43 44 45 46 47 48 49 50 51 52 53 54
from .fleet.layers.mpu.mp_ops import split
from .fleet.layers.mpu.mp_ops import _c_identity
from .fleet.layers.mpu.mp_ops import _c_concat
from .fleet.layers.mpu.mp_ops import _c_split
from .fleet.layers.mpu.mp_ops import _mp_allreduce
from .fleet.layers.mpu.mp_ops import _c_lookup_table
from .fleet.layers.mpu.mp_ops import _Linear
from .fleet.layers.mpu.mp_ops import _set_var_distributed
from .fleet.layers.mpu.mp_ops import _c_softmax_with_cross_entropy
from .fleet.layers.mpu.mp_ops import _linear
from .fleet.layers.mpu.mp_ops import _parallel_linear
from .fleet.layers.mpu.mp_ops import _parallel_embedding
55 56 57
from .communication.group import Group, _add_new_group
from .communication.all_reduce import all_reduce
from .communication.reduce import _get_reduce_op, ReduceOp
58

59
__all__ = []
60

K
kuizhiqing 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73
_global_env = None


def _get_global_env():
    global _global_env
    if not _global_env:
        _global_env = paddle.distributed.ParallelEnv()
    return _global_env


# group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group]
_group_map = {}
74
_global_env_gid = 0
K
kuizhiqing 已提交
75

76 77 78 79
# group map by name : the map of all groups from their names
# Dict[name, Group]
_group_map_by_name = {}

80 81 82 83
# backend map by group : the map of all backend from their groups
# Dict[group, backend]
_group_map_backend = {}

84 85 86
# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"

87
_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter', 'xccl']
88 89
_default_store = None  # the default tcp store
_default_backend = None
90 91
_default_timeout = datetime.timedelta(seconds=1800)
_start_ring_id = 0
92

K
kuizhiqing 已提交
93

L
lilong12 已提交
94 95 96 97 98 99 100 101 102 103
def _set_default_backend(backend):
    global _default_backend
    _default_backend = backend


def _set_default_store(store):
    global _default_store
    _default_store = store


K
kuizhiqing 已提交
104 105
def _get_group_map():
    global _group_map
106
    if _global_env_gid not in _group_map:
K
kuizhiqing 已提交
107
        genv = _get_global_env()
108 109
        _group_map[_global_env_gid] = Group(genv.rank, 0,
                                            list(range(genv.world_size)))
K
kuizhiqing 已提交
110 111 112 113
    return _group_map


def _get_global_group():
114
    return _get_group_map()[_global_env_gid]
K
kuizhiqing 已提交
115 116


117 118 119 120 121 122
def _get_group_map_by_name():
    global _group_map_by_name
    return _group_map_by_name


def _get_default_group():
L
lilong12 已提交
123
    global _group_map_by_name
124 125
    assert is_initialized(), ("Call paddle.distributed.init_parallel_env first "
                              "to initialize the distributed environment.")
126 127 128
    return _get_group_map_by_name()[_default_group_name]


L
lilong12 已提交
129 130 131 132 133 134 135 136 137 138 139 140
def _set_group_map(gid, group):
    global _group_map
    assert gid not in _group_map
    _group_map[gid] = group


def _set_group_map_by_name(name, group):
    global _group_map_by_name
    assert name not in _group_map_by_name
    _group_map_by_name[name] = group


141 142 143 144 145 146
def _set_group_map_backend(group, backend):
    global _group_map_backend
    assert group not in _group_map_backend
    _group_map_backend[group] = backend


K
kuizhiqing 已提交
147
def _new_ring_id():
148 149 150 151 152 153 154
    # NOTE(liyurui): For compatible reason, auto parallel and eager mode relay on previous syntax.
    if in_dygraph_mode():
        global _start_ring_id
        _start_ring_id += 1
        return _start_ring_id + max(_get_global_env().nrings, 9)
    else:
        return len(_get_group_map()) + max(_get_global_env().nrings, 9)
K
kuizhiqing 已提交
155 156 157 158 159 160 161 162


def get_group(id=0):
    """

    Get group instance by group id.

    Args:
K
kuizhiqing 已提交
163
        id (int): the group id. Default value is 0.
K
kuizhiqing 已提交
164 165 166 167 168 169 170 171 172 173 174 175 176 177

    Returns:
        Group: the group instance.

    Examples:
        .. code-block:: python

            ...
            gid = paddle.distributed.new_group([2,4,6])
            paddle.distributed.get_group(gid.id)

    """

    gm = _get_group_map()
J
Jiangxinz 已提交
178
    return gm[id] if id in gm else None
K
kuizhiqing 已提交
179 180


181 182 183 184 185 186
def _new_process_group_impl(backend,
                            store,
                            rank,
                            world_size,
                            group_name,
                            pg_options,
L
lilong12 已提交
187 188 189
                            group_id=0,
                            src_rank=None,
                            dst_rank=None):
190
    pg = None
191
    genv = _get_global_env()
L
lilong12 已提交
192 193 194 195
    if backend != 'heter':
        assert src_rank is None and dst_rank is None, (
            "src_rank and dst_rank "
            "can only be set for heter backend.")
L
lilong12 已提交
196
    assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
197
    if backend == "gloo":
198 199
        place = core.CPUPlace()
        pg = core.ProcessGroupGloo(store, rank, world_size, place, group_id)
200
    elif backend == "nccl":
201 202
        place = core.CUDAPlace(genv.device_id)
        pg = core.ProcessGroupNCCL(store, rank, world_size, place, group_id)
203
    elif backend == "hccl":
204 205
        place = core.NPUPlace(genv.device_id)
        pg = core.ProcessGroupHCCL(store, rank, world_size, place, group_id)
206 207 208
    elif backend == "xccl":
        place = core.CustomPlace(genv.device_type, genv.device_id)
        pg = core.ProcessGroupCustom(store, rank, world_size, place, group_id)
209
    elif backend == "heter":
210 211 212 213 214
        place = None
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(genv.device_id)
        elif core.is_compiled_with_npu():
            place = core.NPUPlace(genv.device_id)
215 216 217 218 219 220 221 222 223 224 225 226 227
        cluster_id = int(os.getenv("CLUSTER_ID", "-1"))
        assert cluster_id >= 0, "please set the CLUSTER_ID variable."
        cluster_size = os.getenv("CLUSTER_SIZE", None)
        assert cluster_size, "please set the CLUSTER_SIZE variable."
        cluster_size = cluster_size.split(",")
        cluster_size = [int(s) for s in cluster_size]
        switch_ep = os.getenv("CLUSTER_SWITCH", None)
        assert switch_ep, "please set the CLUSTER_SWITCH variable."
        cluster_size_cumsum = np.cumsum(cluster_size)
        cluster_offset = 0 if cluster_id == 0 else cluster_size_cumsum[
            cluster_id - 1]
        global_rank = cluster_offset + rank
        global_world_size = cluster_size_cumsum[-1]
228
        global_rank, global_world_size = _get_global_config(backend, rank)
229 230 231 232 233 234 235 236 237 238 239 240 241
        pg = core.ProcessGroupHeter(store,
                                    rank=global_rank,
                                    world_size=global_world_size,
                                    place=place,
                                    gid=group_id,
                                    local_rank=rank,
                                    local_size=world_size,
                                    gloo_rank=cluster_id,
                                    gloo_size=len(cluster_size),
                                    with_switch=True,
                                    switch_endpoint=switch_ep,
                                    src_rank=src_rank,
                                    dst_rank=dst_rank)
242 243 244 245

    return pg


S
ShenLiang 已提交
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
def barrier(group=None):
    """

    Barrier among all participators in the group.

    Args:
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            paddle.distributed.barrier()
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
270
    if in_dygraph_mode():
271 272 273 274 275
        group = _get_default_group() if group is None else group
        task = group.process_group.barrier()
        task.wait()
        return

S
ShenLiang 已提交
276 277 278
    ring_id = 0 if group is None else group.id

    temp = fill_constant([1], dtype="int32", value="1")
J
Jiabin Yang 已提交
279
    if _non_static_mode():
280
        return _legacy_C_ops.barrier(temp, temp, 'ring_id', ring_id)
W
wanghuancoder 已提交
281 282 283

    op_type = 'barrier'

S
ShenLiang 已提交
284 285 286
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'group' for barrier must be int.")
    helper = LayerHelper(op_type, **locals())
287 288 289 290
    helper.append_op(type=op_type,
                     inputs={'X': [temp]},
                     outputs={'Out': [temp]},
                     attrs={'ring_id': ring_id})
S
ShenLiang 已提交
291 292


L
lilong12 已提交
293 294 295 296 297 298 299
# _custom_gid provides a way for users to
# set the group id, which is usually useful
# to be compatible with the static mode.
_custom_gid = None


def _set_custom_gid(gid):
300
    global _custom_gid
L
lilong12 已提交
301 302 303
    _custom_gid = gid


304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
def _barrier_by_tcp_store(group_name, store, timeout):
    global_rank = paddle.distributed.get_rank()
    global_world_size = paddle.distributed.get_world_size()

    if global_world_size < 2:
        return

    barrier_prefix = "Barrier/" + group_name + "/"
    is_master = (global_rank == 0)

    def _check_keys_ready(wait_keys):
        start_time = time.time()
        while len(wait_keys) > 0:
            time.sleep(0.1)
            elapse_time = time.time() - start_time
            if datetime.timedelta(seconds=elapse_time) > timeout:
                raise RuntimeError(
                    "Timeout while initializing process group {}."
                    "Keys {} are not ready sinck rank {} is waiting them."
                    "Two reason may cause this error:\n 1. The create process group api should be called by all ranks.\n"
                    " 2. Try to increase the waiting time.\n".format(
                        group_name, wait_keys, global_rank))
            wait_keys = list(
                filter(lambda key: int(store.get(key)) != 1, wait_keys))

    # all the workers set their exiting key and exit
    # the master will wait for all workers' exiting key, ensure to exit in the end
    if is_master:
        wait_keys = [
            barrier_prefix + str(rank) for rank in range(1, global_world_size)
        ]
        _check_keys_ready(wait_keys)
    else:
        store.add(barrier_prefix + str(global_rank), 1)


def new_group(ranks=None, backend=None, timeout=_default_timeout):
K
kuizhiqing 已提交
341 342
    """

K
kuizhiqing 已提交
343
    Creates a new distributed communication group.
K
kuizhiqing 已提交
344 345

    Args:
K
kuizhiqing 已提交
346
        ranks (list): The global ranks of group members.
K
kuizhiqing 已提交
347
        backend (str): The backend used to create group, only nccl is supported now.
348
        timeout (datetime.timedelta, optional): The waiting timeout for store relevant options, default is 30 minutes.
K
kuizhiqing 已提交
349 350

    Returns:
K
kuizhiqing 已提交
351
        Group: The group instance.
K
kuizhiqing 已提交
352 353 354 355 356 357 358

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
359 360
            tindata = paddle.randn(shape=[2, 3])
            gp = paddle.distributed.new_group([2,4,6])
361
            paddle.distributed.all_reduce(tindata, group=gp, sync_op=False)
K
kuizhiqing 已提交
362 363

    """
364
    global _custom_gid
365
    global _group_map
L
lilong12 已提交
366
    if in_dygraph_mode():
367
        global _default_group_name
L
lilong12 已提交
368
        gid = _custom_gid if _custom_gid else _new_ring_id()
369
        group_name = _default_group_name + str(gid)
L
lilong12 已提交
370
        if backend != 'heter' and (ranks is None or len(ranks) > 1):
371 372 373 374 375 376 377 378 379
            global_group = _get_default_group()
            global_rank = global_group.rank
            global_ranks = global_group.ranks
            backend = _default_backend if backend is None else backend
            if ranks is None:
                ranks = global_ranks
            assert len(ranks) <= len(global_ranks), (
                "Size of new group must be less than or "
                "equal to that of the default global group.")
380 381
        size = len(ranks)
        ranks = sorted(ranks)
L
lilong12 已提交
382 383 384 385
        if backend == 'heter' or (size > 1 and global_rank in ranks):
            rank = 0 if backend == 'heter' else ranks.index(global_rank)
            src_rank = ranks[0] if backend == 'heter' else None
            dst_rank = ranks[1] if backend == 'heter' else None
386 387 388 389 390 391 392 393 394
            pg = _new_process_group_impl(backend,
                                         _default_store,
                                         rank,
                                         size,
                                         group_name,
                                         pg_options=None,
                                         group_id=gid,
                                         src_rank=src_rank,
                                         dst_rank=dst_rank)
395 396 397
        else:
            rank = -1
            pg = None
398
        group = Group(rank, gid, ranks, pg=pg, name=group_name)
399 400
        _group_map_by_name[group_name] = group
        _group_map[gid] = group
401
        _group_map_backend[group] = backend
402 403 404
        #TODO: The method below is a new method for group management, will replace the previous
        # three in the future.
        _add_new_group(group)
405

406
        # TODO(shenliang03): This is a temporary solution to solve the problem of
407
        # hang caused by tcp
408
        paddle.distributed.barrier(group=group)
409 410 411 412 413
        # NOTE(liyurui): All processors should hang and wait using tcp store, in case master exit before sub-group is created.
        if backend != 'heter':
            _barrier_by_tcp_store(group_name, _default_store, timeout)
        else:
            print("Warning: store barrier is not supported for heter backend.")
414
        return group
K
kuizhiqing 已提交
415 416 417 418 419 420 421 422 423 424 425

    if not backend:
        backend = 'nccl'
    assert backend == 'nccl', ("backend other than nccl is not supported yet")

    genv = _get_global_env()
    global_rank = genv.rank

    ring_id = _new_ring_id()

    if global_rank not in ranks:
426
        gp = Group(-1, ring_id, ranks)
K
kuizhiqing 已提交
427 428
        _group_map[ring_id] = gp
    else:
429 430 431
        ranks = sorted(ranks)
        group_rank = ranks.index(global_rank)
        group_size = len(ranks)
432
        gp = Group(group_rank, ring_id, ranks)
433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
        _group_map[ring_id] = gp

        if group_size >= 2:
            strategy = core.ParallelStrategy()
            strategy.nranks = group_size
            strategy.local_rank = group_rank
            strategy.trainer_endpoints = [
                genv.trainer_endpoints[i] for i in ranks
            ]
            strategy.current_endpoint = genv.current_endpoint
            strategy.nrings = 1

            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(genv.device_id)
                core.NCCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
449 450 451 452
            elif core.is_compiled_with_npu():
                place = core.NPUPlace(genv.device_id)
                core.HCCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
453 454 455 456
            elif core.is_compiled_with_mlu():
                place = core.MLUPlace(genv.device_id)
                core.CNCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
457 458 459 460
            elif core.is_compiled_with_xpu():
                place = core.XPUPlace(genv.device_id)
                core.BKCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
461 462 463 464 465
            else:
                assert False, ("no cuda device found")
        else:
            return gp

466
    # TODO(shenliang03): This is a temporary solution to solve the problem of
467
    # hang caused by cross-creation of new_group
468
    tmp = paddle.to_tensor(
J
Jiabin Yang 已提交
469
        [1], dtype="int32") if _non_static_mode() else fill_constant(
470
            [0], dtype="int32", value="1")
471
    paddle.distributed.all_reduce(tmp, sync_op=True)
472
    paddle.distributed.wait(tmp)
K
kuizhiqing 已提交
473 474
    return gp

475

476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505
def is_initialized():
    """

    Check whether the distributed environment has been initialized

    Returns (bool): `True` if distributed environment has been initialized, otherwise `False`.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle

            print(paddle.distributed.is_initialized())
            # False

            paddle.distributed.init_parallel_env()
            print(paddle.distributed.is_initialized())
            # True

    """
    global _group_map_by_name
    return _default_group_name in _group_map_by_name


def destroy_process_group(group=None):
    """
    Destroy a given group for communication

    Args:
506 507
        group (ProcessGroup, optional): The group to be destroyed. All of process groups, including
                                        the default group, will be destroyed and the distributed
508
                                        environment will be deinitialized.
509

510 511 512 513 514 515 516
    Returns : None

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
517
            import paddle.distributed as dist
518

519 520
            dist.init_parallel_env()
            group = dist.new_group([0, 1])
521

522 523
            dist.destroy_process_group(group)
            print(dist.is_initialized())
524
            # True
525 526
            dist.destroy_process_group()
            print(dist.is_initialized())
527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545
            # False

    """
    global _group_map
    global _group_map_by_name

    pg = _get_default_group() if group is None else group
    assert _group_map.get(pg.id, None) is not None, "Invalid group."

    if group is None:
        _group_map.clear()
        _group_map_by_name.clear()
        _group_map_backend.clear()
    else:
        del _group_map[pg.id]
        del _group_map_by_name[pg.name]
        del _group_map_backend[pg]


K
kuizhiqing 已提交
546 547 548 549 550 551 552 553
def wait(tensor, group=None, use_calc_stream=True):
    """

    wait to sync stream for group.

    Args:
        tensor (Tensor): The Tensor used before sync.
        group (Group): The Group instance to perform sync.
K
kuizhiqing 已提交
554 555
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
K
kuizhiqing 已提交
556 557 558 559 560 561 562 563 564 565

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
566
            tindata = paddle.randn(shape=[2, 3])
567
            paddle.distributed.all_reduce(tindata, sync_op=True)
K
kuizhiqing 已提交
568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
            paddle.distributed.wait(tindata)

    """

    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id

    if use_calc_stream:
        _sync_calc_stream(tensor)
    else:
        _sync_comm_stream(tensor, ring_id)


def _sync_calc_stream(tensor):

J
Jiabin Yang 已提交
585
    if _non_static_mode():
586
        return _legacy_C_ops.c_sync_calc_stream(tensor, tensor)
K
kuizhiqing 已提交
587 588 589 590 591 592 593

    op_type = 'c_sync_calc_stream'

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
594 595
        outputs={'Out': [tensor]},
    )
596

597

K
kuizhiqing 已提交
598
def _sync_comm_stream(tensor, ring_id=0):
599

J
Jiabin Yang 已提交
600
    if _non_static_mode():
601 602
        return _legacy_C_ops.c_sync_comm_stream([tensor], [tensor], 'ring_id',
                                                ring_id)
603

K
kuizhiqing 已提交
604
    op_type = 'c_sync_comm_stream'
605

K
kuizhiqing 已提交
606 607 608 609 610
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
611 612
        attrs={'ring_id': ring_id},
    )
K
kuizhiqing 已提交
613 614


615
def broadcast(tensor, src, group=None, sync_op=True):
616 617 618
    """

    Broadcast a tensor from the source to all others.
619 620
    As shown below, one process is started with a GPU and GPU0 owns data 0. Through broadcast operator,
    data 0 will be sent to all GPUs from GPU0.
621 622 623 624 625

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/broadcast.png
        :width: 800
        :alt: broadcast
        :align: center
626 627

    Args:
628 629
        tensor (Tensor): The Tensor to send if current rank is the source, or the Tensor to receive otherwise. Its data type
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
630
        src (int): The source rank.
631 632
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
633 634 635 636 637 638 639

    Returns:
        None.

    Examples:
        .. code-block:: python

640
            # required: distributed
641
            import paddle
642
            import paddle.distributed as dist
643

644 645 646
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
647
            else:
648 649 650 651
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.broadcast(data, src=1)
            print(data)
            # [[1, 2, 3], [1, 2, 3]] (2 GPUs)
652
    """
K
kuizhiqing 已提交
653 654 655 656 657 658 659

    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

L
lilong12 已提交
660
    if in_dygraph_mode():
661 662 663 664
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
        assert gsrc >= 0, ("src rank out of group, need global rank")
        task = group.process_group.broadcast(tensor, gsrc)
665
        if sync_op:
666 667 668 669 670
            task.wait()
            return None
        else:
            return task

671
    use_calc_stream = sync_op
672
    ring_id = ring_id = 0 if group is None else group.id
K
kuizhiqing 已提交
673
    gsrc = src if group is None else group.get_group_rank(src)
K
kuizhiqing 已提交
674
    assert gsrc >= 0, ("src rank out of group, need global rank")
K
kuizhiqing 已提交
675

J
Jiabin Yang 已提交
676
    if _non_static_mode():
677 678 679
        return _legacy_C_ops.c_broadcast(tensor, tensor, 'root', gsrc,
                                         'use_calc_stream', use_calc_stream,
                                         'ring_id', ring_id)
680 681

    op_type = 'c_broadcast'
682 683 684 685
    check_variable_and_dtype(tensor, 'tensor', [
        'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
        'bool'
    ], 'broadcast')
686 687

    helper = LayerHelper(op_type, **locals())
688 689 690 691 692 693 694 695
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'root': gsrc,
                         'use_calc_stream': use_calc_stream,
                         'ring_id': ring_id,
                     })
696 697


698
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, sync_op=True):
699 700
    """

701 702
    Reduce a tensor to the destination from all others. As shown below, one process is started with a GPU and the data of this process is represented
    by its group rank. The destination of the reduce operator is GPU0 and the process is sum. Through reduce operator,
703 704 705 706 707 708
    the GPU0 will owns the sum of all data from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/reduce.png
        :width: 800
        :alt: reduce
        :align: center
709 710 711

    Args:
        tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
712
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
713
        dst (int): The destination rank id.
714 715 716
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The operation used. Default value is ReduceOp.SUM.
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
717 718 719 720 721 722 723

    Returns:
        None.

    Examples:
        .. code-block:: python

724
            # required: distributed
725
            import paddle
726
            import paddle.distributed as dist
727

728 729 730
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
731
            else:
732 733 734 735 736
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.reduce(data, dst=0)
            print(data)
            # [[5, 7, 9], [5, 7, 9]] (2 GPUs, out for rank 0)
            # [[1, 2, 3], [1, 2, 3]] (2 GPUs, out for rank 1)
737
    """
K
kuizhiqing 已提交
738 739 740
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
741
    if in_dygraph_mode():
742
        op_type = _get_reduce_op(op, "reduce")
743 744 745 746
        group = _get_default_group() if group is None else group
        gdst = group.get_group_rank(dst)
        assert gdst >= 0, ("dst rank out of group, need global rank")
        task = group.process_group.reduce(tensor, gdst, op_type)
747
        if sync_op:
748 749 750 751
            task.wait()
            return None
        else:
            return task
K
kuizhiqing 已提交
752

753
    use_calc_stream = sync_op
K
kuizhiqing 已提交
754 755
    ring_id = 0 if group is None else group.id
    gdst = dst if group is None else group.get_group_rank(dst)
K
kuizhiqing 已提交
756
    assert gdst >= 0, ("dst rank out of group, need global rank")
K
kuizhiqing 已提交
757

J
Jiabin Yang 已提交
758
    if _non_static_mode():
759
        if op == ReduceOp.SUM:
760 761 762
            return _legacy_C_ops.c_reduce_sum(tensor, tensor, 'use_calc_stream',
                                              use_calc_stream, 'ring_id',
                                              ring_id, 'root_id', gdst)
763
        elif op == ReduceOp.MAX:
764 765 766
            return _legacy_C_ops.c_reduce_max(tensor, tensor, 'use_calc_stream',
                                              use_calc_stream, 'ring_id',
                                              ring_id, 'root_id', gdst)
767
        elif op == ReduceOp.MIN:
768 769 770
            return _legacy_C_ops.c_reduce_min(tensor, tensor, 'use_calc_stream',
                                              use_calc_stream, 'ring_id',
                                              ring_id, 'root_id', gdst)
771
        elif op == ReduceOp.PROD:
772 773 774 775
            return _legacy_C_ops.c_reduce_prod(tensor, tensor,
                                               'use_calc_stream',
                                               use_calc_stream, 'ring_id',
                                               ring_id, 'root_id', gdst)
776 777 778 779
        else:
            raise ValueError("Unknown parameter: {}.".format(op))

    op_type = 'c_reduce'
780 781 782 783
    check_variable_and_dtype(tensor, 'tensor', [
        'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
        'bool'
    ], 'reduce')
784 785 786 787 788 789 790 791 792 793 794

    if op == ReduceOp.SUM:
        op_type = 'c_reduce_sum'
    elif op == ReduceOp.MAX:
        op_type = 'c_reduce_max'
    elif op == ReduceOp.MIN:
        op_type = 'c_reduce_min'
    elif op == ReduceOp.PROD:
        op_type = 'c_reduce_prod'

    helper = LayerHelper(op_type, **locals())
795 796 797 798 799 800 801 802
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': use_calc_stream,
                         'root_id': gdst,
                     })
803 804


805
def all_gather(tensor_list, tensor, group=None, sync_op=True):
806 807
    """

808
    Gather tensors from all participators and all get the result. As shown
809 810
    below, one process is started with a GPU and the data of this process is represented
    by its group rank. Through the all_gather operator, each GPU will have data
811 812 813 814 815 816
    from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png
        :width: 800
        :alt: all_gather
        :align: center
817 818 819

    Args:
        tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
820
            should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
821
        tensor (Tensor): The Tensor to send. Its data type
822
            should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
823 824
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
825 826 827 828 829 830 831

    Returns:
        None.

    Examples:
        .. code-block:: python

832
            # required: distributed
833
            import paddle
834
            import paddle.distributed as dist
835

836
            dist.init_parallel_env()
837
            tensor_list = []
838 839
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
840
            else:
841 842 843 844
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.all_gather(tensor_list, data)
            print(tensor_list)
            # [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
845
    """
K
kuizhiqing 已提交
846 847 848
    if group is not None and not group.is_member():
        return

849 850 851 852 853 854 855 856 857 858 859
    def convert_to_complex(list_of_tensor):
        list_of_complex = []
        for tensor in list_of_tensor:
            list_of_complex.append(paddle.as_complex(tensor))
        return list_of_complex

    is_input_complex = (tensor.dtype == paddle.complex64
                        or tensor.dtype == paddle.complex128)
    if is_input_complex:
        tensor = paddle.as_real(tensor)

L
lilong12 已提交
860
    if in_dygraph_mode():
861
        group = _get_default_group() if group is None else group
862 863 864 865 866 867
        if len(tensor_list) == 0:
            tensor_shape = list(tensor.shape)
            tensor_shape[0] *= group.nranks
            out = paddle.empty(tensor_shape, tensor.dtype)
        else:
            out = paddle.concat(tensor_list, axis=0)
868 869 870
        task = group.process_group.all_gather(tensor, out)
        task.wait()
        tensor_list.clear()
871 872 873 874 875
        list_of_tensor = paddle.split(out, group.nranks, 0)
        if is_input_complex:
            tensor_list.extend(convert_to_complex(list_of_tensor))
        else:
            tensor_list.extend(list_of_tensor)
876 877
        return

878
    use_calc_stream = sync_op
K
kuizhiqing 已提交
879 880 881
    ring_id = 0 if group is None else group.id
    nranks = _get_global_group().nranks if group is None else group.nranks

J
Jiabin Yang 已提交
882
    if _non_static_mode():
883 884 885
        out = _legacy_C_ops.c_allgather(tensor, 'use_calc_stream',
                                        use_calc_stream, 'ring_id', ring_id,
                                        'nranks', nranks)
886
    else:
887 888 889
        op_type = 'c_allgather'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
890 891 892 893
        if not isinstance(tensor_list, list):
            raise ValueError("The type of 'tensor_list' for all_gather "
                             "should be list.")
        for elem in tensor_list:
894 895 896 897 898 899 900 901
            check_variable_and_dtype(elem, 'tensor_list', [
                'float16', 'float32', 'float64', 'int32', 'int64', 'bool',
                'int8', 'uint8', 'complex64', 'complex128'
            ], 'all_gather')
        check_variable_and_dtype(tensor, 'tensor', [
            'float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'int8',
            'uint8', 'complex64', 'complex128'
        ], 'all_gather')
902 903 904 905 906 907 908 909
        helper.append_op(type=op_type,
                         inputs={'X': [tensor]},
                         outputs={'Out': [out]},
                         attrs={
                             'ring_id': ring_id,
                             'use_calc_stream': use_calc_stream,
                             'nranks': nranks
                         })
910

911 912 913 914 915 916 917 918 919 920 921 922 923
    list_of_tensor = paddle.split(out, nranks, 0)
    if is_input_complex:
        tensor_list.extend(convert_to_complex(list_of_tensor))
    else:
        tensor_list.extend(list_of_tensor)


def _convert_object_to_tensor(obj):
    _pickler = pickle.Pickler
    f = io.BytesIO()
    _pickler(f).dump(obj)
    data = np.frombuffer(f.getvalue(), dtype=np.uint8)
    tensor = paddle.to_tensor(data)
924
    return tensor, tensor.numel()
925 926


927
def _convert_tensor_to_object(tensor, len_of_tensor):
928
    _unpickler = pickle.Unpickler
929
    return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()
930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956


def all_gather_object(object_list, obj, group=None):
    """

    Gather picklable objects from all participators and all get the result. Similiar to all_gather(), but python object can be passed in.

    Args:
        object_list (list): A list of output object. The datatype of every element in the list is same as the input obj.
        obj (Any): The picklable object to send.
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Warning:
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            object_list = []
957
            if dist.get_rank() == 0:
958 959 960
                obj = {"foo": [1, 2, 3]}
            else:
                obj = {"bar": [4, 5, 6]}
961 962 963
            dist.all_gather_object(object_list, obj)
            print(object_list)
            # [{'foo': [1, 2, 3]}, {'bar': [4, 5, 6]}] (2 GPUs)
964 965 966 967
    """
    assert in_dygraph_mode(
    ), "all_gather_object doesn't support static graph mode."

968 969 970 971 972 973 974 975 976 977 978 979 980
    tensor, len_of_tensor = _convert_object_to_tensor(obj)

    # gather len_of_tensor from all ranks
    list_len_of_tensor = []
    all_gather(list_len_of_tensor, len_of_tensor, group)
    # get the max length from list
    max_len_of_tensor = int(max(list_len_of_tensor).item())
    # resize the input tensor to max length avoid hang in all gather
    # Note(liyurui): Maybe we should support various length all_gather?
    # Now this operation is efficient for we don't support resize in python.
    numpy_data = tensor.numpy()
    numpy_data = np.resize(numpy_data, [max_len_of_tensor])
    input_tensor = paddle.to_tensor(numpy_data)
981 982

    tensor_list = []
983 984 985 986
    all_gather(tensor_list, input_tensor, group)
    for i, tensor in enumerate(tensor_list):
        object_list.append(
            _convert_tensor_to_object(tensor, list_len_of_tensor[i]))
987 988


989
def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True):
990 991
    """

992
    Scatter a tensor to all participators. As shown below, one process is started with a GPU and the source of the scatter
993 994 995 996 997 998
    is GPU0. Through scatter operator, the data in GPU0 will be sent to all GPUs averagely.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/scatter.png
        :width: 800
        :alt: scatter
        :align: center
999 1000 1001

    Args:
        tensor (Tensor): The output Tensor. Its data type
1002
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
1003
        tensor_list (list|tuple): A list/tuple of Tensors to scatter. Every element in the list must be a Tensor whose data type
1004
            should be float16, float32, float64, int32, int64, int8, uint8 or bool. Default value is None.
K
kuizhiqing 已提交
1005
        src (int): The source rank id. Default value is 0.
1006 1007
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
1008 1009 1010 1011 1012 1013 1014

    Returns:
        None.

    Examples:
        .. code-block:: python

1015
            # required: distributed
1016
            import paddle
1017
            import paddle.distributed as dist
1018

1019 1020 1021 1022 1023
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([7, 8, 9])
                data2 = paddle.to_tensor([10, 11, 12])
                dist.scatter(data1, src=1)
1024
            else:
1025 1026 1027 1028 1029 1030
                data1 = paddle.to_tensor([1, 2, 3])
                data2 = paddle.to_tensor([4, 5, 6])
                dist.scatter(data1, tensor_list=[data1, data2], src=1)
            print(data1, data2)
            # [1, 2, 3] [10, 11, 12] (2 GPUs, out for rank 0)
            # [4, 5, 6] [4, 5, 6] (2 GPUs, out for rank 1)
1031
    """
K
kuizhiqing 已提交
1032 1033 1034 1035 1036 1037
    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

L
lilong12 已提交
1038
    if in_dygraph_mode():
1039 1040 1041 1042 1043 1044 1045 1046 1047
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
        rank = group.rank
        nranks = group.nranks
    else:
        ring_id = 0 if group is None else group.id
        gsrc = src if group is None else group.get_group_rank(src)
        rank = _get_global_group().rank if group is None else group.rank
        nranks = _get_global_group().nranks if group is None else group.nranks
K
kuizhiqing 已提交
1048
    assert gsrc >= 0, ("src rank out of group, need global rank")
K
kuizhiqing 已提交
1049 1050

    if rank != gsrc:
1051 1052 1053 1054
        tensor_list = []
        for _ in range(nranks):
            tensor_list.append(tensor)
    temp = paddle.concat(tensor_list, axis=0)
L
lilong12 已提交
1055
    if in_dygraph_mode():
1056
        task = group.process_group.scatter(temp, tensor, gsrc)
1057
        if sync_op:
1058 1059 1060 1061 1062
            task.wait()
            return None
        else:
            return task

1063
    use_calc_stream = sync_op
L
lilong12 已提交
1064
    if _non_static_mode():
1065 1066 1067
        return _legacy_C_ops.c_scatter(temp, tensor, 'use_calc_stream',
                                       use_calc_stream, 'ring_id', ring_id,
                                       'nranks', nranks, 'root', gsrc)
W
wanghuancoder 已提交
1068
    op_type = 'c_scatter'
1069 1070 1071 1072
    check_variable_and_dtype(tensor, 'tensor', [
        'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
        'bool'
    ], 'scatter')
1073
    helper = LayerHelper(op_type, **locals())
1074 1075 1076 1077 1078 1079 1080 1081 1082
    helper.append_op(type=op_type,
                     inputs={'X': [temp]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'root': gsrc,
                         'use_calc_stream': use_calc_stream,
                         'nranks': nranks,
                     })
1083 1084


1085
def alltoall(in_tensor_list, out_tensor_list, group=None, sync_op=True):
L
lilong12 已提交
1086
    """
1087 1088 1089 1090 1091 1092 1093 1094 1095 1096
    Scatter tensors in in_tensor_list to all participators averagely and gather the result tensors in out_tensor_list.
    As shown below, the in_tensor_list in GPU0 includes 0_0 and 0_1, and GPU1 includes 1_0 and 1_1.
    Through alltoall operator, the 0_0 in GPU0 will be sent to GPU0 and 0_1 to GPU1, 1_0 in GPU1 sent to GPU0 and 1_1 to GPU1.
    Finally the out_tensor_list in GPU0 includes 0_0 and 1_0, and GPU1 includes 0_1 and 1_1.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/alltoall.png
        :width: 800
        :alt: alltoall
        :align: center

L
lilong12 已提交
1097 1098
    Args:
        in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type
1099
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
1100
        out_tensor_list (list): A list of output Tensors. The data type of its elements should be the same as the
L
lilong12 已提交
1101 1102
            data type of the input Tensors.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
1103
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
1104

L
lilong12 已提交
1105 1106
    Returns:
        None.
1107

L
lilong12 已提交
1108 1109
    Examples:
        .. code-block:: python
1110

L
lilong12 已提交
1111 1112
            # required: distributed
            import paddle
1113 1114 1115
            import paddle.distributed as dist

            dist.init_parallel_env()
L
lilong12 已提交
1116
            out_tensor_list = []
1117 1118 1119
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
                data2 = paddle.to_tensor([[7, 8, 9], [10, 11, 12]])
L
lilong12 已提交
1120
            else:
1121 1122 1123 1124 1125 1126
                data1 = paddle.to_tensor([[13, 14, 15], [16, 17, 18]])
                data2 = paddle.to_tensor([[19, 20, 21], [22, 23, 24]])
            dist.alltoall([data1, data2], out_tensor_list)
            print(out_tensor_list)
            # [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]] (2 GPUs, out for rank 0)
            # [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]] (2 GPUs, out for rank 1)
L
lilong12 已提交
1127 1128 1129 1130
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
1131
    if in_dygraph_mode():
1132
        group = _get_default_group() if group is None else group
1133 1134
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
1135 1136 1137
    else:
        ring_id = 0 if group is None else group.id

L
lilong12 已提交
1138
    temp = paddle.concat(in_tensor_list, axis=0)
李季 已提交
1139
    nranks = len(in_tensor_list)
L
lilong12 已提交
1140
    if in_dygraph_mode():
1141 1142 1143 1144 1145 1146
        if len(out_tensor_list) == 0:
            tensor_shape = list(in_tensor_list[0].shape)
            tensor_shape[0] *= nranks
            out = paddle.empty(tensor_shape, in_tensor_list[0].dtype)
        else:
            out = paddle.concat(out_tensor_list, axis=0)
1147 1148 1149 1150 1151 1152
        task = group.process_group.alltoall(temp, out)
        task.wait()
        out_tensor_list.clear()
        out_tensor_list.extend(paddle.split(out, nranks, 0))
        return

1153
    use_calc_stream = sync_op
J
Jiabin Yang 已提交
1154
    if _non_static_mode():
1155 1156
        out = _legacy_C_ops.alltoall(temp, 'use_calc_stream', use_calc_stream,
                                     'ring_id', ring_id)
L
lilong12 已提交
1157
    else:
W
wanghuancoder 已提交
1158 1159 1160 1161 1162
        op_type = 'alltoall'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(
            dtype=in_tensor_list[0].dtype)

L
lilong12 已提交
1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176
        if not isinstance(in_tensor_list, list):
            raise ValueError("The type of 'in_tensor_list' for all_to_all "
                             "should be list.")
        for elem in in_tensor_list:
            check_variable_and_dtype(
                elem, 'in_tensor_list',
                ['float16', 'float32', 'float64', 'int32', 'int64'],
                'all_to_all')
        if not isinstance(out_tensor_list, list):
            raise ValueError("The type of 'out_tensor_list' for all_to_all "
                             "should be list.")
        if len(out_tensor_list) != 0:
            raise ValueError("The 'out_tensor_list' for all_to_all "
                             "must be an empty list.")
1177 1178 1179 1180 1181 1182 1183
        helper.append_op(type=op_type,
                         inputs={'X': [temp]},
                         outputs={'Out': [out]},
                         attrs={
                             'ring_id': ring_id,
                             'use_calc_stream': use_calc_stream,
                         })
L
lilong12 已提交
1184 1185 1186
    out_tensor_list.extend(paddle.split(out, nranks, 0))


1187 1188 1189 1190 1191
def alltoall_single(in_tensor,
                    out_tensor,
                    in_split_sizes=None,
                    out_split_sizes=None,
                    group=None,
1192
                    sync_op=True):
1193 1194 1195
    """
    Scatter a single input tensor to all participators and gather the received tensors in out_tensor.

1196
    Note:
1197 1198 1199
        ``alltoall_single`` is only supported in eager mode.

    Args:
1200
        in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
1201
        out_tensor (Tensor): Output Tensor. The data type should be the same as the data type of the input Tensor.
1202
        in_split_sizes (list[int], optional): Split sizes of ``in_tensor`` for dim[0]. If not given, dim[0] of ``in_tensor``
1203
            must be divisible by group size and ``in_tensor`` will be scattered averagely to all participators. Default: None.
1204
        out_split_sizes (list[int], optional): Split sizes of ``out_tensor`` for dim[0]. If not given, dim[0] of ``out_tensor``
1205 1206
            must be divisible by group size and ``out_tensor`` will be gathered averagely from all participators. Default: None.
        group (Group, optional): The group instance return by ``new_group`` or None for global default group. Default: None.
1207
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
1208

1209
    Returns:
1210
        None, if ``sync_op`` is set to ``True``; ``Task`` of ``group``, if ``sync_op`` is set to ``False``.
1211

1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222
    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            size = dist.get_world_size()

1223 1224 1225 1226
            # case 1 (2 GPUs)
            data = paddle.arange(2, dtype='int64') + rank * 2
            # data for rank 0: [0, 1]
            # data for rank 1: [2, 3]
1227
            output = paddle.empty([2], dtype='int64')
1228 1229
            dist.alltoall_single(data, output)
            print(output)
1230 1231 1232
            # output for rank 0: [0, 2]
            # output for rank 1: [1, 3]

1233
            # case 2 (2 GPUs)
1234
            in_split_sizes = [i + 1 for i in range(size)]
1235 1236
            # in_split_sizes for rank 0: [1, 2]
            # in_split_sizes for rank 1: [1, 2]
1237
            out_split_sizes = [rank + 1 for i in range(size)]
1238 1239 1240 1241 1242
            # out_split_sizes for rank 0: [1, 1]
            # out_split_sizes for rank 1: [2, 2]
            data = paddle.ones([sum(in_split_sizes), size], dtype='float32') * rank
            # data for rank 0: [[0., 0.], [0., 0.], [0., 0.]]
            # data for rank 1: [[1., 1.], [1., 1.], [1., 1.]]
1243 1244
            output = paddle.empty([(rank + 1) * size, size], dtype='float32')
            group = dist.new_group([0, 1])
1245
            task = dist.alltoall_single(data,
1246 1247 1248
                                        output,
                                        in_split_sizes,
                                        out_split_sizes,
1249
                                        sync_op=False,
1250 1251
                                        group=group)
            task.wait()
1252
            print(output)
1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263
            # output for rank 0: [[0., 0.], [1., 1.]]
            # output for rank 1: [[0., 0.], [0., 0.], [1., 1.], [1., 1.]]

    """
    if group is not None and not group.is_member():
        return

    assert in_dygraph_mode(), "Only suppport alltoall_single in eager mode."
    # _check_single_tensor

    group = _get_default_group() if group is None else group
1264 1265 1266
    backend = _group_map_backend[group]
    assert backend != 'gloo', ("backend gloo is not supported yet")

1267 1268 1269 1270 1271
    in_split_sizes = [] if in_split_sizes is None else in_split_sizes
    out_split_sizes = [] if out_split_sizes is None else out_split_sizes

    task = group.process_group.alltoall_single(in_tensor, out_tensor,
                                               in_split_sizes, out_split_sizes)
1272
    if sync_op:
1273 1274 1275 1276 1277 1278
        task.wait()
        return
    else:
        return task


S
ShenLiang 已提交
1279 1280 1281 1282
def _get_group_rank(global_rank, group=None):
    return global_rank if group is None else group.get_group_rank(global_rank)


1283
def send(tensor, dst=0, group=None, sync_op=True):
L
lilong12 已提交
1284 1285 1286 1287 1288
    """
    Send a tensor to the receiver.

    Args:
        tensor (Tensor): The Tensor to send. Its data type
1289
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
L
lilong12 已提交
1290
        dst (int): The destination rank id.
L
lilong12 已提交
1291
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
1292
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
1293

L
lilong12 已提交
1294 1295 1296 1297 1298
    Returns:
        None.

    Examples:
        .. code-block:: python
1299

L
lilong12 已提交
1300
            # required: distributed
L
lilong12 已提交
1301
            import paddle
1302
            import paddle.distributed as dist
1303

1304 1305
            dist.init_parallel_env()
            if dist.get_rank() == 0:
L
lilong12 已提交
1306
                data = paddle.to_tensor([7, 8, 9])
1307
                dist.send(data, dst=1)
L
lilong12 已提交
1308
            else:
1309 1310 1311 1312
                data = paddle.to_tensor([1, 2, 3])
                dist.recv(data, src=0)
            print(data)
            # [7, 8, 9] (2 GPUs)
L
lilong12 已提交
1313 1314 1315
    """
    if group is not None and not group.is_member():
        return
S
ShenLiang 已提交
1316
    dst = _get_group_rank(dst, group)
L
lilong12 已提交
1317
    if in_dygraph_mode():
1318
        group = _get_default_group() if group is None else group
1319 1320
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
S
ShenLiang 已提交
1321
        task = group.process_group.send(tensor, dst)
1322
        if sync_op:
1323 1324 1325 1326 1327
            task.wait()
            return None
        else:
            return task

1328
    use_calc_stream = sync_op
L
lilong12 已提交
1329 1330
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
1331
    if _non_static_mode():
1332 1333
        return _legacy_C_ops.send_v2(tensor, 'use_calc_stream', use_calc_stream,
                                     'ring_id', ring_id, 'peer', dst)
W
wanghuancoder 已提交
1334
    op_type = 'send_v2'
L
lilong12 已提交
1335 1336 1337 1338 1339
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'send')

    helper = LayerHelper(op_type, **locals())
1340 1341 1342 1343 1344 1345 1346
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'peer': dst,
                         'use_calc_stream': use_calc_stream,
                     })
L
lilong12 已提交
1347 1348


1349
def recv(tensor, src=0, group=None, sync_op=True):
L
lilong12 已提交
1350 1351 1352 1353 1354
    """
    Receive a tensor to the sender.

    Args:
        tensor (Tensor): The Tensor to receive. Its data type
1355
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
L
lilong12 已提交
1356
        src (int): The source rank id.
L
lilong12 已提交
1357
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
1358
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
1359

L
lilong12 已提交
1360 1361 1362 1363 1364
    Returns:
        None.

    Examples:
        .. code-block:: python
1365

L
lilong12 已提交
1366
            # required: distributed
L
lilong12 已提交
1367
            import paddle
1368
            import paddle.distributed as dist
1369

1370 1371
            dist.init_parallel_env()
            if dist.get_rank() == 0:
L
lilong12 已提交
1372
                data = paddle.to_tensor([7, 8, 9])
1373
                dist.send(data, dst=1)
L
lilong12 已提交
1374
            else:
1375 1376 1377 1378
                data = paddle.to_tensor([1, 2, 3])
                dist.recv(data, src=0)
            print(data)
            # [7, 8, 9] (2 GPUs)
L
lilong12 已提交
1379 1380 1381
    """
    if group is not None and not group.is_member():
        return
1382

S
ShenLiang 已提交
1383
    src = _get_group_rank(src, group)
L
lilong12 已提交
1384
    if in_dygraph_mode():
1385
        group = _get_default_group() if group is None else group
1386 1387
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
S
ShenLiang 已提交
1388
        task = group.process_group.recv(tensor, src)
1389
        if sync_op:
1390 1391 1392 1393 1394
            task.wait()
            return None
        else:
            return task

1395
    use_calc_stream = sync_op
L
lilong12 已提交
1396 1397
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
1398
    if _non_static_mode():
1399 1400 1401
        return _legacy_C_ops.recv_v2(tensor, 'use_calc_stream', use_calc_stream,
                                     'ring_id', ring_id, 'peer', src, 'dtype',
                                     tensor.dtype, 'out_shape', tensor.shape)
W
wanghuancoder 已提交
1402
    op_type = 'recv_v2'
L
lilong12 已提交
1403 1404 1405 1406
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'recv')
    helper = LayerHelper(op_type, **locals())
1407 1408 1409 1410 1411 1412 1413 1414 1415
    helper.append_op(type=op_type,
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'peer': src,
                         'out_shape': tensor.shape,
                         'dtype': tensor.dtype,
                         'use_calc_stream': use_calc_stream,
                     })
1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437


def _check_single_tensor(tensor, tensor_name):
    if not isinstance(tensor, (core.eager.Tensor, paddle.Tensor)):
        raise RuntimeError("Invalid function argument. Expected parameter {}"
                           "to be of type paddle.Tensor, but it's {}".format(
                               tensor_name, type(tensor)))


def _check_tensor_list(tensor_list, tensor_name):
    if not isinstance(tensor_list, list) or \
        not all(isinstance(t, (core.eager.Tensor, paddle.Tensor)) for t in tensor_list):
        raise RuntimeError("Invalid function argument. Expected parameter {}"
                           "to be of type paddle.Tensor".format(tensor_name))


def isend(tensor, dst, group=None):
    """
    Sends a tensor asynchronously

    Args:
        tensor (Tensor): The Tensor to send. Its data type
1438
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
1439 1440
        dst (int): The destination rank.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
1441

1442 1443 1444
    Returns:
        A distributed task object.

1445
    Warning:
1446 1447 1448 1449 1450 1451 1452 1453 1454 1455
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
1456
            if dist.get_rank() == 0:
1457
                data = paddle.to_tensor([7, 8, 9])
1458
                task = dist.isend(data, dst=1)
1459 1460
            else:
                data = paddle.to_tensor([1, 2, 3])
1461
                task = dist.irecv(data, src=0)
1462 1463
            task.wait()
            print(data)
1464
            # [7, 8, 9] (2 GPUs)
1465 1466 1467 1468 1469 1470 1471 1472

    """
    _check_single_tensor(tensor, "tensor")
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
1473 1474
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
1475 1476 1477 1478
        group_dst_rank = group.get_group_rank(dst)
        assert group_dst_rank >= 0, ("dst rank out of group, need global rank")
        return group.process_group.send(tensor, group_dst_rank)
    else:
1479
        raise RuntimeError("Only support eager dygraph mode.")
1480 1481 1482 1483 1484 1485 1486 1487


def irecv(tensor, src=None, group=None):
    """
    Receive a tensor to the sender.

    Args:
        tensor (Tensor): The Tensor to receive. Its data type
1488
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
1489 1490 1491 1492
        src (int): The source rank id.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.

    Returns:
1493
        A distributed task object.
1494

1495
    Warning:
1496 1497 1498 1499 1500 1501 1502 1503 1504 1505
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
1506
            if dist.get_rank() == 0:
1507
                data = paddle.to_tensor([7, 8, 9])
1508
                task = dist.isend(data, dst=1)
1509 1510
            else:
                data = paddle.to_tensor([1, 2, 3])
1511
                task = dist.irecv(data, src=0)
1512 1513
            task.wait()
            print(data)
1514
            # [7, 8, 9] (2 GPUs)
1515 1516 1517 1518 1519 1520 1521
    """
    _check_single_tensor(tensor, "tensor")
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
1522 1523
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
1524 1525 1526 1527
        group_src_rank = group.get_group_rank(src)
        assert group_src_rank >= 0, ("src rank out of group, need global rank")
        return group.process_group.recv(tensor, group_src_rank)
    else:
1528
        raise RuntimeError("Only support eager dygraph mode.")
1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543


class P2POp(object):
    """
    A class that makes point-to-point operations for "batch_isend_irecv".

    This class creates the type of P2P operation, communication buffer, peer rank,
    Group. Instances of this class will be passed to
    ``paddle.distributed.batch_isend_irecv`` for point-to-point communication.

    Args:
        op (callable): A function to send data to or receive data from a peer process.
            The type of ``op`` is either ``paddle.distributed.isend`` or ``paddle.distributed.irecv``.
        tensor (Tensor): Tensor to send or receive.
        peer (int): The destination or source rank.
1544
        group (Group, optional): The group instance return by new_group or None for global
1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592
            default group. Default: None.

    """

    def __init__(self, op, tensor, peer, group=None):
        if op not in [isend, irecv]:
            raise RuntimeError("Invalid ``op`` function. Expected ``op`` "
                               "to be of type ``paddle.distributed.isend`` or "
                               "``paddle.distributed.irecv``.")
        _check_single_tensor(tensor, "tensor")

        self.op = op
        self.tensor = tensor
        self.peer = peer
        self.group = _get_default_group() if group is None else group


@contextlib.contextmanager
def _with_batch_p2p_guard(backend):
    if backend == "nccl":
        core.ProcessGroupNCCL.group_start()
    try:
        yield
    finally:
        if backend == "nccl":
            core.ProcessGroupNCCL.group_end()


def _check_p2p_op_list(p2p_op_list):
    """
    Helper to check that the ``p2p_op_list`` is a list of P2POp instances and
    all ops use the same backend.
    """
    if not isinstance(p2p_op_list, list) or not all(
            isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list):
        raise RuntimeError("Invalid ``p2p_op_list``. Each op is expected to "
                           "to be of type ``paddle.distributed.P2POp``.")

    backend = _group_map_backend[p2p_op_list[0].group]
    if not all(backend == _group_map_backend[p2p_op.group]
               for p2p_op in p2p_op_list):
        raise RuntimeError("All groups need to use the same backend.")


def batch_isend_irecv(p2p_op_list):
    """
    Send or Receive a batch of tensors asynchronously and return a list of requests.

1593
    Process each of the point-to-point operations in ``p2p_op_list`` and return the
1594 1595 1596 1597 1598 1599 1600 1601 1602 1603
    corresponding tasks. NCCL are currently supported.

    Args:
        p2p_op_list: A list of point-to-point operations(type of each operator is
            ``paddle.distributed.P2POp``). The order of the isend/irecv in the list
            matters and it needs to match with corresponding isend/irecv on the
            remote end.

    Returns:
        A list of distributed tasks returned by calling the corresponding
1604
        op in the op_list.
1605

1606
    Warning:
1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed

            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            world_size = dist.get_world_size()

            send_t = paddle.arange(2) + rank
            # paddle.tensor([0, 1])  # Rank-0
            # paddle.tensor([1, 2])  # Rank-1

            recv_t = paddle.empty(shape=[2], dtype=send_t.dtype)

            send_op = dist.P2POp(dist.isend, send_t, (rank + 1) % world_size)
            recv_op = dist.P2POp(dist.irecv, recv_t, (rank - 1 + world_size) % world_size)

            tasks = dist.batch_isend_irecv([send_op, recv_op])

            for task in tasks:
                task.wait()
1634

1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665
            print(recv_t)
            # paddle.tensor([1, 2])     # Rank-0
            # paddle.tensor([0, 1])     # Rank-1
    """
    _check_p2p_op_list(p2p_op_list)
    group = p2p_op_list[0].group
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
        backend = _group_map_backend[group]
        tasks = []
        with _with_batch_p2p_guard(backend):
            for p2p_op in p2p_op_list:
                op = p2p_op.op
                tensor = p2p_op.tensor
                peer = p2p_op.peer
                comm_group = p2p_op.group
                task = op(tensor, peer, comm_group)
                if task is not None:
                    tasks.append(task)
        return tasks
    else:
        raise RuntimeError("Don't support static graph mode currently.")


def reduce_scatter(tensor,
                   tensor_list,
                   op=ReduceOp.SUM,
                   group=None,
1666
                   sync_op=True):
1667 1668 1669 1670
    """
    Reduces, then scatters a list of tensors to all processes in a group

    Args:
1671 1672 1673
        tensor (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
        tensor_list (list[Tensor]): List of tensors to reduce and scatter. Every element in the list must be a Tensor whose data type
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
1674
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
1675
        group (Group, optional): The group instance return by new_group or None for global
1676
            default group. Default: None.
1677
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
1678 1679

    Returns:
1680 1681
        Async task handle, if sync_op is set to False.
        None, if sync_op or if not part of the group.
1682 1683

    Warning:
1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694
        This API only supports the dygraph mode.


    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
1695 1696 1697
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([0, 1])
                data2 = paddle.to_tensor([2, 3])
1698
            else:
1699 1700 1701 1702 1703 1704
                data1 = paddle.to_tensor([4, 5])
                data2 = paddle.to_tensor([6, 7])
            dist.reduce_scatter(data1, [data1, data2])
            print(data1)
            # [4, 6] (2 GPUs, out for rank 0)
            # [8, 10] (2 GPUs, out for rank 1)
1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715

    """
    _check_single_tensor(tensor, "tensor")
    _check_tensor_list(tensor_list, "tensor_list")

    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        op_type = _get_reduce_op(op, "reduce_scatter")
        group = _get_default_group() if group is None else group
1716 1717
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
1718 1719 1720

        temp = paddle.concat(tensor_list, axis=0)
        task = group.process_group._reduce_scatter_base(tensor, temp, op_type)
1721
        if sync_op:
1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733
            task.wait()
            return None
        else:
            return task
    else:
        raise RuntimeError("Don't support static graph mode currently.")


def _reduce_scatter_base(output,
                         input,
                         op=ReduceOp.SUM,
                         group=None,
1734
                         sync_op=True):
1735 1736 1737 1738
    """
    Reduces, then scatters a flattened tensor to all processes in a group.

    Args:
1739
        output (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
1740
        input (Tensor): Input tensor that is of size output tensor size times world size. Its data type
1741
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
1742
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
1743 1744
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.
1745 1746
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.

1747
    Returns:
1748 1749
        Async task handle, if sync_op is set to False.
        None, if sync_op or if not part of the group.
1750 1751 1752 1753 1754 1755 1756 1757 1758 1759

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
1760 1761 1762 1763 1764
            data = paddle.arange(4) + rank
            # [0, 1, 2, 3] (2 GPUs, for rank 0)
            # [1, 2, 3, 4] (2 GPUs, for rank 1)
            output = paddle.empty(shape=[2], dtype=data.dtype)
            dist.collective._reduce_scatter_base(output, data)
1765
            print(output)
1766 1767
            # [1, 3] (2 GPUs, out for rank 0)
            # [5, 7] (2 GPUs, out for rank 1)
1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779

    """
    _check_single_tensor(output, "output")
    _check_single_tensor(input, "input")

    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        op_type = _get_reduce_op(op, "_reduce_scatter_base")
        group = _get_default_group() if group is None else group
        task = group.process_group._reduce_scatter_base(output, input, op_type)
1780
        if sync_op:
1781 1782 1783 1784 1785 1786
            task.wait()
            return None
        else:
            return task
    else:
        raise RuntimeError("Don't support static graph mode currently.")