collective.py 98.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import os
17 18
import pickle
import io
19 20
import datetime
import time
21
from ..fluid.layer_helper import LayerHelper
22
from ..fluid.framework import Variable
23
from ..fluid.framework import in_dygraph_mode
24
from ..fluid.framework import OpProtoHolder
J
Jiabin Yang 已提交
25
from ..fluid.framework import _non_static_mode
26
from ..fluid.framework import _in_legacy_dygraph
27
from ..fluid.framework import convert_np_dtype_to_dtype_
J
Jiangxinz 已提交
28
from ..fluid.framework import _varbase_creator
29 30 31 32
from ..fluid.data_feeder import convert_dtype
from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.data_feeder import check_type
from ..fluid.data_feeder import check_dtype
33 34
from ..fluid.layers.tensor import fill_constant
from ..fluid.layers import utils
B
Baibaifan 已提交
35
from ..fluid.dygraph import layers
36 37 38 39
from ..fluid.dygraph.parallel import prepare_context
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
40
from paddle import _C_ops, _legacy_C_ops
J
Jiangxinz 已提交
41
import paddle.fluid.dygraph_utils as dygraph_utils
42
import contextlib
43

44
__all__ = []
45 46 47


class ReduceOp:
L
lilong12 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
    """
    Specify the type of operation used for element-wise reductions.
    It should be one of the following values:

        ReduceOp.SUM

        ReduceOp.MAX

        ReduceOp.MIN

        ReduceOp.PROD

    Examples:
        .. code-block:: python

63
            # required: distributed
L
lilong12 已提交
64
            import paddle
65
            import paddle.distributed as dist
L
lilong12 已提交
66

67 68 69
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
L
lilong12 已提交
70
            else:
71 72 73 74
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.all_reduce(data, op=dist.ReduceOp.SUM)
            print(data)
            # [[5, 7, 9], [5, 7, 9]] (2 GPUs)
L
lilong12 已提交
75
    """
76 77 78 79
    SUM = 0
    MAX = 1
    MIN = 2
    PROD = 3
80
    AVG = 4
81 82


K
kuizhiqing 已提交
83 84 85 86
class Group():
    """
    The abstract representation of group.
    """
87

88
    def __init__(self, rank, rank_num, id=0, ranks=[], pg=None, name=None):
89 90
        self.rank = rank
        self.nranks = rank_num
K
kuizhiqing 已提交
91 92
        self.id = id
        self.ranks = ranks
93 94
        self.pg = pg
        self.name = name
K
kuizhiqing 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108

    def is_member(self):
        if self.rank < 0:
            return False
        if self.nranks < 2:
            return False
        return True

    def get_group_rank(self, rank):
        if self.is_member() and rank in self.ranks:
            return self.ranks.index(rank)
        else:
            return -1

109 110 111 112
    @property
    def process_group(self):
        return self.pg

L
LiYuRio 已提交
113 114 115 116
    @property
    def world_size(self):
        return self.nranks if self.rank >= 0 else -1

117 118 119 120
    def __repr__(self):
        debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format(
            self.rank, self.nranks, self.id)
        debug_str += ", ".join(map(str, self.ranks))
121 122
        debug_str += "; name: "
        debug_str += self.name if self.name else "None"
123 124
        return debug_str

K
kuizhiqing 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138

_global_env = None


def _get_global_env():
    global _global_env
    if not _global_env:
        _global_env = paddle.distributed.ParallelEnv()
    return _global_env


# group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group]
_group_map = {}
139
_global_env_gid = 0
K
kuizhiqing 已提交
140

141 142 143 144
# group map by name : the map of all groups from their names
# Dict[name, Group]
_group_map_by_name = {}

145 146 147 148
# backend map by group : the map of all backend from their groups
# Dict[group, backend]
_group_map_backend = {}

149 150 151
# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"

152
_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter', 'xccl']
153 154
_default_store = None  # the default tcp store
_default_backend = None
155 156
_default_timeout = datetime.timedelta(seconds=1800)
_start_ring_id = 0
157

K
kuizhiqing 已提交
158

L
lilong12 已提交
159 160 161 162 163 164 165 166 167 168
def _set_default_backend(backend):
    global _default_backend
    _default_backend = backend


def _set_default_store(store):
    global _default_store
    _default_store = store


K
kuizhiqing 已提交
169 170
def _get_group_map():
    global _group_map
171
    if _global_env_gid not in _group_map:
K
kuizhiqing 已提交
172
        genv = _get_global_env()
173 174 175
        _group_map[_global_env_gid] = Group(genv.rank,
                                            genv.world_size,
                                            ranks=list(range(genv.world_size)))
K
kuizhiqing 已提交
176 177 178 179
    return _group_map


def _get_global_group():
180
    return _get_group_map()[_global_env_gid]
K
kuizhiqing 已提交
181 182


183 184 185 186 187 188
def _get_group_map_by_name():
    global _group_map_by_name
    return _group_map_by_name


def _get_default_group():
L
lilong12 已提交
189
    global _group_map_by_name
190 191
    assert is_initialized(), ("Call paddle.distributed.init_parallel_env first "
                              "to initialize the distributed environment.")
192 193 194
    return _get_group_map_by_name()[_default_group_name]


L
lilong12 已提交
195 196 197 198 199 200 201 202 203 204 205 206
def _set_group_map(gid, group):
    global _group_map
    assert gid not in _group_map
    _group_map[gid] = group


def _set_group_map_by_name(name, group):
    global _group_map_by_name
    assert name not in _group_map_by_name
    _group_map_by_name[name] = group


207 208 209 210 211 212
def _set_group_map_backend(group, backend):
    global _group_map_backend
    assert group not in _group_map_backend
    _group_map_backend[group] = backend


K
kuizhiqing 已提交
213
def _new_ring_id():
214 215 216 217 218 219 220
    # NOTE(liyurui): For compatible reason, auto parallel and eager mode relay on previous syntax.
    if in_dygraph_mode():
        global _start_ring_id
        _start_ring_id += 1
        return _start_ring_id + max(_get_global_env().nrings, 9)
    else:
        return len(_get_group_map()) + max(_get_global_env().nrings, 9)
K
kuizhiqing 已提交
221 222


223 224 225 226 227 228 229 230 231 232 233 234 235
def _get_reduce_op(reduce_op, func_name):
    if reduce_op == ReduceOp.SUM:
        return core.ReduceOp.SUM
    elif reduce_op == ReduceOp.MAX:
        return core.ReduceOp.MAX
    elif reduce_op == ReduceOp.MIN:
        return core.ReduceOp.MIN
    elif reduce_op == ReduceOp.PROD:
        return core.ReduceOp.PRODUCT
    else:
        raise ValueError("Unknown reduce_op type for {}.".format(func_name))


K
kuizhiqing 已提交
236 237 238 239 240 241
def get_group(id=0):
    """

    Get group instance by group id.

    Args:
K
kuizhiqing 已提交
242
        id (int): the group id. Default value is 0.
K
kuizhiqing 已提交
243 244 245 246 247 248 249 250 251 252 253 254 255 256

    Returns:
        Group: the group instance.

    Examples:
        .. code-block:: python

            ...
            gid = paddle.distributed.new_group([2,4,6])
            paddle.distributed.get_group(gid.id)

    """

    gm = _get_group_map()
J
Jiangxinz 已提交
257
    return gm[id] if id in gm else None
K
kuizhiqing 已提交
258 259


260 261 262 263 264 265
def _new_process_group_impl(backend,
                            store,
                            rank,
                            world_size,
                            group_name,
                            pg_options,
L
lilong12 已提交
266 267 268
                            group_id=0,
                            src_rank=None,
                            dst_rank=None):
269
    pg = None
270
    genv = _get_global_env()
L
lilong12 已提交
271 272 273 274
    if backend != 'heter':
        assert src_rank is None and dst_rank is None, (
            "src_rank and dst_rank "
            "can only be set for heter backend.")
L
lilong12 已提交
275
    assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
276
    if backend == "gloo":
277 278
        place = core.CPUPlace()
        pg = core.ProcessGroupGloo(store, rank, world_size, place, group_id)
279
    elif backend == "nccl":
280 281
        place = core.CUDAPlace(genv.device_id)
        pg = core.ProcessGroupNCCL(store, rank, world_size, place, group_id)
282
    elif backend == "hccl":
283 284
        place = core.NPUPlace(genv.device_id)
        pg = core.ProcessGroupHCCL(store, rank, world_size, place, group_id)
285 286 287
    elif backend == "xccl":
        place = core.CustomPlace(genv.device_type, genv.device_id)
        pg = core.ProcessGroupCustom(store, rank, world_size, place, group_id)
288
    elif backend == "heter":
289 290 291 292 293
        place = None
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(genv.device_id)
        elif core.is_compiled_with_npu():
            place = core.NPUPlace(genv.device_id)
294 295 296 297 298 299 300 301 302 303 304 305 306
        cluster_id = int(os.getenv("CLUSTER_ID", "-1"))
        assert cluster_id >= 0, "please set the CLUSTER_ID variable."
        cluster_size = os.getenv("CLUSTER_SIZE", None)
        assert cluster_size, "please set the CLUSTER_SIZE variable."
        cluster_size = cluster_size.split(",")
        cluster_size = [int(s) for s in cluster_size]
        switch_ep = os.getenv("CLUSTER_SWITCH", None)
        assert switch_ep, "please set the CLUSTER_SWITCH variable."
        cluster_size_cumsum = np.cumsum(cluster_size)
        cluster_offset = 0 if cluster_id == 0 else cluster_size_cumsum[
            cluster_id - 1]
        global_rank = cluster_offset + rank
        global_world_size = cluster_size_cumsum[-1]
307
        global_rank, global_world_size = _get_global_config(backend, rank)
308 309 310 311 312 313 314 315 316 317 318 319 320
        pg = core.ProcessGroupHeter(store,
                                    rank=global_rank,
                                    world_size=global_world_size,
                                    place=place,
                                    gid=group_id,
                                    local_rank=rank,
                                    local_size=world_size,
                                    gloo_rank=cluster_id,
                                    gloo_size=len(cluster_size),
                                    with_switch=True,
                                    switch_endpoint=switch_ep,
                                    src_rank=src_rank,
                                    dst_rank=dst_rank)
321 322 323 324

    return pg


S
ShenLiang 已提交
325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348
def barrier(group=None):
    """

    Barrier among all participators in the group.

    Args:
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            paddle.distributed.barrier()
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
349
    if in_dygraph_mode():
350 351 352 353 354
        group = _get_default_group() if group is None else group
        task = group.process_group.barrier()
        task.wait()
        return

S
ShenLiang 已提交
355 356 357
    ring_id = 0 if group is None else group.id

    temp = fill_constant([1], dtype="int32", value="1")
J
Jiabin Yang 已提交
358
    if _non_static_mode():
359
        return _legacy_C_ops.barrier(temp, temp, 'ring_id', ring_id)
W
wanghuancoder 已提交
360 361 362

    op_type = 'barrier'

S
ShenLiang 已提交
363 364 365
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'group' for barrier must be int.")
    helper = LayerHelper(op_type, **locals())
366 367 368 369
    helper.append_op(type=op_type,
                     inputs={'X': [temp]},
                     outputs={'Out': [temp]},
                     attrs={'ring_id': ring_id})
S
ShenLiang 已提交
370 371


L
lilong12 已提交
372 373 374 375 376 377 378
# _custom_gid provides a way for users to
# set the group id, which is usually useful
# to be compatible with the static mode.
_custom_gid = None


def _set_custom_gid(gid):
379
    global _custom_gid
L
lilong12 已提交
380 381 382
    _custom_gid = gid


383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419
def _barrier_by_tcp_store(group_name, store, timeout):
    global_rank = paddle.distributed.get_rank()
    global_world_size = paddle.distributed.get_world_size()

    if global_world_size < 2:
        return

    barrier_prefix = "Barrier/" + group_name + "/"
    is_master = (global_rank == 0)

    def _check_keys_ready(wait_keys):
        start_time = time.time()
        while len(wait_keys) > 0:
            time.sleep(0.1)
            elapse_time = time.time() - start_time
            if datetime.timedelta(seconds=elapse_time) > timeout:
                raise RuntimeError(
                    "Timeout while initializing process group {}."
                    "Keys {} are not ready sinck rank {} is waiting them."
                    "Two reason may cause this error:\n 1. The create process group api should be called by all ranks.\n"
                    " 2. Try to increase the waiting time.\n".format(
                        group_name, wait_keys, global_rank))
            wait_keys = list(
                filter(lambda key: int(store.get(key)) != 1, wait_keys))

    # all the workers set their exiting key and exit
    # the master will wait for all workers' exiting key, ensure to exit in the end
    if is_master:
        wait_keys = [
            barrier_prefix + str(rank) for rank in range(1, global_world_size)
        ]
        _check_keys_ready(wait_keys)
    else:
        store.add(barrier_prefix + str(global_rank), 1)


def new_group(ranks=None, backend=None, timeout=_default_timeout):
K
kuizhiqing 已提交
420 421
    """

K
kuizhiqing 已提交
422
    Creates a new distributed communication group.
K
kuizhiqing 已提交
423 424

    Args:
K
kuizhiqing 已提交
425
        ranks (list): The global ranks of group members.
K
kuizhiqing 已提交
426
        backend (str): The backend used to create group, only nccl is supported now.
427
        timeout (datetime.timedelta, optional): The waiting timeout for store relevant options, default is 30 minutes.
K
kuizhiqing 已提交
428 429

    Returns:
K
kuizhiqing 已提交
430
        Group: The group instance.
K
kuizhiqing 已提交
431 432 433 434 435 436 437

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
438 439 440
            tindata = paddle.randn(shape=[2, 3])
            gp = paddle.distributed.new_group([2,4,6])
            paddle.distributed.all_reduce(tindata, group=gp, use_calc_stream=False)
K
kuizhiqing 已提交
441 442

    """
443
    global _custom_gid
444
    global _group_map
L
lilong12 已提交
445
    if in_dygraph_mode():
446
        global _default_group_name
L
lilong12 已提交
447
        gid = _custom_gid if _custom_gid else _new_ring_id()
448
        group_name = _default_group_name + str(gid)
L
lilong12 已提交
449
        if backend != 'heter' and (ranks is None or len(ranks) > 1):
450 451 452 453 454 455 456 457 458
            global_group = _get_default_group()
            global_rank = global_group.rank
            global_ranks = global_group.ranks
            backend = _default_backend if backend is None else backend
            if ranks is None:
                ranks = global_ranks
            assert len(ranks) <= len(global_ranks), (
                "Size of new group must be less than or "
                "equal to that of the default global group.")
459 460
        size = len(ranks)
        ranks = sorted(ranks)
L
lilong12 已提交
461 462 463 464
        if backend == 'heter' or (size > 1 and global_rank in ranks):
            rank = 0 if backend == 'heter' else ranks.index(global_rank)
            src_rank = ranks[0] if backend == 'heter' else None
            dst_rank = ranks[1] if backend == 'heter' else None
465 466 467 468 469 470 471 472 473
            pg = _new_process_group_impl(backend,
                                         _default_store,
                                         rank,
                                         size,
                                         group_name,
                                         pg_options=None,
                                         group_id=gid,
                                         src_rank=src_rank,
                                         dst_rank=dst_rank)
474 475 476 477 478 479
        else:
            rank = -1
            pg = None
        group = Group(rank, size, id=gid, ranks=ranks, pg=pg, name=group_name)
        _group_map_by_name[group_name] = group
        _group_map[gid] = group
480
        _group_map_backend[group] = backend
481

482
        # TODO(shenliang03): This is a temporary solution to solve the problem of
483
        # hang caused by tcp
484
        paddle.distributed.barrier(group=group)
485 486 487 488 489
        # NOTE(liyurui): All processors should hang and wait using tcp store, in case master exit before sub-group is created.
        if backend != 'heter':
            _barrier_by_tcp_store(group_name, _default_store, timeout)
        else:
            print("Warning: store barrier is not supported for heter backend.")
490
        return group
K
kuizhiqing 已提交
491 492 493 494 495 496 497 498 499 500 501 502 503 504

    if not backend:
        backend = 'nccl'
    assert backend == 'nccl', ("backend other than nccl is not supported yet")

    genv = _get_global_env()
    global_rank = genv.rank

    ring_id = _new_ring_id()

    if global_rank not in ranks:
        gp = Group(-1, -1, ring_id, ranks)
        _group_map[ring_id] = gp
    else:
505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524
        ranks = sorted(ranks)
        group_rank = ranks.index(global_rank)
        group_size = len(ranks)
        gp = Group(group_rank, group_size, ring_id, ranks)
        _group_map[ring_id] = gp

        if group_size >= 2:
            strategy = core.ParallelStrategy()
            strategy.nranks = group_size
            strategy.local_rank = group_rank
            strategy.trainer_endpoints = [
                genv.trainer_endpoints[i] for i in ranks
            ]
            strategy.current_endpoint = genv.current_endpoint
            strategy.nrings = 1

            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(genv.device_id)
                core.NCCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
525 526 527 528
            elif core.is_compiled_with_npu():
                place = core.NPUPlace(genv.device_id)
                core.HCCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
529 530 531 532
            elif core.is_compiled_with_mlu():
                place = core.MLUPlace(genv.device_id)
                core.CNCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
533 534 535 536
            elif core.is_compiled_with_xpu():
                place = core.XPUPlace(genv.device_id)
                core.BKCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
537 538 539 540 541
            else:
                assert False, ("no cuda device found")
        else:
            return gp

542
    # TODO(shenliang03): This is a temporary solution to solve the problem of
543
    # hang caused by cross-creation of new_group
544
    tmp = paddle.to_tensor(
J
Jiabin Yang 已提交
545
        [1], dtype="int32") if _non_static_mode() else fill_constant(
546
            [0], dtype="int32", value="1")
547 548
    paddle.distributed.all_reduce(tmp, use_calc_stream=True)
    paddle.distributed.wait(tmp)
K
kuizhiqing 已提交
549 550
    return gp

551

552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581
def is_initialized():
    """

    Check whether the distributed environment has been initialized

    Returns (bool): `True` if distributed environment has been initialized, otherwise `False`.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle

            print(paddle.distributed.is_initialized())
            # False

            paddle.distributed.init_parallel_env()
            print(paddle.distributed.is_initialized())
            # True

    """
    global _group_map_by_name
    return _default_group_name in _group_map_by_name


def destroy_process_group(group=None):
    """
    Destroy a given group for communication

    Args:
582 583
        group (ProcessGroup, optional): The group to be destroyed. All of process groups, including
                                        the default group, will be destroyed and the distributed
584
                                        environment will be deinitialized.
585

586 587 588 589 590 591 592
    Returns : None

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
593
            import paddle.distributed as dist
594

595 596
            dist.init_parallel_env()
            group = dist.new_group([0, 1])
597

598 599
            dist.destroy_process_group(group)
            print(dist.is_initialized())
600
            # True
601 602
            dist.destroy_process_group()
            print(dist.is_initialized())
603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621
            # False

    """
    global _group_map
    global _group_map_by_name

    pg = _get_default_group() if group is None else group
    assert _group_map.get(pg.id, None) is not None, "Invalid group."

    if group is None:
        _group_map.clear()
        _group_map_by_name.clear()
        _group_map_backend.clear()
    else:
        del _group_map[pg.id]
        del _group_map_by_name[pg.name]
        del _group_map_backend[pg]


K
kuizhiqing 已提交
622 623 624 625 626 627 628 629
def wait(tensor, group=None, use_calc_stream=True):
    """

    wait to sync stream for group.

    Args:
        tensor (Tensor): The Tensor used before sync.
        group (Group): The Group instance to perform sync.
K
kuizhiqing 已提交
630 631
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
K
kuizhiqing 已提交
632 633 634 635 636 637 638 639 640 641

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
642
            tindata = paddle.randn(shape=[2, 3])
K
kuizhiqing 已提交
643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660
            paddle.distributed.all_reduce(tindata, use_calc_stream=True)
            paddle.distributed.wait(tindata)

    """

    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id

    if use_calc_stream:
        _sync_calc_stream(tensor)
    else:
        _sync_comm_stream(tensor, ring_id)


def _sync_calc_stream(tensor):

J
Jiabin Yang 已提交
661
    if _non_static_mode():
662
        return _legacy_C_ops.c_sync_calc_stream(tensor, tensor)
K
kuizhiqing 已提交
663 664 665 666 667 668 669

    op_type = 'c_sync_calc_stream'

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
670 671
        outputs={'Out': [tensor]},
    )
672

673

K
kuizhiqing 已提交
674
def _sync_comm_stream(tensor, ring_id=0):
675

J
Jiabin Yang 已提交
676
    if _non_static_mode():
677 678
        return _legacy_C_ops.c_sync_comm_stream([tensor], [tensor], 'ring_id',
                                                ring_id)
679

K
kuizhiqing 已提交
680
    op_type = 'c_sync_comm_stream'
681

K
kuizhiqing 已提交
682 683 684 685 686
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
687 688
        attrs={'ring_id': ring_id},
    )
K
kuizhiqing 已提交
689 690 691


def broadcast(tensor, src, group=None, use_calc_stream=True):
692 693 694
    """

    Broadcast a tensor from the source to all others.
695 696
    As shown below, one process is started with a GPU and GPU0 owns data 0. Through broadcast operator,
    data 0 will be sent to all GPUs from GPU0.
697 698 699 700 701

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/broadcast.png
        :width: 800
        :alt: broadcast
        :align: center
702 703

    Args:
704 705
        tensor (Tensor): The Tensor to send if current rank is the source, or the Tensor to receive otherwise. Its data type
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
706
        src (int): The source rank.
K
kuizhiqing 已提交
707
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
708 709
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
710 711 712 713 714 715 716

    Returns:
        None.

    Examples:
        .. code-block:: python

717
            # required: distributed
718
            import paddle
719
            import paddle.distributed as dist
720

721 722 723
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
724
            else:
725 726 727 728
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.broadcast(data, src=1)
            print(data)
            # [[1, 2, 3], [1, 2, 3]] (2 GPUs)
729
    """
K
kuizhiqing 已提交
730 731 732 733 734 735 736

    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

L
lilong12 已提交
737
    if in_dygraph_mode():
738 739 740 741 742 743 744 745 746 747 748
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
        assert gsrc >= 0, ("src rank out of group, need global rank")
        task = group.process_group.broadcast(tensor, gsrc)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

    ring_id = ring_id = 0 if group is None else group.id
K
kuizhiqing 已提交
749
    gsrc = src if group is None else group.get_group_rank(src)
K
kuizhiqing 已提交
750
    assert gsrc >= 0, ("src rank out of group, need global rank")
K
kuizhiqing 已提交
751

J
Jiabin Yang 已提交
752
    if _non_static_mode():
753 754 755
        return _legacy_C_ops.c_broadcast(tensor, tensor, 'root', gsrc,
                                         'use_calc_stream', use_calc_stream,
                                         'ring_id', ring_id)
756 757

    op_type = 'c_broadcast'
758 759 760 761
    check_variable_and_dtype(tensor, 'tensor', [
        'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
        'bool'
    ], 'broadcast')
762 763

    helper = LayerHelper(op_type, **locals())
764 765 766 767 768 769 770 771
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'root': gsrc,
                         'use_calc_stream': use_calc_stream,
                         'ring_id': ring_id,
                     })
772 773


K
kuizhiqing 已提交
774
def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
775 776 777
    """

    Reduce a tensor over all ranks so that all get the result.
778
    As shown below, one process is started with a GPU and the data of this process is represented
779
    by its group rank. The reduce operator is sum. Through all_reduce operator,
780 781 782 783 784 785
    each GPU will have the sum of the data from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allreduce.png
        :width: 800
        :alt: all_reduce
        :align: center
786 787 788

    Args:
        tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
789 790
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
K
kuizhiqing 已提交
791
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
792 793
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
794 795 796 797 798 799 800

    Returns:
        None.

    Examples:
        .. code-block:: python

801
            # required: distributed
802
            import paddle
803
            import paddle.distributed as dist
804

805 806
            dist.init_parallel_env()
            if dist.get_rank() == 0:
807
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
808
            else:
809
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
810 811 812
            dist.all_reduce(data)
            print(data)
            # [[5, 7, 9], [5, 7, 9]] (2 GPUs)
813
    """
K
kuizhiqing 已提交
814 815 816
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
817
    if in_dygraph_mode():
818
        op_type = _get_reduce_op(op, "all_reduce")
819 820 821 822 823 824 825 826
        group = _get_default_group() if group is None else group
        task = group.process_group.allreduce(tensor, op_type)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

K
kuizhiqing 已提交
827
    ring_id = 0 if group is None else group.id
J
Jiabin Yang 已提交
828
    if _non_static_mode():
829
        if op == ReduceOp.SUM:
830 831 832
            return _legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
                                                  use_calc_stream, 'ring_id',
                                                  ring_id)
833
        elif op == ReduceOp.MAX:
834 835 836
            return _legacy_C_ops.c_allreduce_max_(tensor, 'use_calc_stream',
                                                  use_calc_stream, 'ring_id',
                                                  ring_id)
837
        elif op == ReduceOp.MIN:
838 839 840
            return _legacy_C_ops.c_allreduce_min_(tensor, 'use_calc_stream',
                                                  use_calc_stream, 'ring_id',
                                                  ring_id)
841
        elif op == ReduceOp.PROD:
842 843 844
            return _legacy_C_ops.c_allreduce_prod_(tensor, 'use_calc_stream',
                                                   use_calc_stream, 'ring_id',
                                                   ring_id)
845 846 847
        else:
            raise ValueError("Unknown parameter: {}.".format(op))

848 849 850 851
    check_variable_and_dtype(tensor, 'tensor', [
        'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
        'bool'
    ], 'all_reduce')
852 853 854 855 856 857 858 859
    if op == ReduceOp.SUM:
        op_type = 'c_allreduce_sum'
    elif op == ReduceOp.MAX:
        op_type = 'c_allreduce_max'
    elif op == ReduceOp.MIN:
        op_type = 'c_allreduce_min'
    elif op == ReduceOp.PROD:
        op_type = 'c_allreduce_prod'
K
kuizhiqing 已提交
860 861
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'ring_id' for all_reduce should be int.")
862
    helper = LayerHelper(op_type, **locals())
863 864 865 866 867 868 869
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': use_calc_stream
                     })
870 871


K
kuizhiqing 已提交
872
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
873 874
    """

875 876
    Reduce a tensor to the destination from all others. As shown below, one process is started with a GPU and the data of this process is represented
    by its group rank. The destination of the reduce operator is GPU0 and the process is sum. Through reduce operator,
877 878 879 880 881 882
    the GPU0 will owns the sum of all data from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/reduce.png
        :width: 800
        :alt: reduce
        :align: center
883 884 885

    Args:
        tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
886
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
887
        dst (int): The destination rank id.
888
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
K
kuizhiqing 已提交
889
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
890 891
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
892 893 894 895 896 897 898

    Returns:
        None.

    Examples:
        .. code-block:: python

899
            # required: distributed
900
            import paddle
901
            import paddle.distributed as dist
902

903 904 905
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
906
            else:
907 908 909 910 911
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.reduce(data, dst=0)
            print(data)
            # [[5, 7, 9], [5, 7, 9]] (2 GPUs, out for rank 0)
            # [[1, 2, 3], [1, 2, 3]] (2 GPUs, out for rank 1)
912
    """
K
kuizhiqing 已提交
913 914 915
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
916
    if in_dygraph_mode():
917
        op_type = _get_reduce_op(op, "reduce")
918 919 920 921 922 923 924 925 926
        group = _get_default_group() if group is None else group
        gdst = group.get_group_rank(dst)
        assert gdst >= 0, ("dst rank out of group, need global rank")
        task = group.process_group.reduce(tensor, gdst, op_type)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task
K
kuizhiqing 已提交
927 928 929

    ring_id = 0 if group is None else group.id
    gdst = dst if group is None else group.get_group_rank(dst)
K
kuizhiqing 已提交
930
    assert gdst >= 0, ("dst rank out of group, need global rank")
K
kuizhiqing 已提交
931

J
Jiabin Yang 已提交
932
    if _non_static_mode():
933
        if op == ReduceOp.SUM:
934 935 936
            return _legacy_C_ops.c_reduce_sum(tensor, tensor, 'use_calc_stream',
                                              use_calc_stream, 'ring_id',
                                              ring_id, 'root_id', gdst)
937
        elif op == ReduceOp.MAX:
938 939 940
            return _legacy_C_ops.c_reduce_max(tensor, tensor, 'use_calc_stream',
                                              use_calc_stream, 'ring_id',
                                              ring_id, 'root_id', gdst)
941
        elif op == ReduceOp.MIN:
942 943 944
            return _legacy_C_ops.c_reduce_min(tensor, tensor, 'use_calc_stream',
                                              use_calc_stream, 'ring_id',
                                              ring_id, 'root_id', gdst)
945
        elif op == ReduceOp.PROD:
946 947 948 949
            return _legacy_C_ops.c_reduce_prod(tensor, tensor,
                                               'use_calc_stream',
                                               use_calc_stream, 'ring_id',
                                               ring_id, 'root_id', gdst)
950 951 952 953
        else:
            raise ValueError("Unknown parameter: {}.".format(op))

    op_type = 'c_reduce'
954 955 956 957
    check_variable_and_dtype(tensor, 'tensor', [
        'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
        'bool'
    ], 'reduce')
958 959 960 961 962 963 964 965 966 967 968

    if op == ReduceOp.SUM:
        op_type = 'c_reduce_sum'
    elif op == ReduceOp.MAX:
        op_type = 'c_reduce_max'
    elif op == ReduceOp.MIN:
        op_type = 'c_reduce_min'
    elif op == ReduceOp.PROD:
        op_type = 'c_reduce_prod'

    helper = LayerHelper(op_type, **locals())
969 970 971 972 973 974 975 976
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': use_calc_stream,
                         'root_id': gdst,
                     })
977 978


K
kuizhiqing 已提交
979
def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
980 981
    """

982
    Gather tensors from all participators and all get the result. As shown
983 984
    below, one process is started with a GPU and the data of this process is represented
    by its group rank. Through the all_gather operator, each GPU will have data
985 986 987 988 989 990
    from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png
        :width: 800
        :alt: all_gather
        :align: center
991 992 993

    Args:
        tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
994
            should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
995
        tensor (Tensor): The Tensor to send. Its data type
996
            should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
K
kuizhiqing 已提交
997
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
998 999
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
1000 1001 1002 1003 1004 1005 1006

    Returns:
        None.

    Examples:
        .. code-block:: python

1007
            # required: distributed
1008
            import paddle
1009
            import paddle.distributed as dist
1010

1011
            dist.init_parallel_env()
1012
            tensor_list = []
1013 1014
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
1015
            else:
1016 1017 1018 1019
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.all_gather(tensor_list, data)
            print(tensor_list)
            # [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
1020
    """
K
kuizhiqing 已提交
1021 1022 1023
    if group is not None and not group.is_member():
        return

1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034
    def convert_to_complex(list_of_tensor):
        list_of_complex = []
        for tensor in list_of_tensor:
            list_of_complex.append(paddle.as_complex(tensor))
        return list_of_complex

    is_input_complex = (tensor.dtype == paddle.complex64
                        or tensor.dtype == paddle.complex128)
    if is_input_complex:
        tensor = paddle.as_real(tensor)

L
lilong12 已提交
1035
    if in_dygraph_mode():
1036
        group = _get_default_group() if group is None else group
1037 1038 1039 1040 1041 1042
        if len(tensor_list) == 0:
            tensor_shape = list(tensor.shape)
            tensor_shape[0] *= group.nranks
            out = paddle.empty(tensor_shape, tensor.dtype)
        else:
            out = paddle.concat(tensor_list, axis=0)
1043 1044 1045
        task = group.process_group.all_gather(tensor, out)
        task.wait()
        tensor_list.clear()
1046 1047 1048 1049 1050
        list_of_tensor = paddle.split(out, group.nranks, 0)
        if is_input_complex:
            tensor_list.extend(convert_to_complex(list_of_tensor))
        else:
            tensor_list.extend(list_of_tensor)
1051 1052
        return

K
kuizhiqing 已提交
1053 1054 1055
    ring_id = 0 if group is None else group.id
    nranks = _get_global_group().nranks if group is None else group.nranks

J
Jiabin Yang 已提交
1056
    if _non_static_mode():
1057 1058 1059
        out = _legacy_C_ops.c_allgather(tensor, 'use_calc_stream',
                                        use_calc_stream, 'ring_id', ring_id,
                                        'nranks', nranks)
1060
    else:
1061 1062 1063
        op_type = 'c_allgather'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
1064 1065 1066 1067
        if not isinstance(tensor_list, list):
            raise ValueError("The type of 'tensor_list' for all_gather "
                             "should be list.")
        for elem in tensor_list:
1068 1069 1070 1071 1072 1073 1074 1075
            check_variable_and_dtype(elem, 'tensor_list', [
                'float16', 'float32', 'float64', 'int32', 'int64', 'bool',
                'int8', 'uint8', 'complex64', 'complex128'
            ], 'all_gather')
        check_variable_and_dtype(tensor, 'tensor', [
            'float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'int8',
            'uint8', 'complex64', 'complex128'
        ], 'all_gather')
1076 1077 1078 1079 1080 1081 1082 1083
        helper.append_op(type=op_type,
                         inputs={'X': [tensor]},
                         outputs={'Out': [out]},
                         attrs={
                             'ring_id': ring_id,
                             'use_calc_stream': use_calc_stream,
                             'nranks': nranks
                         })
1084

1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097
    list_of_tensor = paddle.split(out, nranks, 0)
    if is_input_complex:
        tensor_list.extend(convert_to_complex(list_of_tensor))
    else:
        tensor_list.extend(list_of_tensor)


def _convert_object_to_tensor(obj):
    _pickler = pickle.Pickler
    f = io.BytesIO()
    _pickler(f).dump(obj)
    data = np.frombuffer(f.getvalue(), dtype=np.uint8)
    tensor = paddle.to_tensor(data)
1098
    return tensor, tensor.numel()
1099 1100


1101
def _convert_tensor_to_object(tensor, len_of_tensor):
1102
    _unpickler = pickle.Unpickler
1103
    return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()
1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130


def all_gather_object(object_list, obj, group=None):
    """

    Gather picklable objects from all participators and all get the result. Similiar to all_gather(), but python object can be passed in.

    Args:
        object_list (list): A list of output object. The datatype of every element in the list is same as the input obj.
        obj (Any): The picklable object to send.
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Warning:
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            object_list = []
1131
            if dist.get_rank() == 0:
1132 1133 1134
                obj = {"foo": [1, 2, 3]}
            else:
                obj = {"bar": [4, 5, 6]}
1135 1136 1137
            dist.all_gather_object(object_list, obj)
            print(object_list)
            # [{'foo': [1, 2, 3]}, {'bar': [4, 5, 6]}] (2 GPUs)
1138 1139 1140 1141
    """
    assert in_dygraph_mode(
    ), "all_gather_object doesn't support static graph mode."

1142
    tensor, len_of_tensor = _convert_object_to_tensor(obj)
C
Chen Weihang 已提交
1143 1144 1145
    if paddle.get_device() != "cpu":
        len_of_tensor = len_of_tensor._copy_to(
            paddle.framework._current_expected_place(), False)
1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157

    # gather len_of_tensor from all ranks
    list_len_of_tensor = []
    all_gather(list_len_of_tensor, len_of_tensor, group)
    # get the max length from list
    max_len_of_tensor = int(max(list_len_of_tensor).item())
    # resize the input tensor to max length avoid hang in all gather
    # Note(liyurui): Maybe we should support various length all_gather?
    # Now this operation is efficient for we don't support resize in python.
    numpy_data = tensor.numpy()
    numpy_data = np.resize(numpy_data, [max_len_of_tensor])
    input_tensor = paddle.to_tensor(numpy_data)
1158 1159

    tensor_list = []
1160 1161 1162 1163
    all_gather(tensor_list, input_tensor, group)
    for i, tensor in enumerate(tensor_list):
        object_list.append(
            _convert_tensor_to_object(tensor, list_len_of_tensor[i]))
1164 1165


K
kuizhiqing 已提交
1166
def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
1167 1168
    """

1169
    Scatter a tensor to all participators. As shown below, one process is started with a GPU and the source of the scatter
1170 1171 1172 1173 1174 1175
    is GPU0. Through scatter operator, the data in GPU0 will be sent to all GPUs averagely.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/scatter.png
        :width: 800
        :alt: scatter
        :align: center
1176 1177 1178

    Args:
        tensor (Tensor): The output Tensor. Its data type
1179
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
1180
        tensor_list (list|tuple): A list/tuple of Tensors to scatter. Every element in the list must be a Tensor whose data type
1181
            should be float16, float32, float64, int32, int64, int8, uint8 or bool. Default value is None.
K
kuizhiqing 已提交
1182
        src (int): The source rank id. Default value is 0.
K
kuizhiqing 已提交
1183
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
1184 1185
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
1186 1187 1188 1189 1190 1191 1192

    Returns:
        None.

    Examples:
        .. code-block:: python

1193
            # required: distributed
1194
            import paddle
1195
            import paddle.distributed as dist
1196

1197 1198 1199 1200 1201
            dist.init_parallel_env()
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([7, 8, 9])
                data2 = paddle.to_tensor([10, 11, 12])
                dist.scatter(data1, src=1)
1202
            else:
1203 1204 1205 1206 1207 1208
                data1 = paddle.to_tensor([1, 2, 3])
                data2 = paddle.to_tensor([4, 5, 6])
                dist.scatter(data1, tensor_list=[data1, data2], src=1)
            print(data1, data2)
            # [1, 2, 3] [10, 11, 12] (2 GPUs, out for rank 0)
            # [4, 5, 6] [4, 5, 6] (2 GPUs, out for rank 1)
1209
    """
K
kuizhiqing 已提交
1210 1211 1212 1213 1214 1215
    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

L
lilong12 已提交
1216
    if in_dygraph_mode():
1217 1218 1219 1220 1221 1222 1223 1224 1225
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
        rank = group.rank
        nranks = group.nranks
    else:
        ring_id = 0 if group is None else group.id
        gsrc = src if group is None else group.get_group_rank(src)
        rank = _get_global_group().rank if group is None else group.rank
        nranks = _get_global_group().nranks if group is None else group.nranks
K
kuizhiqing 已提交
1226
    assert gsrc >= 0, ("src rank out of group, need global rank")
K
kuizhiqing 已提交
1227 1228

    if rank != gsrc:
1229 1230 1231 1232
        tensor_list = []
        for _ in range(nranks):
            tensor_list.append(tensor)
    temp = paddle.concat(tensor_list, axis=0)
L
lilong12 已提交
1233
    if in_dygraph_mode():
1234 1235 1236 1237 1238 1239 1240
        task = group.process_group.scatter(temp, tensor, gsrc)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

L
lilong12 已提交
1241
    if _non_static_mode():
1242 1243 1244
        return _legacy_C_ops.c_scatter(temp, tensor, 'use_calc_stream',
                                       use_calc_stream, 'ring_id', ring_id,
                                       'nranks', nranks, 'root', gsrc)
W
wanghuancoder 已提交
1245
    op_type = 'c_scatter'
1246 1247 1248 1249
    check_variable_and_dtype(tensor, 'tensor', [
        'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
        'bool'
    ], 'scatter')
1250
    helper = LayerHelper(op_type, **locals())
1251 1252 1253 1254 1255 1256 1257 1258 1259
    helper.append_op(type=op_type,
                     inputs={'X': [temp]},
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'root': gsrc,
                         'use_calc_stream': use_calc_stream,
                         'nranks': nranks,
                     })
1260 1261


1262
def _c_identity(tensor, group=None):
L
lilong12 已提交
1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273
    """
    Return a copy of the tensor, mainly used with model parallel.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
        group (int): The id of the process group to work on.

    Returns:
        Tensor.
    """
1274 1275 1276 1277
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
1278
    if _non_static_mode():
1279 1280 1281
        return _legacy_C_ops.c_identity(tensor, 'use_calc_stream', True,
                                        'ring_id', ring_id,
                                        'use_model_parallel', True)
L
lilong12 已提交
1282 1283 1284
    op_type = 'c_identity'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
1285

L
lilong12 已提交
1286 1287 1288
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        '_c_identity')
1289

1290 1291 1292 1293 1294 1295 1296 1297
    helper.append_op(type=op_type,
                     inputs={'X': tensor},
                     outputs={'Out': out},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': True,
                         'use_model_parallel': True,
                     })
L
lilong12 已提交
1298 1299 1300
    return out


1301
def _c_concat(tensor, group=None):
1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314
    """
    Return allgather of the tensor, mainly used with model parallel.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
        group (int): The id of the process group to work on.

    Returns:
        Tensor.
    """
    if group is not None and not group.is_member():
        return
1315 1316
    group = _get_default_group() if group is None else group
    ring_id = group.id
1317

1318
    global_rank = _get_global_env().rank
1319 1320
    rank = group.rank
    nranks = group.nranks
1321

J
Jiabin Yang 已提交
1322
    if _non_static_mode():
1323 1324 1325 1326
        return _legacy_C_ops.c_concat(tensor, 'ring_id', ring_id,
                                      'use_calc_stream', True, 'rank', rank,
                                      'nranks', nranks, 'use_model_parallel',
                                      True)
1327 1328 1329 1330 1331 1332 1333 1334 1335

    op_type = 'c_concat'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)

    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        '_c_concat')

1336 1337 1338 1339 1340 1341 1342 1343 1344 1345
    helper.append_op(type=op_type,
                     inputs={'X': tensor},
                     outputs={'Out': out},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': True,
                         'use_model_parallel': True,
                         'nranks': nranks,
                         'rank': rank
                     })
1346 1347 1348
    return out


1349
def _c_split(tensor, group=None):
L
lilong12 已提交
1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361
    """
    Split tensor evenly among all members, mainly used with model parallel.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
        rank (int): The rank of the current process.
        group (int): The id of the process group to work on.

    Returns:
        Tensor.
    """
1362 1363 1364 1365
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

1366 1367 1368 1369
    global_rank = _get_global_env().rank
    rank = global_rank if group is None else group.get_group_rank(global_rank)
    nranks = _get_global_env().world_size if group is None else group.nranks

J
Jiabin Yang 已提交
1370
    if _non_static_mode():
1371 1372 1373
        return _legacy_C_ops.c_split(tensor, 'use_calc_stream', True, 'ring_id',
                                     ring_id, 'rank', rank, 'nranks', nranks,
                                     'use_model_parallel', True)
1374

L
lilong12 已提交
1375 1376 1377
    op_type = 'c_split'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
1378

L
lilong12 已提交
1379 1380 1381
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        '_c_split')
1382

1383 1384 1385 1386 1387 1388 1389 1390 1391 1392
    helper.append_op(type=op_type,
                     inputs={'X': tensor},
                     outputs={'Out': out},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': True,
                         'rank': rank,
                         'nranks': nranks,
                         'use_model_parallel': True,
                     })
L
lilong12 已提交
1393 1394 1395
    return out


1396 1397 1398 1399 1400
def _mp_allreduce(tensor,
                  op=ReduceOp.SUM,
                  group=None,
                  use_calc_stream=True,
                  use_model_parallel=True):
1401
    """[it is same as allreduce above, but it supports model parallel. And it support inplace startegy]
1402 1403 1404 1405
    """
    if group is not None and not group.is_member():
        return

1406
    if in_dygraph_mode():
1407
        group = _get_default_group() if group is None else group
1408 1409
        assert op == ReduceOp.SUM, "Unknown parameter: {}.".format(op)

1410
        from paddle.autograd import PyLayer
1411

1412
        class mp_allreduce_eager(PyLayer):
1413

1414
            @staticmethod
1415
            def forward(ctx, tensor, group, use_calc_stream,
1416
                        use_model_parallel):
1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427
                ctx.ring_id = group.id

                if use_calc_stream:
                    op_type = _get_reduce_op(op, "_mp_allreduce")
                    group.process_group.allreduce_on_calc_stream(
                        tensor, op_type)
                    return tensor
                else:
                    return _legacy_C_ops.c_allreduce_sum_(
                        tensor, 'use_calc_stream', use_calc_stream, 'ring_id',
                        ring_id, "use_model_parallel", use_model_parallel)
1428 1429 1430

            @staticmethod
            def backward(ctx, dy):
1431 1432 1433
                return _legacy_C_ops.c_identity(dy, 'use_calc_stream', True,
                                                'ring_id', ctx.ring_id,
                                                'use_model_parallel', True)
1434

1435
        return mp_allreduce_eager.apply(tensor, group, use_calc_stream,
1436 1437
                                        use_model_parallel)

1438 1439
    ring_id = 0 if group is None else group.id
    if _in_legacy_dygraph():
1440
        if op == ReduceOp.SUM:
1441 1442 1443 1444
            return _legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
                                                  use_calc_stream, 'ring_id',
                                                  ring_id, "use_model_parallel",
                                                  use_model_parallel)
1445 1446
        else:
            raise ValueError("Unknown parameter: {}.".format(op))
1447 1448 1449 1450 1451 1452 1453 1454 1455

    op_type = 'c_allreduce_sum'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)

    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        op_type)

1456 1457 1458 1459 1460 1461 1462 1463
    helper.append_op(type=op_type,
                     inputs={'X': tensor},
                     outputs={'Out': out},
                     attrs={
                         'ring_id': ring_id,
                         'use_calc_stream': use_calc_stream,
                         'use_model_parallel': use_model_parallel,
                     })
1464
    return out
1465 1466


1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480
def _c_lookup_table(table, index, start_index=0, name=None):
    """
    Lookup table according to index.

    Args:
        table (Tensor): The input Tensor. Its data type
            should be float16, float32, float64.
        index (Tensor): The index to lookup table.
        start_index (int): The initial index for table range.
        name (string): The name of the api

    Returns:
        Tensor.
    """
J
Jiabin Yang 已提交
1481
    if _non_static_mode():
1482 1483
        return _legacy_C_ops.c_embedding(table, index, "start_index",
                                         start_index)
1484

1485 1486 1487 1488 1489
    op_type = 'c_embedding'
    helper = LayerHelper(op_type, **locals())
    dtype = helper.input_dtype(input_param_name='table')
    check_variable_and_dtype(index, 'input', ['int32', 'int64'], op_type)
    tmp = helper.create_variable_for_type_inference(dtype)
1490 1491 1492 1493 1494 1495 1496
    helper.append_op(type='c_embedding',
                     inputs={
                         'Ids': index,
                         'W': table
                     },
                     outputs={'Out': tmp},
                     attrs={"start_index": start_index})
1497 1498
    return tmp

1499

B
Baibaifan 已提交
1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514
class _Linear(layers.Layer):
    """
    Linear
    """

    def __init__(self,
                 in_features,
                 out_features,
                 weight_attr=None,
                 bias_attr=None,
                 name=None):
        super(_Linear, self).__init__()
        self._dtype = self._helper.get_default_dtype()
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
1515 1516 1517 1518 1519 1520 1521 1522
        self.weight = self.create_parameter(shape=[in_features, out_features],
                                            attr=self._weight_attr,
                                            dtype=self._dtype,
                                            is_bias=False)
        self.bias = self.create_parameter(shape=[out_features],
                                          attr=self._bias_attr,
                                          dtype=self._dtype,
                                          is_bias=True)
B
Baibaifan 已提交
1523 1524 1525
        self.name = name

    def forward(self, input):
1526 1527 1528 1529
        out = _linear(x=input,
                      weight=self.weight,
                      bias=self.bias,
                      name=self.name)
B
Baibaifan 已提交
1530 1531 1532 1533 1534 1535 1536 1537
        return out

    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'in_features={}, out_features={}, dtype={}{}'.format(
            self.weight.shape[0], self.weight.shape[1], self._dtype, name_str)


1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557
def _c_softmax_with_cross_entropy(logits,
                                  label,
                                  group=None,
                                  return_softmax=False):
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id
    global_rank = _get_global_env().rank
    rank = global_rank if group is None else group.get_group_rank(global_rank)
    nranks = _get_global_env().world_size if group is None else group.nranks

    input_dims = len(list(logits.shape))
    label_dims = len(list(label.shape))
    if input_dims - 1 != label_dims and input_dims != label_dims:
        raise ValueError(
            'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
             (got nput_dims{}, label_dims{})'.format(input_dims, label_dims))
    if input_dims - 1 == label_dims:
        label = paddle.unsqueeze(label, axis=-1)

J
Jiabin Yang 已提交
1558
    if _non_static_mode():
1559
        softmax, loss = _legacy_C_ops.c_softmax_with_cross_entropy(
1560 1561 1562 1563 1564 1565
            logits, label, 'ring_id', ring_id, 'rank', rank, 'nranks', nranks)
        if not return_softmax:
            return loss
        else:
            return loss, softmax

W
WangXi 已提交
1566 1567 1568 1569 1570 1571 1572 1573
    attrs = {
        'ring_id': ring_id,
        'rank': rank,
        'nranks': nranks,
    }
    helper = LayerHelper('c_softmax_with_cross_entropy', **locals())
    softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
    loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
1574 1575 1576 1577 1578 1579 1580 1581 1582 1583
    helper.append_op(type='c_softmax_with_cross_entropy',
                     inputs={
                         'Logits': logits,
                         'Label': label
                     },
                     outputs={
                         'Softmax': softmax,
                         'Loss': loss
                     },
                     attrs=attrs)
W
WangXi 已提交
1584 1585 1586 1587 1588 1589

    if return_softmax:
        return loss, softmax

    return loss

1590

B
Baibaifan 已提交
1591 1592 1593 1594
def _linear(x, weight, bias=None, name=None):
    """
    Fuction Linear
    """
J
Jiabin Yang 已提交
1595
    if _non_static_mode():
B
Baibaifan 已提交
1596
        pre_bias = _varbase_creator(dtype=x.dtype)
1597 1598
        _legacy_C_ops.matmul(x, weight, pre_bias, 'transpose_X', False,
                             'transpose_Y', False, "alpha", 1)
1599 1600 1601
        return dygraph_utils._append_bias_in_dygraph(pre_bias,
                                                     bias,
                                                     axis=len(x.shape) - 1)
B
Baibaifan 已提交
1602 1603 1604
    else:
        helper = LayerHelper('linear', **locals())
        dtype = x.dtype
B
Baibaifan 已提交
1605 1606
        assert len(
            x.shape) < 4, "X latitude is not supported greater than 3 now."
B
Baibaifan 已提交
1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618

        check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
                                 'linear')
        check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear')

        inputs = {'X': [x], 'Y': [weight]}
        attrs = {
            'transpose_X': False,
            'transpose_Y': False,
            'alpha': 1,
        }
        tmp = helper.create_variable_for_type_inference(dtype)
1619 1620 1621 1622
        helper.append_op(type='matmul_v2',
                         inputs=inputs,
                         outputs={'Out': tmp},
                         attrs=attrs)
B
Baibaifan 已提交
1623 1624
        if bias is not None:
            res = helper.create_variable_for_type_inference(dtype)
1625 1626 1627 1628 1629 1630 1631
            helper.append_op(type='elementwise_add',
                             inputs={
                                 'X': [tmp],
                                 'Y': [bias]
                             },
                             outputs={'Out': [res]},
                             attrs={'axis': len(x.shape) - 1})
B
Baibaifan 已提交
1632 1633 1634 1635 1636
        else:
            res = tmp
        return res


1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649
def _set_var_distributed(var):
    if var is None:
        return

    var.is_distributed = True

    # NOTE: use current_block and find_var_recursive to support while_loop
    startup_block = paddle.static.default_startup_program().current_block()
    main_block = paddle.static.default_main_program().current_block()
    startup_block._find_var_recursive(var.name).is_distributed = True
    main_block._find_var_recursive(var.name).is_distributed = True


L
lilong12 已提交
1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660
def _parallel_linear(x,
                     num_rows,
                     num_cols,
                     axis,
                     param_attr,
                     bias_attr,
                     gather_out,
                     inner_rank,
                     nranks,
                     split_tensor,
                     name,
1661
                     group=None):
1662 1663
    """
    Parallel Linear
1664

1665
    axis the dimension of the parameter of linear layer.
1666
    axis = 0: the row dimension
1667
    axis = 1: the col dimension
1668

1669
    """
1670 1671 1672 1673
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

L
lilong12 已提交
1674 1675
    if axis == 0:
        if split_tensor:
1676
            x = _c_split(x, group=group)
1677
    else:
L
lilong12 已提交
1678 1679
        x = _c_identity(x, group=group)

1680 1681 1682 1683 1684
    linear = paddle.nn.Linear(num_rows,
                              num_cols,
                              weight_attr=param_attr,
                              bias_attr=bias_attr,
                              name=name)
1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696

    # NOTE: npu linear function use matmul_v2 but linear use matmul
    linear_function = _linear if core.is_compiled_with_npu()\
        else paddle.nn.functional.linear
    linear_out = linear_function(
        x,
        linear.weight,
        # NOTE(wangxi): row split, bias need add after allreduce
        None if axis == 0 else linear.bias,
        linear.name)

    _set_var_distributed(linear.weight)
1697 1698 1699 1700
    # set is_distributed for splited bias
    # if a linear layer is splited by row, each rank would hold a complete bias and they should be the same in each rank.
    # if a linear layer is splited by col, the bias would also be split into each rank as its weight
    if axis == 1 and linear._bias_attr != False:
1701
        _set_var_distributed(linear.bias)
L
lilong12 已提交
1702 1703 1704 1705 1706

    if not gather_out: return linear_out

    out_shape = list(linear_out.shape)
    out_shape[0] *= 1 if axis == 0 else nranks
1707
    main_block = paddle.static.default_main_program().current_block()
L
lilong12 已提交
1708 1709 1710 1711 1712 1713 1714 1715 1716
    out = main_block.create_var(
        shape=out_shape,
        dtype=linear_out.dtype,
        type=linear_out.type,
        lod_level=linear_out.lod_level,
        persistable=False,
        is_data=False,
        need_check_feed=linear_out.desc.need_check_feed())
    if axis == 0:
1717 1718 1719 1720 1721 1722 1723 1724
        main_block.append_op(type='c_allreduce_sum',
                             inputs={'X': linear_out},
                             outputs={'Out': out},
                             attrs={
                                 'ring_id': ring_id,
                                 'use_calc_stream': True,
                                 'use_model_parallel': True
                             })
1725 1726
        if linear.bias is not None:
            out = out + linear.bias
L
lilong12 已提交
1727
    else:
1728 1729 1730 1731 1732 1733 1734 1735 1736 1737
        main_block.append_op(type='c_concat',
                             inputs={'X': linear_out},
                             outputs={'Out': out},
                             attrs={
                                 'rank': inner_rank,
                                 'ring_id': ring_id,
                                 'nranks': nranks,
                                 'use_calc_stream': True,
                                 'use_model_parallel': True
                             })
L
lilong12 已提交
1738
    return out
1739 1740


L
lilong12 已提交
1741 1742 1743 1744 1745 1746 1747
def _parallel_embedding(x,
                        per_part_embeddings,
                        origin_size,
                        param_attr,
                        inner_rank,
                        num_partitions,
                        name,
1748
                        group=None):
1749 1750 1751
    """
    Parallel Embedding
    """
1752 1753 1754 1755
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

1756 1757 1758 1759 1760 1761 1762 1763 1764
    helper = LayerHelper("_parallel_embedding", **locals())

    per_part_size = per_part_embeddings
    rank = inner_rank

    vocab_start_index = rank * per_part_size
    dtype = helper.get_default_dtype()
    size = [per_part_size, origin_size[1]]

1765 1766 1767 1768
    weight = helper.create_parameter(attr=param_attr,
                                     shape=size,
                                     dtype=dtype,
                                     is_bias=False)
1769 1770

    if num_partitions == 1:
1771 1772 1773 1774 1775
        return paddle.nn.functional.embedding(x,
                                              weight=weight,
                                              padding_idx=None,
                                              sparse=False,
                                              name=name)
1776

1777 1778
    startup_block = paddle.static.default_startup_program().global_block()
    main_block = paddle.static.default_main_program().global_block()
1779 1780 1781 1782 1783
    startup_block.vars[weight.name].is_distributed = True
    main_block.vars[weight.name].is_distributed = True

    output_parallel = paddle.distributed.collective._c_lookup_table(
        weight, x, start_index=vocab_start_index, name=name)
1784 1785 1786 1787
    out = paddle.distributed.collective._mp_allreduce(output_parallel,
                                                      group=group,
                                                      use_calc_stream=True,
                                                      use_model_parallel=True)
L
lilong12 已提交
1788
    return out
1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811


def split(x,
          size,
          operation,
          axis=0,
          num_partitions=1,
          gather_out=True,
          weight_attr=None,
          bias_attr=None,
          name=None):
    """

    Split the weight of the specified operation into multiple devices
    and do the computation in parallel.

    Now the following three cases are supported.

    Case 1: Parallel Embedding
        The weight of the embedding operation is a NxM matrix with N rows and M columns.
        With parallel embedding, the weight is split into num_partitions partitions, each
        of which is a matrix with (N/num_partitions + 1) rows and M column where the last
        row as the padding idx.
K
kuizhiqing 已提交
1812

1813 1814 1815 1816 1817 1818 1819 1820 1821
        Suppose we split the NxM weight into two partitons on device_0 and device_1
        respectively. Then, one each device, the final weight has (N/2 + 1) rows with the
        index range from 0 to N/2. On device_0, all values in the input within [0, N/2 -1]
        keep unchanged and all other values are changed to N/2 which is the padding index and
        are mapped to all zeros after embedding. In the same way, on device_1, the value V in the
        input within [N/2, N-1] will be changed to (V - N/2), and all other values are changed
        to N/2 and are mapped to all zeros after embedding. Finally, the results on the two
        devices are sum-reduced.

1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836
        The Embedding put on single card is as shown below:

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_single.png
            :width: 800
            :height: 350
            :alt: single_embedding
            :align: center

        Parallel Embedding is shown as below:

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_split.png
            :width: 800
            :alt: split_embedding
            :align: center

1837 1838 1839 1840 1841
    Case 2: Row Parallel Linear
        The weight of the linear operation is a NxM matrix with N rows and M columns.
        With row parallel linear, the weight is split into num_partitions partitions, each
        of which is a matrix with N/num_partitions rows and M column.

1842
        The linear layer put on single card is shown as below, the input variable is represented by X,
1843
        the weight matrix is represented by W and the output vaiable is O. The linear layer on single card is
1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859
        simple matrix multiplication operation, O = X * W.

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_single.png
            :width: 800
            :alt: single_linear
            :align: center

        Row Parallel Linear is shown as below. As the name suggests, Row Parallel Linear splits the weight matrix W into
        [[W_row1], [W_row2]] along the row. And accordingly the input is splitted along the column into [X_col1, X_col2] and multiply their
        respective weight matrices. Finally apply AllReduce on the output from each card to get the final output.

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_row.png
            :width: 800
            :alt: split_row
            :align: center

1860 1861 1862 1863 1864
    Case 3: Column Parallel Linear
        The weight of the linear operation is a NxM matrix with N rows and M columns.
        With column parallel linear, the weight is split into num_paratitions partitions, each
        of which is a matrix with N rows and M/num_partitions column.

1865
        The linear layer put on single card has been illustrated on case 2 and Column Parallel Linear
1866 1867
        is shown as below. The Column Parallel Linear splits the weight matrix W into [W_col1, W_col2] along the column and
        these splitted matrices respectively multiply the input. Finally apply AllGather on the output from each card to get the final output.
1868 1869 1870 1871 1872

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col.png
            :width: 800
            :alt: split_col
            :align: center
1873

1874 1875 1876 1877 1878 1879 1880 1881
    As observed, the column parallel linear and row parallel linear can be combined to skip one ALLGATHER communication
    operator. Furthermore the Attention and MLP can be combined to imporve the performance as shown below.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col_row.png
            :width: 800
            :alt: split_col_row
            :align: center

1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901
    Args:
        x (Tensor): Input tensor. It's data type should be float16, float32, float64, int32 or int64.
        size (list|tuple): A list or tuple with two elements indicating the shape of the weight.
        operation (str): The name of the operation. The supported operations are 'linear' and 'embedding'.
        axis (int, Optional): Indicate along which axis to split the weight. Default: 0.
        num_partitions (int, Optional): How many parts the weight is partitioned. Default: 1.
        gather_out (bool, Optional): Whether to gather the output after computation. By default, the output
            on each partitions will be gathered after computation. Default: True.
        weight_attr (ParamAttr, Optional): The parameter attribute for the learnable
            weights(Parameter) of the specified operation. Default: None.
        bias_attr (ParamAttr, Optional): The parameter attribute for the bias
            of the specified operation. Default: None.
        name (str, Optional): The default value is None. Normally there is no need for user to set this
            property. Default: None. For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor.

    Examples:
        .. code-block:: python
1902

1903
            # required: distributed
1904
            import paddle
1905
            import paddle.distributed.fleet as fleet
1906

1907
            paddle.enable_static()
1908
            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
1909
            fleet.init(is_collective=True)
1910
            data = paddle.randint(0, 8, shape=[10,4])
1911
            emb_out = paddle.distributed.split(
1912 1913 1914 1915
                data,
                (8, 8),
                operation="embedding",
                num_partitions=2)
1916

1917
    """
1918 1919 1920 1921
    assert isinstance(
        size,
        (list, tuple)), ("The type of size for "
                         "paddle.distributed.split must be list or tuple.")
1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933
    assert len(size) == 2, ("Number of elements in size of "
                            "paddle.distributed.split must be two.")
    assert isinstance(operation, str), ("The type of operation for "
                                        "paddle.distributed.split must be str.")
    supported_operations = [
        'linear',
        'embedding',
    ]
    assert operation in supported_operations, (
        "The operation for "
        "paddle.distributed.split must be one of {}.".format(
            supported_operations))
J
Jiabin Yang 已提交
1934
    if _non_static_mode():
L
lilong12 已提交
1935 1936 1937 1938
        raise ValueError(
            "paddle.distributed.split cannot be used in dynamic "
            "graph mode, plese use ParallelEmbedding, ParallelRowLinear, "
            "ParallelColumnLinear instead.")
1939
    else:
1940
        from .fleet import fleet
1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951
        assert fleet._role_maker, ("To use paddle.distributed.split, "
                                   "you must call fleet.init() firstly.")
        rank = fleet.worker_index()
        nranks = fleet.worker_num()

    # rank within a model parallel group
    inner_rank = rank % num_partitions

    if operation == "embedding":
        assert axis == 0, ("We only support to split the weight of embedding "
                           "along the first axis now.")
1952 1953 1954
        assert size[0] % num_partitions == 0, \
            "The length of the vocabulary must be divisible by num_partitions " \
            "but received vocabulary={} num_partitions={}".format(size[0], num_partitions)
1955

1956
        per_part_size = size[0] // num_partitions
1957 1958 1959 1960 1961 1962 1963 1964
        emb_out = _parallel_embedding(x,
                                      per_part_size,
                                      size,
                                      weight_attr,
                                      inner_rank,
                                      num_partitions,
                                      name,
                                      group=None)
B
Baibaifan 已提交
1965
        return emb_out
1966
    else:
L
lilong12 已提交
1967
        should_split = False
1968 1969 1970
        if axis == 0:
            assert size[0] % num_partitions == 0, (
                "Number of rows of the weight for linear ({}) must be"
1971 1972
                " divisible by num_partitions ({})".format(
                    size[0], num_partitions))
1973 1974
            per_part_size = size[0] // num_partitions
            linear_size = (per_part_size, size[1])
L
lilong12 已提交
1975
            if x.shape[-1] == size[0]: should_split = True
1976 1977 1978 1979

        elif axis == 1:
            assert size[1] % num_partitions == 0, (
                "Number of column of the weight for linear ({}) must be"
1980 1981
                " divisible by num_partitions ({})".format(
                    size[1], num_partitions))
1982 1983 1984 1985 1986 1987
            per_part_size = size[1] // num_partitions
            linear_size = (size[0], per_part_size)
        else:
            raise ValueError("The value of axis must be 0 or 1, but the value "
                             "given is {}.".format(axis))

1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999
        linear_out = _parallel_linear(x,
                                      linear_size[0],
                                      linear_size[1],
                                      axis,
                                      weight_attr,
                                      bias_attr,
                                      gather_out,
                                      inner_rank,
                                      num_partitions,
                                      should_split,
                                      name=name,
                                      group=None)
2000
        return linear_out
L
lilong12 已提交
2001 2002


L
lilong12 已提交
2003 2004
def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
    """
2005 2006 2007 2008 2009 2010 2011 2012 2013 2014
    Scatter tensors in in_tensor_list to all participators averagely and gather the result tensors in out_tensor_list.
    As shown below, the in_tensor_list in GPU0 includes 0_0 and 0_1, and GPU1 includes 1_0 and 1_1.
    Through alltoall operator, the 0_0 in GPU0 will be sent to GPU0 and 0_1 to GPU1, 1_0 in GPU1 sent to GPU0 and 1_1 to GPU1.
    Finally the out_tensor_list in GPU0 includes 0_0 and 1_0, and GPU1 includes 0_1 and 1_1.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/alltoall.png
        :width: 800
        :alt: alltoall
        :align: center

L
lilong12 已提交
2015 2016
    Args:
        in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type
2017
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
2018
        out_tensor_list (list): A list of output Tensors. The data type of its elements should be the same as the
L
lilong12 已提交
2019 2020
            data type of the input Tensors.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
2021
        use_calc_stream (bool, optional): Whether to use calculation stream (True) or communication stream. Default: True.
2022

L
lilong12 已提交
2023 2024
    Returns:
        None.
2025

L
lilong12 已提交
2026 2027
    Examples:
        .. code-block:: python
2028

L
lilong12 已提交
2029 2030
            # required: distributed
            import paddle
2031 2032 2033
            import paddle.distributed as dist

            dist.init_parallel_env()
L
lilong12 已提交
2034
            out_tensor_list = []
2035 2036 2037
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
                data2 = paddle.to_tensor([[7, 8, 9], [10, 11, 12]])
L
lilong12 已提交
2038
            else:
2039 2040 2041 2042 2043 2044
                data1 = paddle.to_tensor([[13, 14, 15], [16, 17, 18]])
                data2 = paddle.to_tensor([[19, 20, 21], [22, 23, 24]])
            dist.alltoall([data1, data2], out_tensor_list)
            print(out_tensor_list)
            # [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]] (2 GPUs, out for rank 0)
            # [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]] (2 GPUs, out for rank 1)
L
lilong12 已提交
2045 2046 2047 2048
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
2049
    if in_dygraph_mode():
2050
        group = _get_default_group() if group is None else group
2051 2052
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
2053 2054 2055
    else:
        ring_id = 0 if group is None else group.id

L
lilong12 已提交
2056
    temp = paddle.concat(in_tensor_list, axis=0)
李季 已提交
2057
    nranks = len(in_tensor_list)
L
lilong12 已提交
2058
    if in_dygraph_mode():
2059 2060 2061 2062 2063 2064
        if len(out_tensor_list) == 0:
            tensor_shape = list(in_tensor_list[0].shape)
            tensor_shape[0] *= nranks
            out = paddle.empty(tensor_shape, in_tensor_list[0].dtype)
        else:
            out = paddle.concat(out_tensor_list, axis=0)
2065 2066 2067 2068 2069 2070
        task = group.process_group.alltoall(temp, out)
        task.wait()
        out_tensor_list.clear()
        out_tensor_list.extend(paddle.split(out, nranks, 0))
        return

J
Jiabin Yang 已提交
2071
    if _non_static_mode():
2072 2073
        out = _legacy_C_ops.alltoall(temp, 'use_calc_stream', use_calc_stream,
                                     'ring_id', ring_id)
L
lilong12 已提交
2074
    else:
W
wanghuancoder 已提交
2075 2076 2077 2078 2079
        op_type = 'alltoall'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(
            dtype=in_tensor_list[0].dtype)

L
lilong12 已提交
2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093
        if not isinstance(in_tensor_list, list):
            raise ValueError("The type of 'in_tensor_list' for all_to_all "
                             "should be list.")
        for elem in in_tensor_list:
            check_variable_and_dtype(
                elem, 'in_tensor_list',
                ['float16', 'float32', 'float64', 'int32', 'int64'],
                'all_to_all')
        if not isinstance(out_tensor_list, list):
            raise ValueError("The type of 'out_tensor_list' for all_to_all "
                             "should be list.")
        if len(out_tensor_list) != 0:
            raise ValueError("The 'out_tensor_list' for all_to_all "
                             "must be an empty list.")
2094 2095 2096 2097 2098 2099 2100
        helper.append_op(type=op_type,
                         inputs={'X': [temp]},
                         outputs={'Out': [out]},
                         attrs={
                             'ring_id': ring_id,
                             'use_calc_stream': use_calc_stream,
                         })
L
lilong12 已提交
2101 2102 2103
    out_tensor_list.extend(paddle.split(out, nranks, 0))


2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116
def alltoall_single(in_tensor,
                    out_tensor,
                    in_split_sizes=None,
                    out_split_sizes=None,
                    group=None,
                    use_calc_stream=True):
    """
    Scatter a single input tensor to all participators and gather the received tensors in out_tensor.

    .. note::
        ``alltoall_single`` is only supported in eager mode.

    Args:
2117
        in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
2118
        out_tensor (Tensor): Output Tensor. The data type should be the same as the data type of the input Tensor.
2119
        in_split_sizes (list[int], optional): Split sizes of ``in_tensor`` for dim[0]. If not given, dim[0] of ``in_tensor``
2120
            must be divisible by group size and ``in_tensor`` will be scattered averagely to all participators. Default: None.
2121
        out_split_sizes (list[int], optional): Split sizes of ``out_tensor`` for dim[0]. If not given, dim[0] of ``out_tensor``
2122 2123 2124
            must be divisible by group size and ``out_tensor`` will be gathered averagely from all participators. Default: None.
        group (Group, optional): The group instance return by ``new_group`` or None for global default group. Default: None.
        use_calc_stream (bool, optional): Whether to use calculation stream (True) or communication stream. Default: True.
2125

2126 2127
    Returns:
        None, if ``use_calc_stream`` is set to ``True``; ``Task`` of ``group``, if ``use_calc_stream`` is set to ``False``.
2128

2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139
    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            size = dist.get_world_size()

2140 2141 2142 2143
            # case 1 (2 GPUs)
            data = paddle.arange(2, dtype='int64') + rank * 2
            # data for rank 0: [0, 1]
            # data for rank 1: [2, 3]
2144
            output = paddle.empty([2], dtype='int64')
2145 2146
            dist.alltoall_single(data, output)
            print(output)
2147 2148 2149
            # output for rank 0: [0, 2]
            # output for rank 1: [1, 3]

2150
            # case 2 (2 GPUs)
2151
            in_split_sizes = [i + 1 for i in range(size)]
2152 2153
            # in_split_sizes for rank 0: [1, 2]
            # in_split_sizes for rank 1: [1, 2]
2154
            out_split_sizes = [rank + 1 for i in range(size)]
2155 2156 2157 2158 2159
            # out_split_sizes for rank 0: [1, 1]
            # out_split_sizes for rank 1: [2, 2]
            data = paddle.ones([sum(in_split_sizes), size], dtype='float32') * rank
            # data for rank 0: [[0., 0.], [0., 0.], [0., 0.]]
            # data for rank 1: [[1., 1.], [1., 1.], [1., 1.]]
2160 2161
            output = paddle.empty([(rank + 1) * size, size], dtype='float32')
            group = dist.new_group([0, 1])
2162
            task = dist.alltoall_single(data,
2163 2164 2165 2166 2167 2168
                                        output,
                                        in_split_sizes,
                                        out_split_sizes,
                                        use_calc_stream=False,
                                        group=group)
            task.wait()
2169
            print(output)
2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180
            # output for rank 0: [[0., 0.], [1., 1.]]
            # output for rank 1: [[0., 0.], [0., 0.], [1., 1.], [1., 1.]]

    """
    if group is not None and not group.is_member():
        return

    assert in_dygraph_mode(), "Only suppport alltoall_single in eager mode."
    # _check_single_tensor

    group = _get_default_group() if group is None else group
2181 2182 2183
    backend = _group_map_backend[group]
    assert backend != 'gloo', ("backend gloo is not supported yet")

2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195
    in_split_sizes = [] if in_split_sizes is None else in_split_sizes
    out_split_sizes = [] if out_split_sizes is None else out_split_sizes

    task = group.process_group.alltoall_single(in_tensor, out_tensor,
                                               in_split_sizes, out_split_sizes)
    if use_calc_stream:
        task.wait()
        return
    else:
        return task


S
ShenLiang 已提交
2196 2197 2198 2199
def _get_group_rank(global_rank, group=None):
    return global_rank if group is None else group.get_group_rank(global_rank)


L
lilong12 已提交
2200 2201 2202 2203 2204 2205
def send(tensor, dst=0, group=None, use_calc_stream=True):
    """
    Send a tensor to the receiver.

    Args:
        tensor (Tensor): The Tensor to send. Its data type
2206
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
L
lilong12 已提交
2207
        dst (int): The destination rank id.
L
lilong12 已提交
2208 2209
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
        use_calc_stream (bool, optional): Whether to use calculate stream or communication stream. Default: True.
2210

L
lilong12 已提交
2211 2212 2213 2214 2215
    Returns:
        None.

    Examples:
        .. code-block:: python
2216

L
lilong12 已提交
2217
            # required: distributed
L
lilong12 已提交
2218
            import paddle
2219
            import paddle.distributed as dist
2220

2221 2222
            dist.init_parallel_env()
            if dist.get_rank() == 0:
L
lilong12 已提交
2223
                data = paddle.to_tensor([7, 8, 9])
2224
                dist.send(data, dst=1)
L
lilong12 已提交
2225
            else:
2226 2227 2228 2229
                data = paddle.to_tensor([1, 2, 3])
                dist.recv(data, src=0)
            print(data)
            # [7, 8, 9] (2 GPUs)
L
lilong12 已提交
2230 2231 2232
    """
    if group is not None and not group.is_member():
        return
S
ShenLiang 已提交
2233
    dst = _get_group_rank(dst, group)
L
lilong12 已提交
2234
    if in_dygraph_mode():
2235
        group = _get_default_group() if group is None else group
2236 2237
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
S
ShenLiang 已提交
2238
        task = group.process_group.send(tensor, dst)
2239 2240 2241 2242 2243 2244
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

L
lilong12 已提交
2245 2246
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
2247
    if _non_static_mode():
2248 2249
        return _legacy_C_ops.send_v2(tensor, 'use_calc_stream', use_calc_stream,
                                     'ring_id', ring_id, 'peer', dst)
W
wanghuancoder 已提交
2250
    op_type = 'send_v2'
L
lilong12 已提交
2251 2252 2253 2254 2255
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'send')

    helper = LayerHelper(op_type, **locals())
2256 2257 2258 2259 2260 2261 2262
    helper.append_op(type=op_type,
                     inputs={'X': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'peer': dst,
                         'use_calc_stream': use_calc_stream,
                     })
L
lilong12 已提交
2263 2264 2265 2266 2267 2268 2269 2270


def recv(tensor, src=0, group=None, use_calc_stream=True):
    """
    Receive a tensor to the sender.

    Args:
        tensor (Tensor): The Tensor to receive. Its data type
2271
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
L
lilong12 已提交
2272
        src (int): The source rank id.
L
lilong12 已提交
2273 2274
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
        use_calc_stream (bool, optional): Whether to use calculate stream or communication stream. Default: True.
2275

L
lilong12 已提交
2276 2277 2278 2279 2280
    Returns:
        None.

    Examples:
        .. code-block:: python
2281

L
lilong12 已提交
2282
            # required: distributed
L
lilong12 已提交
2283
            import paddle
2284
            import paddle.distributed as dist
2285

2286 2287
            dist.init_parallel_env()
            if dist.get_rank() == 0:
L
lilong12 已提交
2288
                data = paddle.to_tensor([7, 8, 9])
2289
                dist.send(data, dst=1)
L
lilong12 已提交
2290
            else:
2291 2292 2293 2294
                data = paddle.to_tensor([1, 2, 3])
                dist.recv(data, src=0)
            print(data)
            # [7, 8, 9] (2 GPUs)
L
lilong12 已提交
2295 2296 2297
    """
    if group is not None and not group.is_member():
        return
2298

S
ShenLiang 已提交
2299
    src = _get_group_rank(src, group)
L
lilong12 已提交
2300
    if in_dygraph_mode():
2301
        group = _get_default_group() if group is None else group
2302 2303
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
S
ShenLiang 已提交
2304
        task = group.process_group.recv(tensor, src)
2305 2306 2307 2308 2309 2310
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

L
lilong12 已提交
2311 2312
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
2313
    if _non_static_mode():
2314 2315 2316
        return _legacy_C_ops.recv_v2(tensor, 'use_calc_stream', use_calc_stream,
                                     'ring_id', ring_id, 'peer', src, 'dtype',
                                     tensor.dtype, 'out_shape', tensor.shape)
W
wanghuancoder 已提交
2317
    op_type = 'recv_v2'
L
lilong12 已提交
2318 2319 2320 2321
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'recv')
    helper = LayerHelper(op_type, **locals())
2322 2323 2324 2325 2326 2327 2328 2329 2330
    helper.append_op(type=op_type,
                     outputs={'Out': [tensor]},
                     attrs={
                         'ring_id': ring_id,
                         'peer': src,
                         'out_shape': tensor.shape,
                         'dtype': tensor.dtype,
                         'use_calc_stream': use_calc_stream,
                     })
2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352


def _check_single_tensor(tensor, tensor_name):
    if not isinstance(tensor, (core.eager.Tensor, paddle.Tensor)):
        raise RuntimeError("Invalid function argument. Expected parameter {}"
                           "to be of type paddle.Tensor, but it's {}".format(
                               tensor_name, type(tensor)))


def _check_tensor_list(tensor_list, tensor_name):
    if not isinstance(tensor_list, list) or \
        not all(isinstance(t, (core.eager.Tensor, paddle.Tensor)) for t in tensor_list):
        raise RuntimeError("Invalid function argument. Expected parameter {}"
                           "to be of type paddle.Tensor".format(tensor_name))


def isend(tensor, dst, group=None):
    """
    Sends a tensor asynchronously

    Args:
        tensor (Tensor): The Tensor to send. Its data type
2353
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
2354 2355
        dst (int): The destination rank.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
2356

2357 2358 2359
    Returns:
        A distributed task object.

2360
    Warning:
2361 2362 2363 2364 2365 2366 2367 2368 2369 2370
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
2371
            if dist.get_rank() == 0:
2372
                data = paddle.to_tensor([7, 8, 9])
2373
                task = dist.isend(data, dst=1)
2374 2375
            else:
                data = paddle.to_tensor([1, 2, 3])
2376
                task = dist.irecv(data, src=0)
2377 2378
            task.wait()
            print(data)
2379
            # [7, 8, 9] (2 GPUs)
2380 2381 2382 2383 2384 2385 2386 2387

    """
    _check_single_tensor(tensor, "tensor")
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
2388 2389
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
2390 2391 2392 2393
        group_dst_rank = group.get_group_rank(dst)
        assert group_dst_rank >= 0, ("dst rank out of group, need global rank")
        return group.process_group.send(tensor, group_dst_rank)
    else:
2394
        raise RuntimeError("Only support eager dygraph mode.")
2395 2396 2397 2398 2399 2400 2401 2402


def irecv(tensor, src=None, group=None):
    """
    Receive a tensor to the sender.

    Args:
        tensor (Tensor): The Tensor to receive. Its data type
2403
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
2404 2405 2406 2407
        src (int): The source rank id.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.

    Returns:
2408
        A distributed task object.
2409

2410
    Warning:
2411 2412 2413 2414 2415 2416 2417 2418 2419 2420
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
2421
            if dist.get_rank() == 0:
2422
                data = paddle.to_tensor([7, 8, 9])
2423
                task = dist.isend(data, dst=1)
2424 2425
            else:
                data = paddle.to_tensor([1, 2, 3])
2426
                task = dist.irecv(data, src=0)
2427 2428
            task.wait()
            print(data)
2429
            # [7, 8, 9] (2 GPUs)
2430 2431 2432 2433 2434 2435 2436
    """
    _check_single_tensor(tensor, "tensor")
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
2437 2438
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
2439 2440 2441 2442
        group_src_rank = group.get_group_rank(src)
        assert group_src_rank >= 0, ("src rank out of group, need global rank")
        return group.process_group.recv(tensor, group_src_rank)
    else:
2443
        raise RuntimeError("Only support eager dygraph mode.")
2444 2445 2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2458


class P2POp(object):
    """
    A class that makes point-to-point operations for "batch_isend_irecv".

    This class creates the type of P2P operation, communication buffer, peer rank,
    Group. Instances of this class will be passed to
    ``paddle.distributed.batch_isend_irecv`` for point-to-point communication.

    Args:
        op (callable): A function to send data to or receive data from a peer process.
            The type of ``op`` is either ``paddle.distributed.isend`` or ``paddle.distributed.irecv``.
        tensor (Tensor): Tensor to send or receive.
        peer (int): The destination or source rank.
2459
        group (Group, optional): The group instance return by new_group or None for global
2460 2461 2462 2463 2464 2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507
            default group. Default: None.

    """

    def __init__(self, op, tensor, peer, group=None):
        if op not in [isend, irecv]:
            raise RuntimeError("Invalid ``op`` function. Expected ``op`` "
                               "to be of type ``paddle.distributed.isend`` or "
                               "``paddle.distributed.irecv``.")
        _check_single_tensor(tensor, "tensor")

        self.op = op
        self.tensor = tensor
        self.peer = peer
        self.group = _get_default_group() if group is None else group


@contextlib.contextmanager
def _with_batch_p2p_guard(backend):
    if backend == "nccl":
        core.ProcessGroupNCCL.group_start()
    try:
        yield
    finally:
        if backend == "nccl":
            core.ProcessGroupNCCL.group_end()


def _check_p2p_op_list(p2p_op_list):
    """
    Helper to check that the ``p2p_op_list`` is a list of P2POp instances and
    all ops use the same backend.
    """
    if not isinstance(p2p_op_list, list) or not all(
            isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list):
        raise RuntimeError("Invalid ``p2p_op_list``. Each op is expected to "
                           "to be of type ``paddle.distributed.P2POp``.")

    backend = _group_map_backend[p2p_op_list[0].group]
    if not all(backend == _group_map_backend[p2p_op.group]
               for p2p_op in p2p_op_list):
        raise RuntimeError("All groups need to use the same backend.")


def batch_isend_irecv(p2p_op_list):
    """
    Send or Receive a batch of tensors asynchronously and return a list of requests.

2508
    Process each of the point-to-point operations in ``p2p_op_list`` and return the
2509 2510 2511 2512 2513 2514 2515 2516 2517 2518
    corresponding tasks. NCCL are currently supported.

    Args:
        p2p_op_list: A list of point-to-point operations(type of each operator is
            ``paddle.distributed.P2POp``). The order of the isend/irecv in the list
            matters and it needs to match with corresponding isend/irecv on the
            remote end.

    Returns:
        A list of distributed tasks returned by calling the corresponding
2519
        op in the op_list.
2520

2521
    Warning:
2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532 2533 2534 2535 2536 2537 2538 2539 2540 2541 2542 2543 2544 2545 2546 2547 2548
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed

            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
            world_size = dist.get_world_size()

            send_t = paddle.arange(2) + rank
            # paddle.tensor([0, 1])  # Rank-0
            # paddle.tensor([1, 2])  # Rank-1

            recv_t = paddle.empty(shape=[2], dtype=send_t.dtype)

            send_op = dist.P2POp(dist.isend, send_t, (rank + 1) % world_size)
            recv_op = dist.P2POp(dist.irecv, recv_t, (rank - 1 + world_size) % world_size)

            tasks = dist.batch_isend_irecv([send_op, recv_op])

            for task in tasks:
                task.wait()
2549

2550 2551 2552 2553 2554 2555 2556 2557 2558 2559 2560 2561 2562 2563 2564 2565 2566 2567 2568 2569 2570 2571 2572 2573 2574 2575 2576 2577 2578 2579 2580 2581 2582 2583 2584 2585
            print(recv_t)
            # paddle.tensor([1, 2])     # Rank-0
            # paddle.tensor([0, 1])     # Rank-1
    """
    _check_p2p_op_list(p2p_op_list)
    group = p2p_op_list[0].group
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = _get_default_group() if group is None else group
        backend = _group_map_backend[group]
        tasks = []
        with _with_batch_p2p_guard(backend):
            for p2p_op in p2p_op_list:
                op = p2p_op.op
                tensor = p2p_op.tensor
                peer = p2p_op.peer
                comm_group = p2p_op.group
                task = op(tensor, peer, comm_group)
                if task is not None:
                    tasks.append(task)
        return tasks
    else:
        raise RuntimeError("Don't support static graph mode currently.")


def reduce_scatter(tensor,
                   tensor_list,
                   op=ReduceOp.SUM,
                   group=None,
                   use_calc_stream=True):
    """
    Reduces, then scatters a list of tensors to all processes in a group

    Args:
2586 2587 2588
        tensor (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
        tensor_list (list[Tensor]): List of tensors to reduce and scatter. Every element in the list must be a Tensor whose data type
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
2589
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
2590
        group (Group, optional): The group instance return by new_group or None for global
2591 2592 2593 2594 2595 2596
            default group. Default: None.
        use_calc_stream (bool, optional): Whether this op should be an async op.

    Returns:
        Async task handle, if use_calc_stream is set to False.
        None, if use_calc_stream or if not part of the group.
2597 2598

    Warning:
2599 2600 2601 2602 2603 2604 2605 2606 2607 2608 2609
        This API only supports the dygraph mode.


    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
2610 2611 2612
            if dist.get_rank() == 0:
                data1 = paddle.to_tensor([0, 1])
                data2 = paddle.to_tensor([2, 3])
2613
            else:
2614 2615 2616 2617 2618 2619
                data1 = paddle.to_tensor([4, 5])
                data2 = paddle.to_tensor([6, 7])
            dist.reduce_scatter(data1, [data1, data2])
            print(data1)
            # [4, 6] (2 GPUs, out for rank 0)
            # [8, 10] (2 GPUs, out for rank 1)
2620 2621 2622 2623 2624 2625 2626 2627 2628 2629 2630

    """
    _check_single_tensor(tensor, "tensor")
    _check_tensor_list(tensor_list, "tensor_list")

    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        op_type = _get_reduce_op(op, "reduce_scatter")
        group = _get_default_group() if group is None else group
2631 2632
        backend = _group_map_backend[group]
        assert backend != 'gloo', ("backend gloo is not supported yet")
2633 2634 2635 2636 2637 2638 2639 2640 2641 2642 2643 2644 2645 2646 2647 2648 2649 2650 2651 2652 2653

        temp = paddle.concat(tensor_list, axis=0)
        task = group.process_group._reduce_scatter_base(tensor, temp, op_type)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task
    else:
        raise RuntimeError("Don't support static graph mode currently.")


def _reduce_scatter_base(output,
                         input,
                         op=ReduceOp.SUM,
                         group=None,
                         use_calc_stream=True):
    """
    Reduces, then scatters a flattened tensor to all processes in a group.

    Args:
2654
        output (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
2655
        input (Tensor): Input tensor that is of size output tensor size times world size. Its data type
2656
            should be float16, float32, float64, int32, int64, int8, uint8 or bool.
2657
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
2658 2659 2660 2661 2662 2663 2664 2665 2666 2667 2668 2669 2670 2671 2672 2673 2674
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.
        use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
    Returns:
        Async task handle, if use_calc_stream is set to False.
        None, if use_calc_stream or if not part of the group.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            rank = dist.get_rank()
2675 2676 2677 2678 2679
            data = paddle.arange(4) + rank
            # [0, 1, 2, 3] (2 GPUs, for rank 0)
            # [1, 2, 3, 4] (2 GPUs, for rank 1)
            output = paddle.empty(shape=[2], dtype=data.dtype)
            dist.collective._reduce_scatter_base(output, data)
2680
            print(output)
2681 2682
            # [1, 3] (2 GPUs, out for rank 0)
            # [5, 7] (2 GPUs, out for rank 1)
2683 2684 2685 2686 2687 2688 2689 2690 2691 2692 2693 2694 2695 2696 2697 2698 2699 2700 2701

    """
    _check_single_tensor(output, "output")
    _check_single_tensor(input, "input")

    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        op_type = _get_reduce_op(op, "_reduce_scatter_base")
        group = _get_default_group() if group is None else group
        task = group.process_group._reduce_scatter_base(output, input, op_type)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task
    else:
        raise RuntimeError("Don't support static graph mode currently.")