collective.py 69.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import os
17
from datetime import timedelta
18
from ..fluid.layer_helper import LayerHelper
19
from ..fluid.framework import Variable
20
from ..fluid.framework import in_dygraph_mode
21
from ..fluid.framework import OpProtoHolder
J
Jiabin Yang 已提交
22
from ..fluid.framework import _non_static_mode
23
from ..fluid.framework import _in_legacy_dygraph
24
from ..fluid.framework import convert_np_dtype_to_dtype_
J
Jiangxinz 已提交
25
from ..fluid.framework import _varbase_creator
26 27 28 29
from ..fluid.data_feeder import convert_dtype
from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.data_feeder import check_type
from ..fluid.data_feeder import check_dtype
30 31
from ..fluid.layers.tensor import fill_constant
from ..fluid.layers import utils
B
Baibaifan 已提交
32
from ..fluid.dygraph import layers
33 34 35 36
from ..fluid.dygraph.parallel import prepare_context
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
W
wanghuancoder 已提交
37
from paddle import _C_ops
J
Jiangxinz 已提交
38
import paddle.fluid.dygraph_utils as dygraph_utils
39

40
__all__ = []
41 42 43


class ReduceOp:
L
lilong12 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
    """
    Specify the type of operation used for element-wise reductions.
    It should be one of the following values:

        ReduceOp.SUM

        ReduceOp.MAX

        ReduceOp.MIN

        ReduceOp.PROD

    Examples:
        .. code-block:: python

            import numpy as np
            import paddle
            from paddle.distributed import ReduceOp
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data = np.array([[4, 5, 6], [4, 5, 6]])
            else:
                np_data = np.array([[1, 2, 3], [1, 2, 3]])
            data = paddle.to_tensor(np_data)
            paddle.distributed.all_reduce(data, op=ReduceOp.SUM)
            out = data.numpy()
            # [[5, 7, 9], [5, 7, 9]]
    """
75 76 77 78
    SUM = 0
    MAX = 1
    MIN = 2
    PROD = 3
79
    AVG = 4
80 81


K
kuizhiqing 已提交
82 83 84 85
class Group():
    """
    The abstract representation of group.
    """
86

87
    def __init__(self, rank, rank_num, id=0, ranks=[], pg=None, name=None):
88 89
        self.rank = rank
        self.nranks = rank_num
K
kuizhiqing 已提交
90 91
        self.id = id
        self.ranks = ranks
92 93
        self.pg = pg
        self.name = name
K
kuizhiqing 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107

    def is_member(self):
        if self.rank < 0:
            return False
        if self.nranks < 2:
            return False
        return True

    def get_group_rank(self, rank):
        if self.is_member() and rank in self.ranks:
            return self.ranks.index(rank)
        else:
            return -1

108 109 110 111
    @property
    def process_group(self):
        return self.pg

112 113 114 115
    def __repr__(self):
        debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format(
            self.rank, self.nranks, self.id)
        debug_str += ", ".join(map(str, self.ranks))
116 117
        debug_str += "; name: "
        debug_str += self.name if self.name else "None"
118 119
        return debug_str

K
kuizhiqing 已提交
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134

_global_env = None


def _get_global_env():
    global _global_env
    if not _global_env:
        _global_env = paddle.distributed.ParallelEnv()
    return _global_env


# group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group]
_group_map = {}

135 136 137 138 139 140 141
# group map by name : the map of all groups from their names
# Dict[name, Group]
_group_map_by_name = {}

# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"

142
_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter']
143 144 145
_default_store = None  # the default tcp store
_default_backend = None

K
kuizhiqing 已提交
146

L
lilong12 已提交
147 148 149 150 151 152 153 154 155 156
def _set_default_backend(backend):
    global _default_backend
    _default_backend = backend


def _set_default_store(store):
    global _default_store
    _default_store = store


K
kuizhiqing 已提交
157 158 159 160
def _get_group_map():
    global _group_map
    if not _group_map:
        genv = _get_global_env()
161 162
        _group_map[0] = Group(
            genv.rank, genv.world_size, ranks=list(range(genv.world_size)))
K
kuizhiqing 已提交
163 164 165 166 167 168 169
    return _group_map


def _get_global_group():
    return _get_group_map()[0]


170 171 172 173 174 175
def _get_group_map_by_name():
    global _group_map_by_name
    return _group_map_by_name


def _get_default_group():
L
lilong12 已提交
176
    global _group_map_by_name
177 178 179 180 181 182
    assert _default_group_name in _group_map_by_name, (
        "Call paddle.distributed.init_parallel_env first "
        "to initialize the distributed environment.")
    return _get_group_map_by_name()[_default_group_name]


L
lilong12 已提交
183 184 185 186 187 188 189 190 191 192 193 194
def _set_group_map(gid, group):
    global _group_map
    assert gid not in _group_map
    _group_map[gid] = group


def _set_group_map_by_name(name, group):
    global _group_map_by_name
    assert name not in _group_map_by_name
    _group_map_by_name[name] = group


K
kuizhiqing 已提交
195 196 197 198 199 200 201 202 203 204
def _new_ring_id():
    return len(_get_group_map()) + max(_get_global_env().nrings, 9)


def get_group(id=0):
    """

    Get group instance by group id.

    Args:
K
kuizhiqing 已提交
205
        id (int): the group id. Default value is 0.
K
kuizhiqing 已提交
206 207 208 209 210 211 212 213 214 215 216 217 218 219

    Returns:
        Group: the group instance.

    Examples:
        .. code-block:: python

            ...
            gid = paddle.distributed.new_group([2,4,6])
            paddle.distributed.get_group(gid.id)

    """

    gm = _get_group_map()
J
Jiangxinz 已提交
220
    return gm[id] if id in gm else None
K
kuizhiqing 已提交
221 222


223 224 225 226 227 228
def _new_process_group_impl(backend,
                            store,
                            rank,
                            world_size,
                            group_name,
                            pg_options,
L
lilong12 已提交
229 230 231
                            group_id=0,
                            src_rank=None,
                            dst_rank=None):
232
    pg = None
233
    genv = _get_global_env()
L
lilong12 已提交
234 235 236 237
    if backend != 'heter':
        assert src_rank is None and dst_rank is None, (
            "src_rank and dst_rank "
            "can only be set for heter backend.")
L
lilong12 已提交
238
    assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
239
    if backend == "gloo":
240 241
        place = core.CPUPlace()
        pg = core.ProcessGroupGloo(store, rank, world_size, place, group_id)
242
    elif backend == "nccl":
243 244
        place = core.CUDAPlace(genv.device_id)
        pg = core.ProcessGroupNCCL(store, rank, world_size, place, group_id)
245
    elif backend == "hccl":
246 247
        place = core.NPUPlace(genv.device_id)
        pg = core.ProcessGroupHCCL(store, rank, world_size, place, group_id)
248
    elif backend == "heter":
249 250 251 252 253
        place = None
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(genv.device_id)
        elif core.is_compiled_with_npu():
            place = core.NPUPlace(genv.device_id)
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
        cluster_id = int(os.getenv("CLUSTER_ID", "-1"))
        assert cluster_id >= 0, "please set the CLUSTER_ID variable."
        cluster_size = os.getenv("CLUSTER_SIZE", None)
        assert cluster_size, "please set the CLUSTER_SIZE variable."
        cluster_size = cluster_size.split(",")
        cluster_size = [int(s) for s in cluster_size]
        switch_ep = os.getenv("CLUSTER_SWITCH", None)
        assert switch_ep, "please set the CLUSTER_SWITCH variable."
        cluster_size_cumsum = np.cumsum(cluster_size)
        cluster_offset = 0 if cluster_id == 0 else cluster_size_cumsum[
            cluster_id - 1]
        global_rank = cluster_offset + rank
        global_world_size = cluster_size_cumsum[-1]
        pg = core.ProcessGroupHeter(
            store,
            rank=global_rank,
            world_size=global_world_size,
271
            place=place,
272
            gid=group_id,
273 274 275 276 277
            local_rank=rank,
            local_size=world_size,
            gloo_rank=cluster_id,
            gloo_size=len(cluster_size),
            with_switch=True,
L
lilong12 已提交
278 279 280
            switch_endpoint=switch_ep,
            src_rank=src_rank,
            dst_rank=dst_rank)
281 282 283 284

    return pg


S
ShenLiang 已提交
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308
def barrier(group=None):
    """

    Barrier among all participators in the group.

    Args:
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            paddle.distributed.barrier()
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
309
    if in_dygraph_mode():
310 311 312 313 314
        group = _get_default_group() if group is None else group
        task = group.process_group.barrier()
        task.wait()
        return

S
ShenLiang 已提交
315 316 317
    ring_id = 0 if group is None else group.id

    temp = fill_constant([1], dtype="int32", value="1")
J
Jiabin Yang 已提交
318
    if _non_static_mode():
W
wanghuancoder 已提交
319
        return _C_ops.barrier(temp, temp, 'ring_id', ring_id)
W
wanghuancoder 已提交
320 321 322

    op_type = 'barrier'

S
ShenLiang 已提交
323 324 325 326 327 328 329 330 331 332
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'group' for barrier must be int.")
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [temp]},
        outputs={'Out': [temp]},
        attrs={'ring_id': ring_id})


L
lilong12 已提交
333 334 335 336 337 338 339
# _custom_gid provides a way for users to
# set the group id, which is usually useful
# to be compatible with the static mode.
_custom_gid = None


def _set_custom_gid(gid):
340
    global _custom_gid
L
lilong12 已提交
341 342 343
    _custom_gid = gid


K
kuizhiqing 已提交
344 345 346
def new_group(ranks=None, backend=None):
    """

K
kuizhiqing 已提交
347
    Creates a new distributed communication group.
K
kuizhiqing 已提交
348 349

    Args:
K
kuizhiqing 已提交
350
        ranks (list): The global ranks of group members.
K
kuizhiqing 已提交
351 352 353
        backend (str): The backend used to create group, only nccl is supported now.

    Returns:
K
kuizhiqing 已提交
354
        Group: The group instance.
K
kuizhiqing 已提交
355 356 357 358 359 360 361

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
362 363 364
            tindata = paddle.randn(shape=[2, 3])
            gp = paddle.distributed.new_group([2,4,6])
            paddle.distributed.all_reduce(tindata, group=gp, use_calc_stream=False)
K
kuizhiqing 已提交
365 366

    """
367
    global _custom_gid
368
    global _group_map
L
lilong12 已提交
369
    if in_dygraph_mode():
370
        global _default_group_name
L
lilong12 已提交
371
        gid = _custom_gid if _custom_gid else _new_ring_id()
372
        group_name = _default_group_name + str(gid)
L
lilong12 已提交
373
        if backend != 'heter' and (ranks is None or len(ranks) > 1):
374 375 376 377 378 379 380 381 382
            global_group = _get_default_group()
            global_rank = global_group.rank
            global_ranks = global_group.ranks
            backend = _default_backend if backend is None else backend
            if ranks is None:
                ranks = global_ranks
            assert len(ranks) <= len(global_ranks), (
                "Size of new group must be less than or "
                "equal to that of the default global group.")
383 384
        size = len(ranks)
        ranks = sorted(ranks)
L
lilong12 已提交
385 386 387 388
        if backend == 'heter' or (size > 1 and global_rank in ranks):
            rank = 0 if backend == 'heter' else ranks.index(global_rank)
            src_rank = ranks[0] if backend == 'heter' else None
            dst_rank = ranks[1] if backend == 'heter' else None
389 390 391 392 393 394 395
            pg = _new_process_group_impl(
                backend,
                _default_store,
                rank,
                size,
                group_name,
                pg_options=None,
L
lilong12 已提交
396 397 398
                group_id=gid,
                src_rank=src_rank,
                dst_rank=dst_rank)
399 400 401 402 403 404 405
        else:
            rank = -1
            pg = None
        group = Group(rank, size, id=gid, ranks=ranks, pg=pg, name=group_name)
        _group_map_by_name[group_name] = group
        _group_map[gid] = group

406 407 408 409 410
        # TODO(shenliang03): This is a temporary solution to solve the problem of 
        # hang caused by tcp
        tmp = paddle.to_tensor([1], dtype="int32")
        paddle.distributed.all_reduce(tmp, group=group, use_calc_stream=True)
        paddle.distributed.wait(tmp)
411
        return group
K
kuizhiqing 已提交
412 413 414 415 416 417 418 419 420 421 422 423 424 425

    if not backend:
        backend = 'nccl'
    assert backend == 'nccl', ("backend other than nccl is not supported yet")

    genv = _get_global_env()
    global_rank = genv.rank

    ring_id = _new_ring_id()

    if global_rank not in ranks:
        gp = Group(-1, -1, ring_id, ranks)
        _group_map[ring_id] = gp
    else:
426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445
        ranks = sorted(ranks)
        group_rank = ranks.index(global_rank)
        group_size = len(ranks)
        gp = Group(group_rank, group_size, ring_id, ranks)
        _group_map[ring_id] = gp

        if group_size >= 2:
            strategy = core.ParallelStrategy()
            strategy.nranks = group_size
            strategy.local_rank = group_rank
            strategy.trainer_endpoints = [
                genv.trainer_endpoints[i] for i in ranks
            ]
            strategy.current_endpoint = genv.current_endpoint
            strategy.nrings = 1

            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(genv.device_id)
                core.NCCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
446 447 448 449
            elif core.is_compiled_with_npu():
                place = core.NPUPlace(genv.device_id)
                core.HCCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
450 451 452 453
            elif core.is_compiled_with_mlu():
                place = core.MLUPlace(genv.device_id)
                core.CNCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
454 455 456 457 458 459 460
            else:
                assert False, ("no cuda device found")
        else:
            return gp

    # TODO(shenliang03): This is a temporary solution to solve the problem of 
    # hang caused by cross-creation of new_group
461
    tmp = paddle.to_tensor(
J
Jiabin Yang 已提交
462
        [1], dtype="int32") if _non_static_mode() else fill_constant(
463
            [0], dtype="int32", value="1")
464 465
    paddle.distributed.all_reduce(tmp, use_calc_stream=True)
    paddle.distributed.wait(tmp)
K
kuizhiqing 已提交
466 467
    return gp

468

K
kuizhiqing 已提交
469 470 471 472 473 474 475 476
def wait(tensor, group=None, use_calc_stream=True):
    """

    wait to sync stream for group.

    Args:
        tensor (Tensor): The Tensor used before sync.
        group (Group): The Group instance to perform sync.
K
kuizhiqing 已提交
477 478
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
K
kuizhiqing 已提交
479 480 481 482 483 484 485 486 487 488

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
489
            tindata = paddle.randn(shape=[2, 3])
K
kuizhiqing 已提交
490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507
            paddle.distributed.all_reduce(tindata, use_calc_stream=True)
            paddle.distributed.wait(tindata)

    """

    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id

    if use_calc_stream:
        _sync_calc_stream(tensor)
    else:
        _sync_comm_stream(tensor, ring_id)


def _sync_calc_stream(tensor):

J
Jiabin Yang 已提交
508
    if _non_static_mode():
W
wanghuancoder 已提交
509
        return _C_ops.c_sync_calc_stream(tensor, tensor)
K
kuizhiqing 已提交
510 511 512 513 514 515 516 517

    op_type = 'c_sync_calc_stream'

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]}, )
518

519

K
kuizhiqing 已提交
520
def _sync_comm_stream(tensor, ring_id=0):
521

J
Jiabin Yang 已提交
522
    if _non_static_mode():
W
wanghuancoder 已提交
523
        return _C_ops.c_sync_comm_stream([tensor], [tensor], 'ring_id', ring_id)
524

K
kuizhiqing 已提交
525
    op_type = 'c_sync_comm_stream'
526

K
kuizhiqing 已提交
527 528 529 530 531 532 533 534 535
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
        attrs={'ring_id': ring_id}, )


def broadcast(tensor, src, group=None, use_calc_stream=True):
536 537 538
    """

    Broadcast a tensor from the source to all others.
539 540 541 542 543 544 545
    As shown below, 4 GPUs each start 4 processes and GPU0 owns data 0. Through broadcast operator,
    the data 0 will be sent to all GPUs from GPU0.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/broadcast.png
        :width: 800
        :alt: broadcast
        :align: center
546 547 548 549 550

    Args:
        tensor (Tensor): The Tensor to send if current rank is the source, or the tensor to receive otherwise. Its data type
            should be float16, float32, float64, int32 or int64.
        src (int): The source rank.
K
kuizhiqing 已提交
551
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
552 553
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
554 555 556 557 558 559 560

    Returns:
        None.

    Examples:
        .. code-block:: python

561
            # required: distributed
562 563 564 565 566 567 568 569 570 571 572 573 574 575
            import numpy as np
            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data = np.array([[4, 5, 6], [4, 5, 6]])
            else:
                np_data = np.array([[1, 2, 3], [1, 2, 3]])
            data = paddle.to_tensor(np_data)
            paddle.distributed.broadcast(data, 1)
            out = data.numpy()
            # [[1, 2, 3], [1, 2, 3]]
576
    """
K
kuizhiqing 已提交
577 578 579 580 581 582 583

    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

L
lilong12 已提交
584
    if in_dygraph_mode():
585 586 587 588 589 590 591 592 593 594 595
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
        assert gsrc >= 0, ("src rank out of group, need global rank")
        task = group.process_group.broadcast(tensor, gsrc)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

    ring_id = ring_id = 0 if group is None else group.id
K
kuizhiqing 已提交
596
    gsrc = src if group is None else group.get_group_rank(src)
K
kuizhiqing 已提交
597
    assert gsrc >= 0, ("src rank out of group, need global rank")
K
kuizhiqing 已提交
598

J
Jiabin Yang 已提交
599
    if _non_static_mode():
W
wanghuancoder 已提交
600 601 602
        return _C_ops.c_broadcast(tensor, tensor, 'root', gsrc,
                                  'use_calc_stream', use_calc_stream, 'ring_id',
                                  ring_id)
603 604 605 606 607 608 609 610 611 612 613 614

    op_type = 'c_broadcast'
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'broadcast')

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
        attrs={
K
kuizhiqing 已提交
615 616 617
            'root': gsrc,
            'use_calc_stream': use_calc_stream,
            'ring_id': ring_id,
618 619 620
        })


K
kuizhiqing 已提交
621
def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
622 623 624
    """

    Reduce a tensor over all ranks so that all get the result.
625 626 627 628 629 630 631 632
    As shown below, 4 GPUs each start 4 processes and the data on each GPU is represnted
    by the GPU number. The reduce operator is sum. Through all_reduce operator, 
    each GPU will have the sum of the data from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allreduce.png
        :width: 800
        :alt: all_reduce
        :align: center
633 634 635 636

    Args:
        tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
K
kuizhiqing 已提交
637
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
K
kuizhiqing 已提交
638
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
639 640
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
641 642 643 644 645 646 647

    Returns:
        None.

    Examples:
        .. code-block:: python

648
            # required: distributed
649 650 651 652 653 654 655 656 657 658 659 660 661 662 663
            import numpy as np
            import paddle
            from paddle.distributed import ReduceOp
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data = np.array([[4, 5, 6], [4, 5, 6]])
            else:
                np_data = np.array([[1, 2, 3], [1, 2, 3]])
            data = paddle.to_tensor(np_data)
            paddle.distributed.all_reduce(data)
            out = data.numpy()
            # [[5, 7, 9], [5, 7, 9]]
664
    """
K
kuizhiqing 已提交
665 666 667
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
668
    if in_dygraph_mode():
669 670 671 672 673 674
        if op == ReduceOp.SUM:
            op_type = core.ReduceOp.SUM
        elif op == ReduceOp.MAX:
            op_type = core.ReduceOp.MAX
        elif op == ReduceOp.MIN:
            op_type = core.ReduceOp.MIN
675 676
        elif op == ReduceOp.PROD:
            op_type = core.ReduceOp.PRODUCT
677 678 679 680 681 682 683 684 685 686
        else:
            raise ValueError("Unknown reduce_op type for allreduce.")
        group = _get_default_group() if group is None else group
        task = group.process_group.allreduce(tensor, op_type)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

K
kuizhiqing 已提交
687
    ring_id = 0 if group is None else group.id
J
Jiabin Yang 已提交
688
    if _non_static_mode():
689
        if op == ReduceOp.SUM:
W
wanghuancoder 已提交
690 691
            return _C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
                                           use_calc_stream, 'ring_id', ring_id)
692
        elif op == ReduceOp.MAX:
W
wanghuancoder 已提交
693 694
            return _C_ops.c_allreduce_max_(tensor, 'use_calc_stream',
                                           use_calc_stream, 'ring_id', ring_id)
695
        elif op == ReduceOp.MIN:
W
wanghuancoder 已提交
696 697
            return _C_ops.c_allreduce_min_(tensor, 'use_calc_stream',
                                           use_calc_stream, 'ring_id', ring_id)
698
        elif op == ReduceOp.PROD:
W
wanghuancoder 已提交
699 700
            return _C_ops.c_allreduce_prod_(tensor, 'use_calc_stream',
                                            use_calc_stream, 'ring_id', ring_id)
701 702 703 704 705 706 707 708 709 710 711 712 713 714
        else:
            raise ValueError("Unknown parameter: {}.".format(op))

    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'all_reduce')
    if op == ReduceOp.SUM:
        op_type = 'c_allreduce_sum'
    elif op == ReduceOp.MAX:
        op_type = 'c_allreduce_max'
    elif op == ReduceOp.MIN:
        op_type = 'c_allreduce_min'
    elif op == ReduceOp.PROD:
        op_type = 'c_allreduce_prod'
K
kuizhiqing 已提交
715 716
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'ring_id' for all_reduce should be int.")
717 718 719 720 721
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
K
kuizhiqing 已提交
722 723
        attrs={'ring_id': ring_id,
               'use_calc_stream': use_calc_stream})
724 725


K
kuizhiqing 已提交
726
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
727 728
    """

729 730 731 732 733 734 735 736
    Reduce a tensor to the destination from all others. As shown below, 4 GPUs each start 4 processes and the data on each GPU is respresnted
    by the GPU number. The destination of the reduce operator is GPU0 and the process is sum. Through reduce operator,
    the GPU0 will owns the sum of all data from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/reduce.png
        :width: 800
        :alt: reduce
        :align: center
737 738 739 740 741

    Args:
        tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
            should be float16, float32, float64, int32 or int64.
        dst (int): The destination rank id.
K
kuizhiqing 已提交
742
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
K
kuizhiqing 已提交
743
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
744 745
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
746 747 748 749 750 751 752

    Returns:
        None.

    Examples:
        .. code-block:: python

753
            # required: distributed
754 755 756 757 758 759 760 761 762 763 764 765 766 767
            import numpy as np
            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data = np.array([[4, 5, 6], [4, 5, 6]])
            else:
                np_data = np.array([[1, 2, 3], [1, 2, 3]])
            data = paddle.to_tensor(np_data)
            paddle.distributed.reduce(data, 0)
            out = data.numpy()
            # [[5, 7, 9], [5, 7, 9]]
768
    """
K
kuizhiqing 已提交
769 770 771
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
772
    if in_dygraph_mode():
773 774 775 776 777 778
        if op == ReduceOp.SUM:
            op_type = core.ReduceOp.SUM
        elif op == ReduceOp.MAX:
            op_type = core.ReduceOp.MAX
        elif op == ReduceOp.MIN:
            op_type = core.ReduceOp.MIN
779 780
        elif op == ReduceOp.PROD:
            op_type = core.ReduceOp.PRODUCT
781 782 783 784 785 786 787 788 789 790 791
        else:
            raise ValueError("Unknown reduce_op type for reduce.")
        group = _get_default_group() if group is None else group
        gdst = group.get_group_rank(dst)
        assert gdst >= 0, ("dst rank out of group, need global rank")
        task = group.process_group.reduce(tensor, gdst, op_type)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task
K
kuizhiqing 已提交
792 793 794

    ring_id = 0 if group is None else group.id
    gdst = dst if group is None else group.get_group_rank(dst)
K
kuizhiqing 已提交
795
    assert gdst >= 0, ("dst rank out of group, need global rank")
K
kuizhiqing 已提交
796

J
Jiabin Yang 已提交
797
    if _non_static_mode():
798
        if op == ReduceOp.SUM:
W
wanghuancoder 已提交
799 800 801
            return _C_ops.c_reduce_sum(tensor, tensor, 'use_calc_stream',
                                       use_calc_stream, 'ring_id', ring_id,
                                       'root_id', gdst)
802
        elif op == ReduceOp.MAX:
W
wanghuancoder 已提交
803 804 805
            return _C_ops.c_reduce_max(tensor, tensor, 'use_calc_stream',
                                       use_calc_stream, 'ring_id', ring_id,
                                       'root_id', gdst)
806
        elif op == ReduceOp.MIN:
W
wanghuancoder 已提交
807 808 809
            return _C_ops.c_reduce_min(tensor, tensor, 'use_calc_stream',
                                       use_calc_stream, 'ring_id', ring_id,
                                       'root_id', gdst)
810
        elif op == ReduceOp.PROD:
W
wanghuancoder 已提交
811 812 813
            return _C_ops.c_reduce_prod(tensor, tensor, 'use_calc_stream',
                                        use_calc_stream, 'ring_id', ring_id,
                                        'root_id', gdst)
814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836
        else:
            raise ValueError("Unknown parameter: {}.".format(op))

    op_type = 'c_reduce'
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'all_reduce')

    if op == ReduceOp.SUM:
        op_type = 'c_reduce_sum'
    elif op == ReduceOp.MAX:
        op_type = 'c_reduce_max'
    elif op == ReduceOp.MIN:
        op_type = 'c_reduce_min'
    elif op == ReduceOp.PROD:
        op_type = 'c_reduce_prod'

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
        attrs={
K
kuizhiqing 已提交
837 838 839
            'ring_id': ring_id,
            'use_calc_stream': use_calc_stream,
            'root_id': gdst,
840 841 842
        })


K
kuizhiqing 已提交
843
def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
844 845
    """

846 847 848 849 850 851 852 853 854
    Gather tensors from all participators and all get the result. As shown
    below, 4 GPUs each start 4 processes and the data on each GPU is represnted
    by the GPU number. Through the all_gather operator, each GPU will have data
    from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png
        :width: 800
        :alt: all_gather
        :align: center
855 856 857 858 859 860

    Args:
        tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
            should be float16, float32, float64, int32 or int64.
        tensor (Tensor): The Tensor to send. Its data type
            should be float16, float32, float64, int32 or int64.
K
kuizhiqing 已提交
861
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
862 863
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
864 865 866 867 868 869 870

    Returns:
        None.

    Examples:
        .. code-block:: python

871
            # required: distributed
872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890
            import numpy as np
            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            tensor_list = []
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data1 = np.array([[4, 5, 6], [4, 5, 6]])
                np_data2 = np.array([[4, 5, 6], [4, 5, 6]])
                data1 = paddle.to_tensor(np_data1)
                data2 = paddle.to_tensor(np_data2)
                paddle.distributed.all_gather(tensor_list, data1)
            else:
                np_data1 = np.array([[1, 2, 3], [1, 2, 3]])
                np_data2 = np.array([[1, 2, 3], [1, 2, 3]])
                data1 = paddle.to_tensor(np_data1)
                data2 = paddle.to_tensor(np_data2)
                paddle.distributed.all_gather(tensor_list, data2)
891
    """
K
kuizhiqing 已提交
892 893 894
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
895
    if in_dygraph_mode():
896
        group = _get_default_group() if group is None else group
897 898 899 900 901 902
        if len(tensor_list) == 0:
            tensor_shape = list(tensor.shape)
            tensor_shape[0] *= group.nranks
            out = paddle.empty(tensor_shape, tensor.dtype)
        else:
            out = paddle.concat(tensor_list, axis=0)
903 904 905 906 907 908
        task = group.process_group.all_gather(tensor, out)
        task.wait()
        tensor_list.clear()
        tensor_list.extend(paddle.split(out, group.nranks, 0))
        return

K
kuizhiqing 已提交
909 910 911
    ring_id = 0 if group is None else group.id
    nranks = _get_global_group().nranks if group is None else group.nranks

J
Jiabin Yang 已提交
912
    if _non_static_mode():
913 914
        out = _C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream,
                                 'ring_id', ring_id, 'nranks', nranks)
915
    else:
916 917 918
        op_type = 'c_allgather'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934
        if not isinstance(tensor_list, list):
            raise ValueError("The type of 'tensor_list' for all_gather "
                             "should be list.")
        for elem in tensor_list:
            check_variable_and_dtype(
                elem, 'tensor_list',
                ['float16', 'float32', 'float64', 'int32', 'int64'],
                'all_gather')
        check_variable_and_dtype(
            tensor, 'tensor',
            ['float16', 'float32', 'float64', 'int32', 'int64'], 'all_gather')
        helper.append_op(
            type=op_type,
            inputs={'X': [tensor]},
            outputs={'Out': [out]},
            attrs={
K
kuizhiqing 已提交
935 936 937
                'ring_id': ring_id,
                'use_calc_stream': use_calc_stream,
                'nranks': nranks
938 939
            })

K
kuizhiqing 已提交
940
    tensor_list.extend(paddle.split(out, nranks, 0))
941 942


K
kuizhiqing 已提交
943
def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
944 945
    """

946 947 948 949 950 951 952
    Scatter a tensor to all participators. As shown below, 4 GPUs each start 4 processes and the source of the scatter
    is GPU0. Through scatter operator, the data in GPU0 will be sent to all GPUs averagely.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/scatter.png
        :width: 800
        :alt: scatter
        :align: center
953 954 955 956

    Args:
        tensor (Tensor): The output Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
957
        tensor_list (list|tuple): A list/tuple of Tensors to scatter. Every element in the list must be a Tensor whose data type
K
kuizhiqing 已提交
958 959
            should be float16, float32, float64, int32 or int64. Default value is None.
        src (int): The source rank id. Default value is 0.
K
kuizhiqing 已提交
960
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
961 962
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
963 964 965 966 967 968 969

    Returns:
        None.

    Examples:
        .. code-block:: python

970
            # required: distributed
971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989
            import numpy as np
            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data1 = np.array([7, 8, 9])
                np_data2 = np.array([10, 11, 12])
            else:
                np_data1 = np.array([1, 2, 3])
                np_data2 = np.array([4, 5, 6])
            data1 = paddle.to_tensor(np_data1)
            data2 = paddle.to_tensor(np_data2)
            if paddle.distributed.ParallelEnv().local_rank == 0:
                paddle.distributed.scatter(data1, src=1)
            else:
                paddle.distributed.scatter(data1, tensor_list=[data1, data2], src=1)
            out = data1.numpy()
990
    """
K
kuizhiqing 已提交
991 992 993 994 995 996
    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

L
lilong12 已提交
997
    if in_dygraph_mode():
998 999 1000 1001 1002 1003 1004 1005 1006
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
        rank = group.rank
        nranks = group.nranks
    else:
        ring_id = 0 if group is None else group.id
        gsrc = src if group is None else group.get_group_rank(src)
        rank = _get_global_group().rank if group is None else group.rank
        nranks = _get_global_group().nranks if group is None else group.nranks
K
kuizhiqing 已提交
1007
    assert gsrc >= 0, ("src rank out of group, need global rank")
K
kuizhiqing 已提交
1008 1009

    if rank != gsrc:
1010 1011 1012 1013
        tensor_list = []
        for _ in range(nranks):
            tensor_list.append(tensor)
    temp = paddle.concat(tensor_list, axis=0)
L
lilong12 已提交
1014
    if in_dygraph_mode():
1015 1016 1017 1018 1019 1020 1021
        task = group.process_group.scatter(temp, tensor, gsrc)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

L
lilong12 已提交
1022
    if _non_static_mode():
W
wanghuancoder 已提交
1023 1024 1025
        return _C_ops.c_scatter(temp, tensor, 'use_calc_stream',
                                use_calc_stream, 'ring_id', ring_id, 'nranks',
                                nranks, 'root', gsrc)
W
wanghuancoder 已提交
1026
    op_type = 'c_scatter'
1027 1028 1029 1030 1031 1032 1033 1034 1035
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'scatter')
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [temp]},
        outputs={'Out': [tensor]},
        attrs={
K
kuizhiqing 已提交
1036 1037 1038
            'ring_id': ring_id,
            'root': gsrc,
            'use_calc_stream': use_calc_stream,
1039 1040 1041 1042
            'nranks': nranks,
        })


1043
def _c_identity(tensor, group=None):
L
lilong12 已提交
1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054
    """
    Return a copy of the tensor, mainly used with model parallel.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
        group (int): The id of the process group to work on.

    Returns:
        Tensor.
    """
1055 1056 1057 1058
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
1059
    if _non_static_mode():
W
wanghuancoder 已提交
1060 1061
        return _C_ops.c_identity(tensor, 'use_calc_stream', True, 'ring_id',
                                 ring_id, 'use_model_parallel', True)
L
lilong12 已提交
1062 1063 1064
    op_type = 'c_identity'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
1065

L
lilong12 已提交
1066 1067 1068
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        '_c_identity')
1069

L
lilong12 已提交
1070 1071 1072 1073 1074
    helper.append_op(
        type=op_type,
        inputs={'X': tensor},
        outputs={'Out': out},
        attrs={
1075
            'ring_id': ring_id,
L
lilong12 已提交
1076 1077 1078 1079 1080 1081
            'use_calc_stream': True,
            'use_model_parallel': True,
        })
    return out


1082
def _c_concat(tensor, group=None):
1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095
    """
    Return allgather of the tensor, mainly used with model parallel.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
        group (int): The id of the process group to work on.

    Returns:
        Tensor.
    """
    if group is not None and not group.is_member():
        return
1096 1097
    group = _get_default_group() if group is None else group
    ring_id = group.id
1098

1099
    global_rank = _get_global_env().rank
1100 1101
    rank = group.rank
    nranks = group.nranks
1102

J
Jiabin Yang 已提交
1103
    if _non_static_mode():
W
wanghuancoder 已提交
1104 1105 1106
        return _C_ops.c_concat(tensor, 'ring_id', ring_id, 'use_calc_stream',
                               True, 'rank', rank, 'nranks', nranks,
                               'use_model_parallel', True)
1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123

    op_type = 'c_concat'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)

    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        '_c_concat')

    helper.append_op(
        type=op_type,
        inputs={'X': tensor},
        outputs={'Out': out},
        attrs={
            'ring_id': ring_id,
            'use_calc_stream': True,
            'use_model_parallel': True,
1124 1125
            'nranks': nranks,
            'rank': rank
1126 1127 1128 1129
        })
    return out


1130
def _c_split(tensor, group=None):
L
lilong12 已提交
1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142
    """
    Split tensor evenly among all members, mainly used with model parallel.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
        rank (int): The rank of the current process.
        group (int): The id of the process group to work on.

    Returns:
        Tensor.
    """
1143 1144 1145 1146
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

1147 1148 1149 1150
    global_rank = _get_global_env().rank
    rank = global_rank if group is None else group.get_group_rank(global_rank)
    nranks = _get_global_env().world_size if group is None else group.nranks

J
Jiabin Yang 已提交
1151
    if _non_static_mode():
W
wanghuancoder 已提交
1152 1153 1154
        return _C_ops.c_split(tensor, 'use_calc_stream', True, 'ring_id',
                              ring_id, 'rank', rank, 'nranks', nranks,
                              'use_model_parallel', True)
1155

L
lilong12 已提交
1156 1157 1158
    op_type = 'c_split'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
1159

L
lilong12 已提交
1160 1161 1162
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        '_c_split')
1163

L
lilong12 已提交
1164 1165 1166 1167 1168
    helper.append_op(
        type=op_type,
        inputs={'X': tensor},
        outputs={'Out': out},
        attrs={
1169
            'ring_id': ring_id,
L
lilong12 已提交
1170 1171 1172 1173 1174 1175 1176 1177
            'use_calc_stream': True,
            'rank': rank,
            'nranks': nranks,
            'use_model_parallel': True,
        })
    return out


1178 1179 1180 1181 1182
def _mp_allreduce(tensor,
                  op=ReduceOp.SUM,
                  group=None,
                  use_calc_stream=True,
                  use_model_parallel=True):
1183
    """[it is same as allreduce above, but it supports model parallel. And it support inplace startegy]
1184 1185 1186 1187 1188
    """
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212
    if in_dygraph_mode():
        assert op == ReduceOp.SUM, "Unknown parameter: {}.".format(op)

        from paddle.autograd import EagerPyLayer

        class mp_allreduce_eager(EagerPyLayer):
            @staticmethod
            def forward(ctx, tensor, use_calc_stream, ring_id,
                        use_model_parallel):
                ctx.ring_id = ring_id
                return _C_ops.c_allreduce_sum_(
                    tensor, 'use_calc_stream', use_calc_stream, 'ring_id',
                    ring_id, "use_model_parallel", use_model_parallel)

            @staticmethod
            def backward(ctx, dy):
                return _C_ops.c_identity(dy, 'use_calc_stream', True, 'ring_id',
                                         ctx.ring_id, 'use_model_parallel',
                                         True)

        return mp_allreduce_eager.apply(tensor, use_calc_stream, ring_id,
                                        use_model_parallel)

    elif _in_legacy_dygraph():
1213
        if op == ReduceOp.SUM:
W
wanghuancoder 已提交
1214
            return _C_ops.c_allreduce_sum_(
1215 1216 1217 1218
                tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id,
                "use_model_parallel", use_model_parallel)
        else:
            raise ValueError("Unknown parameter: {}.".format(op))
1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237

    op_type = 'c_allreduce_sum'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)

    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        op_type)

    helper.append_op(
        type=op_type,
        inputs={'X': tensor},
        outputs={'Out': out},
        attrs={
            'ring_id': ring_id,
            'use_calc_stream': use_calc_stream,
            'use_model_parallel': use_model_parallel,
        })
    return out
1238 1239


1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253
def _c_lookup_table(table, index, start_index=0, name=None):
    """
    Lookup table according to index.

    Args:
        table (Tensor): The input Tensor. Its data type
            should be float16, float32, float64.
        index (Tensor): The index to lookup table.
        start_index (int): The initial index for table range.
        name (string): The name of the api

    Returns:
        Tensor.
    """
J
Jiabin Yang 已提交
1254
    if _non_static_mode():
W
wanghuancoder 已提交
1255
        return _C_ops.c_embedding(table, index, "start_index", start_index)
1256

1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269
    op_type = 'c_embedding'
    helper = LayerHelper(op_type, **locals())
    dtype = helper.input_dtype(input_param_name='table')
    check_variable_and_dtype(index, 'input', ['int32', 'int64'], op_type)
    tmp = helper.create_variable_for_type_inference(dtype)
    helper.append_op(
        type='c_embedding',
        inputs={'Ids': index,
                'W': table},
        outputs={'Out': tmp},
        attrs={"start_index": start_index})
    return tmp

1270

B
Baibaifan 已提交
1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308
class _Linear(layers.Layer):
    """
    Linear
    """

    def __init__(self,
                 in_features,
                 out_features,
                 weight_attr=None,
                 bias_attr=None,
                 name=None):
        super(_Linear, self).__init__()
        self._dtype = self._helper.get_default_dtype()
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
        self.weight = self.create_parameter(
            shape=[in_features, out_features],
            attr=self._weight_attr,
            dtype=self._dtype,
            is_bias=False)
        self.bias = self.create_parameter(
            shape=[out_features],
            attr=self._bias_attr,
            dtype=self._dtype,
            is_bias=True)
        self.name = name

    def forward(self, input):
        out = _linear(
            x=input, weight=self.weight, bias=self.bias, name=self.name)
        return out

    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'in_features={}, out_features={}, dtype={}{}'.format(
            self.weight.shape[0], self.weight.shape[1], self._dtype, name_str)


1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328
def _c_softmax_with_cross_entropy(logits,
                                  label,
                                  group=None,
                                  return_softmax=False):
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id
    global_rank = _get_global_env().rank
    rank = global_rank if group is None else group.get_group_rank(global_rank)
    nranks = _get_global_env().world_size if group is None else group.nranks

    input_dims = len(list(logits.shape))
    label_dims = len(list(label.shape))
    if input_dims - 1 != label_dims and input_dims != label_dims:
        raise ValueError(
            'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
             (got nput_dims{}, label_dims{})'.format(input_dims, label_dims))
    if input_dims - 1 == label_dims:
        label = paddle.unsqueeze(label, axis=-1)

J
Jiabin Yang 已提交
1329
    if _non_static_mode():
W
wanghuancoder 已提交
1330
        softmax, loss = _C_ops.c_softmax_with_cross_entropy(
1331 1332 1333 1334 1335 1336
            logits, label, 'ring_id', ring_id, 'rank', rank, 'nranks', nranks)
        if not return_softmax:
            return loss
        else:
            return loss, softmax

W
WangXi 已提交
1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357
    attrs = {
        'ring_id': ring_id,
        'rank': rank,
        'nranks': nranks,
    }
    helper = LayerHelper('c_softmax_with_cross_entropy', **locals())
    softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
    loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
    helper.append_op(
        type='c_softmax_with_cross_entropy',
        inputs={'Logits': logits,
                'Label': label},
        outputs={'Softmax': softmax,
                 'Loss': loss},
        attrs=attrs)

    if return_softmax:
        return loss, softmax

    return loss

1358

B
Baibaifan 已提交
1359 1360 1361 1362
def _linear(x, weight, bias=None, name=None):
    """
    Fuction Linear
    """
J
Jiabin Yang 已提交
1363
    if _non_static_mode():
B
Baibaifan 已提交
1364
        pre_bias = _varbase_creator(dtype=x.dtype)
W
wanghuancoder 已提交
1365 1366
        _C_ops.matmul(x, weight, pre_bias, 'transpose_X', False, 'transpose_Y',
                      False, "alpha", 1)
B
Baibaifan 已提交
1367 1368 1369 1370 1371
        return dygraph_utils._append_bias_in_dygraph(
            pre_bias, bias, axis=len(x.shape) - 1)
    else:
        helper = LayerHelper('linear', **locals())
        dtype = x.dtype
B
Baibaifan 已提交
1372 1373
        assert len(
            x.shape) < 4, "X latitude is not supported greater than 3 now."
B
Baibaifan 已提交
1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400

        check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
                                 'linear')
        check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear')

        inputs = {'X': [x], 'Y': [weight]}
        attrs = {
            'transpose_X': False,
            'transpose_Y': False,
            'alpha': 1,
        }
        tmp = helper.create_variable_for_type_inference(dtype)
        helper.append_op(
            type='matmul_v2', inputs=inputs, outputs={'Out': tmp}, attrs=attrs)
        if bias is not None:
            res = helper.create_variable_for_type_inference(dtype)
            helper.append_op(
                type='elementwise_add',
                inputs={'X': [tmp],
                        'Y': [bias]},
                outputs={'Out': [res]},
                attrs={'axis': len(x.shape) - 1})
        else:
            res = tmp
        return res


1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413
def _set_var_distributed(var):
    if var is None:
        return

    var.is_distributed = True

    # NOTE: use current_block and find_var_recursive to support while_loop
    startup_block = paddle.static.default_startup_program().current_block()
    main_block = paddle.static.default_main_program().current_block()
    startup_block._find_var_recursive(var.name).is_distributed = True
    main_block._find_var_recursive(var.name).is_distributed = True


L
lilong12 已提交
1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424
def _parallel_linear(x,
                     num_rows,
                     num_cols,
                     axis,
                     param_attr,
                     bias_attr,
                     gather_out,
                     inner_rank,
                     nranks,
                     split_tensor,
                     name,
1425
                     group=None):
1426 1427
    """
    Parallel Linear
1428 1429 1430

    axis the dimension of the parameter of linear layer. 
    axis = 0: the row dimension
1431
    axis = 1: the col dimension
1432
    
1433
    """
1434 1435 1436 1437
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

L
lilong12 已提交
1438 1439
    if axis == 0:
        if split_tensor:
1440
            x = _c_split(x, group=group)
1441
    else:
L
lilong12 已提交
1442 1443
        x = _c_identity(x, group=group)

1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461
    linear = paddle.nn.Linear(
        num_rows,
        num_cols,
        weight_attr=param_attr,
        bias_attr=bias_attr,
        name=name)

    # NOTE: npu linear function use matmul_v2 but linear use matmul
    linear_function = _linear if core.is_compiled_with_npu()\
        else paddle.nn.functional.linear
    linear_out = linear_function(
        x,
        linear.weight,
        # NOTE(wangxi): row split, bias need add after allreduce
        None if axis == 0 else linear.bias,
        linear.name)

    _set_var_distributed(linear.weight)
1462 1463 1464 1465
    # set is_distributed for splited bias
    # if a linear layer is splited by row, each rank would hold a complete bias and they should be the same in each rank.
    # if a linear layer is splited by col, the bias would also be split into each rank as its weight
    if axis == 1 and linear._bias_attr != False:
1466
        _set_var_distributed(linear.bias)
L
lilong12 已提交
1467 1468 1469 1470 1471

    if not gather_out: return linear_out

    out_shape = list(linear_out.shape)
    out_shape[0] *= 1 if axis == 0 else nranks
1472
    main_block = paddle.static.default_main_program().current_block()
L
lilong12 已提交
1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486
    out = main_block.create_var(
        shape=out_shape,
        dtype=linear_out.dtype,
        type=linear_out.type,
        lod_level=linear_out.lod_level,
        persistable=False,
        is_data=False,
        need_check_feed=linear_out.desc.need_check_feed())
    if axis == 0:
        main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': linear_out},
            outputs={'Out': out},
            attrs={
1487
                'ring_id': ring_id,
L
lilong12 已提交
1488 1489 1490
                'use_calc_stream': True,
                'use_model_parallel': True
            })
1491 1492
        if linear.bias is not None:
            out = out + linear.bias
L
lilong12 已提交
1493 1494 1495 1496 1497 1498
    else:
        main_block.append_op(
            type='c_concat',
            inputs={'X': linear_out},
            outputs={'Out': out},
            attrs={
1499
                'rank': inner_rank,
1500
                'ring_id': ring_id,
L
lilong12 已提交
1501 1502 1503 1504 1505
                'nranks': nranks,
                'use_calc_stream': True,
                'use_model_parallel': True
            })
    return out
1506 1507


L
lilong12 已提交
1508 1509 1510 1511 1512 1513 1514
def _parallel_embedding(x,
                        per_part_embeddings,
                        origin_size,
                        param_attr,
                        inner_rank,
                        num_partitions,
                        name,
1515
                        group=None):
1516 1517 1518
    """
    Parallel Embedding
    """
1519 1520 1521 1522
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538
    helper = LayerHelper("_parallel_embedding", **locals())

    per_part_size = per_part_embeddings
    rank = inner_rank

    vocab_start_index = rank * per_part_size
    dtype = helper.get_default_dtype()
    size = [per_part_size, origin_size[1]]

    weight = helper.create_parameter(
        attr=param_attr, shape=size, dtype=dtype, is_bias=False)

    if num_partitions == 1:
        return paddle.nn.functional.embedding(
            x, weight=weight, padding_idx=None, sparse=False, name=name)

1539 1540
    startup_block = paddle.static.default_startup_program().global_block()
    main_block = paddle.static.default_main_program().global_block()
1541 1542 1543 1544 1545 1546 1547 1548 1549 1550
    startup_block.vars[weight.name].is_distributed = True
    main_block.vars[weight.name].is_distributed = True

    output_parallel = paddle.distributed.collective._c_lookup_table(
        weight, x, start_index=vocab_start_index, name=name)
    out = paddle.distributed.collective._mp_allreduce(
        output_parallel,
        group=group,
        use_calc_stream=True,
        use_model_parallel=True)
L
lilong12 已提交
1551
    return out
1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574


def split(x,
          size,
          operation,
          axis=0,
          num_partitions=1,
          gather_out=True,
          weight_attr=None,
          bias_attr=None,
          name=None):
    """

    Split the weight of the specified operation into multiple devices
    and do the computation in parallel.

    Now the following three cases are supported.

    Case 1: Parallel Embedding
        The weight of the embedding operation is a NxM matrix with N rows and M columns.
        With parallel embedding, the weight is split into num_partitions partitions, each
        of which is a matrix with (N/num_partitions + 1) rows and M column where the last
        row as the padding idx.
K
kuizhiqing 已提交
1575

1576 1577 1578 1579 1580 1581 1582 1583 1584
        Suppose we split the NxM weight into two partitons on device_0 and device_1
        respectively. Then, one each device, the final weight has (N/2 + 1) rows with the
        index range from 0 to N/2. On device_0, all values in the input within [0, N/2 -1]
        keep unchanged and all other values are changed to N/2 which is the padding index and
        are mapped to all zeros after embedding. In the same way, on device_1, the value V in the
        input within [N/2, N-1] will be changed to (V - N/2), and all other values are changed
        to N/2 and are mapped to all zeros after embedding. Finally, the results on the two
        devices are sum-reduced.

1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599
        The Embedding put on single card is as shown below:

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_single.png
            :width: 800
            :height: 350
            :alt: single_embedding
            :align: center

        Parallel Embedding is shown as below:

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_split.png
            :width: 800
            :alt: split_embedding
            :align: center

1600 1601 1602 1603 1604
    Case 2: Row Parallel Linear
        The weight of the linear operation is a NxM matrix with N rows and M columns.
        With row parallel linear, the weight is split into num_partitions partitions, each
        of which is a matrix with N/num_partitions rows and M column.

1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622
        The linear layer put on single card is shown as below, the input variable is represented by X,
        the weight matrix is represented by W and the output vaiable is O. The linear layer on single card is 
        simple matrix multiplication operation, O = X * W.

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_single.png
            :width: 800
            :alt: single_linear
            :align: center

        Row Parallel Linear is shown as below. As the name suggests, Row Parallel Linear splits the weight matrix W into
        [[W_row1], [W_row2]] along the row. And accordingly the input is splitted along the column into [X_col1, X_col2] and multiply their
        respective weight matrices. Finally apply AllReduce on the output from each card to get the final output.

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_row.png
            :width: 800
            :alt: split_row
            :align: center

1623 1624 1625 1626 1627
    Case 3: Column Parallel Linear
        The weight of the linear operation is a NxM matrix with N rows and M columns.
        With column parallel linear, the weight is split into num_paratitions partitions, each
        of which is a matrix with N rows and M/num_partitions column.

1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644
        The linear layer put on single card has been illustrated on case 2 and Column Parallel Linear
        is shown as below. The Column Parallel Linear splits the weight matrix W into [W_col1, W_col2] along the column and 
        these splitted matrices respectively multiply the input. Finally apply AllGather on the output from each card to get the final output. 

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col.png
            :width: 800
            :alt: split_col
            :align: center
    
    As observed, the column parallel linear and row parallel linear can be combined to skip one ALLGATHER communication
    operator. Furthermore the Attention and MLP can be combined to imporve the performance as shown below.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col_row.png
            :width: 800
            :alt: split_col_row
            :align: center

1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664
    Args:
        x (Tensor): Input tensor. It's data type should be float16, float32, float64, int32 or int64.
        size (list|tuple): A list or tuple with two elements indicating the shape of the weight.
        operation (str): The name of the operation. The supported operations are 'linear' and 'embedding'.
        axis (int, Optional): Indicate along which axis to split the weight. Default: 0.
        num_partitions (int, Optional): How many parts the weight is partitioned. Default: 1.
        gather_out (bool, Optional): Whether to gather the output after computation. By default, the output
            on each partitions will be gathered after computation. Default: True.
        weight_attr (ParamAttr, Optional): The parameter attribute for the learnable
            weights(Parameter) of the specified operation. Default: None.
        bias_attr (ParamAttr, Optional): The parameter attribute for the bias
            of the specified operation. Default: None.
        name (str, Optional): The default value is None. Normally there is no need for user to set this
            property. Default: None. For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor.

    Examples:
        .. code-block:: python
1665

1666
            # required: distributed
1667
            import paddle
1668
            import paddle.distributed.fleet as fleet
1669

1670
            paddle.enable_static()
1671
            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
1672
            fleet.init(is_collective=True)
1673
            data = paddle.randint(0, 8, shape=[10,4])
1674
            emb_out = paddle.distributed.split(
1675 1676 1677 1678
                data,
                (8, 8),
                operation="embedding",
                num_partitions=2)
1679

1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695
    """
    assert isinstance(size, (list, tuple)), (
        "The type of size for "
        "paddle.distributed.split must be list or tuple.")
    assert len(size) == 2, ("Number of elements in size of "
                            "paddle.distributed.split must be two.")
    assert isinstance(operation, str), ("The type of operation for "
                                        "paddle.distributed.split must be str.")
    supported_operations = [
        'linear',
        'embedding',
    ]
    assert operation in supported_operations, (
        "The operation for "
        "paddle.distributed.split must be one of {}.".format(
            supported_operations))
J
Jiabin Yang 已提交
1696
    if _non_static_mode():
L
lilong12 已提交
1697 1698 1699 1700
        raise ValueError(
            "paddle.distributed.split cannot be used in dynamic "
            "graph mode, plese use ParallelEmbedding, ParallelRowLinear, "
            "ParallelColumnLinear instead.")
1701
    else:
1702
        from .fleet import fleet
1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713
        assert fleet._role_maker, ("To use paddle.distributed.split, "
                                   "you must call fleet.init() firstly.")
        rank = fleet.worker_index()
        nranks = fleet.worker_num()

    # rank within a model parallel group
    inner_rank = rank % num_partitions

    if operation == "embedding":
        assert axis == 0, ("We only support to split the weight of embedding "
                           "along the first axis now.")
1714 1715 1716
        assert size[0] % num_partitions == 0, \
            "The length of the vocabulary must be divisible by num_partitions " \
            "but received vocabulary={} num_partitions={}".format(size[0], num_partitions)
1717

1718
        per_part_size = size[0] // num_partitions
B
Baibaifan 已提交
1719 1720 1721 1722 1723 1724 1725 1726 1727 1728
        emb_out = _parallel_embedding(
            x,
            per_part_size,
            size,
            weight_attr,
            inner_rank,
            num_partitions,
            name,
            group=None)
        return emb_out
1729
    else:
L
lilong12 已提交
1730
        should_split = False
1731 1732 1733 1734 1735 1736 1737
        if axis == 0:
            assert size[0] % num_partitions == 0, (
                "Number of rows of the weight for linear ({}) must be"
                " divisible by num_partitions ({})".format(size[0],
                                                           num_partitions))
            per_part_size = size[0] // num_partitions
            linear_size = (per_part_size, size[1])
L
lilong12 已提交
1738
            if x.shape[-1] == size[0]: should_split = True
1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759

        elif axis == 1:
            assert size[1] % num_partitions == 0, (
                "Number of column of the weight for linear ({}) must be"
                " divisible by num_partitions ({})".format(size[1],
                                                           num_partitions))
            per_part_size = size[1] // num_partitions
            linear_size = (size[0], per_part_size)
        else:
            raise ValueError("The value of axis must be 0 or 1, but the value "
                             "given is {}.".format(axis))

        linear_out = _parallel_linear(
            x,
            linear_size[0],
            linear_size[1],
            axis,
            weight_attr,
            bias_attr,
            gather_out,
            inner_rank,
L
lilong12 已提交
1760 1761 1762
            num_partitions,
            should_split,
            name=name,
1763
            group=None)
1764
        return linear_out
L
lilong12 已提交
1765 1766


L
lilong12 已提交
1767 1768
def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
    """
1769 1770 1771 1772 1773 1774 1775 1776 1777 1778
    Scatter tensors in in_tensor_list to all participators averagely and gather the result tensors in out_tensor_list.
    As shown below, the in_tensor_list in GPU0 includes 0_0 and 0_1, and GPU1 includes 1_0 and 1_1.
    Through alltoall operator, the 0_0 in GPU0 will be sent to GPU0 and 0_1 to GPU1, 1_0 in GPU1 sent to GPU0 and 1_1 to GPU1.
    Finally the out_tensor_list in GPU0 includes 0_0 and 1_0, and GPU1 includes 0_1 and 1_1.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/alltoall.png
        :width: 800
        :alt: alltoall
        :align: center

L
lilong12 已提交
1779 1780 1781 1782 1783 1784 1785
    Args:
        in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type
            should be float16, float32, float64, int32 or int64.
        out_tensor_list (Tensor): A list of output Tensors. The data type of its elements should be the same as the
            data type of the input Tensors.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
        use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True.
1786
    
L
lilong12 已提交
1787 1788
    Returns:
        None.
1789
    
L
lilong12 已提交
1790 1791
    Examples:
        .. code-block:: python
1792

L
lilong12 已提交
1793 1794 1795 1796
            # required: distributed
            import numpy as np
            import paddle
            from paddle.distributed import init_parallel_env
1797
            
L
lilong12 已提交
1798 1799 1800 1801 1802 1803 1804 1805 1806 1807
            init_parallel_env()
            out_tensor_list = []
            if paddle.distributed.ParallelEnv().rank == 0:
                np_data1 = np.array([[1, 2, 3], [4, 5, 6]])
                np_data2 = np.array([[7, 8, 9], [10, 11, 12]])
            else:
                np_data1 = np.array([[13, 14, 15], [16, 17, 18]])
                np_data2 = np.array([[19, 20, 21], [22, 23, 24]])
            data1 = paddle.to_tensor(np_data1)
            data2 = paddle.to_tensor(np_data2)
李季 已提交
1808
            paddle.distributed.alltoall([data1, data2], out_tensor_list)
L
lilong12 已提交
1809 1810 1811 1812 1813 1814
            # out for rank 0: [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]]
            # out for rank 1: [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]]
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
1815
    if in_dygraph_mode():
1816 1817 1818 1819
        group = _get_default_group() if group is None else group
    else:
        ring_id = 0 if group is None else group.id

L
lilong12 已提交
1820
    temp = paddle.concat(in_tensor_list, axis=0)
李季 已提交
1821
    nranks = len(in_tensor_list)
L
lilong12 已提交
1822
    if in_dygraph_mode():
1823 1824 1825 1826 1827 1828
        if len(out_tensor_list) == 0:
            tensor_shape = list(in_tensor_list[0].shape)
            tensor_shape[0] *= nranks
            out = paddle.empty(tensor_shape, in_tensor_list[0].dtype)
        else:
            out = paddle.concat(out_tensor_list, axis=0)
1829 1830 1831 1832 1833 1834
        task = group.process_group.alltoall(temp, out)
        task.wait()
        out_tensor_list.clear()
        out_tensor_list.extend(paddle.split(out, nranks, 0))
        return

J
Jiabin Yang 已提交
1835
    if _non_static_mode():
李季 已提交
1836 1837
        out = _C_ops.alltoall(temp, 'use_calc_stream', use_calc_stream,
                              'ring_id', ring_id)
L
lilong12 已提交
1838
    else:
W
wanghuancoder 已提交
1839 1840 1841 1842 1843
        op_type = 'alltoall'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(
            dtype=in_tensor_list[0].dtype)

L
lilong12 已提交
1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862
        if not isinstance(in_tensor_list, list):
            raise ValueError("The type of 'in_tensor_list' for all_to_all "
                             "should be list.")
        for elem in in_tensor_list:
            check_variable_and_dtype(
                elem, 'in_tensor_list',
                ['float16', 'float32', 'float64', 'int32', 'int64'],
                'all_to_all')
        if not isinstance(out_tensor_list, list):
            raise ValueError("The type of 'out_tensor_list' for all_to_all "
                             "should be list.")
        if len(out_tensor_list) != 0:
            raise ValueError("The 'out_tensor_list' for all_to_all "
                             "must be an empty list.")
        helper.append_op(
            type=op_type,
            inputs={'X': [temp]},
            outputs={'Out': [out]},
            attrs={
L
lilong12 已提交
1863
                'ring_id': ring_id,
L
lilong12 已提交
1864 1865 1866 1867 1868
                'use_calc_stream': use_calc_stream,
            })
    out_tensor_list.extend(paddle.split(out, nranks, 0))


L
lilong12 已提交
1869 1870 1871 1872 1873 1874 1875 1876
def send(tensor, dst=0, group=None, use_calc_stream=True):
    """
    Send a tensor to the receiver.

    Args:
        tensor (Tensor): The Tensor to send. Its data type
            should be float16, float32, float64, int32 or int64.
        dst (int): The destination rank id.
L
lilong12 已提交
1877 1878
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
        use_calc_stream (bool, optional): Whether to use calculate stream or communication stream. Default: True.
1879
    
L
lilong12 已提交
1880 1881 1882 1883 1884
    Returns:
        None.

    Examples:
        .. code-block:: python
1885

L
lilong12 已提交
1886
            # required: distributed
L
lilong12 已提交
1887
            import paddle
L
lilong12 已提交
1888
            from paddle.distributed import init_parallel_env
1889

L
lilong12 已提交
1890 1891 1892 1893 1894 1895 1896 1897
            init_parallel_env()
            if paddle.distributed.ParallelEnv().rank == 0:
                data = paddle.to_tensor([7, 8, 9])
                paddle.distributed.send(data, dst=1)
            else:
                data = paddle.to_tensor([1,2,3])
                paddle.distributed.recv(data, src=0)
            out = data.numpy()
L
lilong12 已提交
1898 1899 1900
    """
    if group is not None and not group.is_member():
        return
1901

L
lilong12 已提交
1902
    if in_dygraph_mode():
1903 1904 1905 1906 1907 1908 1909 1910
        group = _get_default_group() if group is None else group
        task = group.process_group.send(tensor, dst)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

L
lilong12 已提交
1911 1912
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
1913
    if _non_static_mode():
W
wanghuancoder 已提交
1914 1915
        return _C_ops.send_v2(tensor, 'use_calc_stream', use_calc_stream,
                              'ring_id', ring_id, 'peer', dst)
W
wanghuancoder 已提交
1916
    op_type = 'send_v2'
L
lilong12 已提交
1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'send')

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        attrs={
            'ring_id': ring_id,
            'peer': dst,
            'use_calc_stream': use_calc_stream,
        })


def recv(tensor, src=0, group=None, use_calc_stream=True):
    """
    Receive a tensor to the sender.

    Args:
        tensor (Tensor): The Tensor to receive. Its data type
            should be float16, float32, float64, int32 or int64.
        src (int): The source rank id.
L
lilong12 已提交
1940 1941
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
        use_calc_stream (bool, optional): Whether to use calculate stream or communication stream. Default: True.
1942
    
L
lilong12 已提交
1943 1944 1945 1946 1947
    Returns:
        None.

    Examples:
        .. code-block:: python
1948

L
lilong12 已提交
1949
            # required: distributed
L
lilong12 已提交
1950
            import paddle
L
lilong12 已提交
1951
            from paddle.distributed import init_parallel_env
1952

L
lilong12 已提交
1953 1954 1955 1956 1957 1958 1959 1960
            init_parallel_env()
            if paddle.distributed.ParallelEnv().rank == 0:
                data = paddle.to_tensor([7, 8, 9])
                paddle.distributed.send(data, dst=1)
            else:
                data = paddle.to_tensor([1,2,3])
                paddle.distributed.recv(data, src=0)
            out = data.numpy()
L
lilong12 已提交
1961 1962 1963
    """
    if group is not None and not group.is_member():
        return
1964

L
lilong12 已提交
1965
    if in_dygraph_mode():
1966 1967 1968 1969 1970 1971 1972 1973
        group = _get_default_group() if group is None else group
        task = group.process_group.recv(tensor, src)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

L
lilong12 已提交
1974 1975
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
1976
    if _non_static_mode():
W
wanghuancoder 已提交
1977 1978 1979
        return _C_ops.recv_v2(tensor, 'use_calc_stream', use_calc_stream,
                              'ring_id', ring_id, 'peer', src, 'dtype',
                              tensor.dtype, 'out_shape', tensor.shape)
W
wanghuancoder 已提交
1980
    op_type = 'recv_v2'
L
lilong12 已提交
1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'recv')
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        outputs={'Out': [tensor]},
        attrs={
            'ring_id': ring_id,
            'peer': src,
            'out_shape': tensor.shape,
            'dtype': tensor.dtype,
            'use_calc_stream': use_calc_stream,
        })