collective.py 63.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import os
17 18
import pickle
import io
19 20
import datetime
import time
21
from ..fluid.layer_helper import LayerHelper
22
from ..fluid.framework import in_dygraph_mode
J
Jiabin Yang 已提交
23
from ..fluid.framework import _non_static_mode
24
from ..fluid.data_feeder import check_variable_and_dtype
25 26 27
from ..fluid.layers.tensor import fill_constant
import paddle
import paddle.fluid.core as core
28
from paddle import _legacy_C_ops
29
import contextlib
30 31 32 33 34 35 36 37 38 39 40 41
from .fleet.layers.mpu.mp_ops import split  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_identity  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_concat  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_split  # noqa: F401
from .fleet.layers.mpu.mp_ops import _mp_allreduce  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_lookup_table  # noqa: F401
from .fleet.layers.mpu.mp_ops import _Linear  # noqa: F401
from .fleet.layers.mpu.mp_ops import _set_var_distributed  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_softmax_with_cross_entropy  # noqa: F401
from .fleet.layers.mpu.mp_ops import _linear  # noqa: F401
from .fleet.layers.mpu.mp_ops import _parallel_linear  # noqa: F401
from .fleet.layers.mpu.mp_ops import _parallel_embedding  # noqa: F401
42
from .communication.group import Group, _add_new_group
43
from .communication.all_reduce import all_reduce  # noqa: F401
44
from .communication.reduce import _get_reduce_op, ReduceOp
45

46
__all__ = []
47

K
kuizhiqing 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60
_global_env = None


def _get_global_env():
    global _global_env
    if not _global_env:
        _global_env = paddle.distributed.ParallelEnv()
    return _global_env


# group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group]
_group_map = {}
61
_global_env_gid = 0
K
kuizhiqing 已提交
62

63 64 65 66
# group map by name : the map of all groups from their names
# Dict[name, Group]
_group_map_by_name = {}

67 68 69 70
# backend map by group : the map of all backend from their groups
# Dict[group, backend]
_group_map_backend = {}

71 72 73
# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"

74
_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter', 'xccl']
75 76
_default_store = None  # the default tcp store
_default_backend = None
77 78
_default_timeout = datetime.timedelta(seconds=1800)
_start_ring_id = 0
79

K
kuizhiqing 已提交
80

L
lilong12 已提交
81 82 83 84 85 86 87 88 89 90
def _set_default_backend(backend):
    global _default_backend
    _default_backend = backend


def _set_default_store(store):
    global _default_store
    _default_store = store


K
kuizhiqing 已提交
91 92
def _get_group_map():
    global _group_map
93
    if _global_env_gid not in _group_map:
K
kuizhiqing 已提交
94
        genv = _get_global_env()
95 96
        _group_map[_global_env_gid] = Group(genv.rank, 0,
                                            list(range(genv.world_size)))
K
kuizhiqing 已提交
97 98 99 100
    return _group_map


def _get_global_group():
101
    return _get_group_map()[_global_env_gid]
K
kuizhiqing 已提交
102 103


104 105 106 107 108 109
def _get_group_map_by_name():
    global _group_map_by_name
    return _group_map_by_name


def _get_default_group():
L
lilong12 已提交
110
    global _group_map_by_name
111 112
    assert is_initialized(), ("Call paddle.distributed.init_parallel_env first "
                              "to initialize the distributed environment.")
113 114 115
    return _get_group_map_by_name()[_default_group_name]


L
lilong12 已提交
116 117 118 119 120 121 122 123 124 125 126 127
def _set_group_map(gid, group):
    global _group_map
    assert gid not in _group_map
    _group_map[gid] = group


def _set_group_map_by_name(name, group):
    global _group_map_by_name
    assert name not in _group_map_by_name
    _group_map_by_name[name] = group


128 129 130 131 132 133
def _set_group_map_backend(group, backend):
    global _group_map_backend
    assert group not in _group_map_backend
    _group_map_backend[group] = backend


K
kuizhiqing 已提交
134
def _new_ring_id():
135 136 137 138 139 140 141
    # NOTE(liyurui): For compatible reason, auto parallel and eager mode relay on previous syntax.
    if in_dygraph_mode():
        global _start_ring_id
        _start_ring_id += 1
        return _start_ring_id + max(_get_global_env().nrings, 9)
    else:
        return len(_get_group_map()) + max(_get_global_env().nrings, 9)
K
kuizhiqing 已提交
142 143 144 145 146 147 148 149


def get_group(id=0):
    """

    Get group instance by group id.

    Args:
K
kuizhiqing 已提交
150
        id (int): the group id. Default value is 0.
K
kuizhiqing 已提交
151 152 153 154 155 156 157 158 159 160 161 162 163 164

    Returns:
        Group: the group instance.

    Examples:
        .. code-block:: python

            ...
            gid = paddle.distributed.new_group([2,4,6])
            paddle.distributed.get_group(gid.id)

    """

    gm = _get_group_map()
J
Jiangxinz 已提交
165
    return gm[id] if id in gm else None
K
kuizhiqing 已提交
166 167


168 169 170 171 172 173
def _new_process_group_impl(backend,
                            store,
                            rank,
                            world_size,
                            group_name,
                            pg_options,
L
lilong12 已提交
174 175 176
                            group_id=0,
                            src_rank=None,
                            dst_rank=None):
177
    pg = None
178
    genv = _get_global_env()
L
lilong12 已提交
179 180 181 182
    if backend != 'heter':
        assert src_rank is None and dst_rank is None, (
            "src_rank and dst_rank "
            "can only be set for heter backend.")
L
lilong12 已提交
183
    assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
184
    if backend == "gloo":
185 186
        place = core.CPUPlace()
        pg = core.ProcessGroupGloo(store, rank, world_size, place, group_id)
187
    elif backend == "nccl":
188 189
        place = core.CUDAPlace(genv.device_id)
        pg = core.ProcessGroupNCCL(store, rank, world_size, place, group_id)
190
    elif backend == "hccl":
191 192
        place = core.NPUPlace(genv.device_id)
        pg = core.ProcessGroupHCCL(store, rank, world_size, place, group_id)
193 194 195
    elif backend == "xccl":
        place = core.CustomPlace(genv.device_type, genv.device_id)
        pg = core.ProcessGroupCustom(store, rank, world_size, place, group_id)
196
    elif backend == "heter":
197 198 199 200 201
        place = None
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(genv.device_id)
        elif core.is_compiled_with_npu():
            place = core.NPUPlace(genv.device_id)
202 203 204 205 206 207 208 209 210 211 212 213 214
        cluster_id = int(os.getenv("CLUSTER_ID", "-1"))
        assert cluster_id >= 0, "please set the CLUSTER_ID variable."
        cluster_size = os.getenv("CLUSTER_SIZE", None)
        assert cluster_size, "please set the CLUSTER_SIZE variable."
        cluster_size = cluster_size.split(",")
        cluster_size = [int(s) for s in cluster_size]
        switch_ep = os.getenv("CLUSTER_SWITCH", None)
        assert switch_ep, "please set the CLUSTER_SWITCH variable."
        cluster_size_cumsum = np.cumsum(cluster_size)
        cluster_offset = 0 if cluster_id == 0 else cluster_size_cumsum[
            cluster_id - 1]
        global_rank = cluster_offset + rank
        global_world_size = cluster_size_cumsum[-1]
215
        global_rank, global_world_size = _get_global_config(backend, rank)
216 217 218 219 220 221 222 223 224 225 226 227 228
        pg = core.ProcessGroupHeter(store,
                                    rank=global_rank,
                                    world_size=global_world_size,
                                    place=place,
                                    gid=group_id,
                                    local_rank=rank,
                                    local_size=world_size,
                                    gloo_rank=cluster_id,
                                    gloo_size=len(cluster_size),
                                    with_switch=True,
                                    switch_endpoint=switch_ep,
                                    src_rank=src_rank,
                                    dst_rank=dst_rank)
229 230 231 232

    return pg


S
ShenLiang 已提交
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
def barrier(group=None):
    """

    Barrier among all participators in the group.

    Args:
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            paddle.distributed.barrier()
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
257
    if in_dygraph_mode():
258 259 260 261 262
        group = _get_default_group() if group is None else group
        task = group.process_group.barrier()
        task.wait()
        return

S
ShenLiang 已提交
263 264 265
    ring_id = 0 if group is None else group.id

    temp = fill_constant([1], dtype="int32", value="1")
J
Jiabin Yang 已提交
266
    if _non_static_mode():
267
        return _legacy_C_ops.barrier(temp, temp, 'ring_id', ring_id)
W
wanghuancoder 已提交
268 269 270

    op_type = 'barrier'

S
ShenLiang 已提交
271 272 273
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'group' for barrier must be int.")
    helper = LayerHelper(op_type, **locals())
274 275 276 277
    helper.append_op(type=op_type,
                     inputs={'X': [temp]},
                     outputs={'Out': [temp]},
                     attrs={'ring_id': ring_id})
S
ShenLiang 已提交
278 279


L
lilong12 已提交
280 281 282 283 284 285 286
# _custom_gid provides a way for users to
# set the group id, which is usually useful
# to be compatible with the static mode.
_custom_gid = None


def _set_custom_gid(gid):
287
    global _custom_gid
L
lilong12 已提交
288 289 290
    _custom_gid = gid


291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
def _barrier_by_tcp_store(group_name, store, timeout):
    global_rank = paddle.distributed.get_rank()
    global_world_size = paddle.distributed.get_world_size()

    if global_world_size < 2:
        return

    barrier_prefix = "Barrier/" + group_name + "/"
    is_master = (global_rank == 0)

    def _check_keys_ready(wait_keys):
        start_time = time.time()
        while len(wait_keys) > 0:
            time.sleep(0.1)
            elapse_time = time.time() - start_time
            if datetime.timedelta(seconds=elapse_time) > timeout:
                raise RuntimeError(
                    "Timeout while initializing process group {}."
                    "Keys {} are not ready sinck rank {} is waiting them."
                    "Two reason may cause this error:\n 1. The create process group api should be called by all ranks.\n"
                    " 2. Try to increase the waiting time.\n".format(
                        group_name, wait_keys, global_rank))
            wait_keys = list(
                filter(lambda key: int(store.get(key)) != 1, wait_keys))

    # all the workers set their exiting key and exit
    # the master will wait for all workers' exiting key, ensure to exit in the end
    if is_master:
        wait_keys = [
            barrier_prefix + str(rank) for rank in range(1, global_world_size)
        ]
        _check_keys_ready(wait_keys)
    else:
        store.add(barrier_prefix + str(global_rank), 1)


def new_group(ranks=None, backend=None, timeout=_default_timeout):
K
kuizhiqing 已提交
328 329
    """

K
kuizhiqing 已提交
330
    Creates a new distributed communication group.
K
kuizhiqing 已提交
331 332

    Args:
K
kuizhiqing 已提交
333
        ranks (list): The global ranks of group members.
K
kuizhiqing 已提交
334
        backend (str): The backend used to create group, only nccl is supported now.
335
        timeout (datetime.timedelta, optional): The waiting timeout for store relevant options, default is 30 minutes.
K
kuizhiqing 已提交
336 337

    Returns:
K
kuizhiqing 已提交
338
        Group: The group instance.
K
kuizhiqing 已提交
339 340 341 342 343 344 345

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
346 347
            tindata = paddle.randn(shape=[2, 3])
            gp = paddle.distributed.new_group([2,4,6])
348
            paddle.distributed.all_reduce(tindata, group=gp, sync_op=False)
K
kuizhiqing 已提交
349 350

    """
351
    global _custom_gid
352
    global _group_map
L
lilong12 已提交
353
    if in_dygraph_mode():
354
        global _default_group_name
L
lilong12 已提交
355
        gid = _custom_gid if _custom_gid else _new_ring_id()
356
        group_name = _default_group_name + str(gid)
L
lilong12 已提交
357
        if backend != 'heter' and (ranks is None or len(ranks) > 1):
358 359 360 361 362 363 364 365 366
            global_group = _get_default_group()
            global_rank = global_group.rank
            global_ranks = global_group.ranks
            backend = _default_backend if backend is None else backend
            if ranks is None:
                ranks = global_ranks
            assert len(ranks) <= len(global_ranks), (
                "Size of new group must be less than or "
                "equal to that of the default global group.")
367 368
        size = len(ranks)
        ranks = sorted(ranks)
L
lilong12 已提交
369 370 371 372
        if backend == 'heter' or (size > 1 and global_rank in ranks):
            rank = 0 if backend == 'heter' else ranks.index(global_rank)
            src_rank = ranks[0] if backend == 'heter' else None
            dst_rank = ranks[1] if backend == 'heter' else None
373 374 375 376 377 378 379 380 381
            pg = _new_process_group_impl(backend,
                                         _default_store,
                                         rank,
                                         size,
                                         group_name,
                                         pg_options=None,
                                         group_id=gid,
                                         src_rank=src_rank,
                                         dst_rank=dst_rank)
382 383 384
        else:
            rank = -1
            pg = None
385
        group = Group(rank, gid, ranks, pg=pg, name=group_name)
386 387
        _group_map_by_name[group_name] = group
        _group_map[gid] = group
388
        _group_map_backend[group] = backend
389 390 391
        #TODO: The method below is a new method for group management, will replace the previous
        # three in the future.
        _add_new_group(group)
392

393
        # TODO(shenliang03): This is a temporary solution to solve the problem of
394
        # hang caused by tcp
395
        paddle.distributed.barrier(group=group)
396 397 398 399 400
        # NOTE(liyurui): All processors should hang and wait using tcp store, in case master exit before sub-group is created.
        if backend != 'heter':
            _barrier_by_tcp_store(group_name, _default_store, timeout)
        else:
            print("Warning: store barrier is not supported for heter backend.")
401
        return group
K
kuizhiqing 已提交
402 403 404 405 406 407 408 409 410 411 412

    if not backend:
        backend = 'nccl'
    assert backend == 'nccl', ("backend other than nccl is not supported yet")

    genv = _get_global_env()
    global_rank = genv.rank

    ring_id = _new_ring_id()

    if global_rank not in ranks:
413
        gp = Group(-1, ring_id, ranks)
K
kuizhiqing 已提交
414 415
        _group_map[ring_id] = gp
    else:
416 417 418
        ranks = sorted(ranks)
        group_rank = ranks.index(global_rank)
        group_size = len(ranks)
419
        gp = Group(group_rank, ring_id, ranks)
420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435
        _group_map[ring_id] = gp

        if group_size >= 2:
            strategy = core.ParallelStrategy()
            strategy.nranks = group_size
            strategy.local_rank = group_rank
            strategy.trainer_endpoints = [
                genv.trainer_endpoints[i] for i in ranks
            ]
            strategy.current_endpoint = genv.current_endpoint
            strategy.nrings = 1

            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(genv.device_id)
                core.NCCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
436 437 438 439
            elif core.is_compiled_with_npu():
                place = core.NPUPlace(genv.device_id)
                core.HCCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
440 441 442 443
            elif core.is_compiled_with_mlu():
                place = core.MLUPlace(genv.device_id)
                core.CNCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
444 445 446 447
            elif core.is_compiled_with_xpu():
                place = core.XPUPlace(genv.device_id)
                core.BKCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
448 449 450 451 452
            else:
                assert False, ("no cuda device found")
        else:
            return gp

453
    # TODO(shenliang03): This is a temporary solution to solve the problem of
454
    # hang caused by cross-creation of new_group
455
    tmp = paddle.to_tensor(
J
Jiabin Yang 已提交
456
        [1], dtype="int32") if _non_static_mode() else fill_constant(
457
            [0], dtype="int32", value="1")
458
    paddle.distributed.all_reduce(tmp, sync_op=True)
459
    paddle.distributed.wait(tmp)
K
kuizhiqing 已提交
460 461
    return gp

462

463 464 465 466 467
def is_initialized():
    """

    Check whether the distributed environment has been initialized

468 469
    Returns:
        `True` if distributed environment has been initialized, otherwise `False`.
470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle

            print(paddle.distributed.is_initialized())
            # False

            paddle.distributed.init_parallel_env()
            print(paddle.distributed.is_initialized())
            # True

    """
    global _group_map_by_name
    return _default_group_name in _group_map_by_name


def destroy_process_group(group=None):
    """
    Destroy a given group for communication

    Args:
494 495
        group (ProcessGroup, optional): The group to be destroyed. All of process groups, including
                                        the default group, will be destroyed and the distributed
496
                                        environment will be deinitialized.
497

498 499 500 501 502 503 504
    Returns : None

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
505
            import paddle.distributed as dist
506

507 508
            dist.init_parallel_env()
            group = dist.new_group([0, 1])
509

510 511
            dist.destroy_process_group(group)
            print(dist.is_initialized())
512
            # True
513 514
            dist.destroy_process_group()
            print(dist.is_initialized())
515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533
            # False

    """
    global _group_map
    global _group_map_by_name

    pg = _get_default_group() if group is None else group
    assert _group_map.get(pg.id, None) is not None, "Invalid group."

    if group is None:
        _group_map.clear()
        _group_map_by_name.clear()
        _group_map_backend.clear()
    else:
        del _group_map[pg.id]
        del _group_map_by_name[pg.name]
        del _group_map_backend[pg]


K
kuizhiqing 已提交
534 535 536 537 538 539 540 541
def wait(tensor, group=None, use_calc_stream=True):
    """

    wait to sync stream for group.

    Args:
        tensor (Tensor): The Tensor used before sync.
        group (Group): The Group instance to perform sync.
K
kuizhiqing 已提交
542 543
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
K
kuizhiqing 已提交
544 545 546 547 548 549 550 551 552 553

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
554
            tindata = paddle.randn(shape=[2, 3])
555
            paddle.distributed.all_reduce(tindata, sync_op=True)
K
kuizhiqing 已提交
556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572
            paddle.distributed.wait(tindata)

    """

    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id

    if use_calc_stream:
        _sync_calc_stream(tensor)
    else:
        _sync_comm_stream(tensor, ring_id)


def _sync_calc_stream(tensor):

J
Jiabin Yang 已提交
573
    if _non_static_mode():
574
        return _legacy_C_ops.c_sync_calc_stream(tensor, tensor)
K
kuizhiqing 已提交
575 576 577 578 579 580 581

    op_type = 'c_sync_calc_stream'

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
582 583
        outputs={'Out': [tensor]},
    )
584

585

K
kuizhiqing 已提交
586
def _sync_comm_stream(tensor, ring_id=0):
587

J
Jiabin Yang 已提交
588
    if _non_static_mode():
589 590
        return _legacy_C_ops.c_sync_comm_stream([tensor], [tensor], 'ring_id',
                                                ring_id)
591

K
kuizhiqing 已提交
592
    op_type = 'c_sync_comm_stream'
593

K
kuizhiqing 已提交
594 595 596 597 598
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
599 600
        attrs={'ring_id': ring_id},
    )
K
kuizhiqing 已提交
601 602


603
def broadcast(tensor, src, group=None, sync_op=True):
604 605 606
    """

    Broadcast a tensor from the source to all others.
607 608
    As shown below, one process is started with a GPU and GPU0 owns data 0. Through broadcast operator,
    data 0 will be sent to all GPUs from GPU0.
609 610 611 612 613

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/broadcast.png
        :width: 800
        :alt: broadcast
        :align: center
614 615

    Args:
616
        tensor (Tensor): The Tensor to send if current rank is the source, or the Tensor to receive otherwise. Its data type
617
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
618
        src (int): The source rank.
619 620
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
621 622 623 624 625 626 627

    Returns:
        None.

    Examples:
        .. code-block:: python

628
            # required: distributed
629
            import paddle
630
            import paddle.distributed as dist
631

632 633 634
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
635
            else:
636 637 638 639
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.broadcast(data, src=1)
            print(data)
            # [[1, 2, 3], [1, 2, 3]] (2 GPUs)
640
    """
K
kuizhiqing 已提交
641 642 643 644 645 646 647

    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

L
lilong12 已提交
648
    if in_dygraph_mode():
649 650 651 652
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
        assert gsrc >= 0, ("src rank out of group, need global rank")
        task = group.process_group.broadcast(tensor, gsrc)
653
        if sync_op:
654 655 656 657 658
            task.wait()
            return None
        else:
            return task

659
    use_calc_stream = sync_op
660
    ring_id = ring_id = 0 if group is None else group.id
K
kuizhiqing 已提交
661
    gsrc = src if group is None else group.get_group_rank(src)
K
kuizhiqing 已提交
662
    assert gsrc >= 0, ("src rank out of group, need global rank")
K
kuizhiqing 已提交
663

J
Jiabin Yang 已提交
664
    if _non_static_mode():
665 666 667
        return _legacy_C_ops.c_broadcast(tensor, tensor, 'root', gsrc,
                                         'use_calc_stream', use_calc_stream,
                                         'ring_id', ring_id)
668 669

    op_type = 'c_broadcast'
670 671 672 673
    check_variable_and_dtype(tensor, 'tensor', [
        'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
        'bool'
    ], 'broadcast')
674 675

    helper = LayerHelper(op_type, **locals())
676 677 678 679 680 681 682 683
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'root': gsrc,
                         'use_calc_stream': use_calc_stream,
                         'ring_id': ring_id,
                     })
684 685


686
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, sync_op=True):
687 688
    """

689 690
    Reduce a tensor to the destination from all others. As shown below, one process is started with a GPU and the data of this process is represented
    by its group rank. The destination of the reduce operator is GPU0 and the process is sum. Through reduce operator,
691 692 693 694 695 696
    the GPU0 will owns the sum of all data from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/reduce.png
        :width: 800
        :alt: reduce
        :align: center
697 698 699

    Args:
        tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
700
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
701
        dst (int): The destination rank id.
702 703 704
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The operation used. Default value is ReduceOp.SUM.
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
705 706 707 708 709 710 711

    Returns:
        None.

    Examples:
        .. code-block:: python

712
            # required: distributed
713
            import paddle
714
            import paddle.distributed as dist
715

716 717 718
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
719
            else:
720 721 722 723 724
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.reduce(data, dst=0)
            print(data)
            # [[5, 7, 9], [5, 7, 9]] (2 GPUs, out for rank 0)
            # [[1, 2, 3], [1, 2, 3]] (2 GPUs, out for rank 1)
725
    """
K
kuizhiqing 已提交
726 727 728
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
729
    if in_dygraph_mode():
730
        op_type = _get_reduce_op(op, "reduce")
731 732 733 734
        group = _get_default_group() if group is None else group
        gdst = group.get_group_rank(dst)
        assert gdst >= 0, ("dst rank out of group, need global rank")
        task = group.process_group.reduce(tensor, gdst, op_type)
735
        if sync_op:
736 737 738 739
            task.wait()
            return None
        else:
            return task
K
kuizhiqing 已提交
740

741
    use_calc_stream = sync_op
K
kuizhiqing 已提交
742 743
    ring_id = 0 if group is None else group.id
    gdst = dst if group is None else group.get_group_rank(dst)
K
kuizhiqing 已提交
744
    assert gdst >= 0, ("dst rank out of group, need global rank")
K
kuizhiqing 已提交
745

J
Jiabin Yang 已提交
746
    if _non_static_mode():
747
        if op == ReduceOp.SUM:
748 749 750
            return _legacy_C_ops.c_reduce_sum(tensor, tensor, 'use_calc_stream',
                                              use_calc_stream, 'ring_id',
                                              ring_id, 'root_id', gdst)
751
        elif op == ReduceOp.MAX:
752 753 754
            return _legacy_C_ops.c_reduce_max(tensor, tensor, 'use_calc_stream',
                                              use_calc_stream, 'ring_id',
                                              ring_id, 'root_id', gdst)
755
        elif op == ReduceOp.MIN:
756 757 758
            return _legacy_C_ops.c_reduce_min(tensor, tensor, 'use_calc_stream',
                                              use_calc_stream, 'ring_id',
                                              ring_id, 'root_id', gdst)
759
        elif op == ReduceOp.PROD:
760 761 762 763
            return _legacy_C_ops.c_reduce_prod(tensor, tensor,
                                               'use_calc_stream',
                                               use_calc_stream, 'ring_id',
                                               ring_id, 'root_id', gdst)
764 765 766 767
        else:
            raise ValueError("Unknown parameter: {}.".format(op))

    op_type = 'c_reduce'
768 769 770 771
    check_variable_and_dtype(tensor, 'tensor', [
        'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
        'bool'
    ], 'reduce')
772 773 774 775 776 777 778 779 780 781 782

    if op == ReduceOp.SUM:
        op_type = 'c_reduce_sum'
    elif op == ReduceOp.MAX:
        op_type = 'c_reduce_max'
    elif op == ReduceOp.MIN:
        op_type = 'c_reduce_min'
    elif op == ReduceOp.PROD:
        op_type = 'c_reduce_prod'

    helper = LayerHelper(op_type, **locals())
783 784 785 786 787 788 789 790
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': use_calc_stream,
                         'root_id': gdst,
                     })
791 792


793
def all_gather(tensor_list, tensor, group=None, sync_op=True):
794 795
    """

796
    Gather tensors from all participators and all get the result. As shown
797 798
    below, one process is started with a GPU and the data of this process is represented
    by its group rank. Through the all_gather operator, each GPU will have data
799 800 801 802 803 804
    from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png
        :width: 800
        :alt: all_gather
        :align: center
805 806 807

    Args:
        tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
808
            should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128.
809
        tensor (Tensor): The Tensor to send. Its data type
810
            should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
811 812
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
813 814 815 816 817 818 819

    Returns:
        None.

    Examples:
        .. code-block:: python

820
            # required: distributed
821
            import paddle
822
            import paddle.distributed as dist
823

824
            dist.init_parallel_env()
825
            tensor_list = []
826 827
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
828
            else:
829 830 831 832
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.all_gather(tensor_list, data)
            print(tensor_list)
            # [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
833
    """
K
kuizhiqing 已提交
834 835 836
    if group is not None and not group.is_member():
        return

837 838 839 840 841 842 843 844 845 846 847
    def convert_to_complex(list_of_tensor):
        list_of_complex = []
        for tensor in list_of_tensor:
            list_of_complex.append(paddle.as_complex(tensor))
        return list_of_complex

    is_input_complex = (tensor.dtype == paddle.complex64
                        or tensor.dtype == paddle.complex128)
    if is_input_complex:
        tensor = paddle.as_real(tensor)

L
lilong12 已提交
848
    if in_dygraph_mode():
849
        group = _get_default_group() if group is None else group
850 851 852 853 854 855
        if len(tensor_list) == 0:
            tensor_shape = list(tensor.shape)
            tensor_shape[0] *= group.nranks
            out = paddle.empty(tensor_shape, tensor.dtype)
        else:
            out = paddle.concat(tensor_list, axis=0)
856 857 858
        task = group.process_group.all_gather(tensor, out)
        task.wait()
        tensor_list.clear()
859 860 861 862 863
        list_of_tensor = paddle.split(out, group.nranks, 0)
        if is_input_complex:
            tensor_list.extend(convert_to_complex(list_of_tensor))
        else:
            tensor_list.extend(list_of_tensor)
864 865
        return

866
    use_calc_stream = sync_op
K
kuizhiqing 已提交
867 868 869
    ring_id = 0 if group is None else group.id
    nranks = _get_global_group().nranks if group is None else group.nranks

J
Jiabin Yang 已提交
870
    if _non_static_mode():
871 872 873
        out = _legacy_C_ops.c_allgather(tensor, 'use_calc_stream',
                                        use_calc_stream, 'ring_id', ring_id,
                                        'nranks', nranks)
874
    else:
875 876 877
        op_type = 'c_allgather'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
878 879 880 881
        if not isinstance(tensor_list, list):
            raise ValueError("The type of 'tensor_list' for all_gather "
                             "should be list.")
        for elem in tensor_list:
882 883 884 885 886 887 888 889
            check_variable_and_dtype(elem, 'tensor_list', [
                'float16', 'float32', 'float64', 'int32', 'int64', 'bool',
                'int8', 'uint8', 'complex64', 'complex128'
            ], 'all_gather')
        check_variable_and_dtype(tensor, 'tensor', [
            'float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'int8',
            'uint8', 'complex64', 'complex128'
        ], 'all_gather')
890 891 892 893 894 895 896 897
        helper.append_op(type=op_type,
                         inputs={'X': [tensor]},
                         outputs={'Out': [out]},
                         attrs={
                             'ring_id': ring_id,
                             'use_calc_stream': use_calc_stream,
                             'nranks': nranks
                         })
898

899 900 901 902 903 904 905 906 907 908 909 910 911
    list_of_tensor = paddle.split(out, nranks, 0)
    if is_input_complex:
        tensor_list.extend(convert_to_complex(list_of_tensor))
    else:
        tensor_list.extend(list_of_tensor)


def _convert_object_to_tensor(obj):
    _pickler = pickle.Pickler
    f = io.BytesIO()
    _pickler(f).dump(obj)
    data = np.frombuffer(f.getvalue(), dtype=np.uint8)
    tensor = paddle.to_tensor(data)
912
    return tensor, tensor.numel()
913 914


915
def _convert_tensor_to_object(tensor, len_of_tensor):
916
    _unpickler = pickle.Unpickler
917
    return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()
918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944


def all_gather_object(object_list, obj, group=None):
    """

    Gather picklable objects from all participators and all get the result. Similiar to all_gather(), but python object can be passed in.

    Args:
        object_list (list): A list of output object. The datatype of every element in the list is same as the input obj.
        obj (Any): The picklable object to send.
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Warning:
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            object_list = []
945
            if dist.get_rank() == 0:
946 947 948
                obj = {"foo": [1, 2, 3]}
            else:
                obj = {"bar": [4, 5, 6]}
949 950 951
            dist.all_gather_object(object_list, obj)
            print(object_list)
            # [{'foo': [1, 2, 3]}, {'bar': [4, 5, 6]}] (2 GPUs)
952 953 954 955
    """
    assert in_dygraph_mode(
    ), "all_gather_object doesn't support static graph mode."

956 957 958 959 960 961 962 963 964 965 966 967 968
    tensor, len_of_tensor = _convert_object_to_tensor(obj)

    # gather len_of_tensor from all ranks
    list_len_of_tensor = []
    all_gather(list_len_of_tensor, len_of_tensor, group)
    # get the max length from list
    max_len_of_tensor = int(max(list_len_of_tensor).item())
    # resize the input tensor to max length avoid hang in all gather
    # Note(liyurui): Maybe we should support various length all_gather?
    # Now this operation is efficient for we don't support resize in python.
    numpy_data = tensor.numpy()
    numpy_data = np.resize(numpy_data, [max_len_of_tensor])
    input_tensor = paddle.to_tensor(numpy_data)
969 970

    tensor_list = []
971 972 973 974
    all_gather(tensor_list, input_tensor, group)
    for i, tensor in enumerate(tensor_list):
        object_list.append(
            _convert_tensor_to_object(tensor, list_len_of_tensor[i]))
975 976


977
def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True):
978 979
    """

980
    Scatter a tensor to all participators. As shown below, one process is started with a GPU and the source of the scatter
981 982 983 984 985 986
    is GPU0. Through scatter operator, the data in GPU0 will be sent to all GPUs averagely.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/scatter.png
        :width: 800
        :alt: scatter
        :align: center
987 988 989

    Args:
        tensor (Tensor): The output Tensor. Its data type
990
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
991
        tensor_list (list|tuple): A list/tuple of Tensors to scatter. Every element in the list must be a Tensor whose data type
992
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. Default value is None.
K
kuizhiqing 已提交
993
        src (int): The source rank id. Default value is 0.
994 995
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
996 997 998 999 1000 1001 1002

    Returns:
        None.

    Examples:
        .. code-block:: python

1003
            # required: distributed
1004
            import paddle
1005
            import paddle.distributed as dist
1006

1007 1008 1009 1010 1011
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([7, 8, 9])
                data2 = paddle.to_tensor([10, 11, 12])
                dist.scatter(data1, src=1)
1012
            else:
1013 1014 1015 1016 1017 1018
                data1 = paddle.to_tensor([1, 2, 3])
                data2 = paddle.to_tensor([4, 5, 6])
                dist.scatter(data1, tensor_list=[data1, data2], src=1)
            print(data1, data2)
            # [1, 2, 3] [10, 11, 12] (2 GPUs, out for rank 0)
            # [4, 5, 6] [4, 5, 6] (2 GPUs, out for rank 1)
1019
    """
K
kuizhiqing 已提交
1020 1021 1022 1023 1024 1025
    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

L
lilong12 已提交
1026
    if in_dygraph_mode():
1027 1028 1029 1030 1031 1032 1033 1034 1035
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
        rank = group.rank
        nranks = group.nranks
    else:
        ring_id = 0 if group is None else group.id
        gsrc = src if group is None else group.get_group_rank(src)
        rank = _get_global_group().rank if group is None else group.rank
        nranks = _get_global_group().nranks if group is None else group.nranks
K
kuizhiqing 已提交
1036
    assert gsrc >= 0, ("src rank out of group, need global rank")
K
kuizhiqing 已提交
1037 1038

    if rank != gsrc:
1039 1040 1041 1042
        tensor_list = []
        for _ in range(nranks):
            tensor_list.append(tensor)
    temp = paddle.concat(tensor_list, axis=0)
L
lilong12 已提交
1043
    if in_dygraph_mode():
1044
        task = group.process_group.scatter(temp, tensor, gsrc)
1045
        if sync_op:
1046 1047 1048 1049 1050
            task.wait()
            return None
        else:
            return task

1051
    use_calc_stream = sync_op
L
lilong12 已提交
1052
    if _non_static_mode():
1053 1054 1055
        return _legacy_C_ops.c_scatter(temp, tensor, 'use_calc_stream',
                                       use_calc_stream, 'ring_id', ring_id,
                                       'nranks', nranks, 'root', gsrc)
W
wanghuancoder 已提交
1056
    op_type = 'c_scatter'
1057 1058 1059 1060
    check_variable_and_dtype(tensor, 'tensor', [
        'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
        'bool'
    ], 'scatter')
1061
    helper = LayerHelper(op_type, **locals())
1062 1063 1064 1065 1066 1067 1068 1069 1070
    helper.append_op(type=op_type,
                     inputs={'X': [temp]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'root': gsrc,
                         'use_calc_stream': use_calc_stream,
                         'nranks': nranks,
                     })
1071 1072


1073
def alltoall(in_tensor_list, out_tensor_list, group=None, sync_op=True):
L
lilong12 已提交
1074
    """
1075 1076 1077 1078 1079 1080 1081 1082 1083 1084
    Scatter tensors in in_tensor_list to all participators averagely and gather the result tensors in out_tensor_list.
    As shown below, the in_tensor_list in GPU0 includes 0_0 and 0_1, and GPU1 includes 1_0 and 1_1.
    Through alltoall operator, the 0_0 in GPU0 will be sent to GPU0 and 0_1 to GPU1, 1_0 in GPU1 sent to GPU0 and 1_1 to GPU1.
    Finally the out_tensor_list in GPU0 includes 0_0 and 1_0, and GPU1 includes 0_1 and 1_1.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/alltoall.png
        :width: 800
        :alt: alltoall
        :align: center

L
lilong12 已提交
1085 1086
    Args:
        in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type
1087
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1088
        out_tensor_list (list): A list of output Tensors. The data type of its elements should be the same as the
L
lilong12 已提交
1089 1090
            data type of the input Tensors.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
1091
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
1092

L
lilong12 已提交
1093 1094
    Returns:
        None.
1095

L
lilong12 已提交
1096 1097
    Examples:
        .. code-block:: python
1098

L
lilong12 已提交
1099 1100
            # required: distributed
            import paddle
1101 1102 1103
            import paddle.distributed as dist

            dist.init_parallel_env()
L
lilong12 已提交
1104
            out_tensor_list = []
1105 1106 1107
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
                data2 = paddle.to_tensor([[7, 8, 9], [10, 11, 12]])
L
lilong12 已提交
1108
            else:
1109 1110 1111 1112 1113 1114
                data1 = paddle.to_tensor([[13, 14, 15], [16, 17, 18]])
                data2 = paddle.to_tensor([[19, 20, 21], [22, 23, 24]])
            dist.alltoall([data1, data2], out_tensor_list)
            print(out_tensor_list)
            # [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]] (2 GPUs, out for rank 0)
            # [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]] (2 GPUs, out for rank 1)
L
lilong12 已提交
1115 1116 1117 1118
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
1119
    if in_dygraph_mode():
1120
        group = _get_default_group() if group is None else group
1121 1122
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
1123 1124 1125
    else:
        ring_id = 0 if group is None else group.id

L
lilong12 已提交
1126
    temp = paddle.concat(in_tensor_list, axis=0)
李季 已提交
1127
    nranks = len(in_tensor_list)
L
lilong12 已提交
1128
    if in_dygraph_mode():
1129 1130 1131 1132 1133 1134
        if len(out_tensor_list) == 0:
            tensor_shape = list(in_tensor_list[0].shape)
            tensor_shape[0] *= nranks
            out = paddle.empty(tensor_shape, in_tensor_list[0].dtype)
        else:
            out = paddle.concat(out_tensor_list, axis=0)
1135 1136 1137 1138 1139 1140
        task = group.process_group.alltoall(temp, out)
        task.wait()
        out_tensor_list.clear()
        out_tensor_list.extend(paddle.split(out, nranks, 0))
        return

1141
    use_calc_stream = sync_op
J
Jiabin Yang 已提交
1142
    if _non_static_mode():
1143 1144
        out = _legacy_C_ops.alltoall(temp, 'use_calc_stream', use_calc_stream,
                                     'ring_id', ring_id)
L
lilong12 已提交
1145
    else:
W
wanghuancoder 已提交
1146 1147 1148 1149 1150
        op_type = 'alltoall'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(
            dtype=in_tensor_list[0].dtype)

L
lilong12 已提交
1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164
        if not isinstance(in_tensor_list, list):
            raise ValueError("The type of 'in_tensor_list' for all_to_all "
                             "should be list.")
        for elem in in_tensor_list:
            check_variable_and_dtype(
                elem, 'in_tensor_list',
                ['float16', 'float32', 'float64', 'int32', 'int64'],
                'all_to_all')
        if not isinstance(out_tensor_list, list):
            raise ValueError("The type of 'out_tensor_list' for all_to_all "
                             "should be list.")
        if len(out_tensor_list) != 0:
            raise ValueError("The 'out_tensor_list' for all_to_all "
                             "must be an empty list.")
1165 1166 1167 1168 1169 1170 1171
        helper.append_op(type=op_type,
                         inputs={'X': [temp]},
                         outputs={'Out': [out]},
                         attrs={
                             'ring_id': ring_id,
                             'use_calc_stream': use_calc_stream,
                         })
L
lilong12 已提交
1172 1173 1174
    out_tensor_list.extend(paddle.split(out, nranks, 0))


1175 1176 1177 1178 1179
def alltoall_single(in_tensor,
                    out_tensor,
                    in_split_sizes=None,
                    out_split_sizes=None,
                    group=None,
1180
                    sync_op=True):
1181 1182 1183
    """
    Scatter a single input tensor to all participators and gather the received tensors in out_tensor.

1184
    Note:
1185 1186 1187
        ``alltoall_single`` is only supported in eager mode.

    Args:
1188
        in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1189
        out_tensor (Tensor): Output Tensor. The data type should be the same as the data type of the input Tensor.
1190
        in_split_sizes (list[int], optional): Split sizes of ``in_tensor`` for dim[0]. If not given, dim[0] of ``in_tensor``
1191
            must be divisible by group size and ``in_tensor`` will be scattered averagely to all participators. Default: None.
1192
        out_split_sizes (list[int], optional): Split sizes of ``out_tensor`` for dim[0]. If not given, dim[0] of ``out_tensor``
1193 1194
            must be divisible by group size and ``out_tensor`` will be gathered averagely from all participators. Default: None.
        group (Group, optional): The group instance return by ``new_group`` or None for global default group. Default: None.
1195
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
1196

1197
    Returns:
1198
        None, if ``sync_op`` is set to ``True``; ``Task`` of ``group``, if ``sync_op`` is set to ``False``.
1199

1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210
    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            size = dist.get_world_size()

1211 1212 1213 1214
            # case 1 (2 GPUs)
            data = paddle.arange(2, dtype='int64') + rank * 2
            # data for rank 0: [0, 1]
            # data for rank 1: [2, 3]
1215
            output = paddle.empty([2], dtype='int64')
1216 1217
            dist.alltoall_single(data, output)
            print(output)
1218 1219 1220
            # output for rank 0: [0, 2]
            # output for rank 1: [1, 3]

1221
            # case 2 (2 GPUs)
1222
            in_split_sizes = [i + 1 for i in range(size)]
1223 1224
            # in_split_sizes for rank 0: [1, 2]
            # in_split_sizes for rank 1: [1, 2]
1225
            out_split_sizes = [rank + 1 for i in range(size)]
1226 1227 1228 1229 1230
            # out_split_sizes for rank 0: [1, 1]
            # out_split_sizes for rank 1: [2, 2]
            data = paddle.ones([sum(in_split_sizes), size], dtype='float32') * rank
            # data for rank 0: [[0., 0.], [0., 0.], [0., 0.]]
            # data for rank 1: [[1., 1.], [1., 1.], [1., 1.]]
1231 1232
            output = paddle.empty([(rank + 1) * size, size], dtype='float32')
            group = dist.new_group([0, 1])
1233
            task = dist.alltoall_single(data,
1234 1235 1236
                                        output,
                                        in_split_sizes,
                                        out_split_sizes,
1237
                                        sync_op=False,
1238 1239
                                        group=group)
            task.wait()
1240
            print(output)
1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251
            # output for rank 0: [[0., 0.], [1., 1.]]
            # output for rank 1: [[0., 0.], [0., 0.], [1., 1.], [1., 1.]]

    """
    if group is not None and not group.is_member():
        return

    assert in_dygraph_mode(), "Only suppport alltoall_single in eager mode."
    # _check_single_tensor

    group = _get_default_group() if group is None else group
1252 1253 1254
    backend = _group_map_backend[group]
    assert backend != 'gloo', ("backend gloo is not supported yet")

1255 1256 1257 1258 1259
    in_split_sizes = [] if in_split_sizes is None else in_split_sizes
    out_split_sizes = [] if out_split_sizes is None else out_split_sizes

    task = group.process_group.alltoall_single(in_tensor, out_tensor,
                                               in_split_sizes, out_split_sizes)
1260
    if sync_op:
1261 1262 1263 1264 1265 1266
        task.wait()
        return
    else:
        return task


S
ShenLiang 已提交
1267 1268 1269 1270
def _get_group_rank(global_rank, group=None):
    return global_rank if group is None else group.get_group_rank(global_rank)


1271
def send(tensor, dst=0, group=None, sync_op=True):
L
lilong12 已提交
1272 1273 1274 1275 1276
    """
    Send a tensor to the receiver.

    Args:
        tensor (Tensor): The Tensor to send. Its data type
1277
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
L
lilong12 已提交
1278
        dst (int): The destination rank id.
L
lilong12 已提交
1279
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
1280
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
1281

L
lilong12 已提交
1282 1283 1284 1285 1286
    Returns:
        None.

    Examples:
        .. code-block:: python
1287

L
lilong12 已提交
1288
            # required: distributed
L
lilong12 已提交
1289
            import paddle
1290
            import paddle.distributed as dist
1291

1292 1293
            dist.init_parallel_env()
            if dist.get_rank() == 0:
L
lilong12 已提交
1294
                data = paddle.to_tensor([7, 8, 9])
1295
                dist.send(data, dst=1)
L
lilong12 已提交
1296
            else:
1297 1298 1299 1300
                data = paddle.to_tensor([1, 2, 3])
                dist.recv(data, src=0)
            print(data)
            # [7, 8, 9] (2 GPUs)
L
lilong12 已提交
1301 1302 1303
    """
    if group is not None and not group.is_member():
        return
S
ShenLiang 已提交
1304
    dst = _get_group_rank(dst, group)
L
lilong12 已提交
1305
    if in_dygraph_mode():
1306
        group = _get_default_group() if group is None else group
1307 1308
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
S
ShenLiang 已提交
1309
        task = group.process_group.send(tensor, dst)
1310
        if sync_op:
1311 1312 1313 1314 1315
            task.wait()
            return None
        else:
            return task

1316
    use_calc_stream = sync_op
L
lilong12 已提交
1317 1318
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
1319
    if _non_static_mode():
1320 1321
        return _legacy_C_ops.send_v2(tensor, 'use_calc_stream', use_calc_stream,
                                     'ring_id', ring_id, 'peer', dst)
W
wanghuancoder 已提交
1322
    op_type = 'send_v2'
L
lilong12 已提交
1323 1324 1325 1326 1327
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'send')

    helper = LayerHelper(op_type, **locals())
1328 1329 1330 1331 1332 1333 1334
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'peer': dst,
                         'use_calc_stream': use_calc_stream,
                     })
L
lilong12 已提交
1335 1336


1337
def recv(tensor, src=0, group=None, sync_op=True):
L
lilong12 已提交
1338 1339 1340 1341 1342
    """
    Receive a tensor to the sender.

    Args:
        tensor (Tensor): The Tensor to receive. Its data type
1343
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
L
lilong12 已提交
1344
        src (int): The source rank id.
L
lilong12 已提交
1345
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
1346
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
1347

L
lilong12 已提交
1348 1349 1350 1351 1352
    Returns:
        None.

    Examples:
        .. code-block:: python
1353

L
lilong12 已提交
1354
            # required: distributed
L
lilong12 已提交
1355
            import paddle
1356
            import paddle.distributed as dist
1357

1358 1359
            dist.init_parallel_env()
            if dist.get_rank() == 0:
L
lilong12 已提交
1360
                data = paddle.to_tensor([7, 8, 9])
1361
                dist.send(data, dst=1)
L
lilong12 已提交
1362
            else:
1363 1364 1365 1366
                data = paddle.to_tensor([1, 2, 3])
                dist.recv(data, src=0)
            print(data)
            # [7, 8, 9] (2 GPUs)
L
lilong12 已提交
1367 1368 1369
    """
    if group is not None and not group.is_member():
        return
1370

S
ShenLiang 已提交
1371
    src = _get_group_rank(src, group)
L
lilong12 已提交
1372
    if in_dygraph_mode():
1373
        group = _get_default_group() if group is None else group
1374 1375
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
S
ShenLiang 已提交
1376
        task = group.process_group.recv(tensor, src)
1377
        if sync_op:
1378 1379 1380 1381 1382
            task.wait()
            return None
        else:
            return task

1383
    use_calc_stream = sync_op
L
lilong12 已提交
1384 1385
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
1386
    if _non_static_mode():
1387 1388 1389
        return _legacy_C_ops.recv_v2(tensor, 'use_calc_stream', use_calc_stream,
                                     'ring_id', ring_id, 'peer', src, 'dtype',
                                     tensor.dtype, 'out_shape', tensor.shape)
W
wanghuancoder 已提交
1390
    op_type = 'recv_v2'
L
lilong12 已提交
1391 1392 1393 1394
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'recv')
    helper = LayerHelper(op_type, **locals())
1395 1396 1397 1398 1399 1400 1401 1402 1403
    helper.append_op(type=op_type,
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'peer': src,
                         'out_shape': tensor.shape,
                         'dtype': tensor.dtype,
                         'use_calc_stream': use_calc_stream,
                     })
1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425


def _check_single_tensor(tensor, tensor_name):
    if not isinstance(tensor, (core.eager.Tensor, paddle.Tensor)):
        raise RuntimeError("Invalid function argument. Expected parameter {}"
                           "to be of type paddle.Tensor, but it's {}".format(
                               tensor_name, type(tensor)))


def _check_tensor_list(tensor_list, tensor_name):
    if not isinstance(tensor_list, list) or \
        not all(isinstance(t, (core.eager.Tensor, paddle.Tensor)) for t in tensor_list):
        raise RuntimeError("Invalid function argument. Expected parameter {}"
                           "to be of type paddle.Tensor".format(tensor_name))


def isend(tensor, dst, group=None):
    """
    Sends a tensor asynchronously

    Args:
        tensor (Tensor): The Tensor to send. Its data type
1426
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1427 1428
        dst (int): The destination rank.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
1429

1430 1431 1432
    Returns:
        A distributed task object.

1433
    Warning:
1434 1435 1436 1437 1438 1439 1440 1441 1442 1443
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
1444
            if dist.get_rank() == 0:
1445
                data = paddle.to_tensor([7, 8, 9])
1446
                task = dist.isend(data, dst=1)
1447 1448
            else:
                data = paddle.to_tensor([1, 2, 3])
1449
                task = dist.irecv(data, src=0)
1450 1451
            task.wait()
            print(data)
1452
            # [7, 8, 9] (2 GPUs)
1453 1454 1455 1456 1457 1458 1459 1460

    """
    _check_single_tensor(tensor, "tensor")
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
1461 1462
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
1463 1464 1465 1466
        group_dst_rank = group.get_group_rank(dst)
        assert group_dst_rank >= 0, ("dst rank out of group, need global rank")
        return group.process_group.send(tensor, group_dst_rank)
    else:
1467
        raise RuntimeError("Only support eager dygraph mode.")
1468 1469 1470 1471 1472 1473 1474 1475


def irecv(tensor, src=None, group=None):
    """
    Receive a tensor to the sender.

    Args:
        tensor (Tensor): The Tensor to receive. Its data type
1476
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1477 1478 1479 1480
        src (int): The source rank id.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.

    Returns:
1481
        A distributed task object.
1482

1483
    Warning:
1484 1485 1486 1487 1488 1489 1490 1491 1492 1493
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
1494
            if dist.get_rank() == 0:
1495
                data = paddle.to_tensor([7, 8, 9])
1496
                task = dist.isend(data, dst=1)
1497 1498
            else:
                data = paddle.to_tensor([1, 2, 3])
1499
                task = dist.irecv(data, src=0)
1500 1501
            task.wait()
            print(data)
1502
            # [7, 8, 9] (2 GPUs)
1503 1504 1505 1506 1507 1508 1509
    """
    _check_single_tensor(tensor, "tensor")
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
1510 1511
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
1512 1513 1514 1515
        group_src_rank = group.get_group_rank(src)
        assert group_src_rank >= 0, ("src rank out of group, need global rank")
        return group.process_group.recv(tensor, group_src_rank)
    else:
1516
        raise RuntimeError("Only support eager dygraph mode.")
1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531


class P2POp(object):
    """
    A class that makes point-to-point operations for "batch_isend_irecv".

    This class creates the type of P2P operation, communication buffer, peer rank,
    Group. Instances of this class will be passed to
    ``paddle.distributed.batch_isend_irecv`` for point-to-point communication.

    Args:
        op (callable): A function to send data to or receive data from a peer process.
            The type of ``op`` is either ``paddle.distributed.isend`` or ``paddle.distributed.irecv``.
        tensor (Tensor): Tensor to send or receive.
        peer (int): The destination or source rank.
1532
        group (Group, optional): The group instance return by new_group or None for global
1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580
            default group. Default: None.

    """

    def __init__(self, op, tensor, peer, group=None):
        if op not in [isend, irecv]:
            raise RuntimeError("Invalid ``op`` function. Expected ``op`` "
                               "to be of type ``paddle.distributed.isend`` or "
                               "``paddle.distributed.irecv``.")
        _check_single_tensor(tensor, "tensor")

        self.op = op
        self.tensor = tensor
        self.peer = peer
        self.group = _get_default_group() if group is None else group


@contextlib.contextmanager
def _with_batch_p2p_guard(backend):
    if backend == "nccl":
        core.ProcessGroupNCCL.group_start()
    try:
        yield
    finally:
        if backend == "nccl":
            core.ProcessGroupNCCL.group_end()


def _check_p2p_op_list(p2p_op_list):
    """
    Helper to check that the ``p2p_op_list`` is a list of P2POp instances and
    all ops use the same backend.
    """
    if not isinstance(p2p_op_list, list) or not all(
            isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list):
        raise RuntimeError("Invalid ``p2p_op_list``. Each op is expected to "
                           "to be of type ``paddle.distributed.P2POp``.")

    backend = _group_map_backend[p2p_op_list[0].group]
    if not all(backend == _group_map_backend[p2p_op.group]
               for p2p_op in p2p_op_list):
        raise RuntimeError("All groups need to use the same backend.")


def batch_isend_irecv(p2p_op_list):
    """
    Send or Receive a batch of tensors asynchronously and return a list of requests.

1581
    Process each of the point-to-point operations in ``p2p_op_list`` and return the
1582 1583 1584
    corresponding tasks. NCCL are currently supported.

    Args:
1585
        p2p_op_list (List[P2POp]): A list of point-to-point operations(type of each operator is
1586 1587 1588 1589 1590 1591
            ``paddle.distributed.P2POp``). The order of the isend/irecv in the list
            matters and it needs to match with corresponding isend/irecv on the
            remote end.

    Returns:
        A list of distributed tasks returned by calling the corresponding
1592
        op in the op_list.
1593

1594
    Warning:
1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed

            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            world_size = dist.get_world_size()

            send_t = paddle.arange(2) + rank
            # paddle.tensor([0, 1])  # Rank-0
            # paddle.tensor([1, 2])  # Rank-1

            recv_t = paddle.empty(shape=[2], dtype=send_t.dtype)

            send_op = dist.P2POp(dist.isend, send_t, (rank + 1) % world_size)
            recv_op = dist.P2POp(dist.irecv, recv_t, (rank - 1 + world_size) % world_size)

            tasks = dist.batch_isend_irecv([send_op, recv_op])

            for task in tasks:
                task.wait()
1622

1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653
            print(recv_t)
            # paddle.tensor([1, 2])     # Rank-0
            # paddle.tensor([0, 1])     # Rank-1
    """
    _check_p2p_op_list(p2p_op_list)
    group = p2p_op_list[0].group
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
        backend = _group_map_backend[group]
        tasks = []
        with _with_batch_p2p_guard(backend):
            for p2p_op in p2p_op_list:
                op = p2p_op.op
                tensor = p2p_op.tensor
                peer = p2p_op.peer
                comm_group = p2p_op.group
                task = op(tensor, peer, comm_group)
                if task is not None:
                    tasks.append(task)
        return tasks
    else:
        raise RuntimeError("Don't support static graph mode currently.")


def reduce_scatter(tensor,
                   tensor_list,
                   op=ReduceOp.SUM,
                   group=None,
1654
                   sync_op=True):
1655 1656 1657 1658
    """
    Reduces, then scatters a list of tensors to all processes in a group

    Args:
1659
        tensor (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1660
        tensor_list (list[Tensor]): List of tensors to reduce and scatter. Every element in the list must be a Tensor whose data type
1661
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1662
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
1663
        group (Group, optional): The group instance return by new_group or None for global
1664
            default group. Default: None.
1665
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
1666 1667

    Returns:
1668 1669
        Async task handle, if sync_op is set to False.
        None, if sync_op or if not part of the group.
1670 1671

    Warning:
1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682
        This API only supports the dygraph mode.


    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
1683 1684 1685
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([0, 1])
                data2 = paddle.to_tensor([2, 3])
1686
            else:
1687 1688 1689 1690 1691 1692
                data1 = paddle.to_tensor([4, 5])
                data2 = paddle.to_tensor([6, 7])
            dist.reduce_scatter(data1, [data1, data2])
            print(data1)
            # [4, 6] (2 GPUs, out for rank 0)
            # [8, 10] (2 GPUs, out for rank 1)
1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703

    """
    _check_single_tensor(tensor, "tensor")
    _check_tensor_list(tensor_list, "tensor_list")

    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        op_type = _get_reduce_op(op, "reduce_scatter")
        group = _get_default_group() if group is None else group
1704 1705
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
1706 1707 1708

        temp = paddle.concat(tensor_list, axis=0)
        task = group.process_group._reduce_scatter_base(tensor, temp, op_type)
1709
        if sync_op:
1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721
            task.wait()
            return None
        else:
            return task
    else:
        raise RuntimeError("Don't support static graph mode currently.")


def _reduce_scatter_base(output,
                         input,
                         op=ReduceOp.SUM,
                         group=None,
1722
                         sync_op=True):
1723 1724 1725 1726
    """
    Reduces, then scatters a flattened tensor to all processes in a group.

    Args:
1727
        output (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1728
        input (Tensor): Input tensor that is of size output tensor size times world size. Its data type
1729
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1730
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
1731 1732
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.
1733 1734
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.

1735
    Returns:
1736 1737
        Async task handle, if sync_op is set to False.
        None, if sync_op or if not part of the group.
1738 1739 1740 1741 1742 1743 1744 1745 1746 1747

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
1748 1749 1750 1751 1752
            data = paddle.arange(4) + rank
            # [0, 1, 2, 3] (2 GPUs, for rank 0)
            # [1, 2, 3, 4] (2 GPUs, for rank 1)
            output = paddle.empty(shape=[2], dtype=data.dtype)
            dist.collective._reduce_scatter_base(output, data)
1753
            print(output)
1754 1755
            # [1, 3] (2 GPUs, out for rank 0)
            # [5, 7] (2 GPUs, out for rank 1)
1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767

    """
    _check_single_tensor(output, "output")
    _check_single_tensor(input, "input")

    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        op_type = _get_reduce_op(op, "_reduce_scatter_base")
        group = _get_default_group() if group is None else group
        task = group.process_group._reduce_scatter_base(output, input, op_type)
1768
        if sync_op:
1769 1770 1771 1772 1773 1774
            task.wait()
            return None
        else:
            return task
    else:
        raise RuntimeError("Don't support static graph mode currently.")