collective.py 63.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import os
17 18
import pickle
import io
19 20
import datetime
import time
21
from ..fluid.layer_helper import LayerHelper
22
from ..fluid.framework import Variable
23
from ..fluid.framework import in_dygraph_mode
24
from ..fluid.framework import OpProtoHolder
J
Jiabin Yang 已提交
25
from ..fluid.framework import _non_static_mode
26
from ..fluid.framework import _in_legacy_dygraph
27
from ..fluid.framework import convert_np_dtype_to_dtype_
J
Jiangxinz 已提交
28
from ..fluid.framework import _varbase_creator
29 30 31 32
from ..fluid.data_feeder import convert_dtype
from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.data_feeder import check_type
from ..fluid.data_feeder import check_dtype
33 34
from ..fluid.layers.tensor import fill_constant
from ..fluid.layers import utils
B
Baibaifan 已提交
35
from ..fluid.dygraph import layers
36 37 38 39
from ..fluid.dygraph.parallel import prepare_context
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
40
from paddle import _C_ops, _legacy_C_ops
J
Jiangxinz 已提交
41
import paddle.fluid.dygraph_utils as dygraph_utils
42
import contextlib
W
wuhuachaocoding 已提交
43 44 45 46 47 48 49 50 51 52 53 54
from .fleet.layers.mpu.mp_ops import split
from .fleet.layers.mpu.mp_ops import _c_identity
from .fleet.layers.mpu.mp_ops import _c_concat
from .fleet.layers.mpu.mp_ops import _c_split
from .fleet.layers.mpu.mp_ops import _mp_allreduce
from .fleet.layers.mpu.mp_ops import _c_lookup_table
from .fleet.layers.mpu.mp_ops import _Linear
from .fleet.layers.mpu.mp_ops import _set_var_distributed
from .fleet.layers.mpu.mp_ops import _c_softmax_with_cross_entropy
from .fleet.layers.mpu.mp_ops import _linear
from .fleet.layers.mpu.mp_ops import _parallel_linear
from .fleet.layers.mpu.mp_ops import _parallel_embedding
55 56 57
from .communication.group import Group, _add_new_group
from .communication.all_reduce import all_reduce
from .communication.reduce import _get_reduce_op, ReduceOp
58

59
__all__ = []
60

K
kuizhiqing 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73
_global_env = None


def _get_global_env():
    global _global_env
    if not _global_env:
        _global_env = paddle.distributed.ParallelEnv()
    return _global_env


# group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group]
_group_map = {}
74
_global_env_gid = 0
K
kuizhiqing 已提交
75

76 77 78 79
# group map by name : the map of all groups from their names
# Dict[name, Group]
_group_map_by_name = {}

80 81 82 83
# backend map by group : the map of all backend from their groups
# Dict[group, backend]
_group_map_backend = {}

84 85 86
# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"

87
_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter', 'xccl']
88 89
_default_store = None  # the default tcp store
_default_backend = None
90 91
_default_timeout = datetime.timedelta(seconds=1800)
_start_ring_id = 0
92

K
kuizhiqing 已提交
93

L
lilong12 已提交
94 95 96 97 98 99 100 101 102 103
def _set_default_backend(backend):
    global _default_backend
    _default_backend = backend


def _set_default_store(store):
    global _default_store
    _default_store = store


K
kuizhiqing 已提交
104 105
def _get_group_map():
    global _group_map
106
    if _global_env_gid not in _group_map:
K
kuizhiqing 已提交
107
        genv = _get_global_env()
108 109
        _group_map[_global_env_gid] = Group(genv.rank, 0,
                                            list(range(genv.world_size)))
K
kuizhiqing 已提交
110 111 112 113
    return _group_map


def _get_global_group():
114
    return _get_group_map()[_global_env_gid]
K
kuizhiqing 已提交
115 116


117 118 119 120 121 122
def _get_group_map_by_name():
    global _group_map_by_name
    return _group_map_by_name


def _get_default_group():
L
lilong12 已提交
123
    global _group_map_by_name
124 125
    assert is_initialized(), ("Call paddle.distributed.init_parallel_env first "
                              "to initialize the distributed environment.")
126 127 128
    return _get_group_map_by_name()[_default_group_name]


L
lilong12 已提交
129 130 131 132 133 134 135 136 137 138 139 140
def _set_group_map(gid, group):
    global _group_map
    assert gid not in _group_map
    _group_map[gid] = group


def _set_group_map_by_name(name, group):
    global _group_map_by_name
    assert name not in _group_map_by_name
    _group_map_by_name[name] = group


141 142 143 144 145 146
def _set_group_map_backend(group, backend):
    global _group_map_backend
    assert group not in _group_map_backend
    _group_map_backend[group] = backend


K
kuizhiqing 已提交
147
def _new_ring_id():
148 149 150 151 152 153 154
    # NOTE(liyurui): For compatible reason, auto parallel and eager mode relay on previous syntax.
    if in_dygraph_mode():
        global _start_ring_id
        _start_ring_id += 1
        return _start_ring_id + max(_get_global_env().nrings, 9)
    else:
        return len(_get_group_map()) + max(_get_global_env().nrings, 9)
K
kuizhiqing 已提交
155 156 157 158 159 160 161 162


def get_group(id=0):
    """

    Get group instance by group id.

    Args:
K
kuizhiqing 已提交
163
        id (int): the group id. Default value is 0.
K
kuizhiqing 已提交
164 165 166 167 168 169 170 171 172 173 174 175 176 177

    Returns:
        Group: the group instance.

    Examples:
        .. code-block:: python

            ...
            gid = paddle.distributed.new_group([2,4,6])
            paddle.distributed.get_group(gid.id)

    """

    gm = _get_group_map()
J
Jiangxinz 已提交
178
    return gm[id] if id in gm else None
K
kuizhiqing 已提交
179 180


181 182 183 184 185 186
def _new_process_group_impl(backend,
                            store,
                            rank,
                            world_size,
                            group_name,
                            pg_options,
L
lilong12 已提交
187 188 189
                            group_id=0,
                            src_rank=None,
                            dst_rank=None):
190
    pg = None
191
    genv = _get_global_env()
L
lilong12 已提交
192 193 194 195
    if backend != 'heter':
        assert src_rank is None and dst_rank is None, (
            "src_rank and dst_rank "
            "can only be set for heter backend.")
L
lilong12 已提交
196
    assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
197
    if backend == "gloo":
198 199
        place = core.CPUPlace()
        pg = core.ProcessGroupGloo(store, rank, world_size, place, group_id)
200
    elif backend == "nccl":
201 202
        place = core.CUDAPlace(genv.device_id)
        pg = core.ProcessGroupNCCL(store, rank, world_size, place, group_id)
203
    elif backend == "hccl":
204 205
        place = core.NPUPlace(genv.device_id)
        pg = core.ProcessGroupHCCL(store, rank, world_size, place, group_id)
206 207 208
    elif backend == "xccl":
        place = core.CustomPlace(genv.device_type, genv.device_id)
        pg = core.ProcessGroupCustom(store, rank, world_size, place, group_id)
209
    elif backend == "heter":
210 211 212 213 214
        place = None
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(genv.device_id)
        elif core.is_compiled_with_npu():
            place = core.NPUPlace(genv.device_id)
215 216 217 218 219 220 221 222 223 224 225 226 227
        cluster_id = int(os.getenv("CLUSTER_ID", "-1"))
        assert cluster_id >= 0, "please set the CLUSTER_ID variable."
        cluster_size = os.getenv("CLUSTER_SIZE", None)
        assert cluster_size, "please set the CLUSTER_SIZE variable."
        cluster_size = cluster_size.split(",")
        cluster_size = [int(s) for s in cluster_size]
        switch_ep = os.getenv("CLUSTER_SWITCH", None)
        assert switch_ep, "please set the CLUSTER_SWITCH variable."
        cluster_size_cumsum = np.cumsum(cluster_size)
        cluster_offset = 0 if cluster_id == 0 else cluster_size_cumsum[
            cluster_id - 1]
        global_rank = cluster_offset + rank
        global_world_size = cluster_size_cumsum[-1]
228
        global_rank, global_world_size = _get_global_config(backend, rank)
229 230 231 232 233 234 235 236 237 238 239 240 241
        pg = core.ProcessGroupHeter(store,
                                    rank=global_rank,
                                    world_size=global_world_size,
                                    place=place,
                                    gid=group_id,
                                    local_rank=rank,
                                    local_size=world_size,
                                    gloo_rank=cluster_id,
                                    gloo_size=len(cluster_size),
                                    with_switch=True,
                                    switch_endpoint=switch_ep,
                                    src_rank=src_rank,
                                    dst_rank=dst_rank)
242 243 244 245

    return pg


S
ShenLiang 已提交
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
def barrier(group=None):
    """

    Barrier among all participators in the group.

    Args:
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            paddle.distributed.barrier()
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
270
    if in_dygraph_mode():
271 272 273 274 275
        group = _get_default_group() if group is None else group
        task = group.process_group.barrier()
        task.wait()
        return

S
ShenLiang 已提交
276 277 278
    ring_id = 0 if group is None else group.id

    temp = fill_constant([1], dtype="int32", value="1")
J
Jiabin Yang 已提交
279
    if _non_static_mode():
280
        return _legacy_C_ops.barrier(temp, temp, 'ring_id', ring_id)
W
wanghuancoder 已提交
281 282 283

    op_type = 'barrier'

S
ShenLiang 已提交
284 285 286
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'group' for barrier must be int.")
    helper = LayerHelper(op_type, **locals())
287 288 289 290
    helper.append_op(type=op_type,
                     inputs={'X': [temp]},
                     outputs={'Out': [temp]},
                     attrs={'ring_id': ring_id})
S
ShenLiang 已提交
291 292


L
lilong12 已提交
293 294 295 296 297 298 299
# _custom_gid provides a way for users to
# set the group id, which is usually useful
# to be compatible with the static mode.
_custom_gid = None


def _set_custom_gid(gid):
300
    global _custom_gid
L
lilong12 已提交
301 302 303
    _custom_gid = gid


304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
def _barrier_by_tcp_store(group_name, store, timeout):
    global_rank = paddle.distributed.get_rank()
    global_world_size = paddle.distributed.get_world_size()

    if global_world_size < 2:
        return

    barrier_prefix = "Barrier/" + group_name + "/"
    is_master = (global_rank == 0)

    def _check_keys_ready(wait_keys):
        start_time = time.time()
        while len(wait_keys) > 0:
            time.sleep(0.1)
            elapse_time = time.time() - start_time
            if datetime.timedelta(seconds=elapse_time) > timeout:
                raise RuntimeError(
                    "Timeout while initializing process group {}."
                    "Keys {} are not ready sinck rank {} is waiting them."
                    "Two reason may cause this error:\n 1. The create process group api should be called by all ranks.\n"
                    " 2. Try to increase the waiting time.\n".format(
                        group_name, wait_keys, global_rank))
            wait_keys = list(
                filter(lambda key: int(store.get(key)) != 1, wait_keys))

    # all the workers set their exiting key and exit
    # the master will wait for all workers' exiting key, ensure to exit in the end
    if is_master:
        wait_keys = [
            barrier_prefix + str(rank) for rank in range(1, global_world_size)
        ]
        _check_keys_ready(wait_keys)
    else:
        store.add(barrier_prefix + str(global_rank), 1)


def new_group(ranks=None, backend=None, timeout=_default_timeout):
K
kuizhiqing 已提交
341 342
    """

K
kuizhiqing 已提交
343
    Creates a new distributed communication group.
K
kuizhiqing 已提交
344 345

    Args:
K
kuizhiqing 已提交
346
        ranks (list): The global ranks of group members.
K
kuizhiqing 已提交
347
        backend (str): The backend used to create group, only nccl is supported now.
348
        timeout (datetime.timedelta, optional): The waiting timeout for store relevant options, default is 30 minutes.
K
kuizhiqing 已提交
349 350

    Returns:
K
kuizhiqing 已提交
351
        Group: The group instance.
K
kuizhiqing 已提交
352 353 354 355 356 357 358

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
359 360
            tindata = paddle.randn(shape=[2, 3])
            gp = paddle.distributed.new_group([2,4,6])
361
            paddle.distributed.all_reduce(tindata, group=gp, sync_op=False)
K
kuizhiqing 已提交
362 363

    """
364
    global _custom_gid
365
    global _group_map
L
lilong12 已提交
366
    if in_dygraph_mode():
367
        global _default_group_name
L
lilong12 已提交
368
        gid = _custom_gid if _custom_gid else _new_ring_id()
369
        group_name = _default_group_name + str(gid)
L
lilong12 已提交
370
        if backend != 'heter' and (ranks is None or len(ranks) > 1):
371 372 373 374 375 376 377 378 379
            global_group = _get_default_group()
            global_rank = global_group.rank
            global_ranks = global_group.ranks
            backend = _default_backend if backend is None else backend
            if ranks is None:
                ranks = global_ranks
            assert len(ranks) <= len(global_ranks), (
                "Size of new group must be less than or "
                "equal to that of the default global group.")
380 381
        size = len(ranks)
        ranks = sorted(ranks)
L
lilong12 已提交
382 383 384 385
        if backend == 'heter' or (size > 1 and global_rank in ranks):
            rank = 0 if backend == 'heter' else ranks.index(global_rank)
            src_rank = ranks[0] if backend == 'heter' else None
            dst_rank = ranks[1] if backend == 'heter' else None
386 387 388 389 390 391 392 393 394
            pg = _new_process_group_impl(backend,
                                         _default_store,
                                         rank,
                                         size,
                                         group_name,
                                         pg_options=None,
                                         group_id=gid,
                                         src_rank=src_rank,
                                         dst_rank=dst_rank)
395 396 397
        else:
            rank = -1
            pg = None
398
        group = Group(rank, gid, ranks, pg=pg, name=group_name)
399 400
        _group_map_by_name[group_name] = group
        _group_map[gid] = group
401
        _group_map_backend[group] = backend
402 403 404
        #TODO: The method below is a new method for group management, will replace the previous
        # three in the future.
        _add_new_group(group)
405

406
        # TODO(shenliang03): This is a temporary solution to solve the problem of
407
        # hang caused by tcp
408
        paddle.distributed.barrier(group=group)
409 410 411 412 413
        # NOTE(liyurui): All processors should hang and wait using tcp store, in case master exit before sub-group is created.
        if backend != 'heter':
            _barrier_by_tcp_store(group_name, _default_store, timeout)
        else:
            print("Warning: store barrier is not supported for heter backend.")
414
        return group
K
kuizhiqing 已提交
415 416 417 418 419 420 421 422 423 424 425

    if not backend:
        backend = 'nccl'
    assert backend == 'nccl', ("backend other than nccl is not supported yet")

    genv = _get_global_env()
    global_rank = genv.rank

    ring_id = _new_ring_id()

    if global_rank not in ranks:
426
        gp = Group(-1, ring_id, ranks)
K
kuizhiqing 已提交
427 428
        _group_map[ring_id] = gp
    else:
429 430 431
        ranks = sorted(ranks)
        group_rank = ranks.index(global_rank)
        group_size = len(ranks)
432
        gp = Group(group_rank, ring_id, ranks)
433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
        _group_map[ring_id] = gp

        if group_size >= 2:
            strategy = core.ParallelStrategy()
            strategy.nranks = group_size
            strategy.local_rank = group_rank
            strategy.trainer_endpoints = [
                genv.trainer_endpoints[i] for i in ranks
            ]
            strategy.current_endpoint = genv.current_endpoint
            strategy.nrings = 1

            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(genv.device_id)
                core.NCCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
449 450 451 452
            elif core.is_compiled_with_npu():
                place = core.NPUPlace(genv.device_id)
                core.HCCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
453 454 455 456
            elif core.is_compiled_with_mlu():
                place = core.MLUPlace(genv.device_id)
                core.CNCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
457 458 459 460
            elif core.is_compiled_with_xpu():
                place = core.XPUPlace(genv.device_id)
                core.BKCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
461 462 463 464 465
            else:
                assert False, ("no cuda device found")
        else:
            return gp

466
    # TODO(shenliang03): This is a temporary solution to solve the problem of
467
    # hang caused by cross-creation of new_group
468
    tmp = paddle.to_tensor(
J
Jiabin Yang 已提交
469
        [1], dtype="int32") if _non_static_mode() else fill_constant(
470
            [0], dtype="int32", value="1")
471
    paddle.distributed.all_reduce(tmp, sync_op=True)
472
    paddle.distributed.wait(tmp)
K
kuizhiqing 已提交
473 474
    return gp

475

476 477 478 479 480
def is_initialized():
    """

    Check whether the distributed environment has been initialized

481 482
    Returns:
        `True` if distributed environment has been initialized, otherwise `False`.
483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle

            print(paddle.distributed.is_initialized())
            # False

            paddle.distributed.init_parallel_env()
            print(paddle.distributed.is_initialized())
            # True

    """
    global _group_map_by_name
    return _default_group_name in _group_map_by_name


def destroy_process_group(group=None):
    """
    Destroy a given group for communication

    Args:
        group (ProcessGroup, optional): The group to be destroyed. All of process groups, including 
                                        the default group, will be destroyed and the distributed 
                                        environment will be deinitialized.
    
    Returns : None

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
518
            import paddle.distributed as dist
519

520 521
            dist.init_parallel_env()
            group = dist.new_group([0, 1])
522

523 524
            dist.destroy_process_group(group)
            print(dist.is_initialized())
525
            # True
526 527
            dist.destroy_process_group()
            print(dist.is_initialized())
528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546
            # False

    """
    global _group_map
    global _group_map_by_name

    pg = _get_default_group() if group is None else group
    assert _group_map.get(pg.id, None) is not None, "Invalid group."

    if group is None:
        _group_map.clear()
        _group_map_by_name.clear()
        _group_map_backend.clear()
    else:
        del _group_map[pg.id]
        del _group_map_by_name[pg.name]
        del _group_map_backend[pg]


K
kuizhiqing 已提交
547 548 549 550 551 552 553 554
def wait(tensor, group=None, use_calc_stream=True):
    """

    wait to sync stream for group.

    Args:
        tensor (Tensor): The Tensor used before sync.
        group (Group): The Group instance to perform sync.
K
kuizhiqing 已提交
555 556
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
K
kuizhiqing 已提交
557 558 559 560 561 562 563 564 565 566

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
567
            tindata = paddle.randn(shape=[2, 3])
568
            paddle.distributed.all_reduce(tindata, sync_op=True)
K
kuizhiqing 已提交
569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585
            paddle.distributed.wait(tindata)

    """

    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id

    if use_calc_stream:
        _sync_calc_stream(tensor)
    else:
        _sync_comm_stream(tensor, ring_id)


def _sync_calc_stream(tensor):

J
Jiabin Yang 已提交
586
    if _non_static_mode():
587
        return _legacy_C_ops.c_sync_calc_stream(tensor, tensor)
K
kuizhiqing 已提交
588 589 590 591 592 593 594

    op_type = 'c_sync_calc_stream'

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
595 596
        outputs={'Out': [tensor]},
    )
597

598

K
kuizhiqing 已提交
599
def _sync_comm_stream(tensor, ring_id=0):
600

J
Jiabin Yang 已提交
601
    if _non_static_mode():
602 603
        return _legacy_C_ops.c_sync_comm_stream([tensor], [tensor], 'ring_id',
                                                ring_id)
604

K
kuizhiqing 已提交
605
    op_type = 'c_sync_comm_stream'
606

K
kuizhiqing 已提交
607 608 609 610 611
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
612 613
        attrs={'ring_id': ring_id},
    )
K
kuizhiqing 已提交
614 615


616
def broadcast(tensor, src, group=None, sync_op=True):
617 618 619
    """

    Broadcast a tensor from the source to all others.
620 621
    As shown below, one process is started with a GPU and GPU0 owns data 0. Through broadcast operator,
    data 0 will be sent to all GPUs from GPU0.
622 623 624 625 626

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/broadcast.png
        :width: 800
        :alt: broadcast
        :align: center
627 628

    Args:
629
        tensor (Tensor): The Tensor to send if current rank is the source, or the Tensor to receive otherwise. Its data type
630
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
631
        src (int): The source rank.
632 633
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
634 635 636 637 638 639 640

    Returns:
        None.

    Examples:
        .. code-block:: python

641
            # required: distributed
642
            import paddle
643
            import paddle.distributed as dist
644

645 646 647
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
648
            else:
649 650 651 652
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.broadcast(data, src=1)
            print(data)
            # [[1, 2, 3], [1, 2, 3]] (2 GPUs)
653
    """
K
kuizhiqing 已提交
654 655 656 657 658 659 660

    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

L
lilong12 已提交
661
    if in_dygraph_mode():
662 663 664 665
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
        assert gsrc >= 0, ("src rank out of group, need global rank")
        task = group.process_group.broadcast(tensor, gsrc)
666
        if sync_op:
667 668 669 670 671
            task.wait()
            return None
        else:
            return task

672
    use_calc_stream = sync_op
673
    ring_id = ring_id = 0 if group is None else group.id
K
kuizhiqing 已提交
674
    gsrc = src if group is None else group.get_group_rank(src)
K
kuizhiqing 已提交
675
    assert gsrc >= 0, ("src rank out of group, need global rank")
K
kuizhiqing 已提交
676

J
Jiabin Yang 已提交
677
    if _non_static_mode():
678 679 680
        return _legacy_C_ops.c_broadcast(tensor, tensor, 'root', gsrc,
                                         'use_calc_stream', use_calc_stream,
                                         'ring_id', ring_id)
681 682

    op_type = 'c_broadcast'
683 684 685 686
    check_variable_and_dtype(tensor, 'tensor', [
        'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
        'bool'
    ], 'broadcast')
687 688

    helper = LayerHelper(op_type, **locals())
689 690 691 692 693 694 695 696
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'root': gsrc,
                         'use_calc_stream': use_calc_stream,
                         'ring_id': ring_id,
                     })
697 698


699
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, sync_op=True):
700 701
    """

702 703
    Reduce a tensor to the destination from all others. As shown below, one process is started with a GPU and the data of this process is represented
    by its group rank. The destination of the reduce operator is GPU0 and the process is sum. Through reduce operator,
704 705 706 707 708 709
    the GPU0 will owns the sum of all data from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/reduce.png
        :width: 800
        :alt: reduce
        :align: center
710 711 712

    Args:
        tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
713
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
714
        dst (int): The destination rank id.
715 716 717
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The operation used. Default value is ReduceOp.SUM.
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
718 719 720 721 722 723 724

    Returns:
        None.

    Examples:
        .. code-block:: python

725
            # required: distributed
726
            import paddle
727
            import paddle.distributed as dist
728

729 730 731
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
732
            else:
733 734 735 736 737
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.reduce(data, dst=0)
            print(data)
            # [[5, 7, 9], [5, 7, 9]] (2 GPUs, out for rank 0)
            # [[1, 2, 3], [1, 2, 3]] (2 GPUs, out for rank 1)
738
    """
K
kuizhiqing 已提交
739 740 741
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
742
    if in_dygraph_mode():
743
        op_type = _get_reduce_op(op, "reduce")
744 745 746 747
        group = _get_default_group() if group is None else group
        gdst = group.get_group_rank(dst)
        assert gdst >= 0, ("dst rank out of group, need global rank")
        task = group.process_group.reduce(tensor, gdst, op_type)
748
        if sync_op:
749 750 751 752
            task.wait()
            return None
        else:
            return task
K
kuizhiqing 已提交
753

754
    use_calc_stream = sync_op
K
kuizhiqing 已提交
755 756
    ring_id = 0 if group is None else group.id
    gdst = dst if group is None else group.get_group_rank(dst)
K
kuizhiqing 已提交
757
    assert gdst >= 0, ("dst rank out of group, need global rank")
K
kuizhiqing 已提交
758

J
Jiabin Yang 已提交
759
    if _non_static_mode():
760
        if op == ReduceOp.SUM:
761 762 763
            return _legacy_C_ops.c_reduce_sum(tensor, tensor, 'use_calc_stream',
                                              use_calc_stream, 'ring_id',
                                              ring_id, 'root_id', gdst)
764
        elif op == ReduceOp.MAX:
765 766 767
            return _legacy_C_ops.c_reduce_max(tensor, tensor, 'use_calc_stream',
                                              use_calc_stream, 'ring_id',
                                              ring_id, 'root_id', gdst)
768
        elif op == ReduceOp.MIN:
769 770 771
            return _legacy_C_ops.c_reduce_min(tensor, tensor, 'use_calc_stream',
                                              use_calc_stream, 'ring_id',
                                              ring_id, 'root_id', gdst)
772
        elif op == ReduceOp.PROD:
773 774 775 776
            return _legacy_C_ops.c_reduce_prod(tensor, tensor,
                                               'use_calc_stream',
                                               use_calc_stream, 'ring_id',
                                               ring_id, 'root_id', gdst)
777 778 779 780
        else:
            raise ValueError("Unknown parameter: {}.".format(op))

    op_type = 'c_reduce'
781 782 783 784
    check_variable_and_dtype(tensor, 'tensor', [
        'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
        'bool'
    ], 'reduce')
785 786 787 788 789 790 791 792 793 794 795

    if op == ReduceOp.SUM:
        op_type = 'c_reduce_sum'
    elif op == ReduceOp.MAX:
        op_type = 'c_reduce_max'
    elif op == ReduceOp.MIN:
        op_type = 'c_reduce_min'
    elif op == ReduceOp.PROD:
        op_type = 'c_reduce_prod'

    helper = LayerHelper(op_type, **locals())
796 797 798 799 800 801 802 803
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': use_calc_stream,
                         'root_id': gdst,
                     })
804 805


806
def all_gather(tensor_list, tensor, group=None, sync_op=True):
807 808
    """

809
    Gather tensors from all participators and all get the result. As shown
810 811
    below, one process is started with a GPU and the data of this process is represented
    by its group rank. Through the all_gather operator, each GPU will have data
812 813 814 815 816 817
    from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png
        :width: 800
        :alt: all_gather
        :align: center
818 819 820

    Args:
        tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
821
            should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128.
822
        tensor (Tensor): The Tensor to send. Its data type
823
            should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
824 825
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
826 827 828 829 830 831 832

    Returns:
        None.

    Examples:
        .. code-block:: python

833
            # required: distributed
834
            import paddle
835
            import paddle.distributed as dist
836

837
            dist.init_parallel_env()
838
            tensor_list = []
839 840
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
841
            else:
842 843 844 845
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.all_gather(tensor_list, data)
            print(tensor_list)
            # [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
846
    """
K
kuizhiqing 已提交
847 848 849
    if group is not None and not group.is_member():
        return

850 851 852 853 854 855 856 857 858 859 860
    def convert_to_complex(list_of_tensor):
        list_of_complex = []
        for tensor in list_of_tensor:
            list_of_complex.append(paddle.as_complex(tensor))
        return list_of_complex

    is_input_complex = (tensor.dtype == paddle.complex64
                        or tensor.dtype == paddle.complex128)
    if is_input_complex:
        tensor = paddle.as_real(tensor)

L
lilong12 已提交
861
    if in_dygraph_mode():
862
        group = _get_default_group() if group is None else group
863 864 865 866 867 868
        if len(tensor_list) == 0:
            tensor_shape = list(tensor.shape)
            tensor_shape[0] *= group.nranks
            out = paddle.empty(tensor_shape, tensor.dtype)
        else:
            out = paddle.concat(tensor_list, axis=0)
869 870 871
        task = group.process_group.all_gather(tensor, out)
        task.wait()
        tensor_list.clear()
872 873 874 875 876
        list_of_tensor = paddle.split(out, group.nranks, 0)
        if is_input_complex:
            tensor_list.extend(convert_to_complex(list_of_tensor))
        else:
            tensor_list.extend(list_of_tensor)
877 878
        return

879
    use_calc_stream = sync_op
K
kuizhiqing 已提交
880 881 882
    ring_id = 0 if group is None else group.id
    nranks = _get_global_group().nranks if group is None else group.nranks

J
Jiabin Yang 已提交
883
    if _non_static_mode():
884 885 886
        out = _legacy_C_ops.c_allgather(tensor, 'use_calc_stream',
                                        use_calc_stream, 'ring_id', ring_id,
                                        'nranks', nranks)
887
    else:
888 889 890
        op_type = 'c_allgather'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
891 892 893 894
        if not isinstance(tensor_list, list):
            raise ValueError("The type of 'tensor_list' for all_gather "
                             "should be list.")
        for elem in tensor_list:
895 896 897 898 899 900 901 902
            check_variable_and_dtype(elem, 'tensor_list', [
                'float16', 'float32', 'float64', 'int32', 'int64', 'bool',
                'int8', 'uint8', 'complex64', 'complex128'
            ], 'all_gather')
        check_variable_and_dtype(tensor, 'tensor', [
            'float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'int8',
            'uint8', 'complex64', 'complex128'
        ], 'all_gather')
903 904 905 906 907 908 909 910
        helper.append_op(type=op_type,
                         inputs={'X': [tensor]},
                         outputs={'Out': [out]},
                         attrs={
                             'ring_id': ring_id,
                             'use_calc_stream': use_calc_stream,
                             'nranks': nranks
                         })
911

912 913 914 915 916 917 918 919 920 921 922 923 924
    list_of_tensor = paddle.split(out, nranks, 0)
    if is_input_complex:
        tensor_list.extend(convert_to_complex(list_of_tensor))
    else:
        tensor_list.extend(list_of_tensor)


def _convert_object_to_tensor(obj):
    _pickler = pickle.Pickler
    f = io.BytesIO()
    _pickler(f).dump(obj)
    data = np.frombuffer(f.getvalue(), dtype=np.uint8)
    tensor = paddle.to_tensor(data)
925
    return tensor, tensor.numel()
926 927


928
def _convert_tensor_to_object(tensor, len_of_tensor):
929
    _unpickler = pickle.Unpickler
930
    return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()
931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957


def all_gather_object(object_list, obj, group=None):
    """

    Gather picklable objects from all participators and all get the result. Similiar to all_gather(), but python object can be passed in.

    Args:
        object_list (list): A list of output object. The datatype of every element in the list is same as the input obj.
        obj (Any): The picklable object to send.
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Warning:
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            object_list = []
958
            if dist.get_rank() == 0:
959 960 961
                obj = {"foo": [1, 2, 3]}
            else:
                obj = {"bar": [4, 5, 6]}
962 963 964
            dist.all_gather_object(object_list, obj)
            print(object_list)
            # [{'foo': [1, 2, 3]}, {'bar': [4, 5, 6]}] (2 GPUs)
965 966 967 968
    """
    assert in_dygraph_mode(
    ), "all_gather_object doesn't support static graph mode."

969 970 971 972 973 974 975 976 977 978 979 980 981
    tensor, len_of_tensor = _convert_object_to_tensor(obj)

    # gather len_of_tensor from all ranks
    list_len_of_tensor = []
    all_gather(list_len_of_tensor, len_of_tensor, group)
    # get the max length from list
    max_len_of_tensor = int(max(list_len_of_tensor).item())
    # resize the input tensor to max length avoid hang in all gather
    # Note(liyurui): Maybe we should support various length all_gather?
    # Now this operation is efficient for we don't support resize in python.
    numpy_data = tensor.numpy()
    numpy_data = np.resize(numpy_data, [max_len_of_tensor])
    input_tensor = paddle.to_tensor(numpy_data)
982 983

    tensor_list = []
984 985 986 987
    all_gather(tensor_list, input_tensor, group)
    for i, tensor in enumerate(tensor_list):
        object_list.append(
            _convert_tensor_to_object(tensor, list_len_of_tensor[i]))
988 989


990
def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True):
991 992
    """

993
    Scatter a tensor to all participators. As shown below, one process is started with a GPU and the source of the scatter
994 995 996 997 998 999
    is GPU0. Through scatter operator, the data in GPU0 will be sent to all GPUs averagely.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/scatter.png
        :width: 800
        :alt: scatter
        :align: center
1000 1001 1002

    Args:
        tensor (Tensor): The output Tensor. Its data type
1003
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1004
        tensor_list (list|tuple): A list/tuple of Tensors to scatter. Every element in the list must be a Tensor whose data type
1005
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. Default value is None.
K
kuizhiqing 已提交
1006
        src (int): The source rank id. Default value is 0.
1007 1008
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
1009 1010 1011 1012 1013 1014 1015

    Returns:
        None.

    Examples:
        .. code-block:: python

1016
            # required: distributed
1017
            import paddle
1018
            import paddle.distributed as dist
1019

1020 1021 1022 1023 1024
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([7, 8, 9])
                data2 = paddle.to_tensor([10, 11, 12])
                dist.scatter(data1, src=1)
1025
            else:
1026 1027 1028 1029 1030 1031
                data1 = paddle.to_tensor([1, 2, 3])
                data2 = paddle.to_tensor([4, 5, 6])
                dist.scatter(data1, tensor_list=[data1, data2], src=1)
            print(data1, data2)
            # [1, 2, 3] [10, 11, 12] (2 GPUs, out for rank 0)
            # [4, 5, 6] [4, 5, 6] (2 GPUs, out for rank 1)
1032
    """
K
kuizhiqing 已提交
1033 1034 1035 1036 1037 1038
    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

L
lilong12 已提交
1039
    if in_dygraph_mode():
1040 1041 1042 1043 1044 1045 1046 1047 1048
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
        rank = group.rank
        nranks = group.nranks
    else:
        ring_id = 0 if group is None else group.id
        gsrc = src if group is None else group.get_group_rank(src)
        rank = _get_global_group().rank if group is None else group.rank
        nranks = _get_global_group().nranks if group is None else group.nranks
K
kuizhiqing 已提交
1049
    assert gsrc >= 0, ("src rank out of group, need global rank")
K
kuizhiqing 已提交
1050 1051

    if rank != gsrc:
1052 1053 1054 1055
        tensor_list = []
        for _ in range(nranks):
            tensor_list.append(tensor)
    temp = paddle.concat(tensor_list, axis=0)
L
lilong12 已提交
1056
    if in_dygraph_mode():
1057
        task = group.process_group.scatter(temp, tensor, gsrc)
1058
        if sync_op:
1059 1060 1061 1062 1063
            task.wait()
            return None
        else:
            return task

1064
    use_calc_stream = sync_op
L
lilong12 已提交
1065
    if _non_static_mode():
1066 1067 1068
        return _legacy_C_ops.c_scatter(temp, tensor, 'use_calc_stream',
                                       use_calc_stream, 'ring_id', ring_id,
                                       'nranks', nranks, 'root', gsrc)
W
wanghuancoder 已提交
1069
    op_type = 'c_scatter'
1070 1071 1072 1073
    check_variable_and_dtype(tensor, 'tensor', [
        'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
        'bool'
    ], 'scatter')
1074
    helper = LayerHelper(op_type, **locals())
1075 1076 1077 1078 1079 1080 1081 1082 1083
    helper.append_op(type=op_type,
                     inputs={'X': [temp]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'root': gsrc,
                         'use_calc_stream': use_calc_stream,
                         'nranks': nranks,
                     })
1084 1085


1086
def alltoall(in_tensor_list, out_tensor_list, group=None, sync_op=True):
L
lilong12 已提交
1087
    """
1088 1089 1090 1091 1092 1093 1094 1095 1096 1097
    Scatter tensors in in_tensor_list to all participators averagely and gather the result tensors in out_tensor_list.
    As shown below, the in_tensor_list in GPU0 includes 0_0 and 0_1, and GPU1 includes 1_0 and 1_1.
    Through alltoall operator, the 0_0 in GPU0 will be sent to GPU0 and 0_1 to GPU1, 1_0 in GPU1 sent to GPU0 and 1_1 to GPU1.
    Finally the out_tensor_list in GPU0 includes 0_0 and 1_0, and GPU1 includes 0_1 and 1_1.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/alltoall.png
        :width: 800
        :alt: alltoall
        :align: center

L
lilong12 已提交
1098 1099
    Args:
        in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type
1100
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1101
        out_tensor_list (list): A list of output Tensors. The data type of its elements should be the same as the
L
lilong12 已提交
1102 1103
            data type of the input Tensors.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
1104 1105
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.

L
lilong12 已提交
1106 1107
    Returns:
        None.
1108
    
L
lilong12 已提交
1109 1110
    Examples:
        .. code-block:: python
1111

L
lilong12 已提交
1112 1113
            # required: distributed
            import paddle
1114 1115 1116
            import paddle.distributed as dist

            dist.init_parallel_env()
L
lilong12 已提交
1117
            out_tensor_list = []
1118 1119 1120
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
                data2 = paddle.to_tensor([[7, 8, 9], [10, 11, 12]])
L
lilong12 已提交
1121
            else:
1122 1123 1124 1125 1126 1127
                data1 = paddle.to_tensor([[13, 14, 15], [16, 17, 18]])
                data2 = paddle.to_tensor([[19, 20, 21], [22, 23, 24]])
            dist.alltoall([data1, data2], out_tensor_list)
            print(out_tensor_list)
            # [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]] (2 GPUs, out for rank 0)
            # [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]] (2 GPUs, out for rank 1)
L
lilong12 已提交
1128 1129 1130 1131
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
1132
    if in_dygraph_mode():
1133
        group = _get_default_group() if group is None else group
1134 1135
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
1136 1137 1138
    else:
        ring_id = 0 if group is None else group.id

L
lilong12 已提交
1139
    temp = paddle.concat(in_tensor_list, axis=0)
李季 已提交
1140
    nranks = len(in_tensor_list)
L
lilong12 已提交
1141
    if in_dygraph_mode():
1142 1143 1144 1145 1146 1147
        if len(out_tensor_list) == 0:
            tensor_shape = list(in_tensor_list[0].shape)
            tensor_shape[0] *= nranks
            out = paddle.empty(tensor_shape, in_tensor_list[0].dtype)
        else:
            out = paddle.concat(out_tensor_list, axis=0)
1148 1149 1150 1151 1152 1153
        task = group.process_group.alltoall(temp, out)
        task.wait()
        out_tensor_list.clear()
        out_tensor_list.extend(paddle.split(out, nranks, 0))
        return

1154
    use_calc_stream = sync_op
J
Jiabin Yang 已提交
1155
    if _non_static_mode():
1156 1157
        out = _legacy_C_ops.alltoall(temp, 'use_calc_stream', use_calc_stream,
                                     'ring_id', ring_id)
L
lilong12 已提交
1158
    else:
W
wanghuancoder 已提交
1159 1160 1161 1162 1163
        op_type = 'alltoall'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(
            dtype=in_tensor_list[0].dtype)

L
lilong12 已提交
1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177
        if not isinstance(in_tensor_list, list):
            raise ValueError("The type of 'in_tensor_list' for all_to_all "
                             "should be list.")
        for elem in in_tensor_list:
            check_variable_and_dtype(
                elem, 'in_tensor_list',
                ['float16', 'float32', 'float64', 'int32', 'int64'],
                'all_to_all')
        if not isinstance(out_tensor_list, list):
            raise ValueError("The type of 'out_tensor_list' for all_to_all "
                             "should be list.")
        if len(out_tensor_list) != 0:
            raise ValueError("The 'out_tensor_list' for all_to_all "
                             "must be an empty list.")
1178 1179 1180 1181 1182 1183 1184
        helper.append_op(type=op_type,
                         inputs={'X': [temp]},
                         outputs={'Out': [out]},
                         attrs={
                             'ring_id': ring_id,
                             'use_calc_stream': use_calc_stream,
                         })
L
lilong12 已提交
1185 1186 1187
    out_tensor_list.extend(paddle.split(out, nranks, 0))


1188 1189 1190 1191 1192
def alltoall_single(in_tensor,
                    out_tensor,
                    in_split_sizes=None,
                    out_split_sizes=None,
                    group=None,
1193
                    sync_op=True):
1194 1195 1196 1197 1198 1199 1200
    """
    Scatter a single input tensor to all participators and gather the received tensors in out_tensor.

    .. note::
        ``alltoall_single`` is only supported in eager mode.

    Args:
1201
        in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1202 1203 1204 1205 1206 1207
        out_tensor (Tensor): Output Tensor. The data type should be the same as the data type of the input Tensor.
        in_split_sizes (list[int], optional): Split sizes of ``in_tensor`` for dim[0]. If not given, dim[0] of ``in_tensor`` 
            must be divisible by group size and ``in_tensor`` will be scattered averagely to all participators. Default: None.
        out_split_sizes (list[int], optional): Split sizes of ``out_tensor`` for dim[0]. If not given, dim[0] of ``out_tensor`` 
            must be divisible by group size and ``out_tensor`` will be gathered averagely from all participators. Default: None.
        group (Group, optional): The group instance return by ``new_group`` or None for global default group. Default: None.
1208 1209
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.

1210
    Returns:
1211 1212
        None, if ``sync_op`` is set to ``True``; ``Task`` of ``group``, if ``sync_op`` is set to ``False``.

1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223
    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            size = dist.get_world_size()

1224 1225 1226 1227
            # case 1 (2 GPUs)
            data = paddle.arange(2, dtype='int64') + rank * 2
            # data for rank 0: [0, 1]
            # data for rank 1: [2, 3]
1228
            output = paddle.empty([2], dtype='int64')
1229 1230
            dist.alltoall_single(data, output)
            print(output)
1231 1232 1233
            # output for rank 0: [0, 2]
            # output for rank 1: [1, 3]

1234
            # case 2 (2 GPUs)
1235
            in_split_sizes = [i + 1 for i in range(size)]
1236 1237
            # in_split_sizes for rank 0: [1, 2]
            # in_split_sizes for rank 1: [1, 2]
1238
            out_split_sizes = [rank + 1 for i in range(size)]
1239 1240 1241 1242 1243
            # out_split_sizes for rank 0: [1, 1]
            # out_split_sizes for rank 1: [2, 2]
            data = paddle.ones([sum(in_split_sizes), size], dtype='float32') * rank
            # data for rank 0: [[0., 0.], [0., 0.], [0., 0.]]
            # data for rank 1: [[1., 1.], [1., 1.], [1., 1.]]
1244 1245
            output = paddle.empty([(rank + 1) * size, size], dtype='float32')
            group = dist.new_group([0, 1])
1246
            task = dist.alltoall_single(data,
1247 1248 1249
                                        output,
                                        in_split_sizes,
                                        out_split_sizes,
1250
                                        sync_op=False,
1251 1252
                                        group=group)
            task.wait()
1253
            print(output)
1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264
            # output for rank 0: [[0., 0.], [1., 1.]]
            # output for rank 1: [[0., 0.], [0., 0.], [1., 1.], [1., 1.]]

    """
    if group is not None and not group.is_member():
        return

    assert in_dygraph_mode(), "Only suppport alltoall_single in eager mode."
    # _check_single_tensor

    group = _get_default_group() if group is None else group
1265 1266 1267
    backend = _group_map_backend[group]
    assert backend != 'gloo', ("backend gloo is not supported yet")

1268 1269 1270 1271 1272
    in_split_sizes = [] if in_split_sizes is None else in_split_sizes
    out_split_sizes = [] if out_split_sizes is None else out_split_sizes

    task = group.process_group.alltoall_single(in_tensor, out_tensor,
                                               in_split_sizes, out_split_sizes)
1273
    if sync_op:
1274 1275 1276 1277 1278 1279
        task.wait()
        return
    else:
        return task


S
ShenLiang 已提交
1280 1281 1282 1283
def _get_group_rank(global_rank, group=None):
    return global_rank if group is None else group.get_group_rank(global_rank)


1284
def send(tensor, dst=0, group=None, sync_op=True):
L
lilong12 已提交
1285 1286 1287 1288 1289
    """
    Send a tensor to the receiver.

    Args:
        tensor (Tensor): The Tensor to send. Its data type
1290
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
L
lilong12 已提交
1291
        dst (int): The destination rank id.
L
lilong12 已提交
1292
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
1293 1294
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.

L
lilong12 已提交
1295 1296 1297 1298 1299
    Returns:
        None.

    Examples:
        .. code-block:: python
1300

L
lilong12 已提交
1301
            # required: distributed
L
lilong12 已提交
1302
            import paddle
1303
            import paddle.distributed as dist
1304

1305 1306
            dist.init_parallel_env()
            if dist.get_rank() == 0:
L
lilong12 已提交
1307
                data = paddle.to_tensor([7, 8, 9])
1308
                dist.send(data, dst=1)
L
lilong12 已提交
1309
            else:
1310 1311 1312 1313
                data = paddle.to_tensor([1, 2, 3])
                dist.recv(data, src=0)
            print(data)
            # [7, 8, 9] (2 GPUs)
L
lilong12 已提交
1314 1315 1316
    """
    if group is not None and not group.is_member():
        return
S
ShenLiang 已提交
1317
    dst = _get_group_rank(dst, group)
L
lilong12 已提交
1318
    if in_dygraph_mode():
1319
        group = _get_default_group() if group is None else group
1320 1321
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
S
ShenLiang 已提交
1322
        task = group.process_group.send(tensor, dst)
1323
        if sync_op:
1324 1325 1326 1327 1328
            task.wait()
            return None
        else:
            return task

1329
    use_calc_stream = sync_op
L
lilong12 已提交
1330 1331
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
1332
    if _non_static_mode():
1333 1334
        return _legacy_C_ops.send_v2(tensor, 'use_calc_stream', use_calc_stream,
                                     'ring_id', ring_id, 'peer', dst)
W
wanghuancoder 已提交
1335
    op_type = 'send_v2'
L
lilong12 已提交
1336 1337 1338 1339 1340
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'send')

    helper = LayerHelper(op_type, **locals())
1341 1342 1343 1344 1345 1346 1347
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'peer': dst,
                         'use_calc_stream': use_calc_stream,
                     })
L
lilong12 已提交
1348 1349


1350
def recv(tensor, src=0, group=None, sync_op=True):
L
lilong12 已提交
1351 1352 1353 1354 1355
    """
    Receive a tensor to the sender.

    Args:
        tensor (Tensor): The Tensor to receive. Its data type
1356
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
L
lilong12 已提交
1357
        src (int): The source rank id.
L
lilong12 已提交
1358
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
1359 1360
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.

L
lilong12 已提交
1361 1362 1363 1364 1365
    Returns:
        None.

    Examples:
        .. code-block:: python
1366

L
lilong12 已提交
1367
            # required: distributed
L
lilong12 已提交
1368
            import paddle
1369
            import paddle.distributed as dist
1370

1371 1372
            dist.init_parallel_env()
            if dist.get_rank() == 0:
L
lilong12 已提交
1373
                data = paddle.to_tensor([7, 8, 9])
1374
                dist.send(data, dst=1)
L
lilong12 已提交
1375
            else:
1376 1377 1378 1379
                data = paddle.to_tensor([1, 2, 3])
                dist.recv(data, src=0)
            print(data)
            # [7, 8, 9] (2 GPUs)
L
lilong12 已提交
1380 1381 1382
    """
    if group is not None and not group.is_member():
        return
1383

S
ShenLiang 已提交
1384
    src = _get_group_rank(src, group)
L
lilong12 已提交
1385
    if in_dygraph_mode():
1386
        group = _get_default_group() if group is None else group
1387 1388
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
S
ShenLiang 已提交
1389
        task = group.process_group.recv(tensor, src)
1390
        if sync_op:
1391 1392 1393 1394 1395
            task.wait()
            return None
        else:
            return task

1396
    use_calc_stream = sync_op
L
lilong12 已提交
1397 1398
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
1399
    if _non_static_mode():
1400 1401 1402
        return _legacy_C_ops.recv_v2(tensor, 'use_calc_stream', use_calc_stream,
                                     'ring_id', ring_id, 'peer', src, 'dtype',
                                     tensor.dtype, 'out_shape', tensor.shape)
W
wanghuancoder 已提交
1403
    op_type = 'recv_v2'
L
lilong12 已提交
1404 1405 1406 1407
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'recv')
    helper = LayerHelper(op_type, **locals())
1408 1409 1410 1411 1412 1413 1414 1415 1416
    helper.append_op(type=op_type,
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'peer': src,
                         'out_shape': tensor.shape,
                         'dtype': tensor.dtype,
                         'use_calc_stream': use_calc_stream,
                     })
1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438


def _check_single_tensor(tensor, tensor_name):
    if not isinstance(tensor, (core.eager.Tensor, paddle.Tensor)):
        raise RuntimeError("Invalid function argument. Expected parameter {}"
                           "to be of type paddle.Tensor, but it's {}".format(
                               tensor_name, type(tensor)))


def _check_tensor_list(tensor_list, tensor_name):
    if not isinstance(tensor_list, list) or \
        not all(isinstance(t, (core.eager.Tensor, paddle.Tensor)) for t in tensor_list):
        raise RuntimeError("Invalid function argument. Expected parameter {}"
                           "to be of type paddle.Tensor".format(tensor_name))


def isend(tensor, dst, group=None):
    """
    Sends a tensor asynchronously

    Args:
        tensor (Tensor): The Tensor to send. Its data type
1439
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456
        dst (int): The destination rank.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
    
    Returns:
        A distributed task object.

    Warning:    
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
1457
            if dist.get_rank() == 0:
1458
                data = paddle.to_tensor([7, 8, 9])
1459
                task = dist.isend(data, dst=1)
1460 1461
            else:
                data = paddle.to_tensor([1, 2, 3])
1462
                task = dist.irecv(data, src=0)
1463 1464
            task.wait()
            print(data)
1465
            # [7, 8, 9] (2 GPUs)
1466 1467 1468 1469 1470 1471 1472 1473

    """
    _check_single_tensor(tensor, "tensor")
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
1474 1475
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
1476 1477 1478 1479
        group_dst_rank = group.get_group_rank(dst)
        assert group_dst_rank >= 0, ("dst rank out of group, need global rank")
        return group.process_group.send(tensor, group_dst_rank)
    else:
1480
        raise RuntimeError("Only support eager dygraph mode.")
1481 1482 1483 1484 1485 1486 1487 1488


def irecv(tensor, src=None, group=None):
    """
    Receive a tensor to the sender.

    Args:
        tensor (Tensor): The Tensor to receive. Its data type
1489
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1490 1491 1492 1493
        src (int): The source rank id.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.

    Returns:
1494
        A distributed task object.
1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506

    Warning:    
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
1507
            if dist.get_rank() == 0:
1508
                data = paddle.to_tensor([7, 8, 9])
1509
                task = dist.isend(data, dst=1)
1510 1511
            else:
                data = paddle.to_tensor([1, 2, 3])
1512
                task = dist.irecv(data, src=0)
1513 1514
            task.wait()
            print(data)
1515
            # [7, 8, 9] (2 GPUs)
1516 1517 1518 1519 1520 1521 1522
    """
    _check_single_tensor(tensor, "tensor")
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
1523 1524
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
1525 1526 1527 1528
        group_src_rank = group.get_group_rank(src)
        assert group_src_rank >= 0, ("src rank out of group, need global rank")
        return group.process_group.recv(tensor, group_src_rank)
    else:
1529
        raise RuntimeError("Only support eager dygraph mode.")
1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597


class P2POp(object):
    """
    A class that makes point-to-point operations for "batch_isend_irecv".

    This class creates the type of P2P operation, communication buffer, peer rank,
    Group. Instances of this class will be passed to
    ``paddle.distributed.batch_isend_irecv`` for point-to-point communication.

    Args:
        op (callable): A function to send data to or receive data from a peer process.
            The type of ``op`` is either ``paddle.distributed.isend`` or ``paddle.distributed.irecv``.
        tensor (Tensor): Tensor to send or receive.
        peer (int): The destination or source rank.
        group (Group, optional): The group instance return by new_group or None for global 
            default group. Default: None.

    """

    def __init__(self, op, tensor, peer, group=None):
        if op not in [isend, irecv]:
            raise RuntimeError("Invalid ``op`` function. Expected ``op`` "
                               "to be of type ``paddle.distributed.isend`` or "
                               "``paddle.distributed.irecv``.")
        _check_single_tensor(tensor, "tensor")

        self.op = op
        self.tensor = tensor
        self.peer = peer
        self.group = _get_default_group() if group is None else group


@contextlib.contextmanager
def _with_batch_p2p_guard(backend):
    if backend == "nccl":
        core.ProcessGroupNCCL.group_start()
    try:
        yield
    finally:
        if backend == "nccl":
            core.ProcessGroupNCCL.group_end()


def _check_p2p_op_list(p2p_op_list):
    """
    Helper to check that the ``p2p_op_list`` is a list of P2POp instances and
    all ops use the same backend.
    """
    if not isinstance(p2p_op_list, list) or not all(
            isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list):
        raise RuntimeError("Invalid ``p2p_op_list``. Each op is expected to "
                           "to be of type ``paddle.distributed.P2POp``.")

    backend = _group_map_backend[p2p_op_list[0].group]
    if not all(backend == _group_map_backend[p2p_op.group]
               for p2p_op in p2p_op_list):
        raise RuntimeError("All groups need to use the same backend.")


def batch_isend_irecv(p2p_op_list):
    """
    Send or Receive a batch of tensors asynchronously and return a list of requests.

    Process each of the point-to-point operations in ``p2p_op_list`` and return the 
    corresponding tasks. NCCL are currently supported.

    Args:
1598
        p2p_op_list (List[P2POp]): A list of point-to-point operations(type of each operator is
1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666
            ``paddle.distributed.P2POp``). The order of the isend/irecv in the list
            matters and it needs to match with corresponding isend/irecv on the
            remote end.

    Returns:
        A list of distributed tasks returned by calling the corresponding
        op in the op_list. 

    Warning:    
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed

            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            world_size = dist.get_world_size()

            send_t = paddle.arange(2) + rank
            # paddle.tensor([0, 1])  # Rank-0
            # paddle.tensor([1, 2])  # Rank-1

            recv_t = paddle.empty(shape=[2], dtype=send_t.dtype)

            send_op = dist.P2POp(dist.isend, send_t, (rank + 1) % world_size)
            recv_op = dist.P2POp(dist.irecv, recv_t, (rank - 1 + world_size) % world_size)

            tasks = dist.batch_isend_irecv([send_op, recv_op])

            for task in tasks:
                task.wait()
            
            print(recv_t)
            # paddle.tensor([1, 2])     # Rank-0
            # paddle.tensor([0, 1])     # Rank-1
    """
    _check_p2p_op_list(p2p_op_list)
    group = p2p_op_list[0].group
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
        backend = _group_map_backend[group]
        tasks = []
        with _with_batch_p2p_guard(backend):
            for p2p_op in p2p_op_list:
                op = p2p_op.op
                tensor = p2p_op.tensor
                peer = p2p_op.peer
                comm_group = p2p_op.group
                task = op(tensor, peer, comm_group)
                if task is not None:
                    tasks.append(task)
        return tasks
    else:
        raise RuntimeError("Don't support static graph mode currently.")


def reduce_scatter(tensor,
                   tensor_list,
                   op=ReduceOp.SUM,
                   group=None,
1667
                   sync_op=True):
1668 1669 1670 1671
    """
    Reduces, then scatters a list of tensors to all processes in a group

    Args:
1672
        tensor (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1673
        tensor_list (list[Tensor]): List of tensors to reduce and scatter. Every element in the list must be a Tensor whose data type
1674
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1675
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
1676 1677
        group (Group, optional): The group instance return by new_group or None for global 
            default group. Default: None.
1678
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
1679 1680

    Returns:
1681 1682 1683 1684
        Async task handle, if sync_op is set to False.
        None, if sync_op or if not part of the group.

    Warning:
1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695
        This API only supports the dygraph mode.


    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
1696 1697 1698
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([0, 1])
                data2 = paddle.to_tensor([2, 3])
1699
            else:
1700 1701 1702 1703 1704 1705
                data1 = paddle.to_tensor([4, 5])
                data2 = paddle.to_tensor([6, 7])
            dist.reduce_scatter(data1, [data1, data2])
            print(data1)
            # [4, 6] (2 GPUs, out for rank 0)
            # [8, 10] (2 GPUs, out for rank 1)
1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716

    """
    _check_single_tensor(tensor, "tensor")
    _check_tensor_list(tensor_list, "tensor_list")

    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        op_type = _get_reduce_op(op, "reduce_scatter")
        group = _get_default_group() if group is None else group
1717 1718
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
1719 1720 1721

        temp = paddle.concat(tensor_list, axis=0)
        task = group.process_group._reduce_scatter_base(tensor, temp, op_type)
1722
        if sync_op:
1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734
            task.wait()
            return None
        else:
            return task
    else:
        raise RuntimeError("Don't support static graph mode currently.")


def _reduce_scatter_base(output,
                         input,
                         op=ReduceOp.SUM,
                         group=None,
1735
                         sync_op=True):
1736 1737 1738 1739
    """
    Reduces, then scatters a flattened tensor to all processes in a group.

    Args:
1740 1741 1742
        output (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
        input (Tensor): Input tensor that is of size output tensor size times world size. Its data type
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1743
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
1744 1745
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.
1746 1747
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.

1748
    Returns:
1749 1750
        Async task handle, if sync_op is set to False.
        None, if sync_op or if not part of the group.
1751 1752 1753 1754 1755 1756 1757 1758 1759 1760

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
1761 1762 1763 1764 1765
            data = paddle.arange(4) + rank
            # [0, 1, 2, 3] (2 GPUs, for rank 0)
            # [1, 2, 3, 4] (2 GPUs, for rank 1)
            output = paddle.empty(shape=[2], dtype=data.dtype)
            dist.collective._reduce_scatter_base(output, data)
1766
            print(output)
1767 1768
            # [1, 3] (2 GPUs, out for rank 0)
            # [5, 7] (2 GPUs, out for rank 1)
1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780

    """
    _check_single_tensor(output, "output")
    _check_single_tensor(input, "input")

    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        op_type = _get_reduce_op(op, "_reduce_scatter_base")
        group = _get_default_group() if group is None else group
        task = group.process_group._reduce_scatter_base(output, input, op_type)
1781
        if sync_op:
1782 1783 1784 1785 1786 1787
            task.wait()
            return None
        else:
            return task
    else:
        raise RuntimeError("Don't support static graph mode currently.")