collective.py 97.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import os
17 18
import pickle
import io
19 20
import datetime
import time
21
from ..fluid.layer_helper import LayerHelper
22
from ..fluid.framework import Variable
23
from ..fluid.framework import in_dygraph_mode
24
from ..fluid.framework import OpProtoHolder
J
Jiabin Yang 已提交
25
from ..fluid.framework import _non_static_mode
26
from ..fluid.framework import _in_legacy_dygraph
27
from ..fluid.framework import convert_np_dtype_to_dtype_
J
Jiangxinz 已提交
28
from ..fluid.framework import _varbase_creator
29 30 31 32
from ..fluid.data_feeder import convert_dtype
from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.data_feeder import check_type
from ..fluid.data_feeder import check_dtype
33 34
from ..fluid.layers.tensor import fill_constant
from ..fluid.layers import utils
B
Baibaifan 已提交
35
from ..fluid.dygraph import layers
36 37 38 39
from ..fluid.dygraph.parallel import prepare_context
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
40
from paddle import _C_ops, _legacy_C_ops
J
Jiangxinz 已提交
41
import paddle.fluid.dygraph_utils as dygraph_utils
42
import contextlib
43

44
__all__ = []
45 46 47


class ReduceOp:
L
lilong12 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
    """
    Specify the type of operation used for element-wise reductions.
    It should be one of the following values:

        ReduceOp.SUM

        ReduceOp.MAX

        ReduceOp.MIN

        ReduceOp.PROD

    Examples:
        .. code-block:: python

63
            # required: distributed
L
lilong12 已提交
64
            import paddle
65
            import paddle.distributed as dist
L
lilong12 已提交
66

67 68 69
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
L
lilong12 已提交
70
            else:
71 72 73 74
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.all_reduce(data, op=dist.ReduceOp.SUM)
            print(data)
            # [[5, 7, 9], [5, 7, 9]] (2 GPUs)
L
lilong12 已提交
75
    """
76 77 78 79
    SUM = 0
    MAX = 1
    MIN = 2
    PROD = 3
80
    AVG = 4
81 82


K
kuizhiqing 已提交
83 84 85 86
class Group():
    """
    The abstract representation of group.
    """
87

88
    def __init__(self, rank, rank_num, id=0, ranks=[], pg=None, name=None):
89 90
        self.rank = rank
        self.nranks = rank_num
K
kuizhiqing 已提交
91 92
        self.id = id
        self.ranks = ranks
93 94
        self.pg = pg
        self.name = name
K
kuizhiqing 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108

    def is_member(self):
        if self.rank < 0:
            return False
        if self.nranks < 2:
            return False
        return True

    def get_group_rank(self, rank):
        if self.is_member() and rank in self.ranks:
            return self.ranks.index(rank)
        else:
            return -1

109 110 111 112
    @property
    def process_group(self):
        return self.pg

113 114 115 116
    def __repr__(self):
        debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format(
            self.rank, self.nranks, self.id)
        debug_str += ", ".join(map(str, self.ranks))
117 118
        debug_str += "; name: "
        debug_str += self.name if self.name else "None"
119 120
        return debug_str

K
kuizhiqing 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134

_global_env = None


def _get_global_env():
    global _global_env
    if not _global_env:
        _global_env = paddle.distributed.ParallelEnv()
    return _global_env


# group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group]
_group_map = {}
135
_global_env_gid = 0
K
kuizhiqing 已提交
136

137 138 139 140
# group map by name : the map of all groups from their names
# Dict[name, Group]
_group_map_by_name = {}

141 142 143 144
# backend map by group : the map of all backend from their groups
# Dict[group, backend]
_group_map_backend = {}

145 146 147
# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"

148
_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter', 'xccl']
149 150
_default_store = None  # the default tcp store
_default_backend = None
151 152
_default_timeout = datetime.timedelta(seconds=1800)
_start_ring_id = 0
153

K
kuizhiqing 已提交
154

L
lilong12 已提交
155 156 157 158 159 160 161 162 163 164
def _set_default_backend(backend):
    global _default_backend
    _default_backend = backend


def _set_default_store(store):
    global _default_store
    _default_store = store


K
kuizhiqing 已提交
165 166
def _get_group_map():
    global _group_map
167
    if _global_env_gid not in _group_map:
K
kuizhiqing 已提交
168
        genv = _get_global_env()
169 170 171
        _group_map[_global_env_gid] = Group(genv.rank,
                                            genv.world_size,
                                            ranks=list(range(genv.world_size)))
K
kuizhiqing 已提交
172 173 174 175
    return _group_map


def _get_global_group():
176
    return _get_group_map()[_global_env_gid]
K
kuizhiqing 已提交
177 178


179 180 181 182 183 184
def _get_group_map_by_name():
    global _group_map_by_name
    return _group_map_by_name


def _get_default_group():
L
lilong12 已提交
185
    global _group_map_by_name
186 187
    assert is_initialized(), ("Call paddle.distributed.init_parallel_env first "
                              "to initialize the distributed environment.")
188 189 190
    return _get_group_map_by_name()[_default_group_name]


L
lilong12 已提交
191 192 193 194 195 196 197 198 199 200 201 202
def _set_group_map(gid, group):
    global _group_map
    assert gid not in _group_map
    _group_map[gid] = group


def _set_group_map_by_name(name, group):
    global _group_map_by_name
    assert name not in _group_map_by_name
    _group_map_by_name[name] = group


203 204 205 206 207 208
def _set_group_map_backend(group, backend):
    global _group_map_backend
    assert group not in _group_map_backend
    _group_map_backend[group] = backend


K
kuizhiqing 已提交
209
def _new_ring_id():
210 211 212 213 214 215 216
    # NOTE(liyurui): For compatible reason, auto parallel and eager mode relay on previous syntax.
    if in_dygraph_mode():
        global _start_ring_id
        _start_ring_id += 1
        return _start_ring_id + max(_get_global_env().nrings, 9)
    else:
        return len(_get_group_map()) + max(_get_global_env().nrings, 9)
K
kuizhiqing 已提交
217 218


219 220 221 222 223 224 225 226 227 228 229 230 231
def _get_reduce_op(reduce_op, func_name):
    if reduce_op == ReduceOp.SUM:
        return core.ReduceOp.SUM
    elif reduce_op == ReduceOp.MAX:
        return core.ReduceOp.MAX
    elif reduce_op == ReduceOp.MIN:
        return core.ReduceOp.MIN
    elif reduce_op == ReduceOp.PROD:
        return core.ReduceOp.PRODUCT
    else:
        raise ValueError("Unknown reduce_op type for {}.".format(func_name))


K
kuizhiqing 已提交
232 233 234 235 236 237
def get_group(id=0):
    """

    Get group instance by group id.

    Args:
K
kuizhiqing 已提交
238
        id (int): the group id. Default value is 0.
K
kuizhiqing 已提交
239 240 241 242 243 244 245 246 247 248 249 250 251 252

    Returns:
        Group: the group instance.

    Examples:
        .. code-block:: python

            ...
            gid = paddle.distributed.new_group([2,4,6])
            paddle.distributed.get_group(gid.id)

    """

    gm = _get_group_map()
J
Jiangxinz 已提交
253
    return gm[id] if id in gm else None
K
kuizhiqing 已提交
254 255


256 257 258 259 260 261
def _new_process_group_impl(backend,
                            store,
                            rank,
                            world_size,
                            group_name,
                            pg_options,
L
lilong12 已提交
262 263 264
                            group_id=0,
                            src_rank=None,
                            dst_rank=None):
265
    pg = None
266
    genv = _get_global_env()
L
lilong12 已提交
267 268 269 270
    if backend != 'heter':
        assert src_rank is None and dst_rank is None, (
            "src_rank and dst_rank "
            "can only be set for heter backend.")
L
lilong12 已提交
271
    assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
272
    if backend == "gloo":
273 274
        place = core.CPUPlace()
        pg = core.ProcessGroupGloo(store, rank, world_size, place, group_id)
275
    elif backend == "nccl":
276 277
        place = core.CUDAPlace(genv.device_id)
        pg = core.ProcessGroupNCCL(store, rank, world_size, place, group_id)
278
    elif backend == "hccl":
279 280
        place = core.NPUPlace(genv.device_id)
        pg = core.ProcessGroupHCCL(store, rank, world_size, place, group_id)
281 282 283
    elif backend == "xccl":
        place = core.CustomPlace(genv.device_type, genv.device_id)
        pg = core.ProcessGroupCustom(store, rank, world_size, place, group_id)
284
    elif backend == "heter":
285 286 287 288 289
        place = None
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(genv.device_id)
        elif core.is_compiled_with_npu():
            place = core.NPUPlace(genv.device_id)
290 291 292 293 294 295 296 297 298 299 300 301 302
        cluster_id = int(os.getenv("CLUSTER_ID", "-1"))
        assert cluster_id >= 0, "please set the CLUSTER_ID variable."
        cluster_size = os.getenv("CLUSTER_SIZE", None)
        assert cluster_size, "please set the CLUSTER_SIZE variable."
        cluster_size = cluster_size.split(",")
        cluster_size = [int(s) for s in cluster_size]
        switch_ep = os.getenv("CLUSTER_SWITCH", None)
        assert switch_ep, "please set the CLUSTER_SWITCH variable."
        cluster_size_cumsum = np.cumsum(cluster_size)
        cluster_offset = 0 if cluster_id == 0 else cluster_size_cumsum[
            cluster_id - 1]
        global_rank = cluster_offset + rank
        global_world_size = cluster_size_cumsum[-1]
303
        global_rank, global_world_size = _get_global_config(backend, rank)
304 305 306 307 308 309 310 311 312 313 314 315 316
        pg = core.ProcessGroupHeter(store,
                                    rank=global_rank,
                                    world_size=global_world_size,
                                    place=place,
                                    gid=group_id,
                                    local_rank=rank,
                                    local_size=world_size,
                                    gloo_rank=cluster_id,
                                    gloo_size=len(cluster_size),
                                    with_switch=True,
                                    switch_endpoint=switch_ep,
                                    src_rank=src_rank,
                                    dst_rank=dst_rank)
317 318 319 320

    return pg


S
ShenLiang 已提交
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
def barrier(group=None):
    """

    Barrier among all participators in the group.

    Args:
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            paddle.distributed.barrier()
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
345
    if in_dygraph_mode():
346 347 348 349 350
        group = _get_default_group() if group is None else group
        task = group.process_group.barrier()
        task.wait()
        return

S
ShenLiang 已提交
351 352 353
    ring_id = 0 if group is None else group.id

    temp = fill_constant([1], dtype="int32", value="1")
J
Jiabin Yang 已提交
354
    if _non_static_mode():
355
        return _legacy_C_ops.barrier(temp, temp, 'ring_id', ring_id)
W
wanghuancoder 已提交
356 357 358

    op_type = 'barrier'

S
ShenLiang 已提交
359 360 361
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'group' for barrier must be int.")
    helper = LayerHelper(op_type, **locals())
362 363 364 365
    helper.append_op(type=op_type,
                     inputs={'X': [temp]},
                     outputs={'Out': [temp]},
                     attrs={'ring_id': ring_id})
S
ShenLiang 已提交
366 367


L
lilong12 已提交
368 369 370 371 372 373 374
# _custom_gid provides a way for users to
# set the group id, which is usually useful
# to be compatible with the static mode.
_custom_gid = None


def _set_custom_gid(gid):
375
    global _custom_gid
L
lilong12 已提交
376 377 378
    _custom_gid = gid


379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415
def _barrier_by_tcp_store(group_name, store, timeout):
    global_rank = paddle.distributed.get_rank()
    global_world_size = paddle.distributed.get_world_size()

    if global_world_size < 2:
        return

    barrier_prefix = "Barrier/" + group_name + "/"
    is_master = (global_rank == 0)

    def _check_keys_ready(wait_keys):
        start_time = time.time()
        while len(wait_keys) > 0:
            time.sleep(0.1)
            elapse_time = time.time() - start_time
            if datetime.timedelta(seconds=elapse_time) > timeout:
                raise RuntimeError(
                    "Timeout while initializing process group {}."
                    "Keys {} are not ready sinck rank {} is waiting them."
                    "Two reason may cause this error:\n 1. The create process group api should be called by all ranks.\n"
                    " 2. Try to increase the waiting time.\n".format(
                        group_name, wait_keys, global_rank))
            wait_keys = list(
                filter(lambda key: int(store.get(key)) != 1, wait_keys))

    # all the workers set their exiting key and exit
    # the master will wait for all workers' exiting key, ensure to exit in the end
    if is_master:
        wait_keys = [
            barrier_prefix + str(rank) for rank in range(1, global_world_size)
        ]
        _check_keys_ready(wait_keys)
    else:
        store.add(barrier_prefix + str(global_rank), 1)


def new_group(ranks=None, backend=None, timeout=_default_timeout):
K
kuizhiqing 已提交
416 417
    """

K
kuizhiqing 已提交
418
    Creates a new distributed communication group.
K
kuizhiqing 已提交
419 420

    Args:
K
kuizhiqing 已提交
421
        ranks (list): The global ranks of group members.
K
kuizhiqing 已提交
422
        backend (str): The backend used to create group, only nccl is supported now.
423
        timeout (datetime.timedelta, optional): The waiting timeout for store relevant options, default is 30 minutes.
K
kuizhiqing 已提交
424 425

    Returns:
K
kuizhiqing 已提交
426
        Group: The group instance.
K
kuizhiqing 已提交
427 428 429 430 431 432 433

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
434 435 436
            tindata = paddle.randn(shape=[2, 3])
            gp = paddle.distributed.new_group([2,4,6])
            paddle.distributed.all_reduce(tindata, group=gp, use_calc_stream=False)
K
kuizhiqing 已提交
437 438

    """
439
    global _custom_gid
440
    global _group_map
L
lilong12 已提交
441
    if in_dygraph_mode():
442
        global _default_group_name
L
lilong12 已提交
443
        gid = _custom_gid if _custom_gid else _new_ring_id()
444
        group_name = _default_group_name + str(gid)
L
lilong12 已提交
445
        if backend != 'heter' and (ranks is None or len(ranks) > 1):
446 447 448 449 450 451 452 453 454
            global_group = _get_default_group()
            global_rank = global_group.rank
            global_ranks = global_group.ranks
            backend = _default_backend if backend is None else backend
            if ranks is None:
                ranks = global_ranks
            assert len(ranks) <= len(global_ranks), (
                "Size of new group must be less than or "
                "equal to that of the default global group.")
455 456
        size = len(ranks)
        ranks = sorted(ranks)
L
lilong12 已提交
457 458 459 460
        if backend == 'heter' or (size > 1 and global_rank in ranks):
            rank = 0 if backend == 'heter' else ranks.index(global_rank)
            src_rank = ranks[0] if backend == 'heter' else None
            dst_rank = ranks[1] if backend == 'heter' else None
461 462 463 464 465 466 467 468 469
            pg = _new_process_group_impl(backend,
                                         _default_store,
                                         rank,
                                         size,
                                         group_name,
                                         pg_options=None,
                                         group_id=gid,
                                         src_rank=src_rank,
                                         dst_rank=dst_rank)
470 471 472 473 474 475
        else:
            rank = -1
            pg = None
        group = Group(rank, size, id=gid, ranks=ranks, pg=pg, name=group_name)
        _group_map_by_name[group_name] = group
        _group_map[gid] = group
476
        _group_map_backend[group] = backend
477

478
        # TODO(shenliang03): This is a temporary solution to solve the problem of
479
        # hang caused by tcp
480
        paddle.distributed.barrier(group=group)
481 482 483 484 485
        # NOTE(liyurui): All processors should hang and wait using tcp store, in case master exit before sub-group is created.
        if backend != 'heter':
            _barrier_by_tcp_store(group_name, _default_store, timeout)
        else:
            print("Warning: store barrier is not supported for heter backend.")
486
        return group
K
kuizhiqing 已提交
487 488 489 490 491 492 493 494 495 496 497 498 499 500

    if not backend:
        backend = 'nccl'
    assert backend == 'nccl', ("backend other than nccl is not supported yet")

    genv = _get_global_env()
    global_rank = genv.rank

    ring_id = _new_ring_id()

    if global_rank not in ranks:
        gp = Group(-1, -1, ring_id, ranks)
        _group_map[ring_id] = gp
    else:
501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520
        ranks = sorted(ranks)
        group_rank = ranks.index(global_rank)
        group_size = len(ranks)
        gp = Group(group_rank, group_size, ring_id, ranks)
        _group_map[ring_id] = gp

        if group_size >= 2:
            strategy = core.ParallelStrategy()
            strategy.nranks = group_size
            strategy.local_rank = group_rank
            strategy.trainer_endpoints = [
                genv.trainer_endpoints[i] for i in ranks
            ]
            strategy.current_endpoint = genv.current_endpoint
            strategy.nrings = 1

            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(genv.device_id)
                core.NCCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
521 522 523 524
            elif core.is_compiled_with_npu():
                place = core.NPUPlace(genv.device_id)
                core.HCCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
525 526 527 528
            elif core.is_compiled_with_mlu():
                place = core.MLUPlace(genv.device_id)
                core.CNCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
529 530 531 532
            elif core.is_compiled_with_xpu():
                place = core.XPUPlace(genv.device_id)
                core.BKCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
533 534 535 536 537
            else:
                assert False, ("no cuda device found")
        else:
            return gp

538
    # TODO(shenliang03): This is a temporary solution to solve the problem of
539
    # hang caused by cross-creation of new_group
540
    tmp = paddle.to_tensor(
J
Jiabin Yang 已提交
541
        [1], dtype="int32") if _non_static_mode() else fill_constant(
542
            [0], dtype="int32", value="1")
543 544
    paddle.distributed.all_reduce(tmp, use_calc_stream=True)
    paddle.distributed.wait(tmp)
K
kuizhiqing 已提交
545 546
    return gp

547

548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588
def is_initialized():
    """

    Check whether the distributed environment has been initialized

    Returns (bool): `True` if distributed environment has been initialized, otherwise `False`.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle

            print(paddle.distributed.is_initialized())
            # False

            paddle.distributed.init_parallel_env()
            print(paddle.distributed.is_initialized())
            # True

    """
    global _group_map_by_name
    return _default_group_name in _group_map_by_name


def destroy_process_group(group=None):
    """
    Destroy a given group for communication

    Args:
        group (ProcessGroup, optional): The group to be destroyed. All of process groups, including 
                                        the default group, will be destroyed and the distributed 
                                        environment will be deinitialized.
    
    Returns : None

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
589
            import paddle.distributed as dist
590

591 592
            dist.init_parallel_env()
            group = dist.new_group([0, 1])
593

594 595
            dist.destroy_process_group(group)
            print(dist.is_initialized())
596
            # True
597 598
            dist.destroy_process_group()
            print(dist.is_initialized())
599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617
            # False

    """
    global _group_map
    global _group_map_by_name

    pg = _get_default_group() if group is None else group
    assert _group_map.get(pg.id, None) is not None, "Invalid group."

    if group is None:
        _group_map.clear()
        _group_map_by_name.clear()
        _group_map_backend.clear()
    else:
        del _group_map[pg.id]
        del _group_map_by_name[pg.name]
        del _group_map_backend[pg]


K
kuizhiqing 已提交
618 619 620 621 622 623 624 625
def wait(tensor, group=None, use_calc_stream=True):
    """

    wait to sync stream for group.

    Args:
        tensor (Tensor): The Tensor used before sync.
        group (Group): The Group instance to perform sync.
K
kuizhiqing 已提交
626 627
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
K
kuizhiqing 已提交
628 629 630 631 632 633 634 635 636 637

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
638
            tindata = paddle.randn(shape=[2, 3])
K
kuizhiqing 已提交
639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656
            paddle.distributed.all_reduce(tindata, use_calc_stream=True)
            paddle.distributed.wait(tindata)

    """

    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id

    if use_calc_stream:
        _sync_calc_stream(tensor)
    else:
        _sync_comm_stream(tensor, ring_id)


def _sync_calc_stream(tensor):

J
Jiabin Yang 已提交
657
    if _non_static_mode():
658
        return _legacy_C_ops.c_sync_calc_stream(tensor, tensor)
K
kuizhiqing 已提交
659 660 661 662 663 664 665

    op_type = 'c_sync_calc_stream'

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
666 667
        outputs={'Out': [tensor]},
    )
668

669

K
kuizhiqing 已提交
670
def _sync_comm_stream(tensor, ring_id=0):
671

J
Jiabin Yang 已提交
672
    if _non_static_mode():
673 674
        return _legacy_C_ops.c_sync_comm_stream([tensor], [tensor], 'ring_id',
                                                ring_id)
675

K
kuizhiqing 已提交
676
    op_type = 'c_sync_comm_stream'
677

K
kuizhiqing 已提交
678 679 680 681 682
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
683 684
        attrs={'ring_id': ring_id},
    )
K
kuizhiqing 已提交
685 686 687


def broadcast(tensor, src, group=None, use_calc_stream=True):
688 689 690
    """

    Broadcast a tensor from the source to all others.
691 692
    As shown below, one process is started with a GPU and GPU0 owns data 0. Through broadcast operator,
    data 0 will be sent to all GPUs from GPU0.
693 694 695 696 697

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/broadcast.png
        :width: 800
        :alt: broadcast
        :align: center
698 699

    Args:
700 701
        tensor (Tensor): The Tensor to send if current rank is the source, or the Tensor to receive otherwise. Its data type
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
702
        src (int): The source rank.
K
kuizhiqing 已提交
703
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
704 705
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
706 707 708 709 710 711 712

    Returns:
        None.

    Examples:
        .. code-block:: python

713
            # required: distributed
714
            import paddle
715
            import paddle.distributed as dist
716

717 718 719
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
720
            else:
721 722 723 724
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.broadcast(data, src=1)
            print(data)
            # [[1, 2, 3], [1, 2, 3]] (2 GPUs)
725
    """
K
kuizhiqing 已提交
726 727 728 729 730 731 732

    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

L
lilong12 已提交
733
    if in_dygraph_mode():
734 735 736 737 738 739 740 741 742 743 744
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
        assert gsrc >= 0, ("src rank out of group, need global rank")
        task = group.process_group.broadcast(tensor, gsrc)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

    ring_id = ring_id = 0 if group is None else group.id
K
kuizhiqing 已提交
745
    gsrc = src if group is None else group.get_group_rank(src)
K
kuizhiqing 已提交
746
    assert gsrc >= 0, ("src rank out of group, need global rank")
K
kuizhiqing 已提交
747

J
Jiabin Yang 已提交
748
    if _non_static_mode():
749 750 751
        return _legacy_C_ops.c_broadcast(tensor, tensor, 'root', gsrc,
                                         'use_calc_stream', use_calc_stream,
                                         'ring_id', ring_id)
752 753

    op_type = 'c_broadcast'
754 755 756 757
    check_variable_and_dtype(tensor, 'tensor', [
        'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
        'bool'
    ], 'broadcast')
758 759

    helper = LayerHelper(op_type, **locals())
760 761 762 763 764 765 766 767
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'root': gsrc,
                         'use_calc_stream': use_calc_stream,
                         'ring_id': ring_id,
                     })
768 769


K
kuizhiqing 已提交
770
def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
771 772 773
    """

    Reduce a tensor over all ranks so that all get the result.
774 775
    As shown below, one process is started with a GPU and the data of this process is represented
    by its group rank. The reduce operator is sum. Through all_reduce operator, 
776 777 778 779 780 781
    each GPU will have the sum of the data from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allreduce.png
        :width: 800
        :alt: all_reduce
        :align: center
782 783 784

    Args:
        tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
785 786
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
K
kuizhiqing 已提交
787
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
788 789
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
790 791 792 793 794 795 796

    Returns:
        None.

    Examples:
        .. code-block:: python

797
            # required: distributed
798
            import paddle
799
            import paddle.distributed as dist
800

801 802
            dist.init_parallel_env()
            if dist.get_rank() == 0:
803
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
804
            else:
805
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
806 807 808
            dist.all_reduce(data)
            print(data)
            # [[5, 7, 9], [5, 7, 9]] (2 GPUs)
809
    """
K
kuizhiqing 已提交
810 811 812
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
813
    if in_dygraph_mode():
814
        op_type = _get_reduce_op(op, "all_reduce")
815 816 817 818 819 820 821 822
        group = _get_default_group() if group is None else group
        task = group.process_group.allreduce(tensor, op_type)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

K
kuizhiqing 已提交
823
    ring_id = 0 if group is None else group.id
J
Jiabin Yang 已提交
824
    if _non_static_mode():
825
        if op == ReduceOp.SUM:
826 827 828
            return _legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
                                                  use_calc_stream, 'ring_id',
                                                  ring_id)
829
        elif op == ReduceOp.MAX:
830 831 832
            return _legacy_C_ops.c_allreduce_max_(tensor, 'use_calc_stream',
                                                  use_calc_stream, 'ring_id',
                                                  ring_id)
833
        elif op == ReduceOp.MIN:
834 835 836
            return _legacy_C_ops.c_allreduce_min_(tensor, 'use_calc_stream',
                                                  use_calc_stream, 'ring_id',
                                                  ring_id)
837
        elif op == ReduceOp.PROD:
838 839 840
            return _legacy_C_ops.c_allreduce_prod_(tensor, 'use_calc_stream',
                                                   use_calc_stream, 'ring_id',
                                                   ring_id)
841 842 843
        else:
            raise ValueError("Unknown parameter: {}.".format(op))

844 845 846 847
    check_variable_and_dtype(tensor, 'tensor', [
        'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
        'bool'
    ], 'all_reduce')
848 849 850 851 852 853 854 855
    if op == ReduceOp.SUM:
        op_type = 'c_allreduce_sum'
    elif op == ReduceOp.MAX:
        op_type = 'c_allreduce_max'
    elif op == ReduceOp.MIN:
        op_type = 'c_allreduce_min'
    elif op == ReduceOp.PROD:
        op_type = 'c_allreduce_prod'
K
kuizhiqing 已提交
856 857
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'ring_id' for all_reduce should be int.")
858
    helper = LayerHelper(op_type, **locals())
859 860 861 862 863 864 865
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': use_calc_stream
                     })
866 867


K
kuizhiqing 已提交
868
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
869 870
    """

871 872
    Reduce a tensor to the destination from all others. As shown below, one process is started with a GPU and the data of this process is represented
    by its group rank. The destination of the reduce operator is GPU0 and the process is sum. Through reduce operator,
873 874 875 876 877 878
    the GPU0 will owns the sum of all data from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/reduce.png
        :width: 800
        :alt: reduce
        :align: center
879 880 881

    Args:
        tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
882
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
883
        dst (int): The destination rank id.
884
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
K
kuizhiqing 已提交
885
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
886 887
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
888 889 890 891 892 893 894

    Returns:
        None.

    Examples:
        .. code-block:: python

895
            # required: distributed
896
            import paddle
897
            import paddle.distributed as dist
898

899 900 901
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
902
            else:
903 904 905 906 907
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.reduce(data, dst=0)
            print(data)
            # [[5, 7, 9], [5, 7, 9]] (2 GPUs, out for rank 0)
            # [[1, 2, 3], [1, 2, 3]] (2 GPUs, out for rank 1)
908
    """
K
kuizhiqing 已提交
909 910 911
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
912
    if in_dygraph_mode():
913
        op_type = _get_reduce_op(op, "reduce")
914 915 916 917 918 919 920 921 922
        group = _get_default_group() if group is None else group
        gdst = group.get_group_rank(dst)
        assert gdst >= 0, ("dst rank out of group, need global rank")
        task = group.process_group.reduce(tensor, gdst, op_type)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task
K
kuizhiqing 已提交
923 924 925

    ring_id = 0 if group is None else group.id
    gdst = dst if group is None else group.get_group_rank(dst)
K
kuizhiqing 已提交
926
    assert gdst >= 0, ("dst rank out of group, need global rank")
K
kuizhiqing 已提交
927

J
Jiabin Yang 已提交
928
    if _non_static_mode():
929
        if op == ReduceOp.SUM:
930 931 932
            return _legacy_C_ops.c_reduce_sum(tensor, tensor, 'use_calc_stream',
                                              use_calc_stream, 'ring_id',
                                              ring_id, 'root_id', gdst)
933
        elif op == ReduceOp.MAX:
934 935 936
            return _legacy_C_ops.c_reduce_max(tensor, tensor, 'use_calc_stream',
                                              use_calc_stream, 'ring_id',
                                              ring_id, 'root_id', gdst)
937
        elif op == ReduceOp.MIN:
938 939 940
            return _legacy_C_ops.c_reduce_min(tensor, tensor, 'use_calc_stream',
                                              use_calc_stream, 'ring_id',
                                              ring_id, 'root_id', gdst)
941
        elif op == ReduceOp.PROD:
942 943 944 945
            return _legacy_C_ops.c_reduce_prod(tensor, tensor,
                                               'use_calc_stream',
                                               use_calc_stream, 'ring_id',
                                               ring_id, 'root_id', gdst)
946 947 948 949
        else:
            raise ValueError("Unknown parameter: {}.".format(op))

    op_type = 'c_reduce'
950 951 952 953
    check_variable_and_dtype(tensor, 'tensor', [
        'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
        'bool'
    ], 'reduce')
954 955 956 957 958 959 960 961 962 963 964

    if op == ReduceOp.SUM:
        op_type = 'c_reduce_sum'
    elif op == ReduceOp.MAX:
        op_type = 'c_reduce_max'
    elif op == ReduceOp.MIN:
        op_type = 'c_reduce_min'
    elif op == ReduceOp.PROD:
        op_type = 'c_reduce_prod'

    helper = LayerHelper(op_type, **locals())
965 966 967 968 969 970 971 972
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': use_calc_stream,
                         'root_id': gdst,
                     })
973 974


K
kuizhiqing 已提交
975
def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
976 977
    """

978
    Gather tensors from all participators and all get the result. As shown
979 980
    below, one process is started with a GPU and the data of this process is represented
    by its group rank. Through the all_gather operator, each GPU will have data
981 982 983 984 985 986
    from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png
        :width: 800
        :alt: all_gather
        :align: center
987 988 989

    Args:
        tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
990
            should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
991
        tensor (Tensor): The Tensor to send. Its data type
992
            should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
K
kuizhiqing 已提交
993
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
994 995
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
996 997 998 999 1000 1001 1002

    Returns:
        None.

    Examples:
        .. code-block:: python

1003
            # required: distributed
1004
            import paddle
1005
            import paddle.distributed as dist
1006

1007
            dist.init_parallel_env()
1008
            tensor_list = []
1009 1010
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
1011
            else:
1012 1013 1014 1015
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.all_gather(tensor_list, data)
            print(tensor_list)
            # [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
1016
    """
K
kuizhiqing 已提交
1017 1018 1019
    if group is not None and not group.is_member():
        return

1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030
    def convert_to_complex(list_of_tensor):
        list_of_complex = []
        for tensor in list_of_tensor:
            list_of_complex.append(paddle.as_complex(tensor))
        return list_of_complex

    is_input_complex = (tensor.dtype == paddle.complex64
                        or tensor.dtype == paddle.complex128)
    if is_input_complex:
        tensor = paddle.as_real(tensor)

L
lilong12 已提交
1031
    if in_dygraph_mode():
1032
        group = _get_default_group() if group is None else group
1033 1034 1035 1036 1037 1038
        if len(tensor_list) == 0:
            tensor_shape = list(tensor.shape)
            tensor_shape[0] *= group.nranks
            out = paddle.empty(tensor_shape, tensor.dtype)
        else:
            out = paddle.concat(tensor_list, axis=0)
1039 1040 1041
        task = group.process_group.all_gather(tensor, out)
        task.wait()
        tensor_list.clear()
1042 1043 1044 1045 1046
        list_of_tensor = paddle.split(out, group.nranks, 0)
        if is_input_complex:
            tensor_list.extend(convert_to_complex(list_of_tensor))
        else:
            tensor_list.extend(list_of_tensor)
1047 1048
        return

K
kuizhiqing 已提交
1049 1050 1051
    ring_id = 0 if group is None else group.id
    nranks = _get_global_group().nranks if group is None else group.nranks

J
Jiabin Yang 已提交
1052
    if _non_static_mode():
1053 1054 1055
        out = _legacy_C_ops.c_allgather(tensor, 'use_calc_stream',
                                        use_calc_stream, 'ring_id', ring_id,
                                        'nranks', nranks)
1056
    else:
1057 1058 1059
        op_type = 'c_allgather'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
1060 1061 1062 1063
        if not isinstance(tensor_list, list):
            raise ValueError("The type of 'tensor_list' for all_gather "
                             "should be list.")
        for elem in tensor_list:
1064 1065 1066 1067 1068 1069 1070 1071
            check_variable_and_dtype(elem, 'tensor_list', [
                'float16', 'float32', 'float64', 'int32', 'int64', 'bool',
                'int8', 'uint8', 'complex64', 'complex128'
            ], 'all_gather')
        check_variable_and_dtype(tensor, 'tensor', [
            'float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'int8',
            'uint8', 'complex64', 'complex128'
        ], 'all_gather')
1072 1073 1074 1075 1076 1077 1078 1079
        helper.append_op(type=op_type,
                         inputs={'X': [tensor]},
                         outputs={'Out': [out]},
                         attrs={
                             'ring_id': ring_id,
                             'use_calc_stream': use_calc_stream,
                             'nranks': nranks
                         })
1080

1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093
    list_of_tensor = paddle.split(out, nranks, 0)
    if is_input_complex:
        tensor_list.extend(convert_to_complex(list_of_tensor))
    else:
        tensor_list.extend(list_of_tensor)


def _convert_object_to_tensor(obj):
    _pickler = pickle.Pickler
    f = io.BytesIO()
    _pickler(f).dump(obj)
    data = np.frombuffer(f.getvalue(), dtype=np.uint8)
    tensor = paddle.to_tensor(data)
1094
    return tensor, tensor.numel()
1095 1096


1097
def _convert_tensor_to_object(tensor, len_of_tensor):
1098
    _unpickler = pickle.Unpickler
1099
    return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()
1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126


def all_gather_object(object_list, obj, group=None):
    """

    Gather picklable objects from all participators and all get the result. Similiar to all_gather(), but python object can be passed in.

    Args:
        object_list (list): A list of output object. The datatype of every element in the list is same as the input obj.
        obj (Any): The picklable object to send.
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Warning:
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            object_list = []
1127
            if dist.get_rank() == 0:
1128 1129 1130
                obj = {"foo": [1, 2, 3]}
            else:
                obj = {"bar": [4, 5, 6]}
1131 1132 1133
            dist.all_gather_object(object_list, obj)
            print(object_list)
            # [{'foo': [1, 2, 3]}, {'bar': [4, 5, 6]}] (2 GPUs)
1134 1135 1136 1137
    """
    assert in_dygraph_mode(
    ), "all_gather_object doesn't support static graph mode."

1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150
    tensor, len_of_tensor = _convert_object_to_tensor(obj)

    # gather len_of_tensor from all ranks
    list_len_of_tensor = []
    all_gather(list_len_of_tensor, len_of_tensor, group)
    # get the max length from list
    max_len_of_tensor = int(max(list_len_of_tensor).item())
    # resize the input tensor to max length avoid hang in all gather
    # Note(liyurui): Maybe we should support various length all_gather?
    # Now this operation is efficient for we don't support resize in python.
    numpy_data = tensor.numpy()
    numpy_data = np.resize(numpy_data, [max_len_of_tensor])
    input_tensor = paddle.to_tensor(numpy_data)
1151 1152

    tensor_list = []
1153 1154 1155 1156
    all_gather(tensor_list, input_tensor, group)
    for i, tensor in enumerate(tensor_list):
        object_list.append(
            _convert_tensor_to_object(tensor, list_len_of_tensor[i]))
1157 1158


K
kuizhiqing 已提交
1159
def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
1160 1161
    """

1162
    Scatter a tensor to all participators. As shown below, one process is started with a GPU and the source of the scatter
1163 1164 1165 1166 1167 1168
    is GPU0. Through scatter operator, the data in GPU0 will be sent to all GPUs averagely.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/scatter.png
        :width: 800
        :alt: scatter
        :align: center
1169 1170 1171

    Args:
        tensor (Tensor): The output Tensor. Its data type
1172
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
1173
        tensor_list (list|tuple): A list/tuple of Tensors to scatter. Every element in the list must be a Tensor whose data type
1174
            should be float16, float32, float64, int32, int64, int8, uint8 or bool. Default value is None.
K
kuizhiqing 已提交
1175
        src (int): The source rank id. Default value is 0.
K
kuizhiqing 已提交
1176
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
1177 1178
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
1179 1180 1181 1182 1183 1184 1185

    Returns:
        None.

    Examples:
        .. code-block:: python

1186
            # required: distributed
1187
            import paddle
1188
            import paddle.distributed as dist
1189

1190 1191 1192 1193 1194
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([7, 8, 9])
                data2 = paddle.to_tensor([10, 11, 12])
                dist.scatter(data1, src=1)
1195
            else:
1196 1197 1198 1199 1200 1201
                data1 = paddle.to_tensor([1, 2, 3])
                data2 = paddle.to_tensor([4, 5, 6])
                dist.scatter(data1, tensor_list=[data1, data2], src=1)
            print(data1, data2)
            # [1, 2, 3] [10, 11, 12] (2 GPUs, out for rank 0)
            # [4, 5, 6] [4, 5, 6] (2 GPUs, out for rank 1)
1202
    """
K
kuizhiqing 已提交
1203 1204 1205 1206 1207 1208
    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

L
lilong12 已提交
1209
    if in_dygraph_mode():
1210 1211 1212 1213 1214 1215 1216 1217 1218
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
        rank = group.rank
        nranks = group.nranks
    else:
        ring_id = 0 if group is None else group.id
        gsrc = src if group is None else group.get_group_rank(src)
        rank = _get_global_group().rank if group is None else group.rank
        nranks = _get_global_group().nranks if group is None else group.nranks
K
kuizhiqing 已提交
1219
    assert gsrc >= 0, ("src rank out of group, need global rank")
K
kuizhiqing 已提交
1220 1221

    if rank != gsrc:
1222 1223 1224 1225
        tensor_list = []
        for _ in range(nranks):
            tensor_list.append(tensor)
    temp = paddle.concat(tensor_list, axis=0)
L
lilong12 已提交
1226
    if in_dygraph_mode():
1227 1228 1229 1230 1231 1232 1233
        task = group.process_group.scatter(temp, tensor, gsrc)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

L
lilong12 已提交
1234
    if _non_static_mode():
1235 1236 1237
        return _legacy_C_ops.c_scatter(temp, tensor, 'use_calc_stream',
                                       use_calc_stream, 'ring_id', ring_id,
                                       'nranks', nranks, 'root', gsrc)
W
wanghuancoder 已提交
1238
    op_type = 'c_scatter'
1239 1240 1241 1242
    check_variable_and_dtype(tensor, 'tensor', [
        'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
        'bool'
    ], 'scatter')
1243
    helper = LayerHelper(op_type, **locals())
1244 1245 1246 1247 1248 1249 1250 1251 1252
    helper.append_op(type=op_type,
                     inputs={'X': [temp]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'root': gsrc,
                         'use_calc_stream': use_calc_stream,
                         'nranks': nranks,
                     })
1253 1254


1255
def _c_identity(tensor, group=None):
L
lilong12 已提交
1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266
    """
    Return a copy of the tensor, mainly used with model parallel.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
        group (int): The id of the process group to work on.

    Returns:
        Tensor.
    """
1267 1268 1269 1270
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
1271
    if _non_static_mode():
1272 1273 1274
        return _legacy_C_ops.c_identity(tensor, 'use_calc_stream', True,
                                        'ring_id', ring_id,
                                        'use_model_parallel', True)
L
lilong12 已提交
1275 1276 1277
    op_type = 'c_identity'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
1278

L
lilong12 已提交
1279 1280 1281
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        '_c_identity')
1282

1283 1284 1285 1286 1287 1288 1289 1290
    helper.append_op(type=op_type,
                     inputs={'X': tensor},
                     outputs={'Out': out},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': True,
                         'use_model_parallel': True,
                     })
L
lilong12 已提交
1291 1292 1293
    return out


1294
def _c_concat(tensor, group=None):
1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307
    """
    Return allgather of the tensor, mainly used with model parallel.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
        group (int): The id of the process group to work on.

    Returns:
        Tensor.
    """
    if group is not None and not group.is_member():
        return
1308 1309
    group = _get_default_group() if group is None else group
    ring_id = group.id
1310

1311
    global_rank = _get_global_env().rank
1312 1313
    rank = group.rank
    nranks = group.nranks
1314

J
Jiabin Yang 已提交
1315
    if _non_static_mode():
1316 1317 1318 1319
        return _legacy_C_ops.c_concat(tensor, 'ring_id', ring_id,
                                      'use_calc_stream', True, 'rank', rank,
                                      'nranks', nranks, 'use_model_parallel',
                                      True)
1320 1321 1322 1323 1324 1325 1326 1327 1328

    op_type = 'c_concat'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)

    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        '_c_concat')

1329 1330 1331 1332 1333 1334 1335 1336 1337 1338
    helper.append_op(type=op_type,
                     inputs={'X': tensor},
                     outputs={'Out': out},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': True,
                         'use_model_parallel': True,
                         'nranks': nranks,
                         'rank': rank
                     })
1339 1340 1341
    return out


1342
def _c_split(tensor, group=None):
L
lilong12 已提交
1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354
    """
    Split tensor evenly among all members, mainly used with model parallel.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
        rank (int): The rank of the current process.
        group (int): The id of the process group to work on.

    Returns:
        Tensor.
    """
1355 1356 1357 1358
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

1359 1360 1361 1362
    global_rank = _get_global_env().rank
    rank = global_rank if group is None else group.get_group_rank(global_rank)
    nranks = _get_global_env().world_size if group is None else group.nranks

J
Jiabin Yang 已提交
1363
    if _non_static_mode():
1364 1365 1366
        return _legacy_C_ops.c_split(tensor, 'use_calc_stream', True, 'ring_id',
                                     ring_id, 'rank', rank, 'nranks', nranks,
                                     'use_model_parallel', True)
1367

L
lilong12 已提交
1368 1369 1370
    op_type = 'c_split'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
1371

L
lilong12 已提交
1372 1373 1374
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        '_c_split')
1375

1376 1377 1378 1379 1380 1381 1382 1383 1384 1385
    helper.append_op(type=op_type,
                     inputs={'X': tensor},
                     outputs={'Out': out},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': True,
                         'rank': rank,
                         'nranks': nranks,
                         'use_model_parallel': True,
                     })
L
lilong12 已提交
1386 1387 1388
    return out


1389 1390 1391 1392 1393
def _mp_allreduce(tensor,
                  op=ReduceOp.SUM,
                  group=None,
                  use_calc_stream=True,
                  use_model_parallel=True):
1394
    """[it is same as allreduce above, but it supports model parallel. And it support inplace startegy]
1395 1396 1397 1398
    """
    if group is not None and not group.is_member():
        return

1399
    if in_dygraph_mode():
1400
        group = _get_default_group() if group is None else group
1401 1402
        assert op == ReduceOp.SUM, "Unknown parameter: {}.".format(op)

1403
        from paddle.autograd import PyLayer
1404

1405
        class mp_allreduce_eager(PyLayer):
1406

1407
            @staticmethod
1408
            def forward(ctx, tensor, group, use_calc_stream,
1409
                        use_model_parallel):
1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420
                ctx.ring_id = group.id

                if use_calc_stream:
                    op_type = _get_reduce_op(op, "_mp_allreduce")
                    group.process_group.allreduce_on_calc_stream(
                        tensor, op_type)
                    return tensor
                else:
                    return _legacy_C_ops.c_allreduce_sum_(
                        tensor, 'use_calc_stream', use_calc_stream, 'ring_id',
                        ring_id, "use_model_parallel", use_model_parallel)
1421 1422 1423

            @staticmethod
            def backward(ctx, dy):
1424 1425 1426
                return _legacy_C_ops.c_identity(dy, 'use_calc_stream', True,
                                                'ring_id', ctx.ring_id,
                                                'use_model_parallel', True)
1427

1428
        return mp_allreduce_eager.apply(tensor, group, use_calc_stream,
1429 1430
                                        use_model_parallel)

1431 1432
    ring_id = 0 if group is None else group.id
    if _in_legacy_dygraph():
1433
        if op == ReduceOp.SUM:
1434 1435 1436 1437
            return _legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
                                                  use_calc_stream, 'ring_id',
                                                  ring_id, "use_model_parallel",
                                                  use_model_parallel)
1438 1439
        else:
            raise ValueError("Unknown parameter: {}.".format(op))
1440 1441 1442 1443 1444 1445 1446 1447 1448

    op_type = 'c_allreduce_sum'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)

    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        op_type)

1449 1450 1451 1452 1453 1454 1455 1456
    helper.append_op(type=op_type,
                     inputs={'X': tensor},
                     outputs={'Out': out},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': use_calc_stream,
                         'use_model_parallel': use_model_parallel,
                     })
1457
    return out
1458 1459


1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473
def _c_lookup_table(table, index, start_index=0, name=None):
    """
    Lookup table according to index.

    Args:
        table (Tensor): The input Tensor. Its data type
            should be float16, float32, float64.
        index (Tensor): The index to lookup table.
        start_index (int): The initial index for table range.
        name (string): The name of the api

    Returns:
        Tensor.
    """
J
Jiabin Yang 已提交
1474
    if _non_static_mode():
1475 1476
        return _legacy_C_ops.c_embedding(table, index, "start_index",
                                         start_index)
1477

1478 1479 1480 1481 1482
    op_type = 'c_embedding'
    helper = LayerHelper(op_type, **locals())
    dtype = helper.input_dtype(input_param_name='table')
    check_variable_and_dtype(index, 'input', ['int32', 'int64'], op_type)
    tmp = helper.create_variable_for_type_inference(dtype)
1483 1484 1485 1486 1487 1488 1489
    helper.append_op(type='c_embedding',
                     inputs={
                         'Ids': index,
                         'W': table
                     },
                     outputs={'Out': tmp},
                     attrs={"start_index": start_index})
1490 1491
    return tmp

1492

B
Baibaifan 已提交
1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507
class _Linear(layers.Layer):
    """
    Linear
    """

    def __init__(self,
                 in_features,
                 out_features,
                 weight_attr=None,
                 bias_attr=None,
                 name=None):
        super(_Linear, self).__init__()
        self._dtype = self._helper.get_default_dtype()
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
1508 1509 1510 1511 1512 1513 1514 1515
        self.weight = self.create_parameter(shape=[in_features, out_features],
                                            attr=self._weight_attr,
                                            dtype=self._dtype,
                                            is_bias=False)
        self.bias = self.create_parameter(shape=[out_features],
                                          attr=self._bias_attr,
                                          dtype=self._dtype,
                                          is_bias=True)
B
Baibaifan 已提交
1516 1517 1518
        self.name = name

    def forward(self, input):
1519 1520 1521 1522
        out = _linear(x=input,
                      weight=self.weight,
                      bias=self.bias,
                      name=self.name)
B
Baibaifan 已提交
1523 1524 1525 1526 1527 1528 1529 1530
        return out

    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'in_features={}, out_features={}, dtype={}{}'.format(
            self.weight.shape[0], self.weight.shape[1], self._dtype, name_str)


1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550
def _c_softmax_with_cross_entropy(logits,
                                  label,
                                  group=None,
                                  return_softmax=False):
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id
    global_rank = _get_global_env().rank
    rank = global_rank if group is None else group.get_group_rank(global_rank)
    nranks = _get_global_env().world_size if group is None else group.nranks

    input_dims = len(list(logits.shape))
    label_dims = len(list(label.shape))
    if input_dims - 1 != label_dims and input_dims != label_dims:
        raise ValueError(
            'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
             (got nput_dims{}, label_dims{})'.format(input_dims, label_dims))
    if input_dims - 1 == label_dims:
        label = paddle.unsqueeze(label, axis=-1)

J
Jiabin Yang 已提交
1551
    if _non_static_mode():
1552
        softmax, loss = _legacy_C_ops.c_softmax_with_cross_entropy(
1553 1554 1555 1556 1557 1558
            logits, label, 'ring_id', ring_id, 'rank', rank, 'nranks', nranks)
        if not return_softmax:
            return loss
        else:
            return loss, softmax

W
WangXi 已提交
1559 1560 1561 1562 1563 1564 1565 1566
    attrs = {
        'ring_id': ring_id,
        'rank': rank,
        'nranks': nranks,
    }
    helper = LayerHelper('c_softmax_with_cross_entropy', **locals())
    softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
    loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
1567 1568 1569 1570 1571 1572 1573 1574 1575 1576
    helper.append_op(type='c_softmax_with_cross_entropy',
                     inputs={
                         'Logits': logits,
                         'Label': label
                     },
                     outputs={
                         'Softmax': softmax,
                         'Loss': loss
                     },
                     attrs=attrs)
W
WangXi 已提交
1577 1578 1579 1580 1581 1582

    if return_softmax:
        return loss, softmax

    return loss

1583

B
Baibaifan 已提交
1584 1585 1586 1587
def _linear(x, weight, bias=None, name=None):
    """
    Fuction Linear
    """
J
Jiabin Yang 已提交
1588
    if _non_static_mode():
B
Baibaifan 已提交
1589
        pre_bias = _varbase_creator(dtype=x.dtype)
1590 1591
        _legacy_C_ops.matmul(x, weight, pre_bias, 'transpose_X', False,
                             'transpose_Y', False, "alpha", 1)
1592 1593 1594
        return dygraph_utils._append_bias_in_dygraph(pre_bias,
                                                     bias,
                                                     axis=len(x.shape) - 1)
B
Baibaifan 已提交
1595 1596 1597
    else:
        helper = LayerHelper('linear', **locals())
        dtype = x.dtype
B
Baibaifan 已提交
1598 1599
        assert len(
            x.shape) < 4, "X latitude is not supported greater than 3 now."
B
Baibaifan 已提交
1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611

        check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
                                 'linear')
        check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear')

        inputs = {'X': [x], 'Y': [weight]}
        attrs = {
            'transpose_X': False,
            'transpose_Y': False,
            'alpha': 1,
        }
        tmp = helper.create_variable_for_type_inference(dtype)
1612 1613 1614 1615
        helper.append_op(type='matmul_v2',
                         inputs=inputs,
                         outputs={'Out': tmp},
                         attrs=attrs)
B
Baibaifan 已提交
1616 1617
        if bias is not None:
            res = helper.create_variable_for_type_inference(dtype)
1618 1619 1620 1621 1622 1623 1624
            helper.append_op(type='elementwise_add',
                             inputs={
                                 'X': [tmp],
                                 'Y': [bias]
                             },
                             outputs={'Out': [res]},
                             attrs={'axis': len(x.shape) - 1})
B
Baibaifan 已提交
1625 1626 1627 1628 1629
        else:
            res = tmp
        return res


1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642
def _set_var_distributed(var):
    if var is None:
        return

    var.is_distributed = True

    # NOTE: use current_block and find_var_recursive to support while_loop
    startup_block = paddle.static.default_startup_program().current_block()
    main_block = paddle.static.default_main_program().current_block()
    startup_block._find_var_recursive(var.name).is_distributed = True
    main_block._find_var_recursive(var.name).is_distributed = True


L
lilong12 已提交
1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653
def _parallel_linear(x,
                     num_rows,
                     num_cols,
                     axis,
                     param_attr,
                     bias_attr,
                     gather_out,
                     inner_rank,
                     nranks,
                     split_tensor,
                     name,
1654
                     group=None):
1655 1656
    """
    Parallel Linear
1657 1658 1659

    axis the dimension of the parameter of linear layer. 
    axis = 0: the row dimension
1660
    axis = 1: the col dimension
1661
    
1662
    """
1663 1664 1665 1666
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

L
lilong12 已提交
1667 1668
    if axis == 0:
        if split_tensor:
1669
            x = _c_split(x, group=group)
1670
    else:
L
lilong12 已提交
1671 1672
        x = _c_identity(x, group=group)

1673 1674 1675 1676 1677
    linear = paddle.nn.Linear(num_rows,
                              num_cols,
                              weight_attr=param_attr,
                              bias_attr=bias_attr,
                              name=name)
1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689

    # NOTE: npu linear function use matmul_v2 but linear use matmul
    linear_function = _linear if core.is_compiled_with_npu()\
        else paddle.nn.functional.linear
    linear_out = linear_function(
        x,
        linear.weight,
        # NOTE(wangxi): row split, bias need add after allreduce
        None if axis == 0 else linear.bias,
        linear.name)

    _set_var_distributed(linear.weight)
1690 1691 1692 1693
    # set is_distributed for splited bias
    # if a linear layer is splited by row, each rank would hold a complete bias and they should be the same in each rank.
    # if a linear layer is splited by col, the bias would also be split into each rank as its weight
    if axis == 1 and linear._bias_attr != False:
1694
        _set_var_distributed(linear.bias)
L
lilong12 已提交
1695 1696 1697 1698 1699

    if not gather_out: return linear_out

    out_shape = list(linear_out.shape)
    out_shape[0] *= 1 if axis == 0 else nranks
1700
    main_block = paddle.static.default_main_program().current_block()
L
lilong12 已提交
1701 1702 1703 1704 1705 1706 1707 1708 1709
    out = main_block.create_var(
        shape=out_shape,
        dtype=linear_out.dtype,
        type=linear_out.type,
        lod_level=linear_out.lod_level,
        persistable=False,
        is_data=False,
        need_check_feed=linear_out.desc.need_check_feed())
    if axis == 0:
1710 1711 1712 1713 1714 1715 1716 1717
        main_block.append_op(type='c_allreduce_sum',
                             inputs={'X': linear_out},
                             outputs={'Out': out},
                             attrs={
                                 'ring_id': ring_id,
                                 'use_calc_stream': True,
                                 'use_model_parallel': True
                             })
1718 1719
        if linear.bias is not None:
            out = out + linear.bias
L
lilong12 已提交
1720
    else:
1721 1722 1723 1724 1725 1726 1727 1728 1729 1730
        main_block.append_op(type='c_concat',
                             inputs={'X': linear_out},
                             outputs={'Out': out},
                             attrs={
                                 'rank': inner_rank,
                                 'ring_id': ring_id,
                                 'nranks': nranks,
                                 'use_calc_stream': True,
                                 'use_model_parallel': True
                             })
L
lilong12 已提交
1731
    return out
1732 1733


L
lilong12 已提交
1734 1735 1736 1737 1738 1739 1740
def _parallel_embedding(x,
                        per_part_embeddings,
                        origin_size,
                        param_attr,
                        inner_rank,
                        num_partitions,
                        name,
1741
                        group=None):
1742 1743 1744
    """
    Parallel Embedding
    """
1745 1746 1747 1748
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

1749 1750 1751 1752 1753 1754 1755 1756 1757
    helper = LayerHelper("_parallel_embedding", **locals())

    per_part_size = per_part_embeddings
    rank = inner_rank

    vocab_start_index = rank * per_part_size
    dtype = helper.get_default_dtype()
    size = [per_part_size, origin_size[1]]

1758 1759 1760 1761
    weight = helper.create_parameter(attr=param_attr,
                                     shape=size,
                                     dtype=dtype,
                                     is_bias=False)
1762 1763

    if num_partitions == 1:
1764 1765 1766 1767 1768
        return paddle.nn.functional.embedding(x,
                                              weight=weight,
                                              padding_idx=None,
                                              sparse=False,
                                              name=name)
1769

1770 1771
    startup_block = paddle.static.default_startup_program().global_block()
    main_block = paddle.static.default_main_program().global_block()
1772 1773 1774 1775 1776
    startup_block.vars[weight.name].is_distributed = True
    main_block.vars[weight.name].is_distributed = True

    output_parallel = paddle.distributed.collective._c_lookup_table(
        weight, x, start_index=vocab_start_index, name=name)
1777 1778 1779 1780
    out = paddle.distributed.collective._mp_allreduce(output_parallel,
                                                      group=group,
                                                      use_calc_stream=True,
                                                      use_model_parallel=True)
L
lilong12 已提交
1781
    return out
1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804


def split(x,
          size,
          operation,
          axis=0,
          num_partitions=1,
          gather_out=True,
          weight_attr=None,
          bias_attr=None,
          name=None):
    """

    Split the weight of the specified operation into multiple devices
    and do the computation in parallel.

    Now the following three cases are supported.

    Case 1: Parallel Embedding
        The weight of the embedding operation is a NxM matrix with N rows and M columns.
        With parallel embedding, the weight is split into num_partitions partitions, each
        of which is a matrix with (N/num_partitions + 1) rows and M column where the last
        row as the padding idx.
K
kuizhiqing 已提交
1805

1806 1807 1808 1809 1810 1811 1812 1813 1814
        Suppose we split the NxM weight into two partitons on device_0 and device_1
        respectively. Then, one each device, the final weight has (N/2 + 1) rows with the
        index range from 0 to N/2. On device_0, all values in the input within [0, N/2 -1]
        keep unchanged and all other values are changed to N/2 which is the padding index and
        are mapped to all zeros after embedding. In the same way, on device_1, the value V in the
        input within [N/2, N-1] will be changed to (V - N/2), and all other values are changed
        to N/2 and are mapped to all zeros after embedding. Finally, the results on the two
        devices are sum-reduced.

1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829
        The Embedding put on single card is as shown below:

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_single.png
            :width: 800
            :height: 350
            :alt: single_embedding
            :align: center

        Parallel Embedding is shown as below:

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_split.png
            :width: 800
            :alt: split_embedding
            :align: center

1830 1831 1832 1833 1834
    Case 2: Row Parallel Linear
        The weight of the linear operation is a NxM matrix with N rows and M columns.
        With row parallel linear, the weight is split into num_partitions partitions, each
        of which is a matrix with N/num_partitions rows and M column.

1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852
        The linear layer put on single card is shown as below, the input variable is represented by X,
        the weight matrix is represented by W and the output vaiable is O. The linear layer on single card is 
        simple matrix multiplication operation, O = X * W.

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_single.png
            :width: 800
            :alt: single_linear
            :align: center

        Row Parallel Linear is shown as below. As the name suggests, Row Parallel Linear splits the weight matrix W into
        [[W_row1], [W_row2]] along the row. And accordingly the input is splitted along the column into [X_col1, X_col2] and multiply their
        respective weight matrices. Finally apply AllReduce on the output from each card to get the final output.

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_row.png
            :width: 800
            :alt: split_row
            :align: center

1853 1854 1855 1856 1857
    Case 3: Column Parallel Linear
        The weight of the linear operation is a NxM matrix with N rows and M columns.
        With column parallel linear, the weight is split into num_paratitions partitions, each
        of which is a matrix with N rows and M/num_partitions column.

1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874
        The linear layer put on single card has been illustrated on case 2 and Column Parallel Linear
        is shown as below. The Column Parallel Linear splits the weight matrix W into [W_col1, W_col2] along the column and 
        these splitted matrices respectively multiply the input. Finally apply AllGather on the output from each card to get the final output. 

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col.png
            :width: 800
            :alt: split_col
            :align: center
    
    As observed, the column parallel linear and row parallel linear can be combined to skip one ALLGATHER communication
    operator. Furthermore the Attention and MLP can be combined to imporve the performance as shown below.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col_row.png
            :width: 800
            :alt: split_col_row
            :align: center

1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894
    Args:
        x (Tensor): Input tensor. It's data type should be float16, float32, float64, int32 or int64.
        size (list|tuple): A list or tuple with two elements indicating the shape of the weight.
        operation (str): The name of the operation. The supported operations are 'linear' and 'embedding'.
        axis (int, Optional): Indicate along which axis to split the weight. Default: 0.
        num_partitions (int, Optional): How many parts the weight is partitioned. Default: 1.
        gather_out (bool, Optional): Whether to gather the output after computation. By default, the output
            on each partitions will be gathered after computation. Default: True.
        weight_attr (ParamAttr, Optional): The parameter attribute for the learnable
            weights(Parameter) of the specified operation. Default: None.
        bias_attr (ParamAttr, Optional): The parameter attribute for the bias
            of the specified operation. Default: None.
        name (str, Optional): The default value is None. Normally there is no need for user to set this
            property. Default: None. For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor.

    Examples:
        .. code-block:: python
1895

1896
            # required: distributed
1897
            import paddle
1898
            import paddle.distributed.fleet as fleet
1899

1900
            paddle.enable_static()
1901
            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
1902
            fleet.init(is_collective=True)
1903
            data = paddle.randint(0, 8, shape=[10,4])
1904
            emb_out = paddle.distributed.split(
1905 1906 1907 1908
                data,
                (8, 8),
                operation="embedding",
                num_partitions=2)
1909

1910
    """
1911 1912 1913 1914
    assert isinstance(
        size,
        (list, tuple)), ("The type of size for "
                         "paddle.distributed.split must be list or tuple.")
1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926
    assert len(size) == 2, ("Number of elements in size of "
                            "paddle.distributed.split must be two.")
    assert isinstance(operation, str), ("The type of operation for "
                                        "paddle.distributed.split must be str.")
    supported_operations = [
        'linear',
        'embedding',
    ]
    assert operation in supported_operations, (
        "The operation for "
        "paddle.distributed.split must be one of {}.".format(
            supported_operations))
J
Jiabin Yang 已提交
1927
    if _non_static_mode():
L
lilong12 已提交
1928 1929 1930 1931
        raise ValueError(
            "paddle.distributed.split cannot be used in dynamic "
            "graph mode, plese use ParallelEmbedding, ParallelRowLinear, "
            "ParallelColumnLinear instead.")
1932
    else:
1933
        from .fleet import fleet
1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944
        assert fleet._role_maker, ("To use paddle.distributed.split, "
                                   "you must call fleet.init() firstly.")
        rank = fleet.worker_index()
        nranks = fleet.worker_num()

    # rank within a model parallel group
    inner_rank = rank % num_partitions

    if operation == "embedding":
        assert axis == 0, ("We only support to split the weight of embedding "
                           "along the first axis now.")
1945 1946 1947
        assert size[0] % num_partitions == 0, \
            "The length of the vocabulary must be divisible by num_partitions " \
            "but received vocabulary={} num_partitions={}".format(size[0], num_partitions)
1948

1949
        per_part_size = size[0] // num_partitions
1950 1951 1952 1953 1954 1955 1956 1957
        emb_out = _parallel_embedding(x,
                                      per_part_size,
                                      size,
                                      weight_attr,
                                      inner_rank,
                                      num_partitions,
                                      name,
                                      group=None)
B
Baibaifan 已提交
1958
        return emb_out
1959
    else:
L
lilong12 已提交
1960
        should_split = False
1961 1962 1963
        if axis == 0:
            assert size[0] % num_partitions == 0, (
                "Number of rows of the weight for linear ({}) must be"
1964 1965
                " divisible by num_partitions ({})".format(
                    size[0], num_partitions))
1966 1967
            per_part_size = size[0] // num_partitions
            linear_size = (per_part_size, size[1])
L
lilong12 已提交
1968
            if x.shape[-1] == size[0]: should_split = True
1969 1970 1971 1972

        elif axis == 1:
            assert size[1] % num_partitions == 0, (
                "Number of column of the weight for linear ({}) must be"
1973 1974
                " divisible by num_partitions ({})".format(
                    size[1], num_partitions))
1975 1976 1977 1978 1979 1980
            per_part_size = size[1] // num_partitions
            linear_size = (size[0], per_part_size)
        else:
            raise ValueError("The value of axis must be 0 or 1, but the value "
                             "given is {}.".format(axis))

1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992
        linear_out = _parallel_linear(x,
                                      linear_size[0],
                                      linear_size[1],
                                      axis,
                                      weight_attr,
                                      bias_attr,
                                      gather_out,
                                      inner_rank,
                                      num_partitions,
                                      should_split,
                                      name=name,
                                      group=None)
1993
        return linear_out
L
lilong12 已提交
1994 1995


L
lilong12 已提交
1996 1997
def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
    """
1998 1999 2000 2001 2002 2003 2004 2005 2006 2007
    Scatter tensors in in_tensor_list to all participators averagely and gather the result tensors in out_tensor_list.
    As shown below, the in_tensor_list in GPU0 includes 0_0 and 0_1, and GPU1 includes 1_0 and 1_1.
    Through alltoall operator, the 0_0 in GPU0 will be sent to GPU0 and 0_1 to GPU1, 1_0 in GPU1 sent to GPU0 and 1_1 to GPU1.
    Finally the out_tensor_list in GPU0 includes 0_0 and 1_0, and GPU1 includes 0_1 and 1_1.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/alltoall.png
        :width: 800
        :alt: alltoall
        :align: center

L
lilong12 已提交
2008 2009
    Args:
        in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type
2010
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
2011
        out_tensor_list (list): A list of output Tensors. The data type of its elements should be the same as the
L
lilong12 已提交
2012 2013
            data type of the input Tensors.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
2014
        use_calc_stream (bool, optional): Whether to use calculation stream (True) or communication stream. Default: True.
2015
    
L
lilong12 已提交
2016 2017
    Returns:
        None.
2018
    
L
lilong12 已提交
2019 2020
    Examples:
        .. code-block:: python
2021

L
lilong12 已提交
2022 2023
            # required: distributed
            import paddle
2024 2025 2026
            import paddle.distributed as dist

            dist.init_parallel_env()
L
lilong12 已提交
2027
            out_tensor_list = []
2028 2029 2030
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
                data2 = paddle.to_tensor([[7, 8, 9], [10, 11, 12]])
L
lilong12 已提交
2031
            else:
2032 2033 2034 2035 2036 2037
                data1 = paddle.to_tensor([[13, 14, 15], [16, 17, 18]])
                data2 = paddle.to_tensor([[19, 20, 21], [22, 23, 24]])
            dist.alltoall([data1, data2], out_tensor_list)
            print(out_tensor_list)
            # [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]] (2 GPUs, out for rank 0)
            # [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]] (2 GPUs, out for rank 1)
L
lilong12 已提交
2038 2039 2040 2041
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
2042
    if in_dygraph_mode():
2043
        group = _get_default_group() if group is None else group
2044 2045
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
2046 2047 2048
    else:
        ring_id = 0 if group is None else group.id

L
lilong12 已提交
2049
    temp = paddle.concat(in_tensor_list, axis=0)
李季 已提交
2050
    nranks = len(in_tensor_list)
L
lilong12 已提交
2051
    if in_dygraph_mode():
2052 2053 2054 2055 2056 2057
        if len(out_tensor_list) == 0:
            tensor_shape = list(in_tensor_list[0].shape)
            tensor_shape[0] *= nranks
            out = paddle.empty(tensor_shape, in_tensor_list[0].dtype)
        else:
            out = paddle.concat(out_tensor_list, axis=0)
2058 2059 2060 2061 2062 2063
        task = group.process_group.alltoall(temp, out)
        task.wait()
        out_tensor_list.clear()
        out_tensor_list.extend(paddle.split(out, nranks, 0))
        return

J
Jiabin Yang 已提交
2064
    if _non_static_mode():
2065 2066
        out = _legacy_C_ops.alltoall(temp, 'use_calc_stream', use_calc_stream,
                                     'ring_id', ring_id)
L
lilong12 已提交
2067
    else:
W
wanghuancoder 已提交
2068 2069 2070 2071 2072
        op_type = 'alltoall'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(
            dtype=in_tensor_list[0].dtype)

L
lilong12 已提交
2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086
        if not isinstance(in_tensor_list, list):
            raise ValueError("The type of 'in_tensor_list' for all_to_all "
                             "should be list.")
        for elem in in_tensor_list:
            check_variable_and_dtype(
                elem, 'in_tensor_list',
                ['float16', 'float32', 'float64', 'int32', 'int64'],
                'all_to_all')
        if not isinstance(out_tensor_list, list):
            raise ValueError("The type of 'out_tensor_list' for all_to_all "
                             "should be list.")
        if len(out_tensor_list) != 0:
            raise ValueError("The 'out_tensor_list' for all_to_all "
                             "must be an empty list.")
2087 2088 2089 2090 2091 2092 2093
        helper.append_op(type=op_type,
                         inputs={'X': [temp]},
                         outputs={'Out': [out]},
                         attrs={
                             'ring_id': ring_id,
                             'use_calc_stream': use_calc_stream,
                         })
L
lilong12 已提交
2094 2095 2096
    out_tensor_list.extend(paddle.split(out, nranks, 0))


2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109
def alltoall_single(in_tensor,
                    out_tensor,
                    in_split_sizes=None,
                    out_split_sizes=None,
                    group=None,
                    use_calc_stream=True):
    """
    Scatter a single input tensor to all participators and gather the received tensors in out_tensor.

    .. note::
        ``alltoall_single`` is only supported in eager mode.

    Args:
2110
        in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132
        out_tensor (Tensor): Output Tensor. The data type should be the same as the data type of the input Tensor.
        in_split_sizes (list[int], optional): Split sizes of ``in_tensor`` for dim[0]. If not given, dim[0] of ``in_tensor`` 
            must be divisible by group size and ``in_tensor`` will be scattered averagely to all participators. Default: None.
        out_split_sizes (list[int], optional): Split sizes of ``out_tensor`` for dim[0]. If not given, dim[0] of ``out_tensor`` 
            must be divisible by group size and ``out_tensor`` will be gathered averagely from all participators. Default: None.
        group (Group, optional): The group instance return by ``new_group`` or None for global default group. Default: None.
        use_calc_stream (bool, optional): Whether to use calculation stream (True) or communication stream. Default: True.
    
    Returns:
        None, if ``use_calc_stream`` is set to ``True``; ``Task`` of ``group``, if ``use_calc_stream`` is set to ``False``.
    
    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            size = dist.get_world_size()

2133 2134 2135 2136
            # case 1 (2 GPUs)
            data = paddle.arange(2, dtype='int64') + rank * 2
            # data for rank 0: [0, 1]
            # data for rank 1: [2, 3]
2137
            output = paddle.empty([2], dtype='int64')
2138 2139
            dist.alltoall_single(data, output)
            print(output)
2140 2141 2142
            # output for rank 0: [0, 2]
            # output for rank 1: [1, 3]

2143
            # case 2 (2 GPUs)
2144
            in_split_sizes = [i + 1 for i in range(size)]
2145 2146
            # in_split_sizes for rank 0: [1, 2]
            # in_split_sizes for rank 1: [1, 2]
2147
            out_split_sizes = [rank + 1 for i in range(size)]
2148 2149 2150 2151 2152
            # out_split_sizes for rank 0: [1, 1]
            # out_split_sizes for rank 1: [2, 2]
            data = paddle.ones([sum(in_split_sizes), size], dtype='float32') * rank
            # data for rank 0: [[0., 0.], [0., 0.], [0., 0.]]
            # data for rank 1: [[1., 1.], [1., 1.], [1., 1.]]
2153 2154
            output = paddle.empty([(rank + 1) * size, size], dtype='float32')
            group = dist.new_group([0, 1])
2155
            task = dist.alltoall_single(data,
2156 2157 2158 2159 2160 2161
                                        output,
                                        in_split_sizes,
                                        out_split_sizes,
                                        use_calc_stream=False,
                                        group=group)
            task.wait()
2162
            print(output)
2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173
            # output for rank 0: [[0., 0.], [1., 1.]]
            # output for rank 1: [[0., 0.], [0., 0.], [1., 1.], [1., 1.]]

    """
    if group is not None and not group.is_member():
        return

    assert in_dygraph_mode(), "Only suppport alltoall_single in eager mode."
    # _check_single_tensor

    group = _get_default_group() if group is None else group
2174 2175 2176
    backend = _group_map_backend[group]
    assert backend != 'gloo', ("backend gloo is not supported yet")

2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188
    in_split_sizes = [] if in_split_sizes is None else in_split_sizes
    out_split_sizes = [] if out_split_sizes is None else out_split_sizes

    task = group.process_group.alltoall_single(in_tensor, out_tensor,
                                               in_split_sizes, out_split_sizes)
    if use_calc_stream:
        task.wait()
        return
    else:
        return task


S
ShenLiang 已提交
2189 2190 2191 2192
def _get_group_rank(global_rank, group=None):
    return global_rank if group is None else group.get_group_rank(global_rank)


L
lilong12 已提交
2193 2194 2195 2196 2197 2198
def send(tensor, dst=0, group=None, use_calc_stream=True):
    """
    Send a tensor to the receiver.

    Args:
        tensor (Tensor): The Tensor to send. Its data type
2199
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
L
lilong12 已提交
2200
        dst (int): The destination rank id.
L
lilong12 已提交
2201 2202
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
        use_calc_stream (bool, optional): Whether to use calculate stream or communication stream. Default: True.
2203
    
L
lilong12 已提交
2204 2205 2206 2207 2208
    Returns:
        None.

    Examples:
        .. code-block:: python
2209

L
lilong12 已提交
2210
            # required: distributed
L
lilong12 已提交
2211
            import paddle
2212
            import paddle.distributed as dist
2213

2214 2215
            dist.init_parallel_env()
            if dist.get_rank() == 0:
L
lilong12 已提交
2216
                data = paddle.to_tensor([7, 8, 9])
2217
                dist.send(data, dst=1)
L
lilong12 已提交
2218
            else:
2219 2220 2221 2222
                data = paddle.to_tensor([1, 2, 3])
                dist.recv(data, src=0)
            print(data)
            # [7, 8, 9] (2 GPUs)
L
lilong12 已提交
2223 2224 2225
    """
    if group is not None and not group.is_member():
        return
S
ShenLiang 已提交
2226
    dst = _get_group_rank(dst, group)
L
lilong12 已提交
2227
    if in_dygraph_mode():
2228
        group = _get_default_group() if group is None else group
2229 2230
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
S
ShenLiang 已提交
2231
        task = group.process_group.send(tensor, dst)
2232 2233 2234 2235 2236 2237
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

L
lilong12 已提交
2238 2239
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
2240
    if _non_static_mode():
2241 2242
        return _legacy_C_ops.send_v2(tensor, 'use_calc_stream', use_calc_stream,
                                     'ring_id', ring_id, 'peer', dst)
W
wanghuancoder 已提交
2243
    op_type = 'send_v2'
L
lilong12 已提交
2244 2245 2246 2247 2248
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'send')

    helper = LayerHelper(op_type, **locals())
2249 2250 2251 2252 2253 2254 2255
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'peer': dst,
                         'use_calc_stream': use_calc_stream,
                     })
L
lilong12 已提交
2256 2257 2258 2259 2260 2261 2262 2263


def recv(tensor, src=0, group=None, use_calc_stream=True):
    """
    Receive a tensor to the sender.

    Args:
        tensor (Tensor): The Tensor to receive. Its data type
2264
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
L
lilong12 已提交
2265
        src (int): The source rank id.
L
lilong12 已提交
2266 2267
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
        use_calc_stream (bool, optional): Whether to use calculate stream or communication stream. Default: True.
2268
    
L
lilong12 已提交
2269 2270 2271 2272 2273
    Returns:
        None.

    Examples:
        .. code-block:: python
2274

L
lilong12 已提交
2275
            # required: distributed
L
lilong12 已提交
2276
            import paddle
2277
            import paddle.distributed as dist
2278

2279 2280
            dist.init_parallel_env()
            if dist.get_rank() == 0:
L
lilong12 已提交
2281
                data = paddle.to_tensor([7, 8, 9])
2282
                dist.send(data, dst=1)
L
lilong12 已提交
2283
            else:
2284 2285 2286 2287
                data = paddle.to_tensor([1, 2, 3])
                dist.recv(data, src=0)
            print(data)
            # [7, 8, 9] (2 GPUs)
L
lilong12 已提交
2288 2289 2290
    """
    if group is not None and not group.is_member():
        return
2291

S
ShenLiang 已提交
2292
    src = _get_group_rank(src, group)
L
lilong12 已提交
2293
    if in_dygraph_mode():
2294
        group = _get_default_group() if group is None else group
2295 2296
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
S
ShenLiang 已提交
2297
        task = group.process_group.recv(tensor, src)
2298 2299 2300 2301 2302 2303
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

L
lilong12 已提交
2304 2305
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
2306
    if _non_static_mode():
2307 2308 2309
        return _legacy_C_ops.recv_v2(tensor, 'use_calc_stream', use_calc_stream,
                                     'ring_id', ring_id, 'peer', src, 'dtype',
                                     tensor.dtype, 'out_shape', tensor.shape)
W
wanghuancoder 已提交
2310
    op_type = 'recv_v2'
L
lilong12 已提交
2311 2312 2313 2314
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'recv')
    helper = LayerHelper(op_type, **locals())
2315 2316 2317 2318 2319 2320 2321 2322 2323
    helper.append_op(type=op_type,
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'peer': src,
                         'out_shape': tensor.shape,
                         'dtype': tensor.dtype,
                         'use_calc_stream': use_calc_stream,
                     })
2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345


def _check_single_tensor(tensor, tensor_name):
    if not isinstance(tensor, (core.eager.Tensor, paddle.Tensor)):
        raise RuntimeError("Invalid function argument. Expected parameter {}"
                           "to be of type paddle.Tensor, but it's {}".format(
                               tensor_name, type(tensor)))


def _check_tensor_list(tensor_list, tensor_name):
    if not isinstance(tensor_list, list) or \
        not all(isinstance(t, (core.eager.Tensor, paddle.Tensor)) for t in tensor_list):
        raise RuntimeError("Invalid function argument. Expected parameter {}"
                           "to be of type paddle.Tensor".format(tensor_name))


def isend(tensor, dst, group=None):
    """
    Sends a tensor asynchronously

    Args:
        tensor (Tensor): The Tensor to send. Its data type
2346
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361 2362 2363
        dst (int): The destination rank.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
    
    Returns:
        A distributed task object.

    Warning:    
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
2364
            if dist.get_rank() == 0:
2365
                data = paddle.to_tensor([7, 8, 9])
2366
                task = dist.isend(data, dst=1)
2367 2368
            else:
                data = paddle.to_tensor([1, 2, 3])
2369
                task = dist.irecv(data, src=0)
2370 2371
            task.wait()
            print(data)
2372
            # [7, 8, 9] (2 GPUs)
2373 2374 2375 2376 2377 2378 2379 2380

    """
    _check_single_tensor(tensor, "tensor")
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
2381 2382
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
2383 2384 2385 2386
        group_dst_rank = group.get_group_rank(dst)
        assert group_dst_rank >= 0, ("dst rank out of group, need global rank")
        return group.process_group.send(tensor, group_dst_rank)
    else:
2387
        raise RuntimeError("Only support eager dygraph mode.")
2388 2389 2390 2391 2392 2393 2394 2395


def irecv(tensor, src=None, group=None):
    """
    Receive a tensor to the sender.

    Args:
        tensor (Tensor): The Tensor to receive. Its data type
2396
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
2397 2398 2399 2400
        src (int): The source rank id.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.

    Returns:
2401
        A distributed task object.
2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413

    Warning:    
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
2414
            if dist.get_rank() == 0:
2415
                data = paddle.to_tensor([7, 8, 9])
2416
                task = dist.isend(data, dst=1)
2417 2418
            else:
                data = paddle.to_tensor([1, 2, 3])
2419
                task = dist.irecv(data, src=0)
2420 2421
            task.wait()
            print(data)
2422
            # [7, 8, 9] (2 GPUs)
2423 2424 2425 2426 2427 2428 2429
    """
    _check_single_tensor(tensor, "tensor")
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
2430 2431
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
2432 2433 2434 2435
        group_src_rank = group.get_group_rank(src)
        assert group_src_rank >= 0, ("src rank out of group, need global rank")
        return group.process_group.recv(tensor, group_src_rank)
    else:
2436
        raise RuntimeError("Only support eager dygraph mode.")
2437 2438 2439 2440 2441 2442 2443 2444 2445 2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462 2463 2464 2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508 2509 2510 2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532 2533 2534 2535 2536 2537 2538 2539 2540 2541 2542 2543 2544 2545 2546 2547 2548 2549 2550 2551 2552 2553 2554 2555 2556 2557 2558 2559 2560 2561 2562 2563 2564 2565 2566 2567 2568 2569 2570 2571 2572 2573 2574 2575 2576 2577 2578


class P2POp(object):
    """
    A class that makes point-to-point operations for "batch_isend_irecv".

    This class creates the type of P2P operation, communication buffer, peer rank,
    Group. Instances of this class will be passed to
    ``paddle.distributed.batch_isend_irecv`` for point-to-point communication.

    Args:
        op (callable): A function to send data to or receive data from a peer process.
            The type of ``op`` is either ``paddle.distributed.isend`` or ``paddle.distributed.irecv``.
        tensor (Tensor): Tensor to send or receive.
        peer (int): The destination or source rank.
        group (Group, optional): The group instance return by new_group or None for global 
            default group. Default: None.

    """

    def __init__(self, op, tensor, peer, group=None):
        if op not in [isend, irecv]:
            raise RuntimeError("Invalid ``op`` function. Expected ``op`` "
                               "to be of type ``paddle.distributed.isend`` or "
                               "``paddle.distributed.irecv``.")
        _check_single_tensor(tensor, "tensor")

        self.op = op
        self.tensor = tensor
        self.peer = peer
        self.group = _get_default_group() if group is None else group


@contextlib.contextmanager
def _with_batch_p2p_guard(backend):
    if backend == "nccl":
        core.ProcessGroupNCCL.group_start()
    try:
        yield
    finally:
        if backend == "nccl":
            core.ProcessGroupNCCL.group_end()


def _check_p2p_op_list(p2p_op_list):
    """
    Helper to check that the ``p2p_op_list`` is a list of P2POp instances and
    all ops use the same backend.
    """
    if not isinstance(p2p_op_list, list) or not all(
            isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list):
        raise RuntimeError("Invalid ``p2p_op_list``. Each op is expected to "
                           "to be of type ``paddle.distributed.P2POp``.")

    backend = _group_map_backend[p2p_op_list[0].group]
    if not all(backend == _group_map_backend[p2p_op.group]
               for p2p_op in p2p_op_list):
        raise RuntimeError("All groups need to use the same backend.")


def batch_isend_irecv(p2p_op_list):
    """
    Send or Receive a batch of tensors asynchronously and return a list of requests.

    Process each of the point-to-point operations in ``p2p_op_list`` and return the 
    corresponding tasks. NCCL are currently supported.

    Args:
        p2p_op_list: A list of point-to-point operations(type of each operator is
            ``paddle.distributed.P2POp``). The order of the isend/irecv in the list
            matters and it needs to match with corresponding isend/irecv on the
            remote end.

    Returns:
        A list of distributed tasks returned by calling the corresponding
        op in the op_list. 

    Warning:    
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed

            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            world_size = dist.get_world_size()

            send_t = paddle.arange(2) + rank
            # paddle.tensor([0, 1])  # Rank-0
            # paddle.tensor([1, 2])  # Rank-1

            recv_t = paddle.empty(shape=[2], dtype=send_t.dtype)

            send_op = dist.P2POp(dist.isend, send_t, (rank + 1) % world_size)
            recv_op = dist.P2POp(dist.irecv, recv_t, (rank - 1 + world_size) % world_size)

            tasks = dist.batch_isend_irecv([send_op, recv_op])

            for task in tasks:
                task.wait()
            
            print(recv_t)
            # paddle.tensor([1, 2])     # Rank-0
            # paddle.tensor([0, 1])     # Rank-1
    """
    _check_p2p_op_list(p2p_op_list)
    group = p2p_op_list[0].group
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
        backend = _group_map_backend[group]
        tasks = []
        with _with_batch_p2p_guard(backend):
            for p2p_op in p2p_op_list:
                op = p2p_op.op
                tensor = p2p_op.tensor
                peer = p2p_op.peer
                comm_group = p2p_op.group
                task = op(tensor, peer, comm_group)
                if task is not None:
                    tasks.append(task)
        return tasks
    else:
        raise RuntimeError("Don't support static graph mode currently.")


def reduce_scatter(tensor,
                   tensor_list,
                   op=ReduceOp.SUM,
                   group=None,
                   use_calc_stream=True):
    """
    Reduces, then scatters a list of tensors to all processes in a group

    Args:
2579 2580 2581
        tensor (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
        tensor_list (list[Tensor]): List of tensors to reduce and scatter. Every element in the list must be a Tensor whose data type
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
2582
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
2583 2584 2585 2586 2587 2588 2589 2590 2591 2592 2593 2594 2595 2596 2597 2598 2599 2600 2601 2602
        group (Group, optional): The group instance return by new_group or None for global 
            default group. Default: None.
        use_calc_stream (bool, optional): Whether this op should be an async op.

    Returns:
        Async task handle, if use_calc_stream is set to False.
        None, if use_calc_stream or if not part of the group.
    
    Warning:    
        This API only supports the dygraph mode.


    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
2603 2604 2605
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([0, 1])
                data2 = paddle.to_tensor([2, 3])
2606
            else:
2607 2608 2609 2610 2611 2612
                data1 = paddle.to_tensor([4, 5])
                data2 = paddle.to_tensor([6, 7])
            dist.reduce_scatter(data1, [data1, data2])
            print(data1)
            # [4, 6] (2 GPUs, out for rank 0)
            # [8, 10] (2 GPUs, out for rank 1)
2613 2614 2615 2616 2617 2618 2619 2620 2621 2622 2623

    """
    _check_single_tensor(tensor, "tensor")
    _check_tensor_list(tensor_list, "tensor_list")

    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        op_type = _get_reduce_op(op, "reduce_scatter")
        group = _get_default_group() if group is None else group
2624 2625
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
2626 2627 2628 2629 2630 2631 2632 2633 2634 2635 2636 2637 2638 2639 2640 2641 2642 2643 2644 2645 2646

        temp = paddle.concat(tensor_list, axis=0)
        task = group.process_group._reduce_scatter_base(tensor, temp, op_type)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task
    else:
        raise RuntimeError("Don't support static graph mode currently.")


def _reduce_scatter_base(output,
                         input,
                         op=ReduceOp.SUM,
                         group=None,
                         use_calc_stream=True):
    """
    Reduces, then scatters a flattened tensor to all processes in a group.

    Args:
2647 2648 2649
        output (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
        input (Tensor): Input tensor that is of size output tensor size times world size. Its data type 
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
2650
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
2651 2652 2653 2654 2655 2656 2657 2658 2659 2660 2661 2662 2663 2664 2665 2666 2667
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.
        use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
    Returns:
        Async task handle, if use_calc_stream is set to False.
        None, if use_calc_stream or if not part of the group.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
2668 2669 2670 2671 2672
            data = paddle.arange(4) + rank
            # [0, 1, 2, 3] (2 GPUs, for rank 0)
            # [1, 2, 3, 4] (2 GPUs, for rank 1)
            output = paddle.empty(shape=[2], dtype=data.dtype)
            dist.collective._reduce_scatter_base(output, data)
2673
            print(output)
2674 2675
            # [1, 3] (2 GPUs, out for rank 0)
            # [5, 7] (2 GPUs, out for rank 1)
2676 2677 2678 2679 2680 2681 2682 2683 2684 2685 2686 2687 2688 2689 2690 2691 2692 2693 2694

    """
    _check_single_tensor(output, "output")
    _check_single_tensor(input, "input")

    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        op_type = _get_reduce_op(op, "_reduce_scatter_base")
        group = _get_default_group() if group is None else group
        task = group.process_group._reduce_scatter_base(output, input, op_type)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task
    else:
        raise RuntimeError("Don't support static graph mode currently.")