collective.py 63.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import os
17 18
import pickle
import io
19 20
import datetime
import time
21
from ..fluid.layer_helper import LayerHelper
22
from ..fluid.framework import in_dygraph_mode
J
Jiabin Yang 已提交
23
from ..fluid.framework import _non_static_mode
24
from ..fluid.data_feeder import check_variable_and_dtype
25 26 27
from ..fluid.layers.tensor import fill_constant
import paddle
import paddle.fluid.core as core
28
from paddle import _legacy_C_ops
29
import contextlib
30 31 32 33 34 35 36 37 38 39 40 41
from .fleet.layers.mpu.mp_ops import split  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_identity  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_concat  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_split  # noqa: F401
from .fleet.layers.mpu.mp_ops import _mp_allreduce  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_lookup_table  # noqa: F401
from .fleet.layers.mpu.mp_ops import _Linear  # noqa: F401
from .fleet.layers.mpu.mp_ops import _set_var_distributed  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_softmax_with_cross_entropy  # noqa: F401
from .fleet.layers.mpu.mp_ops import _linear  # noqa: F401
from .fleet.layers.mpu.mp_ops import _parallel_linear  # noqa: F401
from .fleet.layers.mpu.mp_ops import _parallel_embedding  # noqa: F401
42
from .communication.group import Group, _add_new_group
43
from .communication.all_reduce import all_reduce  # noqa: F401
44
from .communication.reduce import _get_reduce_op, ReduceOp
45

46
__all__ = []
47

K
kuizhiqing 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60
_global_env = None


def _get_global_env():
    global _global_env
    if not _global_env:
        _global_env = paddle.distributed.ParallelEnv()
    return _global_env


# group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group]
_group_map = {}
61
_global_env_gid = 0
K
kuizhiqing 已提交
62

63 64 65 66
# group map by name : the map of all groups from their names
# Dict[name, Group]
_group_map_by_name = {}

67 68 69 70
# backend map by group : the map of all backend from their groups
# Dict[group, backend]
_group_map_backend = {}

71 72 73
# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"

74
_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter', 'xccl']
75 76
_default_store = None  # the default tcp store
_default_backend = None
77 78
_default_timeout = datetime.timedelta(seconds=1800)
_start_ring_id = 0
79

K
kuizhiqing 已提交
80

L
lilong12 已提交
81 82 83 84 85 86 87 88 89 90
def _set_default_backend(backend):
    global _default_backend
    _default_backend = backend


def _set_default_store(store):
    global _default_store
    _default_store = store


K
kuizhiqing 已提交
91 92
def _get_group_map():
    global _group_map
93
    if _global_env_gid not in _group_map:
K
kuizhiqing 已提交
94
        genv = _get_global_env()
95 96 97
        _group_map[_global_env_gid] = Group(
            genv.rank, 0, list(range(genv.world_size))
        )
K
kuizhiqing 已提交
98 99 100 101
    return _group_map


def _get_global_group():
102
    return _get_group_map()[_global_env_gid]
K
kuizhiqing 已提交
103 104


105 106 107 108 109 110
def _get_group_map_by_name():
    global _group_map_by_name
    return _group_map_by_name


def _get_default_group():
L
lilong12 已提交
111
    global _group_map_by_name
112 113 114 115
    assert is_initialized(), (
        "Call paddle.distributed.init_parallel_env first "
        "to initialize the distributed environment."
    )
116 117 118
    return _get_group_map_by_name()[_default_group_name]


L
lilong12 已提交
119 120 121 122 123 124 125 126 127 128 129 130
def _set_group_map(gid, group):
    global _group_map
    assert gid not in _group_map
    _group_map[gid] = group


def _set_group_map_by_name(name, group):
    global _group_map_by_name
    assert name not in _group_map_by_name
    _group_map_by_name[name] = group


131 132 133 134 135 136
def _set_group_map_backend(group, backend):
    global _group_map_backend
    assert group not in _group_map_backend
    _group_map_backend[group] = backend


K
kuizhiqing 已提交
137
def _new_ring_id():
138 139 140 141 142 143 144
    # NOTE(liyurui): For compatible reason, auto parallel and eager mode relay on previous syntax.
    if in_dygraph_mode():
        global _start_ring_id
        _start_ring_id += 1
        return _start_ring_id + max(_get_global_env().nrings, 9)
    else:
        return len(_get_group_map()) + max(_get_global_env().nrings, 9)
K
kuizhiqing 已提交
145 146 147 148 149 150 151 152


def get_group(id=0):
    """

    Get group instance by group id.

    Args:
K
kuizhiqing 已提交
153
        id (int): the group id. Default value is 0.
K
kuizhiqing 已提交
154 155 156 157 158 159 160 161 162 163 164 165 166 167

    Returns:
        Group: the group instance.

    Examples:
        .. code-block:: python

            ...
            gid = paddle.distributed.new_group([2,4,6])
            paddle.distributed.get_group(gid.id)

    """

    gm = _get_group_map()
J
Jiangxinz 已提交
168
    return gm[id] if id in gm else None
K
kuizhiqing 已提交
169 170


171 172 173 174 175 176 177 178 179 180 181
def _new_process_group_impl(
    backend,
    store,
    rank,
    world_size,
    group_name,
    pg_options,
    group_id=0,
    src_rank=None,
    dst_rank=None,
):
182
    pg = None
183
    genv = _get_global_env()
L
lilong12 已提交
184 185
    if backend != 'heter':
        assert src_rank is None and dst_rank is None, (
186 187
            "src_rank and dst_rank " "can only be set for heter backend."
        )
L
lilong12 已提交
188
    assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
189
    if backend == "gloo":
190 191
        place = core.CPUPlace()
        pg = core.ProcessGroupGloo(store, rank, world_size, place, group_id)
192
    elif backend == "nccl":
193 194
        place = core.CUDAPlace(genv.device_id)
        pg = core.ProcessGroupNCCL(store, rank, world_size, place, group_id)
195
    elif backend == "hccl":
196 197
        place = core.NPUPlace(genv.device_id)
        pg = core.ProcessGroupHCCL(store, rank, world_size, place, group_id)
198 199 200
    elif backend == "xccl":
        place = core.CustomPlace(genv.device_type, genv.device_id)
        pg = core.ProcessGroupCustom(store, rank, world_size, place, group_id)
201
    elif backend == "heter":
202 203 204 205 206
        place = None
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(genv.device_id)
        elif core.is_compiled_with_npu():
            place = core.NPUPlace(genv.device_id)
207 208 209 210 211 212 213 214 215
        cluster_id = int(os.getenv("CLUSTER_ID", "-1"))
        assert cluster_id >= 0, "please set the CLUSTER_ID variable."
        cluster_size = os.getenv("CLUSTER_SIZE", None)
        assert cluster_size, "please set the CLUSTER_SIZE variable."
        cluster_size = cluster_size.split(",")
        cluster_size = [int(s) for s in cluster_size]
        switch_ep = os.getenv("CLUSTER_SWITCH", None)
        assert switch_ep, "please set the CLUSTER_SWITCH variable."
        cluster_size_cumsum = np.cumsum(cluster_size)
216 217 218
        cluster_offset = (
            0 if cluster_id == 0 else cluster_size_cumsum[cluster_id - 1]
        )
219 220
        global_rank = cluster_offset + rank
        global_world_size = cluster_size_cumsum[-1]
221
        global_rank, global_world_size = _get_global_config(backend, rank)
222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
        pg = core.ProcessGroupHeter(
            store,
            rank=global_rank,
            world_size=global_world_size,
            place=place,
            gid=group_id,
            local_rank=rank,
            local_size=world_size,
            gloo_rank=cluster_id,
            gloo_size=len(cluster_size),
            with_switch=True,
            switch_endpoint=switch_ep,
            src_rank=src_rank,
            dst_rank=dst_rank,
        )
237 238 239 240

    return pg


S
ShenLiang 已提交
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
def barrier(group=None):
    """

    Barrier among all participators in the group.

    Args:
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            paddle.distributed.barrier()
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
265
    if in_dygraph_mode():
266 267 268 269 270
        group = _get_default_group() if group is None else group
        task = group.process_group.barrier()
        task.wait()
        return

S
ShenLiang 已提交
271 272 273
    ring_id = 0 if group is None else group.id

    temp = fill_constant([1], dtype="int32", value="1")
J
Jiabin Yang 已提交
274
    if _non_static_mode():
275
        return _legacy_C_ops.barrier(temp, temp, 'ring_id', ring_id)
W
wanghuancoder 已提交
276 277 278

    op_type = 'barrier'

S
ShenLiang 已提交
279 280 281
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'group' for barrier must be int.")
    helper = LayerHelper(op_type, **locals())
282 283 284 285 286 287
    helper.append_op(
        type=op_type,
        inputs={'X': [temp]},
        outputs={'Out': [temp]},
        attrs={'ring_id': ring_id},
    )
S
ShenLiang 已提交
288 289


L
lilong12 已提交
290 291 292 293 294 295 296
# _custom_gid provides a way for users to
# set the group id, which is usually useful
# to be compatible with the static mode.
_custom_gid = None


def _set_custom_gid(gid):
297
    global _custom_gid
L
lilong12 已提交
298 299 300
    _custom_gid = gid


301 302 303 304 305 306 307 308
def _barrier_by_tcp_store(group_name, store, timeout):
    global_rank = paddle.distributed.get_rank()
    global_world_size = paddle.distributed.get_world_size()

    if global_world_size < 2:
        return

    barrier_prefix = "Barrier/" + group_name + "/"
309
    is_master = global_rank == 0
310 311 312 313 314 315 316 317 318 319 320 321

    def _check_keys_ready(wait_keys):
        start_time = time.time()
        while len(wait_keys) > 0:
            time.sleep(0.1)
            elapse_time = time.time() - start_time
            if datetime.timedelta(seconds=elapse_time) > timeout:
                raise RuntimeError(
                    "Timeout while initializing process group {}."
                    "Keys {} are not ready sinck rank {} is waiting them."
                    "Two reason may cause this error:\n 1. The create process group api should be called by all ranks.\n"
                    " 2. Try to increase the waiting time.\n".format(
322 323 324
                        group_name, wait_keys, global_rank
                    )
                )
325
            wait_keys = list(
326 327
                filter(lambda key: int(store.get(key)) != 1, wait_keys)
            )
328 329 330 331 332 333 334 335 336 337 338 339 340

    # all the workers set their exiting key and exit
    # the master will wait for all workers' exiting key, ensure to exit in the end
    if is_master:
        wait_keys = [
            barrier_prefix + str(rank) for rank in range(1, global_world_size)
        ]
        _check_keys_ready(wait_keys)
    else:
        store.add(barrier_prefix + str(global_rank), 1)


def new_group(ranks=None, backend=None, timeout=_default_timeout):
K
kuizhiqing 已提交
341 342
    """

K
kuizhiqing 已提交
343
    Creates a new distributed communication group.
K
kuizhiqing 已提交
344 345

    Args:
K
kuizhiqing 已提交
346
        ranks (list): The global ranks of group members.
K
kuizhiqing 已提交
347
        backend (str): The backend used to create group, only nccl is supported now.
348
        timeout (datetime.timedelta, optional): The waiting timeout for store relevant options, default is 30 minutes.
K
kuizhiqing 已提交
349 350

    Returns:
K
kuizhiqing 已提交
351
        Group: The group instance.
K
kuizhiqing 已提交
352 353 354 355 356 357 358

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
359 360
            tindata = paddle.randn(shape=[2, 3])
            gp = paddle.distributed.new_group([2,4,6])
361
            paddle.distributed.all_reduce(tindata, group=gp, sync_op=False)
K
kuizhiqing 已提交
362 363

    """
364
    global _custom_gid
365
    global _group_map
L
lilong12 已提交
366
    if in_dygraph_mode():
367
        global _default_group_name
L
lilong12 已提交
368
        gid = _custom_gid if _custom_gid else _new_ring_id()
369
        group_name = _default_group_name + str(gid)
L
lilong12 已提交
370
        if backend != 'heter' and (ranks is None or len(ranks) > 1):
371 372 373 374 375 376 377 378
            global_group = _get_default_group()
            global_rank = global_group.rank
            global_ranks = global_group.ranks
            backend = _default_backend if backend is None else backend
            if ranks is None:
                ranks = global_ranks
            assert len(ranks) <= len(global_ranks), (
                "Size of new group must be less than or "
379 380
                "equal to that of the default global group."
            )
381 382
        size = len(ranks)
        ranks = sorted(ranks)
L
lilong12 已提交
383 384 385 386
        if backend == 'heter' or (size > 1 and global_rank in ranks):
            rank = 0 if backend == 'heter' else ranks.index(global_rank)
            src_rank = ranks[0] if backend == 'heter' else None
            dst_rank = ranks[1] if backend == 'heter' else None
387 388 389 390 391 392 393 394 395 396 397
            pg = _new_process_group_impl(
                backend,
                _default_store,
                rank,
                size,
                group_name,
                pg_options=None,
                group_id=gid,
                src_rank=src_rank,
                dst_rank=dst_rank,
            )
398 399 400
        else:
            rank = -1
            pg = None
401
        group = Group(rank, gid, ranks, pg=pg, name=group_name)
402 403
        _group_map_by_name[group_name] = group
        _group_map[gid] = group
404
        _group_map_backend[group] = backend
405
        # TODO: The method below is a new method for group management, will replace the previous
406 407
        # three in the future.
        _add_new_group(group)
408

409
        # TODO(shenliang03): This is a temporary solution to solve the problem of
410
        # hang caused by tcp
411
        paddle.distributed.barrier(group=group)
412 413 414 415 416
        # NOTE(liyurui): All processors should hang and wait using tcp store, in case master exit before sub-group is created.
        if backend != 'heter':
            _barrier_by_tcp_store(group_name, _default_store, timeout)
        else:
            print("Warning: store barrier is not supported for heter backend.")
417
        return group
K
kuizhiqing 已提交
418 419 420

    if not backend:
        backend = 'nccl'
421
    assert backend == 'nccl', "backend other than nccl is not supported yet"
K
kuizhiqing 已提交
422 423 424 425 426 427 428

    genv = _get_global_env()
    global_rank = genv.rank

    ring_id = _new_ring_id()

    if global_rank not in ranks:
429
        gp = Group(-1, ring_id, ranks)
K
kuizhiqing 已提交
430 431
        _group_map[ring_id] = gp
    else:
432 433 434
        ranks = sorted(ranks)
        group_rank = ranks.index(global_rank)
        group_size = len(ranks)
435
        gp = Group(group_rank, ring_id, ranks)
436 437 438 439 440 441 442 443 444 445 446 447 448 449
        _group_map[ring_id] = gp

        if group_size >= 2:
            strategy = core.ParallelStrategy()
            strategy.nranks = group_size
            strategy.local_rank = group_rank
            strategy.trainer_endpoints = [
                genv.trainer_endpoints[i] for i in ranks
            ]
            strategy.current_endpoint = genv.current_endpoint
            strategy.nrings = 1

            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(genv.device_id)
450 451 452
                core.NCCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
453 454
            elif core.is_compiled_with_npu():
                place = core.NPUPlace(genv.device_id)
455 456 457
                core.HCCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
458 459
            elif core.is_compiled_with_mlu():
                place = core.MLUPlace(genv.device_id)
460 461 462
                core.CNCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
463 464
            elif core.is_compiled_with_xpu():
                place = core.XPUPlace(genv.device_id)
465 466 467
                core.BKCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
468
            else:
469
                assert False, "no cuda device found"
470 471 472
        else:
            return gp

473
    # TODO(shenliang03): This is a temporary solution to solve the problem of
474
    # hang caused by cross-creation of new_group
475 476 477 478 479
    tmp = (
        paddle.to_tensor([1], dtype="int32")
        if _non_static_mode()
        else fill_constant([0], dtype="int32", value="1")
    )
480
    paddle.distributed.all_reduce(tmp, sync_op=True)
481
    paddle.distributed.wait(tmp)
K
kuizhiqing 已提交
482 483
    return gp

484

485 486 487 488 489
def is_initialized():
    """

    Check whether the distributed environment has been initialized

490 491
    Returns:
        `True` if distributed environment has been initialized, otherwise `False`.
492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle

            print(paddle.distributed.is_initialized())
            # False

            paddle.distributed.init_parallel_env()
            print(paddle.distributed.is_initialized())
            # True

    """
    global _group_map_by_name
    return _default_group_name in _group_map_by_name


def destroy_process_group(group=None):
    """
    Destroy a given group for communication

    Args:
516 517
        group (ProcessGroup, optional): The group to be destroyed. All of process groups, including
                                        the default group, will be destroyed and the distributed
518
                                        environment will be deinitialized.
519

520 521 522 523 524 525 526
    Returns : None

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
527
            import paddle.distributed as dist
528

529 530
            dist.init_parallel_env()
            group = dist.new_group([0, 1])
531

532 533
            dist.destroy_process_group(group)
            print(dist.is_initialized())
534
            # True
535 536
            dist.destroy_process_group()
            print(dist.is_initialized())
537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555
            # False

    """
    global _group_map
    global _group_map_by_name

    pg = _get_default_group() if group is None else group
    assert _group_map.get(pg.id, None) is not None, "Invalid group."

    if group is None:
        _group_map.clear()
        _group_map_by_name.clear()
        _group_map_backend.clear()
    else:
        del _group_map[pg.id]
        del _group_map_by_name[pg.name]
        del _group_map_backend[pg]


K
kuizhiqing 已提交
556 557 558 559 560 561 562 563
def wait(tensor, group=None, use_calc_stream=True):
    """

    wait to sync stream for group.

    Args:
        tensor (Tensor): The Tensor used before sync.
        group (Group): The Group instance to perform sync.
K
kuizhiqing 已提交
564 565
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
K
kuizhiqing 已提交
566 567 568 569 570 571 572 573 574 575

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
576
            tindata = paddle.randn(shape=[2, 3])
577
            paddle.distributed.all_reduce(tindata, sync_op=True)
K
kuizhiqing 已提交
578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594
            paddle.distributed.wait(tindata)

    """

    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id

    if use_calc_stream:
        _sync_calc_stream(tensor)
    else:
        _sync_comm_stream(tensor, ring_id)


def _sync_calc_stream(tensor):

J
Jiabin Yang 已提交
595
    if _non_static_mode():
596
        return _legacy_C_ops.c_sync_calc_stream(tensor, tensor)
K
kuizhiqing 已提交
597 598 599 600 601 602 603

    op_type = 'c_sync_calc_stream'

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
604 605
        outputs={'Out': [tensor]},
    )
606

607

K
kuizhiqing 已提交
608
def _sync_comm_stream(tensor, ring_id=0):
609

J
Jiabin Yang 已提交
610
    if _non_static_mode():
611 612 613
        return _legacy_C_ops.c_sync_comm_stream(
            [tensor], [tensor], 'ring_id', ring_id
        )
614

K
kuizhiqing 已提交
615
    op_type = 'c_sync_comm_stream'
616

K
kuizhiqing 已提交
617 618 619 620 621
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
622 623
        attrs={'ring_id': ring_id},
    )
K
kuizhiqing 已提交
624 625


626
def broadcast(tensor, src, group=None, sync_op=True):
627 628 629
    """

    Broadcast a tensor from the source to all others.
630 631
    As shown below, one process is started with a GPU and GPU0 owns data 0. Through broadcast operator,
    data 0 will be sent to all GPUs from GPU0.
632 633 634 635 636

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/broadcast.png
        :width: 800
        :alt: broadcast
        :align: center
637 638

    Args:
639
        tensor (Tensor): The Tensor to send if current rank is the source, or the Tensor to receive otherwise. Its data type
640
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
641
        src (int): The source rank.
642 643
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
644 645 646 647 648 649 650

    Returns:
        None.

    Examples:
        .. code-block:: python

651
            # required: distributed
652
            import paddle
653
            import paddle.distributed as dist
654

655 656 657
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
658
            else:
659 660 661 662
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.broadcast(data, src=1)
            print(data)
            # [[1, 2, 3], [1, 2, 3]] (2 GPUs)
663
    """
K
kuizhiqing 已提交
664 665 666 667 668 669 670

    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

L
lilong12 已提交
671
    if in_dygraph_mode():
672 673
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
674
        assert gsrc >= 0, "src rank out of group, need global rank"
675
        task = group.process_group.broadcast(tensor, gsrc)
676
        if sync_op:
677 678 679 680 681
            task.wait()
            return None
        else:
            return task

682
    use_calc_stream = sync_op
683
    ring_id = ring_id = 0 if group is None else group.id
K
kuizhiqing 已提交
684
    gsrc = src if group is None else group.get_group_rank(src)
685
    assert gsrc >= 0, "src rank out of group, need global rank"
K
kuizhiqing 已提交
686

J
Jiabin Yang 已提交
687
    if _non_static_mode():
688 689 690 691 692 693 694 695 696 697
        return _legacy_C_ops.c_broadcast(
            tensor,
            tensor,
            'root',
            gsrc,
            'use_calc_stream',
            use_calc_stream,
            'ring_id',
            ring_id,
        )
698 699

    op_type = 'c_broadcast'
700 701 702 703 704 705 706 707 708 709 710 711 712 713 714
    check_variable_and_dtype(
        tensor,
        'tensor',
        [
            'float16',
            'float32',
            'float64',
            'int32',
            'int64',
            'int8',
            'uint8',
            'bool',
        ],
        'broadcast',
    )
715 716

    helper = LayerHelper(op_type, **locals())
717 718 719 720 721 722 723 724 725 726
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
        attrs={
            'root': gsrc,
            'use_calc_stream': use_calc_stream,
            'ring_id': ring_id,
        },
    )
727 728


729
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, sync_op=True):
730 731
    """

732 733
    Reduce a tensor to the destination from all others. As shown below, one process is started with a GPU and the data of this process is represented
    by its group rank. The destination of the reduce operator is GPU0 and the process is sum. Through reduce operator,
734 735 736 737 738 739
    the GPU0 will owns the sum of all data from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/reduce.png
        :width: 800
        :alt: reduce
        :align: center
740 741 742

    Args:
        tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
743
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
744
        dst (int): The destination rank id.
745 746 747
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The operation used. Default value is ReduceOp.SUM.
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
748 749 750 751 752 753 754

    Returns:
        None.

    Examples:
        .. code-block:: python

755
            # required: distributed
756
            import paddle
757
            import paddle.distributed as dist
758

759 760 761
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
762
            else:
763 764 765 766 767
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.reduce(data, dst=0)
            print(data)
            # [[5, 7, 9], [5, 7, 9]] (2 GPUs, out for rank 0)
            # [[1, 2, 3], [1, 2, 3]] (2 GPUs, out for rank 1)
768
    """
K
kuizhiqing 已提交
769 770 771
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
772
    if in_dygraph_mode():
773
        op_type = _get_reduce_op(op, "reduce")
774 775
        group = _get_default_group() if group is None else group
        gdst = group.get_group_rank(dst)
776
        assert gdst >= 0, "dst rank out of group, need global rank"
777
        task = group.process_group.reduce(tensor, gdst, op_type)
778
        if sync_op:
779 780 781 782
            task.wait()
            return None
        else:
            return task
K
kuizhiqing 已提交
783

784
    use_calc_stream = sync_op
K
kuizhiqing 已提交
785 786
    ring_id = 0 if group is None else group.id
    gdst = dst if group is None else group.get_group_rank(dst)
787
    assert gdst >= 0, "dst rank out of group, need global rank"
K
kuizhiqing 已提交
788

J
Jiabin Yang 已提交
789
    if _non_static_mode():
790
        if op == ReduceOp.SUM:
791 792 793 794 795 796 797 798 799 800
            return _legacy_C_ops.c_reduce_sum(
                tensor,
                tensor,
                'use_calc_stream',
                use_calc_stream,
                'ring_id',
                ring_id,
                'root_id',
                gdst,
            )
801
        elif op == ReduceOp.MAX:
802 803 804 805 806 807 808 809 810 811
            return _legacy_C_ops.c_reduce_max(
                tensor,
                tensor,
                'use_calc_stream',
                use_calc_stream,
                'ring_id',
                ring_id,
                'root_id',
                gdst,
            )
812
        elif op == ReduceOp.MIN:
813 814 815 816 817 818 819 820 821 822
            return _legacy_C_ops.c_reduce_min(
                tensor,
                tensor,
                'use_calc_stream',
                use_calc_stream,
                'ring_id',
                ring_id,
                'root_id',
                gdst,
            )
823
        elif op == ReduceOp.PROD:
824 825 826 827 828 829 830 831 832 833
            return _legacy_C_ops.c_reduce_prod(
                tensor,
                tensor,
                'use_calc_stream',
                use_calc_stream,
                'ring_id',
                ring_id,
                'root_id',
                gdst,
            )
834 835 836 837
        else:
            raise ValueError("Unknown parameter: {}.".format(op))

    op_type = 'c_reduce'
838 839 840 841 842 843 844 845 846 847 848 849 850 851 852
    check_variable_and_dtype(
        tensor,
        'tensor',
        [
            'float16',
            'float32',
            'float64',
            'int32',
            'int64',
            'int8',
            'uint8',
            'bool',
        ],
        'reduce',
    )
853 854 855 856 857 858 859 860 861 862 863

    if op == ReduceOp.SUM:
        op_type = 'c_reduce_sum'
    elif op == ReduceOp.MAX:
        op_type = 'c_reduce_max'
    elif op == ReduceOp.MIN:
        op_type = 'c_reduce_min'
    elif op == ReduceOp.PROD:
        op_type = 'c_reduce_prod'

    helper = LayerHelper(op_type, **locals())
864 865 866 867 868 869 870 871 872 873
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
        attrs={
            'ring_id': ring_id,
            'use_calc_stream': use_calc_stream,
            'root_id': gdst,
        },
    )
874 875


876
def all_gather(tensor_list, tensor, group=None, sync_op=True):
877 878
    """

879
    Gather tensors from all participators and all get the result. As shown
880 881
    below, one process is started with a GPU and the data of this process is represented
    by its group rank. Through the all_gather operator, each GPU will have data
882 883 884 885 886 887
    from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png
        :width: 800
        :alt: all_gather
        :align: center
888 889 890

    Args:
        tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
891
            should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128.
892
        tensor (Tensor): The Tensor to send. Its data type
893
            should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
894 895
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
896 897 898 899 900 901 902

    Returns:
        None.

    Examples:
        .. code-block:: python

903
            # required: distributed
904
            import paddle
905
            import paddle.distributed as dist
906

907
            dist.init_parallel_env()
908
            tensor_list = []
909 910
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
911
            else:
912 913 914 915
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.all_gather(tensor_list, data)
            print(tensor_list)
            # [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
916
    """
K
kuizhiqing 已提交
917 918 919
    if group is not None and not group.is_member():
        return

920 921 922 923 924 925
    def convert_to_complex(list_of_tensor):
        list_of_complex = []
        for tensor in list_of_tensor:
            list_of_complex.append(paddle.as_complex(tensor))
        return list_of_complex

926 927 928
    is_input_complex = (
        tensor.dtype == paddle.complex64 or tensor.dtype == paddle.complex128
    )
929 930 931
    if is_input_complex:
        tensor = paddle.as_real(tensor)

L
lilong12 已提交
932
    if in_dygraph_mode():
933
        group = _get_default_group() if group is None else group
934 935 936 937 938 939
        if len(tensor_list) == 0:
            tensor_shape = list(tensor.shape)
            tensor_shape[0] *= group.nranks
            out = paddle.empty(tensor_shape, tensor.dtype)
        else:
            out = paddle.concat(tensor_list, axis=0)
940 941 942
        task = group.process_group.all_gather(tensor, out)
        task.wait()
        tensor_list.clear()
943 944 945 946 947
        list_of_tensor = paddle.split(out, group.nranks, 0)
        if is_input_complex:
            tensor_list.extend(convert_to_complex(list_of_tensor))
        else:
            tensor_list.extend(list_of_tensor)
948 949
        return

950
    use_calc_stream = sync_op
K
kuizhiqing 已提交
951 952 953
    ring_id = 0 if group is None else group.id
    nranks = _get_global_group().nranks if group is None else group.nranks

J
Jiabin Yang 已提交
954
    if _non_static_mode():
955 956 957 958 959 960 961 962 963
        out = _legacy_C_ops.c_allgather(
            tensor,
            'use_calc_stream',
            use_calc_stream,
            'ring_id',
            ring_id,
            'nranks',
            nranks,
        )
964
    else:
965 966 967
        op_type = 'c_allgather'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
968
        if not isinstance(tensor_list, list):
969 970 971
            raise ValueError(
                "The type of 'tensor_list' for all_gather " "should be list."
            )
972
        for elem in tensor_list:
973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016
            check_variable_and_dtype(
                elem,
                'tensor_list',
                [
                    'float16',
                    'float32',
                    'float64',
                    'int32',
                    'int64',
                    'bool',
                    'int8',
                    'uint8',
                    'complex64',
                    'complex128',
                ],
                'all_gather',
            )
        check_variable_and_dtype(
            tensor,
            'tensor',
            [
                'float16',
                'float32',
                'float64',
                'int32',
                'int64',
                'bool',
                'int8',
                'uint8',
                'complex64',
                'complex128',
            ],
            'all_gather',
        )
        helper.append_op(
            type=op_type,
            inputs={'X': [tensor]},
            outputs={'Out': [out]},
            attrs={
                'ring_id': ring_id,
                'use_calc_stream': use_calc_stream,
                'nranks': nranks,
            },
        )
1017

1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030
    list_of_tensor = paddle.split(out, nranks, 0)
    if is_input_complex:
        tensor_list.extend(convert_to_complex(list_of_tensor))
    else:
        tensor_list.extend(list_of_tensor)


def _convert_object_to_tensor(obj):
    _pickler = pickle.Pickler
    f = io.BytesIO()
    _pickler(f).dump(obj)
    data = np.frombuffer(f.getvalue(), dtype=np.uint8)
    tensor = paddle.to_tensor(data)
1031
    return tensor, tensor.numel()
1032 1033


1034
def _convert_tensor_to_object(tensor, len_of_tensor):
1035
    _unpickler = pickle.Unpickler
1036
    return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()
1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063


def all_gather_object(object_list, obj, group=None):
    """

    Gather picklable objects from all participators and all get the result. Similiar to all_gather(), but python object can be passed in.

    Args:
        object_list (list): A list of output object. The datatype of every element in the list is same as the input obj.
        obj (Any): The picklable object to send.
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Warning:
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            object_list = []
1064
            if dist.get_rank() == 0:
1065 1066 1067
                obj = {"foo": [1, 2, 3]}
            else:
                obj = {"bar": [4, 5, 6]}
1068 1069 1070
            dist.all_gather_object(object_list, obj)
            print(object_list)
            # [{'foo': [1, 2, 3]}, {'bar': [4, 5, 6]}] (2 GPUs)
1071
    """
1072 1073
    assert (
        in_dygraph_mode()
1074 1075
    ), "all_gather_object doesn't support static graph mode."

1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088
    tensor, len_of_tensor = _convert_object_to_tensor(obj)

    # gather len_of_tensor from all ranks
    list_len_of_tensor = []
    all_gather(list_len_of_tensor, len_of_tensor, group)
    # get the max length from list
    max_len_of_tensor = int(max(list_len_of_tensor).item())
    # resize the input tensor to max length avoid hang in all gather
    # Note(liyurui): Maybe we should support various length all_gather?
    # Now this operation is efficient for we don't support resize in python.
    numpy_data = tensor.numpy()
    numpy_data = np.resize(numpy_data, [max_len_of_tensor])
    input_tensor = paddle.to_tensor(numpy_data)
1089 1090

    tensor_list = []
1091 1092 1093
    all_gather(tensor_list, input_tensor, group)
    for i, tensor in enumerate(tensor_list):
        object_list.append(
1094 1095
            _convert_tensor_to_object(tensor, list_len_of_tensor[i])
        )
1096 1097


1098
def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True):
1099 1100
    """

1101
    Scatter a tensor to all participators. As shown below, one process is started with a GPU and the source of the scatter
1102 1103 1104 1105 1106 1107
    is GPU0. Through scatter operator, the data in GPU0 will be sent to all GPUs averagely.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/scatter.png
        :width: 800
        :alt: scatter
        :align: center
1108 1109 1110

    Args:
        tensor (Tensor): The output Tensor. Its data type
1111
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1112
        tensor_list (list|tuple): A list/tuple of Tensors to scatter. Every element in the list must be a Tensor whose data type
1113
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. Default value is None.
K
kuizhiqing 已提交
1114
        src (int): The source rank id. Default value is 0.
1115 1116
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
1117 1118 1119 1120 1121 1122 1123

    Returns:
        None.

    Examples:
        .. code-block:: python

1124
            # required: distributed
1125
            import paddle
1126
            import paddle.distributed as dist
1127

1128 1129 1130 1131 1132
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([7, 8, 9])
                data2 = paddle.to_tensor([10, 11, 12])
                dist.scatter(data1, src=1)
1133
            else:
1134 1135 1136 1137 1138 1139
                data1 = paddle.to_tensor([1, 2, 3])
                data2 = paddle.to_tensor([4, 5, 6])
                dist.scatter(data1, tensor_list=[data1, data2], src=1)
            print(data1, data2)
            # [1, 2, 3] [10, 11, 12] (2 GPUs, out for rank 0)
            # [4, 5, 6] [4, 5, 6] (2 GPUs, out for rank 1)
1140
    """
K
kuizhiqing 已提交
1141 1142 1143 1144 1145 1146
    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

L
lilong12 已提交
1147
    if in_dygraph_mode():
1148 1149 1150 1151 1152 1153 1154 1155 1156
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
        rank = group.rank
        nranks = group.nranks
    else:
        ring_id = 0 if group is None else group.id
        gsrc = src if group is None else group.get_group_rank(src)
        rank = _get_global_group().rank if group is None else group.rank
        nranks = _get_global_group().nranks if group is None else group.nranks
1157
    assert gsrc >= 0, "src rank out of group, need global rank"
K
kuizhiqing 已提交
1158 1159

    if rank != gsrc:
1160 1161 1162 1163
        tensor_list = []
        for _ in range(nranks):
            tensor_list.append(tensor)
    temp = paddle.concat(tensor_list, axis=0)
L
lilong12 已提交
1164
    if in_dygraph_mode():
1165
        task = group.process_group.scatter(temp, tensor, gsrc)
1166
        if sync_op:
1167 1168 1169 1170 1171
            task.wait()
            return None
        else:
            return task

1172
    use_calc_stream = sync_op
L
lilong12 已提交
1173
    if _non_static_mode():
1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185
        return _legacy_C_ops.c_scatter(
            temp,
            tensor,
            'use_calc_stream',
            use_calc_stream,
            'ring_id',
            ring_id,
            'nranks',
            nranks,
            'root',
            gsrc,
        )
W
wanghuancoder 已提交
1186
    op_type = 'c_scatter'
1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201
    check_variable_and_dtype(
        tensor,
        'tensor',
        [
            'float16',
            'float32',
            'float64',
            'int32',
            'int64',
            'int8',
            'uint8',
            'bool',
        ],
        'scatter',
    )
1202
    helper = LayerHelper(op_type, **locals())
1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213
    helper.append_op(
        type=op_type,
        inputs={'X': [temp]},
        outputs={'Out': [tensor]},
        attrs={
            'ring_id': ring_id,
            'root': gsrc,
            'use_calc_stream': use_calc_stream,
            'nranks': nranks,
        },
    )
1214 1215


1216
def alltoall(in_tensor_list, out_tensor_list, group=None, sync_op=True):
L
lilong12 已提交
1217
    """
1218 1219 1220 1221 1222 1223 1224 1225 1226 1227
    Scatter tensors in in_tensor_list to all participators averagely and gather the result tensors in out_tensor_list.
    As shown below, the in_tensor_list in GPU0 includes 0_0 and 0_1, and GPU1 includes 1_0 and 1_1.
    Through alltoall operator, the 0_0 in GPU0 will be sent to GPU0 and 0_1 to GPU1, 1_0 in GPU1 sent to GPU0 and 1_1 to GPU1.
    Finally the out_tensor_list in GPU0 includes 0_0 and 1_0, and GPU1 includes 0_1 and 1_1.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/alltoall.png
        :width: 800
        :alt: alltoall
        :align: center

L
lilong12 已提交
1228 1229
    Args:
        in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type
1230
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1231
        out_tensor_list (list): A list of output Tensors. The data type of its elements should be the same as the
L
lilong12 已提交
1232 1233
            data type of the input Tensors.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
1234
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
1235

L
lilong12 已提交
1236 1237
    Returns:
        None.
1238

L
lilong12 已提交
1239 1240
    Examples:
        .. code-block:: python
1241

L
lilong12 已提交
1242 1243
            # required: distributed
            import paddle
1244 1245 1246
            import paddle.distributed as dist

            dist.init_parallel_env()
L
lilong12 已提交
1247
            out_tensor_list = []
1248 1249 1250
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
                data2 = paddle.to_tensor([[7, 8, 9], [10, 11, 12]])
L
lilong12 已提交
1251
            else:
1252 1253 1254 1255 1256 1257
                data1 = paddle.to_tensor([[13, 14, 15], [16, 17, 18]])
                data2 = paddle.to_tensor([[19, 20, 21], [22, 23, 24]])
            dist.alltoall([data1, data2], out_tensor_list)
            print(out_tensor_list)
            # [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]] (2 GPUs, out for rank 0)
            # [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]] (2 GPUs, out for rank 1)
L
lilong12 已提交
1258 1259 1260 1261
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
1262
    if in_dygraph_mode():
1263
        group = _get_default_group() if group is None else group
1264
        backend = _group_map_backend[group]
1265
        assert backend != 'gloo', "backend gloo is not supported yet"
1266 1267 1268
    else:
        ring_id = 0 if group is None else group.id

L
lilong12 已提交
1269
    temp = paddle.concat(in_tensor_list, axis=0)
李季 已提交
1270
    nranks = len(in_tensor_list)
L
lilong12 已提交
1271
    if in_dygraph_mode():
1272 1273 1274 1275 1276 1277
        if len(out_tensor_list) == 0:
            tensor_shape = list(in_tensor_list[0].shape)
            tensor_shape[0] *= nranks
            out = paddle.empty(tensor_shape, in_tensor_list[0].dtype)
        else:
            out = paddle.concat(out_tensor_list, axis=0)
1278 1279 1280 1281 1282 1283
        task = group.process_group.alltoall(temp, out)
        task.wait()
        out_tensor_list.clear()
        out_tensor_list.extend(paddle.split(out, nranks, 0))
        return

1284
    use_calc_stream = sync_op
J
Jiabin Yang 已提交
1285
    if _non_static_mode():
1286 1287 1288
        out = _legacy_C_ops.alltoall(
            temp, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id
        )
L
lilong12 已提交
1289
    else:
W
wanghuancoder 已提交
1290 1291 1292
        op_type = 'alltoall'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(
1293 1294
            dtype=in_tensor_list[0].dtype
        )
W
wanghuancoder 已提交
1295

L
lilong12 已提交
1296
        if not isinstance(in_tensor_list, list):
1297 1298 1299
            raise ValueError(
                "The type of 'in_tensor_list' for all_to_all " "should be list."
            )
L
lilong12 已提交
1300 1301
        for elem in in_tensor_list:
            check_variable_and_dtype(
1302 1303
                elem,
                'in_tensor_list',
L
lilong12 已提交
1304
                ['float16', 'float32', 'float64', 'int32', 'int64'],
1305 1306
                'all_to_all',
            )
L
lilong12 已提交
1307
        if not isinstance(out_tensor_list, list):
1308 1309 1310 1311
            raise ValueError(
                "The type of 'out_tensor_list' for all_to_all "
                "should be list."
            )
L
lilong12 已提交
1312
        if len(out_tensor_list) != 0:
1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324
            raise ValueError(
                "The 'out_tensor_list' for all_to_all " "must be an empty list."
            )
        helper.append_op(
            type=op_type,
            inputs={'X': [temp]},
            outputs={'Out': [out]},
            attrs={
                'ring_id': ring_id,
                'use_calc_stream': use_calc_stream,
            },
        )
L
lilong12 已提交
1325 1326 1327
    out_tensor_list.extend(paddle.split(out, nranks, 0))


1328 1329 1330 1331 1332 1333 1334 1335
def alltoall_single(
    in_tensor,
    out_tensor,
    in_split_sizes=None,
    out_split_sizes=None,
    group=None,
    sync_op=True,
):
1336 1337 1338
    """
    Scatter a single input tensor to all participators and gather the received tensors in out_tensor.

1339
    Note:
1340 1341 1342
        ``alltoall_single`` is only supported in eager mode.

    Args:
1343
        in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1344
        out_tensor (Tensor): Output Tensor. The data type should be the same as the data type of the input Tensor.
1345
        in_split_sizes (list[int], optional): Split sizes of ``in_tensor`` for dim[0]. If not given, dim[0] of ``in_tensor``
1346
            must be divisible by group size and ``in_tensor`` will be scattered averagely to all participators. Default: None.
1347
        out_split_sizes (list[int], optional): Split sizes of ``out_tensor`` for dim[0]. If not given, dim[0] of ``out_tensor``
1348 1349
            must be divisible by group size and ``out_tensor`` will be gathered averagely from all participators. Default: None.
        group (Group, optional): The group instance return by ``new_group`` or None for global default group. Default: None.
1350
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
1351

1352
    Returns:
1353
        None, if ``sync_op`` is set to ``True``; ``Task`` of ``group``, if ``sync_op`` is set to ``False``.
1354

1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365
    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            size = dist.get_world_size()

1366 1367 1368 1369
            # case 1 (2 GPUs)
            data = paddle.arange(2, dtype='int64') + rank * 2
            # data for rank 0: [0, 1]
            # data for rank 1: [2, 3]
1370
            output = paddle.empty([2], dtype='int64')
1371 1372
            dist.alltoall_single(data, output)
            print(output)
1373 1374 1375
            # output for rank 0: [0, 2]
            # output for rank 1: [1, 3]

1376
            # case 2 (2 GPUs)
1377
            in_split_sizes = [i + 1 for i in range(size)]
1378 1379
            # in_split_sizes for rank 0: [1, 2]
            # in_split_sizes for rank 1: [1, 2]
1380
            out_split_sizes = [rank + 1 for i in range(size)]
1381 1382 1383 1384 1385
            # out_split_sizes for rank 0: [1, 1]
            # out_split_sizes for rank 1: [2, 2]
            data = paddle.ones([sum(in_split_sizes), size], dtype='float32') * rank
            # data for rank 0: [[0., 0.], [0., 0.], [0., 0.]]
            # data for rank 1: [[1., 1.], [1., 1.], [1., 1.]]
1386 1387
            output = paddle.empty([(rank + 1) * size, size], dtype='float32')
            group = dist.new_group([0, 1])
1388
            task = dist.alltoall_single(data,
1389 1390 1391
                                        output,
                                        in_split_sizes,
                                        out_split_sizes,
1392
                                        sync_op=False,
1393 1394
                                        group=group)
            task.wait()
1395
            print(output)
1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406
            # output for rank 0: [[0., 0.], [1., 1.]]
            # output for rank 1: [[0., 0.], [0., 0.], [1., 1.], [1., 1.]]

    """
    if group is not None and not group.is_member():
        return

    assert in_dygraph_mode(), "Only suppport alltoall_single in eager mode."
    # _check_single_tensor

    group = _get_default_group() if group is None else group
1407
    backend = _group_map_backend[group]
1408
    assert backend != 'gloo', "backend gloo is not supported yet"
1409

1410 1411 1412
    in_split_sizes = [] if in_split_sizes is None else in_split_sizes
    out_split_sizes = [] if out_split_sizes is None else out_split_sizes

1413 1414 1415
    task = group.process_group.alltoall_single(
        in_tensor, out_tensor, in_split_sizes, out_split_sizes
    )
1416
    if sync_op:
1417 1418 1419 1420 1421 1422
        task.wait()
        return
    else:
        return task


S
ShenLiang 已提交
1423 1424 1425 1426
def _get_group_rank(global_rank, group=None):
    return global_rank if group is None else group.get_group_rank(global_rank)


1427
def send(tensor, dst=0, group=None, sync_op=True):
L
lilong12 已提交
1428 1429 1430 1431 1432
    """
    Send a tensor to the receiver.

    Args:
        tensor (Tensor): The Tensor to send. Its data type
1433
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
L
lilong12 已提交
1434
        dst (int): The destination rank id.
L
lilong12 已提交
1435
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
1436
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
1437

L
lilong12 已提交
1438 1439 1440 1441 1442
    Returns:
        None.

    Examples:
        .. code-block:: python
1443

L
lilong12 已提交
1444
            # required: distributed
L
lilong12 已提交
1445
            import paddle
1446
            import paddle.distributed as dist
1447

1448 1449
            dist.init_parallel_env()
            if dist.get_rank() == 0:
L
lilong12 已提交
1450
                data = paddle.to_tensor([7, 8, 9])
1451
                dist.send(data, dst=1)
L
lilong12 已提交
1452
            else:
1453 1454 1455 1456
                data = paddle.to_tensor([1, 2, 3])
                dist.recv(data, src=0)
            print(data)
            # [7, 8, 9] (2 GPUs)
L
lilong12 已提交
1457 1458 1459
    """
    if group is not None and not group.is_member():
        return
S
ShenLiang 已提交
1460
    dst = _get_group_rank(dst, group)
L
lilong12 已提交
1461
    if in_dygraph_mode():
1462
        group = _get_default_group() if group is None else group
1463
        backend = _group_map_backend[group]
1464
        assert backend != 'gloo', "backend gloo is not supported yet"
S
ShenLiang 已提交
1465
        task = group.process_group.send(tensor, dst)
1466
        if sync_op:
1467 1468 1469 1470 1471
            task.wait()
            return None
        else:
            return task

1472
    use_calc_stream = sync_op
L
lilong12 已提交
1473 1474
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
1475
    if _non_static_mode():
1476 1477 1478 1479 1480 1481 1482 1483 1484
        return _legacy_C_ops.send_v2(
            tensor,
            'use_calc_stream',
            use_calc_stream,
            'ring_id',
            ring_id,
            'peer',
            dst,
        )
W
wanghuancoder 已提交
1485
    op_type = 'send_v2'
L
lilong12 已提交
1486
    check_variable_and_dtype(
1487 1488 1489 1490 1491
        tensor,
        'tensor',
        ['float16', 'float32', 'float64', 'int32', 'int64'],
        'send',
    )
L
lilong12 已提交
1492 1493

    helper = LayerHelper(op_type, **locals())
1494 1495 1496 1497 1498 1499 1500 1501 1502
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        attrs={
            'ring_id': ring_id,
            'peer': dst,
            'use_calc_stream': use_calc_stream,
        },
    )
L
lilong12 已提交
1503 1504


1505
def recv(tensor, src=0, group=None, sync_op=True):
L
lilong12 已提交
1506 1507 1508 1509 1510
    """
    Receive a tensor to the sender.

    Args:
        tensor (Tensor): The Tensor to receive. Its data type
1511
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
L
lilong12 已提交
1512
        src (int): The source rank id.
L
lilong12 已提交
1513
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
1514
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
1515

L
lilong12 已提交
1516 1517 1518 1519 1520
    Returns:
        None.

    Examples:
        .. code-block:: python
1521

L
lilong12 已提交
1522
            # required: distributed
L
lilong12 已提交
1523
            import paddle
1524
            import paddle.distributed as dist
1525

1526 1527
            dist.init_parallel_env()
            if dist.get_rank() == 0:
L
lilong12 已提交
1528
                data = paddle.to_tensor([7, 8, 9])
1529
                dist.send(data, dst=1)
L
lilong12 已提交
1530
            else:
1531 1532 1533 1534
                data = paddle.to_tensor([1, 2, 3])
                dist.recv(data, src=0)
            print(data)
            # [7, 8, 9] (2 GPUs)
L
lilong12 已提交
1535 1536 1537
    """
    if group is not None and not group.is_member():
        return
1538

S
ShenLiang 已提交
1539
    src = _get_group_rank(src, group)
L
lilong12 已提交
1540
    if in_dygraph_mode():
1541
        group = _get_default_group() if group is None else group
1542
        backend = _group_map_backend[group]
1543
        assert backend != 'gloo', "backend gloo is not supported yet"
S
ShenLiang 已提交
1544
        task = group.process_group.recv(tensor, src)
1545
        if sync_op:
1546 1547 1548 1549 1550
            task.wait()
            return None
        else:
            return task

1551
    use_calc_stream = sync_op
L
lilong12 已提交
1552 1553
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
1554
    if _non_static_mode():
1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567
        return _legacy_C_ops.recv_v2(
            tensor,
            'use_calc_stream',
            use_calc_stream,
            'ring_id',
            ring_id,
            'peer',
            src,
            'dtype',
            tensor.dtype,
            'out_shape',
            tensor.shape,
        )
W
wanghuancoder 已提交
1568
    op_type = 'recv_v2'
L
lilong12 已提交
1569
    check_variable_and_dtype(
1570 1571 1572 1573 1574
        tensor,
        'tensor',
        ['float16', 'float32', 'float64', 'int32', 'int64'],
        'recv',
    )
L
lilong12 已提交
1575
    helper = LayerHelper(op_type, **locals())
1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586
    helper.append_op(
        type=op_type,
        outputs={'Out': [tensor]},
        attrs={
            'ring_id': ring_id,
            'peer': src,
            'out_shape': tensor.shape,
            'dtype': tensor.dtype,
            'use_calc_stream': use_calc_stream,
        },
    )
1587 1588 1589 1590


def _check_single_tensor(tensor, tensor_name):
    if not isinstance(tensor, (core.eager.Tensor, paddle.Tensor)):
1591 1592 1593 1594 1595 1596
        raise RuntimeError(
            "Invalid function argument. Expected parameter {}"
            "to be of type paddle.Tensor, but it's {}".format(
                tensor_name, type(tensor)
            )
        )
1597 1598 1599


def _check_tensor_list(tensor_list, tensor_name):
1600 1601 1602 1603 1604 1605 1606
    if not isinstance(tensor_list, list) or not all(
        isinstance(t, (core.eager.Tensor, paddle.Tensor)) for t in tensor_list
    ):
        raise RuntimeError(
            "Invalid function argument. Expected parameter {}"
            "to be of type paddle.Tensor".format(tensor_name)
        )
1607 1608 1609 1610 1611 1612 1613 1614


def isend(tensor, dst, group=None):
    """
    Sends a tensor asynchronously

    Args:
        tensor (Tensor): The Tensor to send. Its data type
1615
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1616 1617
        dst (int): The destination rank.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
1618

1619 1620 1621
    Returns:
        A distributed task object.

1622
    Warning:
1623 1624 1625 1626 1627 1628 1629 1630 1631 1632
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
1633
            if dist.get_rank() == 0:
1634
                data = paddle.to_tensor([7, 8, 9])
1635
                task = dist.isend(data, dst=1)
1636 1637
            else:
                data = paddle.to_tensor([1, 2, 3])
1638
                task = dist.irecv(data, src=0)
1639 1640
            task.wait()
            print(data)
1641
            # [7, 8, 9] (2 GPUs)
1642 1643 1644 1645 1646 1647 1648 1649

    """
    _check_single_tensor(tensor, "tensor")
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
1650
        backend = _group_map_backend[group]
1651
        assert backend != 'gloo', "backend gloo is not supported yet"
1652
        group_dst_rank = group.get_group_rank(dst)
1653
        assert group_dst_rank >= 0, "dst rank out of group, need global rank"
1654 1655
        return group.process_group.send(tensor, group_dst_rank)
    else:
1656
        raise RuntimeError("Only support eager dygraph mode.")
1657 1658 1659 1660 1661 1662 1663 1664


def irecv(tensor, src=None, group=None):
    """
    Receive a tensor to the sender.

    Args:
        tensor (Tensor): The Tensor to receive. Its data type
1665
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1666 1667 1668 1669
        src (int): The source rank id.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.

    Returns:
1670
        A distributed task object.
1671

1672
    Warning:
1673 1674 1675 1676 1677 1678 1679 1680 1681 1682
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
1683
            if dist.get_rank() == 0:
1684
                data = paddle.to_tensor([7, 8, 9])
1685
                task = dist.isend(data, dst=1)
1686 1687
            else:
                data = paddle.to_tensor([1, 2, 3])
1688
                task = dist.irecv(data, src=0)
1689 1690
            task.wait()
            print(data)
1691
            # [7, 8, 9] (2 GPUs)
1692 1693 1694 1695 1696 1697 1698
    """
    _check_single_tensor(tensor, "tensor")
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
1699
        backend = _group_map_backend[group]
1700
        assert backend != 'gloo', "backend gloo is not supported yet"
1701
        group_src_rank = group.get_group_rank(src)
1702
        assert group_src_rank >= 0, "src rank out of group, need global rank"
1703 1704
        return group.process_group.recv(tensor, group_src_rank)
    else:
1705
        raise RuntimeError("Only support eager dygraph mode.")
1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720


class P2POp(object):
    """
    A class that makes point-to-point operations for "batch_isend_irecv".

    This class creates the type of P2P operation, communication buffer, peer rank,
    Group. Instances of this class will be passed to
    ``paddle.distributed.batch_isend_irecv`` for point-to-point communication.

    Args:
        op (callable): A function to send data to or receive data from a peer process.
            The type of ``op`` is either ``paddle.distributed.isend`` or ``paddle.distributed.irecv``.
        tensor (Tensor): Tensor to send or receive.
        peer (int): The destination or source rank.
1721
        group (Group, optional): The group instance return by new_group or None for global
1722 1723 1724 1725 1726 1727
            default group. Default: None.

    """

    def __init__(self, op, tensor, peer, group=None):
        if op not in [isend, irecv]:
1728 1729 1730 1731 1732
            raise RuntimeError(
                "Invalid ``op`` function. Expected ``op`` "
                "to be of type ``paddle.distributed.isend`` or "
                "``paddle.distributed.irecv``."
            )
1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757
        _check_single_tensor(tensor, "tensor")

        self.op = op
        self.tensor = tensor
        self.peer = peer
        self.group = _get_default_group() if group is None else group


@contextlib.contextmanager
def _with_batch_p2p_guard(backend):
    if backend == "nccl":
        core.ProcessGroupNCCL.group_start()
    try:
        yield
    finally:
        if backend == "nccl":
            core.ProcessGroupNCCL.group_end()


def _check_p2p_op_list(p2p_op_list):
    """
    Helper to check that the ``p2p_op_list`` is a list of P2POp instances and
    all ops use the same backend.
    """
    if not isinstance(p2p_op_list, list) or not all(
1758 1759 1760 1761 1762 1763
        isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list
    ):
        raise RuntimeError(
            "Invalid ``p2p_op_list``. Each op is expected to "
            "to be of type ``paddle.distributed.P2POp``."
        )
1764 1765

    backend = _group_map_backend[p2p_op_list[0].group]
1766 1767 1768
    if not all(
        backend == _group_map_backend[p2p_op.group] for p2p_op in p2p_op_list
    ):
1769 1770 1771 1772 1773 1774 1775
        raise RuntimeError("All groups need to use the same backend.")


def batch_isend_irecv(p2p_op_list):
    """
    Send or Receive a batch of tensors asynchronously and return a list of requests.

1776
    Process each of the point-to-point operations in ``p2p_op_list`` and return the
1777 1778 1779
    corresponding tasks. NCCL are currently supported.

    Args:
1780
        p2p_op_list (List[P2POp]): A list of point-to-point operations(type of each operator is
1781 1782 1783 1784 1785 1786
            ``paddle.distributed.P2POp``). The order of the isend/irecv in the list
            matters and it needs to match with corresponding isend/irecv on the
            remote end.

    Returns:
        A list of distributed tasks returned by calling the corresponding
1787
        op in the op_list.
1788

1789
    Warning:
1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed

            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            world_size = dist.get_world_size()

            send_t = paddle.arange(2) + rank
            # paddle.tensor([0, 1])  # Rank-0
            # paddle.tensor([1, 2])  # Rank-1

            recv_t = paddle.empty(shape=[2], dtype=send_t.dtype)

            send_op = dist.P2POp(dist.isend, send_t, (rank + 1) % world_size)
            recv_op = dist.P2POp(dist.irecv, recv_t, (rank - 1 + world_size) % world_size)

            tasks = dist.batch_isend_irecv([send_op, recv_op])

            for task in tasks:
                task.wait()
1817

1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844
            print(recv_t)
            # paddle.tensor([1, 2])     # Rank-0
            # paddle.tensor([0, 1])     # Rank-1
    """
    _check_p2p_op_list(p2p_op_list)
    group = p2p_op_list[0].group
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
        backend = _group_map_backend[group]
        tasks = []
        with _with_batch_p2p_guard(backend):
            for p2p_op in p2p_op_list:
                op = p2p_op.op
                tensor = p2p_op.tensor
                peer = p2p_op.peer
                comm_group = p2p_op.group
                task = op(tensor, peer, comm_group)
                if task is not None:
                    tasks.append(task)
        return tasks
    else:
        raise RuntimeError("Don't support static graph mode currently.")


1845 1846 1847
def reduce_scatter(
    tensor, tensor_list, op=ReduceOp.SUM, group=None, sync_op=True
):
1848 1849 1850 1851
    """
    Reduces, then scatters a list of tensors to all processes in a group

    Args:
1852
        tensor (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1853
        tensor_list (list[Tensor]): List of tensors to reduce and scatter. Every element in the list must be a Tensor whose data type
1854
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1855
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
1856
        group (Group, optional): The group instance return by new_group or None for global
1857
            default group. Default: None.
1858
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
1859 1860

    Returns:
1861 1862
        Async task handle, if sync_op is set to False.
        None, if sync_op or if not part of the group.
1863 1864

    Warning:
1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875
        This API only supports the dygraph mode.


    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
1876 1877 1878
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([0, 1])
                data2 = paddle.to_tensor([2, 3])
1879
            else:
1880 1881 1882 1883 1884 1885
                data1 = paddle.to_tensor([4, 5])
                data2 = paddle.to_tensor([6, 7])
            dist.reduce_scatter(data1, [data1, data2])
            print(data1)
            # [4, 6] (2 GPUs, out for rank 0)
            # [8, 10] (2 GPUs, out for rank 1)
1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896

    """
    _check_single_tensor(tensor, "tensor")
    _check_tensor_list(tensor_list, "tensor_list")

    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        op_type = _get_reduce_op(op, "reduce_scatter")
        group = _get_default_group() if group is None else group
1897
        backend = _group_map_backend[group]
1898
        assert backend != 'gloo', "backend gloo is not supported yet"
1899 1900 1901

        temp = paddle.concat(tensor_list, axis=0)
        task = group.process_group._reduce_scatter_base(tensor, temp, op_type)
1902
        if sync_op:
1903 1904 1905 1906 1907 1908 1909 1910
            task.wait()
            return None
        else:
            return task
    else:
        raise RuntimeError("Don't support static graph mode currently.")


1911 1912 1913
def _reduce_scatter_base(
    output, input, op=ReduceOp.SUM, group=None, sync_op=True
):
1914 1915 1916 1917
    """
    Reduces, then scatters a flattened tensor to all processes in a group.

    Args:
1918
        output (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1919
        input (Tensor): Input tensor that is of size output tensor size times world size. Its data type
1920
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1921
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
1922 1923
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.
1924 1925
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.

1926
    Returns:
1927 1928
        Async task handle, if sync_op is set to False.
        None, if sync_op or if not part of the group.
1929 1930 1931 1932 1933 1934 1935 1936 1937 1938

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
1939 1940 1941 1942 1943
            data = paddle.arange(4) + rank
            # [0, 1, 2, 3] (2 GPUs, for rank 0)
            # [1, 2, 3, 4] (2 GPUs, for rank 1)
            output = paddle.empty(shape=[2], dtype=data.dtype)
            dist.collective._reduce_scatter_base(output, data)
1944
            print(output)
1945 1946
            # [1, 3] (2 GPUs, out for rank 0)
            # [5, 7] (2 GPUs, out for rank 1)
1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958

    """
    _check_single_tensor(output, "output")
    _check_single_tensor(input, "input")

    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        op_type = _get_reduce_op(op, "_reduce_scatter_base")
        group = _get_default_group() if group is None else group
        task = group.process_group._reduce_scatter_base(output, input, op_type)
1959
        if sync_op:
1960 1961 1962 1963 1964 1965
            task.wait()
            return None
        else:
            return task
    else:
        raise RuntimeError("Don't support static graph mode currently.")