collective.py 63.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import os
17 18
import pickle
import io
19 20
import datetime
import time
21
from ..fluid.layer_helper import LayerHelper
22
from ..fluid.framework import Variable
23
from ..fluid.framework import in_dygraph_mode
24
from ..fluid.framework import OpProtoHolder
J
Jiabin Yang 已提交
25
from ..fluid.framework import _non_static_mode
26
from ..fluid.framework import _in_legacy_dygraph
27
from ..fluid.framework import convert_np_dtype_to_dtype_
J
Jiangxinz 已提交
28
from ..fluid.framework import _varbase_creator
29 30 31 32
from ..fluid.data_feeder import convert_dtype
from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.data_feeder import check_type
from ..fluid.data_feeder import check_dtype
33 34
from ..fluid.layers.tensor import fill_constant
from ..fluid.layers import utils
B
Baibaifan 已提交
35
from ..fluid.dygraph import layers
36 37 38 39
from ..fluid.dygraph.parallel import prepare_context
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
40
from paddle import _C_ops, _legacy_C_ops
J
Jiangxinz 已提交
41
import paddle.fluid.dygraph_utils as dygraph_utils
42
import contextlib
W
wuhuachaocoding 已提交
43 44 45 46 47 48 49 50 51 52 53 54
from .fleet.layers.mpu.mp_ops import split
from .fleet.layers.mpu.mp_ops import _c_identity
from .fleet.layers.mpu.mp_ops import _c_concat
from .fleet.layers.mpu.mp_ops import _c_split
from .fleet.layers.mpu.mp_ops import _mp_allreduce
from .fleet.layers.mpu.mp_ops import _c_lookup_table
from .fleet.layers.mpu.mp_ops import _Linear
from .fleet.layers.mpu.mp_ops import _set_var_distributed
from .fleet.layers.mpu.mp_ops import _c_softmax_with_cross_entropy
from .fleet.layers.mpu.mp_ops import _linear
from .fleet.layers.mpu.mp_ops import _parallel_linear
from .fleet.layers.mpu.mp_ops import _parallel_embedding
55 56 57
from .communication.group import Group, _add_new_group
from .communication.all_reduce import all_reduce
from .communication.reduce import _get_reduce_op, ReduceOp
58

59
__all__ = []
60

K
kuizhiqing 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73
_global_env = None


def _get_global_env():
    global _global_env
    if not _global_env:
        _global_env = paddle.distributed.ParallelEnv()
    return _global_env


# group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group]
_group_map = {}
74
_global_env_gid = 0
K
kuizhiqing 已提交
75

76 77 78 79
# group map by name : the map of all groups from their names
# Dict[name, Group]
_group_map_by_name = {}

80 81 82 83
# backend map by group : the map of all backend from their groups
# Dict[group, backend]
_group_map_backend = {}

84 85 86
# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"

87
_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter', 'xccl']
88 89
_default_store = None  # the default tcp store
_default_backend = None
90 91
_default_timeout = datetime.timedelta(seconds=1800)
_start_ring_id = 0
92

K
kuizhiqing 已提交
93

L
lilong12 已提交
94 95 96 97 98 99 100 101 102 103
def _set_default_backend(backend):
    global _default_backend
    _default_backend = backend


def _set_default_store(store):
    global _default_store
    _default_store = store


K
kuizhiqing 已提交
104 105
def _get_group_map():
    global _group_map
106
    if _global_env_gid not in _group_map:
K
kuizhiqing 已提交
107
        genv = _get_global_env()
L
Ligoml 已提交
108 109 110
        _group_map[_global_env_gid] = Group(
            genv.rank, 0, list(range(genv.world_size))
        )
K
kuizhiqing 已提交
111 112 113 114
    return _group_map


def _get_global_group():
115
    return _get_group_map()[_global_env_gid]
K
kuizhiqing 已提交
116 117


118 119 120 121 122 123
def _get_group_map_by_name():
    global _group_map_by_name
    return _group_map_by_name


def _get_default_group():
L
lilong12 已提交
124
    global _group_map_by_name
L
Ligoml 已提交
125 126 127 128
    assert is_initialized(), (
        "Call paddle.distributed.init_parallel_env first "
        "to initialize the distributed environment."
    )
129 130 131
    return _get_group_map_by_name()[_default_group_name]


L
lilong12 已提交
132 133 134 135 136 137 138 139 140 141 142 143
def _set_group_map(gid, group):
    global _group_map
    assert gid not in _group_map
    _group_map[gid] = group


def _set_group_map_by_name(name, group):
    global _group_map_by_name
    assert name not in _group_map_by_name
    _group_map_by_name[name] = group


144 145 146 147 148 149
def _set_group_map_backend(group, backend):
    global _group_map_backend
    assert group not in _group_map_backend
    _group_map_backend[group] = backend


K
kuizhiqing 已提交
150
def _new_ring_id():
151 152 153 154 155 156 157
    # NOTE(liyurui): For compatible reason, auto parallel and eager mode relay on previous syntax.
    if in_dygraph_mode():
        global _start_ring_id
        _start_ring_id += 1
        return _start_ring_id + max(_get_global_env().nrings, 9)
    else:
        return len(_get_group_map()) + max(_get_global_env().nrings, 9)
K
kuizhiqing 已提交
158 159 160 161 162 163 164 165


def get_group(id=0):
    """

    Get group instance by group id.

    Args:
K
kuizhiqing 已提交
166
        id (int): the group id. Default value is 0.
K
kuizhiqing 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179 180

    Returns:
        Group: the group instance.

    Examples:
        .. code-block:: python

            ...
            gid = paddle.distributed.new_group([2,4,6])
            paddle.distributed.get_group(gid.id)

    """

    gm = _get_group_map()
J
Jiangxinz 已提交
181
    return gm[id] if id in gm else None
K
kuizhiqing 已提交
182 183


L
Ligoml 已提交
184 185 186 187 188 189 190 191 192 193 194
def _new_process_group_impl(
    backend,
    store,
    rank,
    world_size,
    group_name,
    pg_options,
    group_id=0,
    src_rank=None,
    dst_rank=None,
):
195
    pg = None
196
    genv = _get_global_env()
L
lilong12 已提交
197 198
    if backend != 'heter':
        assert src_rank is None and dst_rank is None, (
L
Ligoml 已提交
199 200
            "src_rank and dst_rank " "can only be set for heter backend."
        )
L
lilong12 已提交
201
    assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
202
    if backend == "gloo":
203 204
        place = core.CPUPlace()
        pg = core.ProcessGroupGloo(store, rank, world_size, place, group_id)
205
    elif backend == "nccl":
206 207
        place = core.CUDAPlace(genv.device_id)
        pg = core.ProcessGroupNCCL(store, rank, world_size, place, group_id)
208
    elif backend == "hccl":
209 210
        place = core.NPUPlace(genv.device_id)
        pg = core.ProcessGroupHCCL(store, rank, world_size, place, group_id)
211 212 213
    elif backend == "xccl":
        place = core.CustomPlace(genv.device_type, genv.device_id)
        pg = core.ProcessGroupCustom(store, rank, world_size, place, group_id)
214
    elif backend == "heter":
215 216 217 218 219
        place = None
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(genv.device_id)
        elif core.is_compiled_with_npu():
            place = core.NPUPlace(genv.device_id)
220 221 222 223 224 225 226 227 228
        cluster_id = int(os.getenv("CLUSTER_ID", "-1"))
        assert cluster_id >= 0, "please set the CLUSTER_ID variable."
        cluster_size = os.getenv("CLUSTER_SIZE", None)
        assert cluster_size, "please set the CLUSTER_SIZE variable."
        cluster_size = cluster_size.split(",")
        cluster_size = [int(s) for s in cluster_size]
        switch_ep = os.getenv("CLUSTER_SWITCH", None)
        assert switch_ep, "please set the CLUSTER_SWITCH variable."
        cluster_size_cumsum = np.cumsum(cluster_size)
L
Ligoml 已提交
229 230 231
        cluster_offset = (
            0 if cluster_id == 0 else cluster_size_cumsum[cluster_id - 1]
        )
232 233
        global_rank = cluster_offset + rank
        global_world_size = cluster_size_cumsum[-1]
234
        global_rank, global_world_size = _get_global_config(backend, rank)
L
Ligoml 已提交
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
        pg = core.ProcessGroupHeter(
            store,
            rank=global_rank,
            world_size=global_world_size,
            place=place,
            gid=group_id,
            local_rank=rank,
            local_size=world_size,
            gloo_rank=cluster_id,
            gloo_size=len(cluster_size),
            with_switch=True,
            switch_endpoint=switch_ep,
            src_rank=src_rank,
            dst_rank=dst_rank,
        )
250 251 252 253

    return pg


S
ShenLiang 已提交
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
def barrier(group=None):
    """

    Barrier among all participators in the group.

    Args:
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            paddle.distributed.barrier()
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
278
    if in_dygraph_mode():
279 280 281 282 283
        group = _get_default_group() if group is None else group
        task = group.process_group.barrier()
        task.wait()
        return

S
ShenLiang 已提交
284 285 286
    ring_id = 0 if group is None else group.id

    temp = fill_constant([1], dtype="int32", value="1")
J
Jiabin Yang 已提交
287
    if _non_static_mode():
288
        return _legacy_C_ops.barrier(temp, temp, 'ring_id', ring_id)
W
wanghuancoder 已提交
289 290 291

    op_type = 'barrier'

S
ShenLiang 已提交
292 293 294
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'group' for barrier must be int.")
    helper = LayerHelper(op_type, **locals())
L
Ligoml 已提交
295 296 297 298 299 300
    helper.append_op(
        type=op_type,
        inputs={'X': [temp]},
        outputs={'Out': [temp]},
        attrs={'ring_id': ring_id},
    )
S
ShenLiang 已提交
301 302


L
lilong12 已提交
303 304 305 306 307 308 309
# _custom_gid provides a way for users to
# set the group id, which is usually useful
# to be compatible with the static mode.
_custom_gid = None


def _set_custom_gid(gid):
310
    global _custom_gid
L
lilong12 已提交
311 312 313
    _custom_gid = gid


314 315 316 317 318 319 320 321
def _barrier_by_tcp_store(group_name, store, timeout):
    global_rank = paddle.distributed.get_rank()
    global_world_size = paddle.distributed.get_world_size()

    if global_world_size < 2:
        return

    barrier_prefix = "Barrier/" + group_name + "/"
L
Ligoml 已提交
322
    is_master = global_rank == 0
323 324 325 326 327 328 329 330 331 332 333 334

    def _check_keys_ready(wait_keys):
        start_time = time.time()
        while len(wait_keys) > 0:
            time.sleep(0.1)
            elapse_time = time.time() - start_time
            if datetime.timedelta(seconds=elapse_time) > timeout:
                raise RuntimeError(
                    "Timeout while initializing process group {}."
                    "Keys {} are not ready sinck rank {} is waiting them."
                    "Two reason may cause this error:\n 1. The create process group api should be called by all ranks.\n"
                    " 2. Try to increase the waiting time.\n".format(
L
Ligoml 已提交
335 336 337
                        group_name, wait_keys, global_rank
                    )
                )
338
            wait_keys = list(
L
Ligoml 已提交
339 340
                filter(lambda key: int(store.get(key)) != 1, wait_keys)
            )
341 342 343 344 345 346 347 348 349 350 351 352 353

    # all the workers set their exiting key and exit
    # the master will wait for all workers' exiting key, ensure to exit in the end
    if is_master:
        wait_keys = [
            barrier_prefix + str(rank) for rank in range(1, global_world_size)
        ]
        _check_keys_ready(wait_keys)
    else:
        store.add(barrier_prefix + str(global_rank), 1)


def new_group(ranks=None, backend=None, timeout=_default_timeout):
K
kuizhiqing 已提交
354 355
    """

K
kuizhiqing 已提交
356
    Creates a new distributed communication group.
K
kuizhiqing 已提交
357 358

    Args:
K
kuizhiqing 已提交
359
        ranks (list): The global ranks of group members.
K
kuizhiqing 已提交
360
        backend (str): The backend used to create group, only nccl is supported now.
361
        timeout (datetime.timedelta, optional): The waiting timeout for store relevant options, default is 30 minutes.
K
kuizhiqing 已提交
362 363

    Returns:
K
kuizhiqing 已提交
364
        Group: The group instance.
K
kuizhiqing 已提交
365 366 367 368 369 370 371

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
372 373
            tindata = paddle.randn(shape=[2, 3])
            gp = paddle.distributed.new_group([2,4,6])
374
            paddle.distributed.all_reduce(tindata, group=gp, sync_op=False)
K
kuizhiqing 已提交
375 376

    """
377
    global _custom_gid
378
    global _group_map
L
lilong12 已提交
379
    if in_dygraph_mode():
380
        global _default_group_name
L
lilong12 已提交
381
        gid = _custom_gid if _custom_gid else _new_ring_id()
382
        group_name = _default_group_name + str(gid)
L
lilong12 已提交
383
        if backend != 'heter' and (ranks is None or len(ranks) > 1):
384 385 386 387 388 389 390 391
            global_group = _get_default_group()
            global_rank = global_group.rank
            global_ranks = global_group.ranks
            backend = _default_backend if backend is None else backend
            if ranks is None:
                ranks = global_ranks
            assert len(ranks) <= len(global_ranks), (
                "Size of new group must be less than or "
L
Ligoml 已提交
392 393
                "equal to that of the default global group."
            )
394 395
        size = len(ranks)
        ranks = sorted(ranks)
L
lilong12 已提交
396 397 398 399
        if backend == 'heter' or (size > 1 and global_rank in ranks):
            rank = 0 if backend == 'heter' else ranks.index(global_rank)
            src_rank = ranks[0] if backend == 'heter' else None
            dst_rank = ranks[1] if backend == 'heter' else None
L
Ligoml 已提交
400 401 402 403 404 405 406 407 408 409 410
            pg = _new_process_group_impl(
                backend,
                _default_store,
                rank,
                size,
                group_name,
                pg_options=None,
                group_id=gid,
                src_rank=src_rank,
                dst_rank=dst_rank,
            )
411 412 413
        else:
            rank = -1
            pg = None
414
        group = Group(rank, gid, ranks, pg=pg, name=group_name)
415 416
        _group_map_by_name[group_name] = group
        _group_map[gid] = group
417
        _group_map_backend[group] = backend
L
Ligoml 已提交
418
        # TODO: The method below is a new method for group management, will replace the previous
419 420
        # three in the future.
        _add_new_group(group)
421

422
        # TODO(shenliang03): This is a temporary solution to solve the problem of
423
        # hang caused by tcp
424
        paddle.distributed.barrier(group=group)
425 426 427 428 429
        # NOTE(liyurui): All processors should hang and wait using tcp store, in case master exit before sub-group is created.
        if backend != 'heter':
            _barrier_by_tcp_store(group_name, _default_store, timeout)
        else:
            print("Warning: store barrier is not supported for heter backend.")
430
        return group
K
kuizhiqing 已提交
431 432 433

    if not backend:
        backend = 'nccl'
L
Ligoml 已提交
434
    assert backend == 'nccl', "backend other than nccl is not supported yet"
K
kuizhiqing 已提交
435 436 437 438 439 440 441

    genv = _get_global_env()
    global_rank = genv.rank

    ring_id = _new_ring_id()

    if global_rank not in ranks:
442
        gp = Group(-1, ring_id, ranks)
K
kuizhiqing 已提交
443 444
        _group_map[ring_id] = gp
    else:
445 446 447
        ranks = sorted(ranks)
        group_rank = ranks.index(global_rank)
        group_size = len(ranks)
448
        gp = Group(group_rank, ring_id, ranks)
449 450 451 452 453 454 455 456 457 458 459 460 461 462
        _group_map[ring_id] = gp

        if group_size >= 2:
            strategy = core.ParallelStrategy()
            strategy.nranks = group_size
            strategy.local_rank = group_rank
            strategy.trainer_endpoints = [
                genv.trainer_endpoints[i] for i in ranks
            ]
            strategy.current_endpoint = genv.current_endpoint
            strategy.nrings = 1

            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(genv.device_id)
L
Ligoml 已提交
463 464 465
                core.NCCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
466 467
            elif core.is_compiled_with_npu():
                place = core.NPUPlace(genv.device_id)
L
Ligoml 已提交
468 469 470
                core.HCCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
471 472
            elif core.is_compiled_with_mlu():
                place = core.MLUPlace(genv.device_id)
L
Ligoml 已提交
473 474 475
                core.CNCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
476 477
            elif core.is_compiled_with_xpu():
                place = core.XPUPlace(genv.device_id)
L
Ligoml 已提交
478 479 480
                core.BKCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
481
            else:
L
Ligoml 已提交
482
                assert False, "no cuda device found"
483 484 485
        else:
            return gp

486
    # TODO(shenliang03): This is a temporary solution to solve the problem of
487
    # hang caused by cross-creation of new_group
L
Ligoml 已提交
488 489 490 491 492
    tmp = (
        paddle.to_tensor([1], dtype="int32")
        if _non_static_mode()
        else fill_constant([0], dtype="int32", value="1")
    )
493
    paddle.distributed.all_reduce(tmp, sync_op=True)
494
    paddle.distributed.wait(tmp)
K
kuizhiqing 已提交
495 496
    return gp

497

498 499 500 501 502
def is_initialized():
    """

    Check whether the distributed environment has been initialized

503 504
    Returns:
        `True` if distributed environment has been initialized, otherwise `False`.
505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle

            print(paddle.distributed.is_initialized())
            # False

            paddle.distributed.init_parallel_env()
            print(paddle.distributed.is_initialized())
            # True

    """
    global _group_map_by_name
    return _default_group_name in _group_map_by_name


def destroy_process_group(group=None):
    """
    Destroy a given group for communication

    Args:
L
Ligoml 已提交
529 530
        group (ProcessGroup, optional): The group to be destroyed. All of process groups, including
                                        the default group, will be destroyed and the distributed
531
                                        environment will be deinitialized.
L
Ligoml 已提交
532

533 534 535 536 537 538 539
    Returns : None

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
540
            import paddle.distributed as dist
541

542 543
            dist.init_parallel_env()
            group = dist.new_group([0, 1])
544

545 546
            dist.destroy_process_group(group)
            print(dist.is_initialized())
547
            # True
548 549
            dist.destroy_process_group()
            print(dist.is_initialized())
550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568
            # False

    """
    global _group_map
    global _group_map_by_name

    pg = _get_default_group() if group is None else group
    assert _group_map.get(pg.id, None) is not None, "Invalid group."

    if group is None:
        _group_map.clear()
        _group_map_by_name.clear()
        _group_map_backend.clear()
    else:
        del _group_map[pg.id]
        del _group_map_by_name[pg.name]
        del _group_map_backend[pg]


K
kuizhiqing 已提交
569 570 571 572 573 574 575 576
def wait(tensor, group=None, use_calc_stream=True):
    """

    wait to sync stream for group.

    Args:
        tensor (Tensor): The Tensor used before sync.
        group (Group): The Group instance to perform sync.
K
kuizhiqing 已提交
577 578
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
K
kuizhiqing 已提交
579 580 581 582 583 584 585 586 587 588

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
589
            tindata = paddle.randn(shape=[2, 3])
590
            paddle.distributed.all_reduce(tindata, sync_op=True)
K
kuizhiqing 已提交
591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607
            paddle.distributed.wait(tindata)

    """

    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id

    if use_calc_stream:
        _sync_calc_stream(tensor)
    else:
        _sync_comm_stream(tensor, ring_id)


def _sync_calc_stream(tensor):

J
Jiabin Yang 已提交
608
    if _non_static_mode():
609
        return _legacy_C_ops.c_sync_calc_stream(tensor, tensor)
K
kuizhiqing 已提交
610 611 612 613 614 615 616

    op_type = 'c_sync_calc_stream'

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
617 618
        outputs={'Out': [tensor]},
    )
619

620

K
kuizhiqing 已提交
621
def _sync_comm_stream(tensor, ring_id=0):
622

J
Jiabin Yang 已提交
623
    if _non_static_mode():
L
Ligoml 已提交
624 625 626
        return _legacy_C_ops.c_sync_comm_stream(
            [tensor], [tensor], 'ring_id', ring_id
        )
627

K
kuizhiqing 已提交
628
    op_type = 'c_sync_comm_stream'
629

K
kuizhiqing 已提交
630 631 632 633 634
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
635 636
        attrs={'ring_id': ring_id},
    )
K
kuizhiqing 已提交
637 638


639
def broadcast(tensor, src, group=None, sync_op=True):
640 641 642
    """

    Broadcast a tensor from the source to all others.
643 644
    As shown below, one process is started with a GPU and GPU0 owns data 0. Through broadcast operator,
    data 0 will be sent to all GPUs from GPU0.
645 646 647 648 649

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/broadcast.png
        :width: 800
        :alt: broadcast
        :align: center
650 651

    Args:
652
        tensor (Tensor): The Tensor to send if current rank is the source, or the Tensor to receive otherwise. Its data type
653
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
654
        src (int): The source rank.
655 656
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
657 658 659 660 661 662 663

    Returns:
        None.

    Examples:
        .. code-block:: python

664
            # required: distributed
665
            import paddle
666
            import paddle.distributed as dist
667

668 669 670
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
671
            else:
672 673 674 675
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.broadcast(data, src=1)
            print(data)
            # [[1, 2, 3], [1, 2, 3]] (2 GPUs)
676
    """
K
kuizhiqing 已提交
677 678 679 680 681 682 683

    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

L
lilong12 已提交
684
    if in_dygraph_mode():
685 686
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
L
Ligoml 已提交
687
        assert gsrc >= 0, "src rank out of group, need global rank"
688
        task = group.process_group.broadcast(tensor, gsrc)
689
        if sync_op:
690 691 692 693 694
            task.wait()
            return None
        else:
            return task

695
    use_calc_stream = sync_op
696
    ring_id = ring_id = 0 if group is None else group.id
K
kuizhiqing 已提交
697
    gsrc = src if group is None else group.get_group_rank(src)
L
Ligoml 已提交
698
    assert gsrc >= 0, "src rank out of group, need global rank"
K
kuizhiqing 已提交
699

J
Jiabin Yang 已提交
700
    if _non_static_mode():
L
Ligoml 已提交
701 702 703 704 705 706 707 708 709 710
        return _legacy_C_ops.c_broadcast(
            tensor,
            tensor,
            'root',
            gsrc,
            'use_calc_stream',
            use_calc_stream,
            'ring_id',
            ring_id,
        )
711 712

    op_type = 'c_broadcast'
L
Ligoml 已提交
713 714 715 716 717 718 719 720 721 722 723 724 725 726 727
    check_variable_and_dtype(
        tensor,
        'tensor',
        [
            'float16',
            'float32',
            'float64',
            'int32',
            'int64',
            'int8',
            'uint8',
            'bool',
        ],
        'broadcast',
    )
728 729

    helper = LayerHelper(op_type, **locals())
L
Ligoml 已提交
730 731 732 733 734 735 736 737 738 739
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
        attrs={
            'root': gsrc,
            'use_calc_stream': use_calc_stream,
            'ring_id': ring_id,
        },
    )
740 741


742
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, sync_op=True):
743 744
    """

745 746
    Reduce a tensor to the destination from all others. As shown below, one process is started with a GPU and the data of this process is represented
    by its group rank. The destination of the reduce operator is GPU0 and the process is sum. Through reduce operator,
747 748 749 750 751 752
    the GPU0 will owns the sum of all data from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/reduce.png
        :width: 800
        :alt: reduce
        :align: center
753 754 755

    Args:
        tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
756
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
757
        dst (int): The destination rank id.
758 759 760
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The operation used. Default value is ReduceOp.SUM.
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
761 762 763 764 765 766 767

    Returns:
        None.

    Examples:
        .. code-block:: python

768
            # required: distributed
769
            import paddle
770
            import paddle.distributed as dist
771

772 773 774
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
775
            else:
776 777 778 779 780
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.reduce(data, dst=0)
            print(data)
            # [[5, 7, 9], [5, 7, 9]] (2 GPUs, out for rank 0)
            # [[1, 2, 3], [1, 2, 3]] (2 GPUs, out for rank 1)
781
    """
K
kuizhiqing 已提交
782 783 784
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
785
    if in_dygraph_mode():
786
        op_type = _get_reduce_op(op, "reduce")
787 788
        group = _get_default_group() if group is None else group
        gdst = group.get_group_rank(dst)
L
Ligoml 已提交
789
        assert gdst >= 0, "dst rank out of group, need global rank"
790
        task = group.process_group.reduce(tensor, gdst, op_type)
791
        if sync_op:
792 793 794 795
            task.wait()
            return None
        else:
            return task
K
kuizhiqing 已提交
796

797
    use_calc_stream = sync_op
K
kuizhiqing 已提交
798 799
    ring_id = 0 if group is None else group.id
    gdst = dst if group is None else group.get_group_rank(dst)
L
Ligoml 已提交
800
    assert gdst >= 0, "dst rank out of group, need global rank"
K
kuizhiqing 已提交
801

J
Jiabin Yang 已提交
802
    if _non_static_mode():
803
        if op == ReduceOp.SUM:
L
Ligoml 已提交
804 805 806 807 808 809 810 811 812 813
            return _legacy_C_ops.c_reduce_sum(
                tensor,
                tensor,
                'use_calc_stream',
                use_calc_stream,
                'ring_id',
                ring_id,
                'root_id',
                gdst,
            )
814
        elif op == ReduceOp.MAX:
L
Ligoml 已提交
815 816 817 818 819 820 821 822 823 824
            return _legacy_C_ops.c_reduce_max(
                tensor,
                tensor,
                'use_calc_stream',
                use_calc_stream,
                'ring_id',
                ring_id,
                'root_id',
                gdst,
            )
825
        elif op == ReduceOp.MIN:
L
Ligoml 已提交
826 827 828 829 830 831 832 833 834 835
            return _legacy_C_ops.c_reduce_min(
                tensor,
                tensor,
                'use_calc_stream',
                use_calc_stream,
                'ring_id',
                ring_id,
                'root_id',
                gdst,
            )
836
        elif op == ReduceOp.PROD:
L
Ligoml 已提交
837 838 839 840 841 842 843 844 845 846
            return _legacy_C_ops.c_reduce_prod(
                tensor,
                tensor,
                'use_calc_stream',
                use_calc_stream,
                'ring_id',
                ring_id,
                'root_id',
                gdst,
            )
847 848 849 850
        else:
            raise ValueError("Unknown parameter: {}.".format(op))

    op_type = 'c_reduce'
L
Ligoml 已提交
851 852 853 854 855 856 857 858 859 860 861 862 863 864 865
    check_variable_and_dtype(
        tensor,
        'tensor',
        [
            'float16',
            'float32',
            'float64',
            'int32',
            'int64',
            'int8',
            'uint8',
            'bool',
        ],
        'reduce',
    )
866 867 868 869 870 871 872 873 874 875 876

    if op == ReduceOp.SUM:
        op_type = 'c_reduce_sum'
    elif op == ReduceOp.MAX:
        op_type = 'c_reduce_max'
    elif op == ReduceOp.MIN:
        op_type = 'c_reduce_min'
    elif op == ReduceOp.PROD:
        op_type = 'c_reduce_prod'

    helper = LayerHelper(op_type, **locals())
L
Ligoml 已提交
877 878 879 880 881 882 883 884 885 886
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
        attrs={
            'ring_id': ring_id,
            'use_calc_stream': use_calc_stream,
            'root_id': gdst,
        },
    )
887 888


889
def all_gather(tensor_list, tensor, group=None, sync_op=True):
890 891
    """

892
    Gather tensors from all participators and all get the result. As shown
893 894
    below, one process is started with a GPU and the data of this process is represented
    by its group rank. Through the all_gather operator, each GPU will have data
895 896 897 898 899 900
    from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png
        :width: 800
        :alt: all_gather
        :align: center
901 902 903

    Args:
        tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
904
            should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128.
905
        tensor (Tensor): The Tensor to send. Its data type
906
            should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
907 908
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
909 910 911 912 913 914 915

    Returns:
        None.

    Examples:
        .. code-block:: python

916
            # required: distributed
917
            import paddle
918
            import paddle.distributed as dist
919

920
            dist.init_parallel_env()
921
            tensor_list = []
922 923
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
924
            else:
925 926 927 928
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.all_gather(tensor_list, data)
            print(tensor_list)
            # [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
929
    """
K
kuizhiqing 已提交
930 931 932
    if group is not None and not group.is_member():
        return

933 934 935 936 937 938
    def convert_to_complex(list_of_tensor):
        list_of_complex = []
        for tensor in list_of_tensor:
            list_of_complex.append(paddle.as_complex(tensor))
        return list_of_complex

L
Ligoml 已提交
939 940 941
    is_input_complex = (
        tensor.dtype == paddle.complex64 or tensor.dtype == paddle.complex128
    )
942 943 944
    if is_input_complex:
        tensor = paddle.as_real(tensor)

L
lilong12 已提交
945
    if in_dygraph_mode():
946
        group = _get_default_group() if group is None else group
947 948 949 950 951 952
        if len(tensor_list) == 0:
            tensor_shape = list(tensor.shape)
            tensor_shape[0] *= group.nranks
            out = paddle.empty(tensor_shape, tensor.dtype)
        else:
            out = paddle.concat(tensor_list, axis=0)
953 954 955
        task = group.process_group.all_gather(tensor, out)
        task.wait()
        tensor_list.clear()
956 957 958 959 960
        list_of_tensor = paddle.split(out, group.nranks, 0)
        if is_input_complex:
            tensor_list.extend(convert_to_complex(list_of_tensor))
        else:
            tensor_list.extend(list_of_tensor)
961 962
        return

963
    use_calc_stream = sync_op
K
kuizhiqing 已提交
964 965 966
    ring_id = 0 if group is None else group.id
    nranks = _get_global_group().nranks if group is None else group.nranks

J
Jiabin Yang 已提交
967
    if _non_static_mode():
L
Ligoml 已提交
968 969 970 971 972 973 974 975 976
        out = _legacy_C_ops.c_allgather(
            tensor,
            'use_calc_stream',
            use_calc_stream,
            'ring_id',
            ring_id,
            'nranks',
            nranks,
        )
977
    else:
978 979 980
        op_type = 'c_allgather'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
981
        if not isinstance(tensor_list, list):
L
Ligoml 已提交
982 983 984
            raise ValueError(
                "The type of 'tensor_list' for all_gather " "should be list."
            )
985
        for elem in tensor_list:
L
Ligoml 已提交
986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029
            check_variable_and_dtype(
                elem,
                'tensor_list',
                [
                    'float16',
                    'float32',
                    'float64',
                    'int32',
                    'int64',
                    'bool',
                    'int8',
                    'uint8',
                    'complex64',
                    'complex128',
                ],
                'all_gather',
            )
        check_variable_and_dtype(
            tensor,
            'tensor',
            [
                'float16',
                'float32',
                'float64',
                'int32',
                'int64',
                'bool',
                'int8',
                'uint8',
                'complex64',
                'complex128',
            ],
            'all_gather',
        )
        helper.append_op(
            type=op_type,
            inputs={'X': [tensor]},
            outputs={'Out': [out]},
            attrs={
                'ring_id': ring_id,
                'use_calc_stream': use_calc_stream,
                'nranks': nranks,
            },
        )
1030

1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043
    list_of_tensor = paddle.split(out, nranks, 0)
    if is_input_complex:
        tensor_list.extend(convert_to_complex(list_of_tensor))
    else:
        tensor_list.extend(list_of_tensor)


def _convert_object_to_tensor(obj):
    _pickler = pickle.Pickler
    f = io.BytesIO()
    _pickler(f).dump(obj)
    data = np.frombuffer(f.getvalue(), dtype=np.uint8)
    tensor = paddle.to_tensor(data)
1044
    return tensor, tensor.numel()
1045 1046


1047
def _convert_tensor_to_object(tensor, len_of_tensor):
1048
    _unpickler = pickle.Unpickler
1049
    return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()
1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076


def all_gather_object(object_list, obj, group=None):
    """

    Gather picklable objects from all participators and all get the result. Similiar to all_gather(), but python object can be passed in.

    Args:
        object_list (list): A list of output object. The datatype of every element in the list is same as the input obj.
        obj (Any): The picklable object to send.
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Warning:
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            object_list = []
1077
            if dist.get_rank() == 0:
1078 1079 1080
                obj = {"foo": [1, 2, 3]}
            else:
                obj = {"bar": [4, 5, 6]}
1081 1082 1083
            dist.all_gather_object(object_list, obj)
            print(object_list)
            # [{'foo': [1, 2, 3]}, {'bar': [4, 5, 6]}] (2 GPUs)
1084
    """
L
Ligoml 已提交
1085 1086
    assert (
        in_dygraph_mode()
1087 1088
    ), "all_gather_object doesn't support static graph mode."

1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101
    tensor, len_of_tensor = _convert_object_to_tensor(obj)

    # gather len_of_tensor from all ranks
    list_len_of_tensor = []
    all_gather(list_len_of_tensor, len_of_tensor, group)
    # get the max length from list
    max_len_of_tensor = int(max(list_len_of_tensor).item())
    # resize the input tensor to max length avoid hang in all gather
    # Note(liyurui): Maybe we should support various length all_gather?
    # Now this operation is efficient for we don't support resize in python.
    numpy_data = tensor.numpy()
    numpy_data = np.resize(numpy_data, [max_len_of_tensor])
    input_tensor = paddle.to_tensor(numpy_data)
1102 1103

    tensor_list = []
1104 1105 1106
    all_gather(tensor_list, input_tensor, group)
    for i, tensor in enumerate(tensor_list):
        object_list.append(
L
Ligoml 已提交
1107 1108
            _convert_tensor_to_object(tensor, list_len_of_tensor[i])
        )
1109 1110


1111
def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True):
1112 1113
    """

1114
    Scatter a tensor to all participators. As shown below, one process is started with a GPU and the source of the scatter
1115 1116 1117 1118 1119 1120
    is GPU0. Through scatter operator, the data in GPU0 will be sent to all GPUs averagely.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/scatter.png
        :width: 800
        :alt: scatter
        :align: center
1121 1122 1123

    Args:
        tensor (Tensor): The output Tensor. Its data type
1124
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1125
        tensor_list (list|tuple): A list/tuple of Tensors to scatter. Every element in the list must be a Tensor whose data type
1126
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. Default value is None.
K
kuizhiqing 已提交
1127
        src (int): The source rank id. Default value is 0.
1128 1129
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
1130 1131 1132 1133 1134 1135 1136

    Returns:
        None.

    Examples:
        .. code-block:: python

1137
            # required: distributed
1138
            import paddle
1139
            import paddle.distributed as dist
1140

1141 1142 1143 1144 1145
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([7, 8, 9])
                data2 = paddle.to_tensor([10, 11, 12])
                dist.scatter(data1, src=1)
1146
            else:
1147 1148 1149 1150 1151 1152
                data1 = paddle.to_tensor([1, 2, 3])
                data2 = paddle.to_tensor([4, 5, 6])
                dist.scatter(data1, tensor_list=[data1, data2], src=1)
            print(data1, data2)
            # [1, 2, 3] [10, 11, 12] (2 GPUs, out for rank 0)
            # [4, 5, 6] [4, 5, 6] (2 GPUs, out for rank 1)
1153
    """
K
kuizhiqing 已提交
1154 1155 1156 1157 1158 1159
    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

L
lilong12 已提交
1160
    if in_dygraph_mode():
1161 1162 1163 1164 1165 1166 1167 1168 1169
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
        rank = group.rank
        nranks = group.nranks
    else:
        ring_id = 0 if group is None else group.id
        gsrc = src if group is None else group.get_group_rank(src)
        rank = _get_global_group().rank if group is None else group.rank
        nranks = _get_global_group().nranks if group is None else group.nranks
L
Ligoml 已提交
1170
    assert gsrc >= 0, "src rank out of group, need global rank"
K
kuizhiqing 已提交
1171 1172

    if rank != gsrc:
1173 1174 1175 1176
        tensor_list = []
        for _ in range(nranks):
            tensor_list.append(tensor)
    temp = paddle.concat(tensor_list, axis=0)
L
lilong12 已提交
1177
    if in_dygraph_mode():
1178
        task = group.process_group.scatter(temp, tensor, gsrc)
1179
        if sync_op:
1180 1181 1182 1183 1184
            task.wait()
            return None
        else:
            return task

1185
    use_calc_stream = sync_op
L
lilong12 已提交
1186
    if _non_static_mode():
L
Ligoml 已提交
1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198
        return _legacy_C_ops.c_scatter(
            temp,
            tensor,
            'use_calc_stream',
            use_calc_stream,
            'ring_id',
            ring_id,
            'nranks',
            nranks,
            'root',
            gsrc,
        )
W
wanghuancoder 已提交
1199
    op_type = 'c_scatter'
L
Ligoml 已提交
1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214
    check_variable_and_dtype(
        tensor,
        'tensor',
        [
            'float16',
            'float32',
            'float64',
            'int32',
            'int64',
            'int8',
            'uint8',
            'bool',
        ],
        'scatter',
    )
1215
    helper = LayerHelper(op_type, **locals())
L
Ligoml 已提交
1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226
    helper.append_op(
        type=op_type,
        inputs={'X': [temp]},
        outputs={'Out': [tensor]},
        attrs={
            'ring_id': ring_id,
            'root': gsrc,
            'use_calc_stream': use_calc_stream,
            'nranks': nranks,
        },
    )
1227 1228


1229
def alltoall(in_tensor_list, out_tensor_list, group=None, sync_op=True):
L
lilong12 已提交
1230
    """
1231 1232 1233 1234 1235 1236 1237 1238 1239 1240
    Scatter tensors in in_tensor_list to all participators averagely and gather the result tensors in out_tensor_list.
    As shown below, the in_tensor_list in GPU0 includes 0_0 and 0_1, and GPU1 includes 1_0 and 1_1.
    Through alltoall operator, the 0_0 in GPU0 will be sent to GPU0 and 0_1 to GPU1, 1_0 in GPU1 sent to GPU0 and 1_1 to GPU1.
    Finally the out_tensor_list in GPU0 includes 0_0 and 1_0, and GPU1 includes 0_1 and 1_1.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/alltoall.png
        :width: 800
        :alt: alltoall
        :align: center

L
lilong12 已提交
1241 1242
    Args:
        in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type
1243
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1244
        out_tensor_list (list): A list of output Tensors. The data type of its elements should be the same as the
L
lilong12 已提交
1245 1246
            data type of the input Tensors.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
1247 1248
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.

L
lilong12 已提交
1249 1250
    Returns:
        None.
L
Ligoml 已提交
1251

L
lilong12 已提交
1252 1253
    Examples:
        .. code-block:: python
1254

L
lilong12 已提交
1255 1256
            # required: distributed
            import paddle
1257 1258 1259
            import paddle.distributed as dist

            dist.init_parallel_env()
L
lilong12 已提交
1260
            out_tensor_list = []
1261 1262 1263
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
                data2 = paddle.to_tensor([[7, 8, 9], [10, 11, 12]])
L
lilong12 已提交
1264
            else:
1265 1266 1267 1268 1269 1270
                data1 = paddle.to_tensor([[13, 14, 15], [16, 17, 18]])
                data2 = paddle.to_tensor([[19, 20, 21], [22, 23, 24]])
            dist.alltoall([data1, data2], out_tensor_list)
            print(out_tensor_list)
            # [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]] (2 GPUs, out for rank 0)
            # [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]] (2 GPUs, out for rank 1)
L
lilong12 已提交
1271 1272 1273 1274
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
1275
    if in_dygraph_mode():
1276
        group = _get_default_group() if group is None else group
1277
        backend = _group_map_backend[group]
L
Ligoml 已提交
1278
        assert backend != 'gloo', "backend gloo is not supported yet"
1279 1280 1281
    else:
        ring_id = 0 if group is None else group.id

L
lilong12 已提交
1282
    temp = paddle.concat(in_tensor_list, axis=0)
李季 已提交
1283
    nranks = len(in_tensor_list)
L
lilong12 已提交
1284
    if in_dygraph_mode():
1285 1286 1287 1288 1289 1290
        if len(out_tensor_list) == 0:
            tensor_shape = list(in_tensor_list[0].shape)
            tensor_shape[0] *= nranks
            out = paddle.empty(tensor_shape, in_tensor_list[0].dtype)
        else:
            out = paddle.concat(out_tensor_list, axis=0)
1291 1292 1293 1294 1295 1296
        task = group.process_group.alltoall(temp, out)
        task.wait()
        out_tensor_list.clear()
        out_tensor_list.extend(paddle.split(out, nranks, 0))
        return

1297
    use_calc_stream = sync_op
J
Jiabin Yang 已提交
1298
    if _non_static_mode():
L
Ligoml 已提交
1299 1300 1301
        out = _legacy_C_ops.alltoall(
            temp, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id
        )
L
lilong12 已提交
1302
    else:
W
wanghuancoder 已提交
1303 1304 1305
        op_type = 'alltoall'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(
L
Ligoml 已提交
1306 1307
            dtype=in_tensor_list[0].dtype
        )
W
wanghuancoder 已提交
1308

L
lilong12 已提交
1309
        if not isinstance(in_tensor_list, list):
L
Ligoml 已提交
1310 1311 1312
            raise ValueError(
                "The type of 'in_tensor_list' for all_to_all " "should be list."
            )
L
lilong12 已提交
1313 1314
        for elem in in_tensor_list:
            check_variable_and_dtype(
L
Ligoml 已提交
1315 1316
                elem,
                'in_tensor_list',
L
lilong12 已提交
1317
                ['float16', 'float32', 'float64', 'int32', 'int64'],
L
Ligoml 已提交
1318 1319
                'all_to_all',
            )
L
lilong12 已提交
1320
        if not isinstance(out_tensor_list, list):
L
Ligoml 已提交
1321 1322 1323 1324
            raise ValueError(
                "The type of 'out_tensor_list' for all_to_all "
                "should be list."
            )
L
lilong12 已提交
1325
        if len(out_tensor_list) != 0:
L
Ligoml 已提交
1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337
            raise ValueError(
                "The 'out_tensor_list' for all_to_all " "must be an empty list."
            )
        helper.append_op(
            type=op_type,
            inputs={'X': [temp]},
            outputs={'Out': [out]},
            attrs={
                'ring_id': ring_id,
                'use_calc_stream': use_calc_stream,
            },
        )
L
lilong12 已提交
1338 1339 1340
    out_tensor_list.extend(paddle.split(out, nranks, 0))


L
Ligoml 已提交
1341 1342 1343 1344 1345 1346 1347 1348
def alltoall_single(
    in_tensor,
    out_tensor,
    in_split_sizes=None,
    out_split_sizes=None,
    group=None,
    sync_op=True,
):
1349 1350 1351 1352 1353 1354 1355
    """
    Scatter a single input tensor to all participators and gather the received tensors in out_tensor.

    .. note::
        ``alltoall_single`` is only supported in eager mode.

    Args:
1356
        in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1357
        out_tensor (Tensor): Output Tensor. The data type should be the same as the data type of the input Tensor.
L
Ligoml 已提交
1358
        in_split_sizes (list[int], optional): Split sizes of ``in_tensor`` for dim[0]. If not given, dim[0] of ``in_tensor``
1359
            must be divisible by group size and ``in_tensor`` will be scattered averagely to all participators. Default: None.
L
Ligoml 已提交
1360
        out_split_sizes (list[int], optional): Split sizes of ``out_tensor`` for dim[0]. If not given, dim[0] of ``out_tensor``
1361 1362
            must be divisible by group size and ``out_tensor`` will be gathered averagely from all participators. Default: None.
        group (Group, optional): The group instance return by ``new_group`` or None for global default group. Default: None.
1363 1364
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.

1365
    Returns:
1366 1367
        None, if ``sync_op`` is set to ``True``; ``Task`` of ``group``, if ``sync_op`` is set to ``False``.

1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378
    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            size = dist.get_world_size()

1379 1380 1381 1382
            # case 1 (2 GPUs)
            data = paddle.arange(2, dtype='int64') + rank * 2
            # data for rank 0: [0, 1]
            # data for rank 1: [2, 3]
1383
            output = paddle.empty([2], dtype='int64')
1384 1385
            dist.alltoall_single(data, output)
            print(output)
1386 1387 1388
            # output for rank 0: [0, 2]
            # output for rank 1: [1, 3]

1389
            # case 2 (2 GPUs)
1390
            in_split_sizes = [i + 1 for i in range(size)]
1391 1392
            # in_split_sizes for rank 0: [1, 2]
            # in_split_sizes for rank 1: [1, 2]
1393
            out_split_sizes = [rank + 1 for i in range(size)]
1394 1395 1396 1397 1398
            # out_split_sizes for rank 0: [1, 1]
            # out_split_sizes for rank 1: [2, 2]
            data = paddle.ones([sum(in_split_sizes), size], dtype='float32') * rank
            # data for rank 0: [[0., 0.], [0., 0.], [0., 0.]]
            # data for rank 1: [[1., 1.], [1., 1.], [1., 1.]]
1399 1400
            output = paddle.empty([(rank + 1) * size, size], dtype='float32')
            group = dist.new_group([0, 1])
1401
            task = dist.alltoall_single(data,
1402 1403 1404
                                        output,
                                        in_split_sizes,
                                        out_split_sizes,
1405
                                        sync_op=False,
1406 1407
                                        group=group)
            task.wait()
1408
            print(output)
1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419
            # output for rank 0: [[0., 0.], [1., 1.]]
            # output for rank 1: [[0., 0.], [0., 0.], [1., 1.], [1., 1.]]

    """
    if group is not None and not group.is_member():
        return

    assert in_dygraph_mode(), "Only suppport alltoall_single in eager mode."
    # _check_single_tensor

    group = _get_default_group() if group is None else group
1420
    backend = _group_map_backend[group]
L
Ligoml 已提交
1421
    assert backend != 'gloo', "backend gloo is not supported yet"
1422

1423 1424 1425
    in_split_sizes = [] if in_split_sizes is None else in_split_sizes
    out_split_sizes = [] if out_split_sizes is None else out_split_sizes

L
Ligoml 已提交
1426 1427 1428
    task = group.process_group.alltoall_single(
        in_tensor, out_tensor, in_split_sizes, out_split_sizes
    )
1429
    if sync_op:
1430 1431 1432 1433 1434 1435
        task.wait()
        return
    else:
        return task


S
ShenLiang 已提交
1436 1437 1438 1439
def _get_group_rank(global_rank, group=None):
    return global_rank if group is None else group.get_group_rank(global_rank)


1440
def send(tensor, dst=0, group=None, sync_op=True):
L
lilong12 已提交
1441 1442 1443 1444 1445
    """
    Send a tensor to the receiver.

    Args:
        tensor (Tensor): The Tensor to send. Its data type
1446
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
L
lilong12 已提交
1447
        dst (int): The destination rank id.
L
lilong12 已提交
1448
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
1449 1450
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.

L
lilong12 已提交
1451 1452 1453 1454 1455
    Returns:
        None.

    Examples:
        .. code-block:: python
1456

L
lilong12 已提交
1457
            # required: distributed
L
lilong12 已提交
1458
            import paddle
1459
            import paddle.distributed as dist
1460

1461 1462
            dist.init_parallel_env()
            if dist.get_rank() == 0:
L
lilong12 已提交
1463
                data = paddle.to_tensor([7, 8, 9])
1464
                dist.send(data, dst=1)
L
lilong12 已提交
1465
            else:
1466 1467 1468 1469
                data = paddle.to_tensor([1, 2, 3])
                dist.recv(data, src=0)
            print(data)
            # [7, 8, 9] (2 GPUs)
L
lilong12 已提交
1470 1471 1472
    """
    if group is not None and not group.is_member():
        return
S
ShenLiang 已提交
1473
    dst = _get_group_rank(dst, group)
L
lilong12 已提交
1474
    if in_dygraph_mode():
1475
        group = _get_default_group() if group is None else group
1476
        backend = _group_map_backend[group]
L
Ligoml 已提交
1477
        assert backend != 'gloo', "backend gloo is not supported yet"
S
ShenLiang 已提交
1478
        task = group.process_group.send(tensor, dst)
1479
        if sync_op:
1480 1481 1482 1483 1484
            task.wait()
            return None
        else:
            return task

1485
    use_calc_stream = sync_op
L
lilong12 已提交
1486 1487
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
1488
    if _non_static_mode():
L
Ligoml 已提交
1489 1490 1491 1492 1493 1494 1495 1496 1497
        return _legacy_C_ops.send_v2(
            tensor,
            'use_calc_stream',
            use_calc_stream,
            'ring_id',
            ring_id,
            'peer',
            dst,
        )
W
wanghuancoder 已提交
1498
    op_type = 'send_v2'
L
lilong12 已提交
1499
    check_variable_and_dtype(
L
Ligoml 已提交
1500 1501 1502 1503 1504
        tensor,
        'tensor',
        ['float16', 'float32', 'float64', 'int32', 'int64'],
        'send',
    )
L
lilong12 已提交
1505 1506

    helper = LayerHelper(op_type, **locals())
L
Ligoml 已提交
1507 1508 1509 1510 1511 1512 1513 1514 1515
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        attrs={
            'ring_id': ring_id,
            'peer': dst,
            'use_calc_stream': use_calc_stream,
        },
    )
L
lilong12 已提交
1516 1517


1518
def recv(tensor, src=0, group=None, sync_op=True):
L
lilong12 已提交
1519 1520 1521 1522 1523
    """
    Receive a tensor to the sender.

    Args:
        tensor (Tensor): The Tensor to receive. Its data type
1524
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
L
lilong12 已提交
1525
        src (int): The source rank id.
L
lilong12 已提交
1526
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
1527 1528
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.

L
lilong12 已提交
1529 1530 1531 1532 1533
    Returns:
        None.

    Examples:
        .. code-block:: python
1534

L
lilong12 已提交
1535
            # required: distributed
L
lilong12 已提交
1536
            import paddle
1537
            import paddle.distributed as dist
1538

1539 1540
            dist.init_parallel_env()
            if dist.get_rank() == 0:
L
lilong12 已提交
1541
                data = paddle.to_tensor([7, 8, 9])
1542
                dist.send(data, dst=1)
L
lilong12 已提交
1543
            else:
1544 1545 1546 1547
                data = paddle.to_tensor([1, 2, 3])
                dist.recv(data, src=0)
            print(data)
            # [7, 8, 9] (2 GPUs)
L
lilong12 已提交
1548 1549 1550
    """
    if group is not None and not group.is_member():
        return
1551

S
ShenLiang 已提交
1552
    src = _get_group_rank(src, group)
L
lilong12 已提交
1553
    if in_dygraph_mode():
1554
        group = _get_default_group() if group is None else group
1555
        backend = _group_map_backend[group]
L
Ligoml 已提交
1556
        assert backend != 'gloo', "backend gloo is not supported yet"
S
ShenLiang 已提交
1557
        task = group.process_group.recv(tensor, src)
1558
        if sync_op:
1559 1560 1561 1562 1563
            task.wait()
            return None
        else:
            return task

1564
    use_calc_stream = sync_op
L
lilong12 已提交
1565 1566
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
1567
    if _non_static_mode():
L
Ligoml 已提交
1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580
        return _legacy_C_ops.recv_v2(
            tensor,
            'use_calc_stream',
            use_calc_stream,
            'ring_id',
            ring_id,
            'peer',
            src,
            'dtype',
            tensor.dtype,
            'out_shape',
            tensor.shape,
        )
W
wanghuancoder 已提交
1581
    op_type = 'recv_v2'
L
lilong12 已提交
1582
    check_variable_and_dtype(
L
Ligoml 已提交
1583 1584 1585 1586 1587
        tensor,
        'tensor',
        ['float16', 'float32', 'float64', 'int32', 'int64'],
        'recv',
    )
L
lilong12 已提交
1588
    helper = LayerHelper(op_type, **locals())
L
Ligoml 已提交
1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599
    helper.append_op(
        type=op_type,
        outputs={'Out': [tensor]},
        attrs={
            'ring_id': ring_id,
            'peer': src,
            'out_shape': tensor.shape,
            'dtype': tensor.dtype,
            'use_calc_stream': use_calc_stream,
        },
    )
1600 1601 1602 1603


def _check_single_tensor(tensor, tensor_name):
    if not isinstance(tensor, (core.eager.Tensor, paddle.Tensor)):
L
Ligoml 已提交
1604 1605 1606 1607 1608 1609
        raise RuntimeError(
            "Invalid function argument. Expected parameter {}"
            "to be of type paddle.Tensor, but it's {}".format(
                tensor_name, type(tensor)
            )
        )
1610 1611 1612


def _check_tensor_list(tensor_list, tensor_name):
L
Ligoml 已提交
1613 1614 1615 1616 1617 1618 1619
    if not isinstance(tensor_list, list) or not all(
        isinstance(t, (core.eager.Tensor, paddle.Tensor)) for t in tensor_list
    ):
        raise RuntimeError(
            "Invalid function argument. Expected parameter {}"
            "to be of type paddle.Tensor".format(tensor_name)
        )
1620 1621 1622 1623 1624 1625 1626 1627


def isend(tensor, dst, group=None):
    """
    Sends a tensor asynchronously

    Args:
        tensor (Tensor): The Tensor to send. Its data type
1628
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1629 1630
        dst (int): The destination rank.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
L
Ligoml 已提交
1631

1632 1633 1634
    Returns:
        A distributed task object.

L
Ligoml 已提交
1635
    Warning:
1636 1637 1638 1639 1640 1641 1642 1643 1644 1645
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
1646
            if dist.get_rank() == 0:
1647
                data = paddle.to_tensor([7, 8, 9])
1648
                task = dist.isend(data, dst=1)
1649 1650
            else:
                data = paddle.to_tensor([1, 2, 3])
1651
                task = dist.irecv(data, src=0)
1652 1653
            task.wait()
            print(data)
1654
            # [7, 8, 9] (2 GPUs)
1655 1656 1657 1658 1659 1660 1661 1662

    """
    _check_single_tensor(tensor, "tensor")
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
1663
        backend = _group_map_backend[group]
L
Ligoml 已提交
1664
        assert backend != 'gloo', "backend gloo is not supported yet"
1665
        group_dst_rank = group.get_group_rank(dst)
L
Ligoml 已提交
1666
        assert group_dst_rank >= 0, "dst rank out of group, need global rank"
1667 1668
        return group.process_group.send(tensor, group_dst_rank)
    else:
1669
        raise RuntimeError("Only support eager dygraph mode.")
1670 1671 1672 1673 1674 1675 1676 1677


def irecv(tensor, src=None, group=None):
    """
    Receive a tensor to the sender.

    Args:
        tensor (Tensor): The Tensor to receive. Its data type
1678
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1679 1680 1681 1682
        src (int): The source rank id.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.

    Returns:
1683
        A distributed task object.
1684

L
Ligoml 已提交
1685
    Warning:
1686 1687 1688 1689 1690 1691 1692 1693 1694 1695
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
1696
            if dist.get_rank() == 0:
1697
                data = paddle.to_tensor([7, 8, 9])
1698
                task = dist.isend(data, dst=1)
1699 1700
            else:
                data = paddle.to_tensor([1, 2, 3])
1701
                task = dist.irecv(data, src=0)
1702 1703
            task.wait()
            print(data)
1704
            # [7, 8, 9] (2 GPUs)
1705 1706 1707 1708 1709 1710 1711
    """
    _check_single_tensor(tensor, "tensor")
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
1712
        backend = _group_map_backend[group]
L
Ligoml 已提交
1713
        assert backend != 'gloo', "backend gloo is not supported yet"
1714
        group_src_rank = group.get_group_rank(src)
L
Ligoml 已提交
1715
        assert group_src_rank >= 0, "src rank out of group, need global rank"
1716 1717
        return group.process_group.recv(tensor, group_src_rank)
    else:
1718
        raise RuntimeError("Only support eager dygraph mode.")
1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733


class P2POp(object):
    """
    A class that makes point-to-point operations for "batch_isend_irecv".

    This class creates the type of P2P operation, communication buffer, peer rank,
    Group. Instances of this class will be passed to
    ``paddle.distributed.batch_isend_irecv`` for point-to-point communication.

    Args:
        op (callable): A function to send data to or receive data from a peer process.
            The type of ``op`` is either ``paddle.distributed.isend`` or ``paddle.distributed.irecv``.
        tensor (Tensor): Tensor to send or receive.
        peer (int): The destination or source rank.
L
Ligoml 已提交
1734
        group (Group, optional): The group instance return by new_group or None for global
1735 1736 1737 1738 1739 1740
            default group. Default: None.

    """

    def __init__(self, op, tensor, peer, group=None):
        if op not in [isend, irecv]:
L
Ligoml 已提交
1741 1742 1743 1744 1745
            raise RuntimeError(
                "Invalid ``op`` function. Expected ``op`` "
                "to be of type ``paddle.distributed.isend`` or "
                "``paddle.distributed.irecv``."
            )
1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770
        _check_single_tensor(tensor, "tensor")

        self.op = op
        self.tensor = tensor
        self.peer = peer
        self.group = _get_default_group() if group is None else group


@contextlib.contextmanager
def _with_batch_p2p_guard(backend):
    if backend == "nccl":
        core.ProcessGroupNCCL.group_start()
    try:
        yield
    finally:
        if backend == "nccl":
            core.ProcessGroupNCCL.group_end()


def _check_p2p_op_list(p2p_op_list):
    """
    Helper to check that the ``p2p_op_list`` is a list of P2POp instances and
    all ops use the same backend.
    """
    if not isinstance(p2p_op_list, list) or not all(
L
Ligoml 已提交
1771 1772 1773 1774 1775 1776
        isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list
    ):
        raise RuntimeError(
            "Invalid ``p2p_op_list``. Each op is expected to "
            "to be of type ``paddle.distributed.P2POp``."
        )
1777 1778

    backend = _group_map_backend[p2p_op_list[0].group]
L
Ligoml 已提交
1779 1780 1781
    if not all(
        backend == _group_map_backend[p2p_op.group] for p2p_op in p2p_op_list
    ):
1782 1783 1784 1785 1786 1787 1788
        raise RuntimeError("All groups need to use the same backend.")


def batch_isend_irecv(p2p_op_list):
    """
    Send or Receive a batch of tensors asynchronously and return a list of requests.

L
Ligoml 已提交
1789
    Process each of the point-to-point operations in ``p2p_op_list`` and return the
1790 1791 1792
    corresponding tasks. NCCL are currently supported.

    Args:
1793
        p2p_op_list (List[P2POp]): A list of point-to-point operations(type of each operator is
1794 1795 1796 1797 1798 1799
            ``paddle.distributed.P2POp``). The order of the isend/irecv in the list
            matters and it needs to match with corresponding isend/irecv on the
            remote end.

    Returns:
        A list of distributed tasks returned by calling the corresponding
L
Ligoml 已提交
1800
        op in the op_list.
1801

L
Ligoml 已提交
1802
    Warning:
1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed

            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            world_size = dist.get_world_size()

            send_t = paddle.arange(2) + rank
            # paddle.tensor([0, 1])  # Rank-0
            # paddle.tensor([1, 2])  # Rank-1

            recv_t = paddle.empty(shape=[2], dtype=send_t.dtype)

            send_op = dist.P2POp(dist.isend, send_t, (rank + 1) % world_size)
            recv_op = dist.P2POp(dist.irecv, recv_t, (rank - 1 + world_size) % world_size)

            tasks = dist.batch_isend_irecv([send_op, recv_op])

            for task in tasks:
                task.wait()
L
Ligoml 已提交
1830

1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857
            print(recv_t)
            # paddle.tensor([1, 2])     # Rank-0
            # paddle.tensor([0, 1])     # Rank-1
    """
    _check_p2p_op_list(p2p_op_list)
    group = p2p_op_list[0].group
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
        backend = _group_map_backend[group]
        tasks = []
        with _with_batch_p2p_guard(backend):
            for p2p_op in p2p_op_list:
                op = p2p_op.op
                tensor = p2p_op.tensor
                peer = p2p_op.peer
                comm_group = p2p_op.group
                task = op(tensor, peer, comm_group)
                if task is not None:
                    tasks.append(task)
        return tasks
    else:
        raise RuntimeError("Don't support static graph mode currently.")


L
Ligoml 已提交
1858 1859 1860
def reduce_scatter(
    tensor, tensor_list, op=ReduceOp.SUM, group=None, sync_op=True
):
1861 1862 1863 1864
    """
    Reduces, then scatters a list of tensors to all processes in a group

    Args:
1865
        tensor (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1866
        tensor_list (list[Tensor]): List of tensors to reduce and scatter. Every element in the list must be a Tensor whose data type
1867
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1868
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
L
Ligoml 已提交
1869
        group (Group, optional): The group instance return by new_group or None for global
1870
            default group. Default: None.
1871
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
1872 1873

    Returns:
1874 1875 1876 1877
        Async task handle, if sync_op is set to False.
        None, if sync_op or if not part of the group.

    Warning:
1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888
        This API only supports the dygraph mode.


    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
1889 1890 1891
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([0, 1])
                data2 = paddle.to_tensor([2, 3])
1892
            else:
1893 1894 1895 1896 1897 1898
                data1 = paddle.to_tensor([4, 5])
                data2 = paddle.to_tensor([6, 7])
            dist.reduce_scatter(data1, [data1, data2])
            print(data1)
            # [4, 6] (2 GPUs, out for rank 0)
            # [8, 10] (2 GPUs, out for rank 1)
1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909

    """
    _check_single_tensor(tensor, "tensor")
    _check_tensor_list(tensor_list, "tensor_list")

    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        op_type = _get_reduce_op(op, "reduce_scatter")
        group = _get_default_group() if group is None else group
1910
        backend = _group_map_backend[group]
L
Ligoml 已提交
1911
        assert backend != 'gloo', "backend gloo is not supported yet"
1912 1913 1914

        temp = paddle.concat(tensor_list, axis=0)
        task = group.process_group._reduce_scatter_base(tensor, temp, op_type)
1915
        if sync_op:
1916 1917 1918 1919 1920 1921 1922 1923
            task.wait()
            return None
        else:
            return task
    else:
        raise RuntimeError("Don't support static graph mode currently.")


L
Ligoml 已提交
1924 1925 1926
def _reduce_scatter_base(
    output, input, op=ReduceOp.SUM, group=None, sync_op=True
):
1927 1928 1929 1930
    """
    Reduces, then scatters a flattened tensor to all processes in a group.

    Args:
1931 1932 1933
        output (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
        input (Tensor): Input tensor that is of size output tensor size times world size. Its data type
            should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
1934
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
1935 1936
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.
1937 1938
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.

1939
    Returns:
1940 1941
        Async task handle, if sync_op is set to False.
        None, if sync_op or if not part of the group.
1942 1943 1944 1945 1946 1947 1948 1949 1950 1951

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
1952 1953 1954 1955 1956
            data = paddle.arange(4) + rank
            # [0, 1, 2, 3] (2 GPUs, for rank 0)
            # [1, 2, 3, 4] (2 GPUs, for rank 1)
            output = paddle.empty(shape=[2], dtype=data.dtype)
            dist.collective._reduce_scatter_base(output, data)
1957
            print(output)
1958 1959
            # [1, 3] (2 GPUs, out for rank 0)
            # [5, 7] (2 GPUs, out for rank 1)
1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971

    """
    _check_single_tensor(output, "output")
    _check_single_tensor(input, "input")

    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        op_type = _get_reduce_op(op, "_reduce_scatter_base")
        group = _get_default_group() if group is None else group
        task = group.process_group._reduce_scatter_base(output, input, op_type)
1972
        if sync_op:
1973 1974 1975 1976 1977 1978
            task.wait()
            return None
        else:
            return task
    else:
        raise RuntimeError("Don't support static graph mode currently.")