collective.py 64.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import os
17
from datetime import timedelta
18
from ..fluid.layer_helper import LayerHelper
19
import paddle.fluid.framework as framework
20
from ..fluid.framework import Variable
21
from ..fluid.framework import in_dygraph_mode
22
from ..fluid.framework import OpProtoHolder
J
Jiabin Yang 已提交
23
from ..fluid.framework import _non_static_mode
24
from ..fluid.framework import convert_np_dtype_to_dtype_
J
Jiangxinz 已提交
25
from ..fluid.framework import _varbase_creator
26 27 28 29
from ..fluid.data_feeder import convert_dtype
from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.data_feeder import check_type
from ..fluid.data_feeder import check_dtype
30 31
from ..fluid.layers.tensor import fill_constant
from ..fluid.layers import utils
B
Baibaifan 已提交
32
from ..fluid.dygraph import layers
33 34 35 36
from ..fluid.dygraph.parallel import prepare_context
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
W
wanghuancoder 已提交
37
from paddle import _C_ops
J
Jiangxinz 已提交
38
import paddle.fluid.dygraph_utils as dygraph_utils
39

40
__all__ = []
41 42 43


class ReduceOp:
L
lilong12 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
    """
    Specify the type of operation used for element-wise reductions.
    It should be one of the following values:

        ReduceOp.SUM

        ReduceOp.MAX

        ReduceOp.MIN

        ReduceOp.PROD

    Examples:
        .. code-block:: python

            import numpy as np
            import paddle
            from paddle.distributed import ReduceOp
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data = np.array([[4, 5, 6], [4, 5, 6]])
            else:
                np_data = np.array([[1, 2, 3], [1, 2, 3]])
            data = paddle.to_tensor(np_data)
            paddle.distributed.all_reduce(data, op=ReduceOp.SUM)
            out = data.numpy()
            # [[5, 7, 9], [5, 7, 9]]
    """
75 76 77 78
    SUM = 0
    MAX = 1
    MIN = 2
    PROD = 3
79
    AVG = 4
80 81


K
kuizhiqing 已提交
82 83 84 85
class Group():
    """
    The abstract representation of group.
    """
86

87
    def __init__(self, rank, rank_num, id=0, ranks=[], pg=None, name=None):
88 89
        self.rank = rank
        self.nranks = rank_num
K
kuizhiqing 已提交
90 91
        self.id = id
        self.ranks = ranks
92 93
        self.pg = pg
        self.name = name
K
kuizhiqing 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107

    def is_member(self):
        if self.rank < 0:
            return False
        if self.nranks < 2:
            return False
        return True

    def get_group_rank(self, rank):
        if self.is_member() and rank in self.ranks:
            return self.ranks.index(rank)
        else:
            return -1

108 109 110 111
    @property
    def process_group(self):
        return self.pg

112 113 114 115
    def __repr__(self):
        debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format(
            self.rank, self.nranks, self.id)
        debug_str += ", ".join(map(str, self.ranks))
116 117
        debug_str += "; name: "
        debug_str += self.name if self.name else "None"
118 119
        return debug_str

K
kuizhiqing 已提交
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134

_global_env = None


def _get_global_env():
    global _global_env
    if not _global_env:
        _global_env = paddle.distributed.ParallelEnv()
    return _global_env


# group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group]
_group_map = {}

135 136 137 138 139 140 141 142 143 144 145
# group map by name : the map of all groups from their names
# Dict[name, Group]
_group_map_by_name = {}

# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"

_valid_backend_list = ['nccl', 'gloo', 'hccl']
_default_store = None  # the default tcp store
_default_backend = None

K
kuizhiqing 已提交
146 147 148 149 150

def _get_group_map():
    global _group_map
    if not _group_map:
        genv = _get_global_env()
151 152
        _group_map[0] = Group(
            genv.rank, genv.world_size, ranks=list(range(genv.world_size)))
K
kuizhiqing 已提交
153 154 155 156 157 158 159
    return _group_map


def _get_global_group():
    return _get_group_map()[0]


160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
def _get_group_map_by_name():
    global _group_map_by_name
    assert _default_group_name in _group_map_by_name, (
        "Call paddle.distributed.init_parallel_env first "
        "to initialize the distributed environment.")
    return _group_map_by_name


def _get_default_group():
    assert _default_group_name in _group_map_by_name, (
        "Call paddle.distributed.init_parallel_env first "
        "to initialize the distributed environment.")
    return _get_group_map_by_name()[_default_group_name]


K
kuizhiqing 已提交
175 176 177 178 179 180 181 182 183 184
def _new_ring_id():
    return len(_get_group_map()) + max(_get_global_env().nrings, 9)


def get_group(id=0):
    """

    Get group instance by group id.

    Args:
K
kuizhiqing 已提交
185
        id (int): the group id. Default value is 0.
K
kuizhiqing 已提交
186 187 188 189 190 191 192 193 194 195 196 197 198 199

    Returns:
        Group: the group instance.

    Examples:
        .. code-block:: python

            ...
            gid = paddle.distributed.new_group([2,4,6])
            paddle.distributed.get_group(gid.id)

    """

    gm = _get_group_map()
J
Jiangxinz 已提交
200
    return gm[id] if id in gm else None
K
kuizhiqing 已提交
201 202


203 204 205 206 207 208 209
def _new_process_group_impl(backend,
                            store,
                            rank,
                            world_size,
                            group_name,
                            pg_options,
                            group_id=0):
210 211
    pg = None
    if backend == "gloo":
212
        pg = core.ProcessGroupGloo(store, rank, world_size, group_id)
213
    elif backend == "nccl":
214
        pg = core.ProcessGroupNCCL(store, rank, world_size, group_id)
215
    elif backend == "hccl":
216
        pg = core.ProcessGroupHCCL(store, rank, world_size, group_id)
217 218 219 220

    return pg


S
ShenLiang 已提交
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
def barrier(group=None):
    """

    Barrier among all participators in the group.

    Args:
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            paddle.distributed.barrier()
    """
    if group is not None and not group.is_member():
        return

245 246 247 248 249 250
    if framework._in_eager_mode_ and in_dygraph_mode():
        group = _get_default_group() if group is None else group
        task = group.process_group.barrier()
        task.wait()
        return

S
ShenLiang 已提交
251 252 253
    ring_id = 0 if group is None else group.id

    temp = fill_constant([1], dtype="int32", value="1")
J
Jiabin Yang 已提交
254
    if _non_static_mode():
W
wanghuancoder 已提交
255
        return _C_ops.barrier(temp, temp, 'ring_id', ring_id)
W
wanghuancoder 已提交
256 257 258

    op_type = 'barrier'

S
ShenLiang 已提交
259 260 261 262 263 264 265 266 267 268
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'group' for barrier must be int.")
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [temp]},
        outputs={'Out': [temp]},
        attrs={'ring_id': ring_id})


K
kuizhiqing 已提交
269 270 271
def new_group(ranks=None, backend=None):
    """

K
kuizhiqing 已提交
272
    Creates a new distributed communication group.
K
kuizhiqing 已提交
273 274

    Args:
K
kuizhiqing 已提交
275
        ranks (list): The global ranks of group members.
K
kuizhiqing 已提交
276 277 278
        backend (str): The backend used to create group, only nccl is supported now.

    Returns:
K
kuizhiqing 已提交
279
        Group: The group instance.
K
kuizhiqing 已提交
280 281 282 283 284 285 286

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
287 288 289
            tindata = paddle.randn(shape=[2, 3])
            gp = paddle.distributed.new_group([2,4,6])
            paddle.distributed.all_reduce(tindata, group=gp, use_calc_stream=False)
K
kuizhiqing 已提交
290 291

    """
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
    global _group_map
    if framework._in_eager_mode_:
        global _default_group_name
        gid = _new_ring_id()
        group_name = _default_group_name + str(gid)
        global_group = _get_default_group()
        global_rank = global_group.rank
        global_ranks = global_group.ranks
        if ranks is None:
            ranks = global_ranks
        assert len(ranks) <= len(global_ranks), (
            "Size of new group must be less than or "
            "equal to that of the default global group.")
        size = len(ranks)
        assert size > 1, "A group must have at least two memebers."
        ranks = sorted(ranks)
        if global_rank in ranks:
            rank = ranks.index(global_rank)
            pg = _new_process_group_impl(
                backend,
                _default_store,
                rank,
                size,
                group_name,
                pg_options=None,
                group_id=gid)
        else:
            rank = -1
            pg = None
        group = Group(rank, size, id=gid, ranks=ranks, pg=pg, name=group_name)
        _group_map_by_name[group_name] = group
        _group_map[gid] = group

        return group
K
kuizhiqing 已提交
326 327 328 329 330 331 332 333 334 335 336 337 338 339

    if not backend:
        backend = 'nccl'
    assert backend == 'nccl', ("backend other than nccl is not supported yet")

    genv = _get_global_env()
    global_rank = genv.rank

    ring_id = _new_ring_id()

    if global_rank not in ranks:
        gp = Group(-1, -1, ring_id, ranks)
        _group_map[ring_id] = gp
    else:
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
        ranks = sorted(ranks)
        group_rank = ranks.index(global_rank)
        group_size = len(ranks)
        gp = Group(group_rank, group_size, ring_id, ranks)
        _group_map[ring_id] = gp

        if group_size >= 2:
            strategy = core.ParallelStrategy()
            strategy.nranks = group_size
            strategy.local_rank = group_rank
            strategy.trainer_endpoints = [
                genv.trainer_endpoints[i] for i in ranks
            ]
            strategy.current_endpoint = genv.current_endpoint
            strategy.nrings = 1

            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(genv.device_id)
                core.NCCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
360 361 362 363
            elif core.is_compiled_with_npu():
                place = core.NPUPlace(genv.device_id)
                core.HCCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
364 365 366 367
            elif core.is_compiled_with_mlu():
                place = core.MLUPlace(genv.device_id)
                core.CNCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
368 369 370 371 372 373 374
            else:
                assert False, ("no cuda device found")
        else:
            return gp

    # TODO(shenliang03): This is a temporary solution to solve the problem of 
    # hang caused by cross-creation of new_group
375
    tmp = paddle.to_tensor(
J
Jiabin Yang 已提交
376
        [1], dtype="int32") if _non_static_mode() else fill_constant(
377
            [0], dtype="int32", value="1")
378 379
    paddle.distributed.all_reduce(tmp, use_calc_stream=True)
    paddle.distributed.wait(tmp)
K
kuizhiqing 已提交
380 381
    return gp

382

K
kuizhiqing 已提交
383 384 385 386 387 388 389 390
def wait(tensor, group=None, use_calc_stream=True):
    """

    wait to sync stream for group.

    Args:
        tensor (Tensor): The Tensor used before sync.
        group (Group): The Group instance to perform sync.
K
kuizhiqing 已提交
391 392
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
K
kuizhiqing 已提交
393 394 395 396 397 398 399 400 401 402

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
403
            tindata = paddle.randn(shape=[2, 3])
K
kuizhiqing 已提交
404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421
            paddle.distributed.all_reduce(tindata, use_calc_stream=True)
            paddle.distributed.wait(tindata)

    """

    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id

    if use_calc_stream:
        _sync_calc_stream(tensor)
    else:
        _sync_comm_stream(tensor, ring_id)


def _sync_calc_stream(tensor):

J
Jiabin Yang 已提交
422
    if _non_static_mode():
W
wanghuancoder 已提交
423
        return _C_ops.c_sync_calc_stream(tensor, tensor)
K
kuizhiqing 已提交
424 425 426 427 428 429 430 431

    op_type = 'c_sync_calc_stream'

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]}, )
432

433

K
kuizhiqing 已提交
434
def _sync_comm_stream(tensor, ring_id=0):
435

J
Jiabin Yang 已提交
436
    if _non_static_mode():
W
wanghuancoder 已提交
437
        return _C_ops.c_sync_comm_stream([tensor], [tensor], 'ring_id', ring_id)
438

K
kuizhiqing 已提交
439
    op_type = 'c_sync_comm_stream'
440

K
kuizhiqing 已提交
441 442 443 444 445 446 447 448 449
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
        attrs={'ring_id': ring_id}, )


def broadcast(tensor, src, group=None, use_calc_stream=True):
450 451 452
    """

    Broadcast a tensor from the source to all others.
453 454 455 456 457 458 459
    As shown below, 4 GPUs each start 4 processes and GPU0 owns data 0. Through broadcast operator,
    the data 0 will be sent to all GPUs from GPU0.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/broadcast.png
        :width: 800
        :alt: broadcast
        :align: center
460 461 462 463 464

    Args:
        tensor (Tensor): The Tensor to send if current rank is the source, or the tensor to receive otherwise. Its data type
            should be float16, float32, float64, int32 or int64.
        src (int): The source rank.
K
kuizhiqing 已提交
465
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
466 467
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
468 469 470 471 472 473 474

    Returns:
        None.

    Examples:
        .. code-block:: python

475
            # required: distributed
476 477 478 479 480 481 482 483 484 485 486 487 488 489
            import numpy as np
            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data = np.array([[4, 5, 6], [4, 5, 6]])
            else:
                np_data = np.array([[1, 2, 3], [1, 2, 3]])
            data = paddle.to_tensor(np_data)
            paddle.distributed.broadcast(data, 1)
            out = data.numpy()
            # [[1, 2, 3], [1, 2, 3]]
490
    """
K
kuizhiqing 已提交
491 492 493 494 495 496 497

    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

498 499 500 501 502 503 504 505 506 507 508 509
    if framework._in_eager_mode_ and in_dygraph_mode():
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
        assert gsrc >= 0, ("src rank out of group, need global rank")
        task = group.process_group.broadcast(tensor, gsrc)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

    ring_id = ring_id = 0 if group is None else group.id
K
kuizhiqing 已提交
510
    gsrc = src if group is None else group.get_group_rank(src)
K
kuizhiqing 已提交
511
    assert gsrc >= 0, ("src rank out of group, need global rank")
K
kuizhiqing 已提交
512

J
Jiabin Yang 已提交
513
    if _non_static_mode():
W
wanghuancoder 已提交
514 515 516
        return _C_ops.c_broadcast(tensor, tensor, 'root', gsrc,
                                  'use_calc_stream', use_calc_stream, 'ring_id',
                                  ring_id)
517 518 519 520 521 522 523 524 525 526 527 528

    op_type = 'c_broadcast'
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'broadcast')

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
        attrs={
K
kuizhiqing 已提交
529 530 531
            'root': gsrc,
            'use_calc_stream': use_calc_stream,
            'ring_id': ring_id,
532 533 534
        })


K
kuizhiqing 已提交
535
def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
536 537 538
    """

    Reduce a tensor over all ranks so that all get the result.
539 540 541 542 543 544 545 546
    As shown below, 4 GPUs each start 4 processes and the data on each GPU is represnted
    by the GPU number. The reduce operator is sum. Through all_reduce operator, 
    each GPU will have the sum of the data from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allreduce.png
        :width: 800
        :alt: all_reduce
        :align: center
547 548 549 550

    Args:
        tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
K
kuizhiqing 已提交
551
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
K
kuizhiqing 已提交
552
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
553 554
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
555 556 557 558 559 560 561

    Returns:
        None.

    Examples:
        .. code-block:: python

562
            # required: distributed
563 564 565 566 567 568 569 570 571 572 573 574 575 576 577
            import numpy as np
            import paddle
            from paddle.distributed import ReduceOp
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data = np.array([[4, 5, 6], [4, 5, 6]])
            else:
                np_data = np.array([[1, 2, 3], [1, 2, 3]])
            data = paddle.to_tensor(np_data)
            paddle.distributed.all_reduce(data)
            out = data.numpy()
            # [[5, 7, 9], [5, 7, 9]]
578
    """
K
kuizhiqing 已提交
579 580 581
    if group is not None and not group.is_member():
        return

582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598
    if framework._in_eager_mode_ and in_dygraph_mode():
        if op == ReduceOp.SUM:
            op_type = core.ReduceOp.SUM
        elif op == ReduceOp.MAX:
            op_type = core.ReduceOp.MAX
        elif op == ReduceOp.MIN:
            op_type = core.ReduceOp.MIN
        else:
            raise ValueError("Unknown reduce_op type for allreduce.")
        group = _get_default_group() if group is None else group
        task = group.process_group.allreduce(tensor, op_type)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

K
kuizhiqing 已提交
599
    ring_id = 0 if group is None else group.id
J
Jiabin Yang 已提交
600
    if _non_static_mode():
601
        if op == ReduceOp.SUM:
W
wanghuancoder 已提交
602 603
            return _C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
                                           use_calc_stream, 'ring_id', ring_id)
604
        elif op == ReduceOp.MAX:
W
wanghuancoder 已提交
605 606
            return _C_ops.c_allreduce_max_(tensor, 'use_calc_stream',
                                           use_calc_stream, 'ring_id', ring_id)
607
        elif op == ReduceOp.MIN:
W
wanghuancoder 已提交
608 609
            return _C_ops.c_allreduce_min_(tensor, 'use_calc_stream',
                                           use_calc_stream, 'ring_id', ring_id)
610
        elif op == ReduceOp.PROD:
W
wanghuancoder 已提交
611 612
            return _C_ops.c_allreduce_prod_(tensor, 'use_calc_stream',
                                            use_calc_stream, 'ring_id', ring_id)
613 614 615 616 617 618 619 620 621 622 623 624 625 626
        else:
            raise ValueError("Unknown parameter: {}.".format(op))

    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'all_reduce')
    if op == ReduceOp.SUM:
        op_type = 'c_allreduce_sum'
    elif op == ReduceOp.MAX:
        op_type = 'c_allreduce_max'
    elif op == ReduceOp.MIN:
        op_type = 'c_allreduce_min'
    elif op == ReduceOp.PROD:
        op_type = 'c_allreduce_prod'
K
kuizhiqing 已提交
627 628
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'ring_id' for all_reduce should be int.")
629 630 631 632 633
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
K
kuizhiqing 已提交
634 635
        attrs={'ring_id': ring_id,
               'use_calc_stream': use_calc_stream})
636 637


K
kuizhiqing 已提交
638
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
639 640
    """

641 642 643 644 645 646 647 648
    Reduce a tensor to the destination from all others. As shown below, 4 GPUs each start 4 processes and the data on each GPU is respresnted
    by the GPU number. The destination of the reduce operator is GPU0 and the process is sum. Through reduce operator,
    the GPU0 will owns the sum of all data from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/reduce.png
        :width: 800
        :alt: reduce
        :align: center
649 650 651 652 653

    Args:
        tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
            should be float16, float32, float64, int32 or int64.
        dst (int): The destination rank id.
K
kuizhiqing 已提交
654
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
K
kuizhiqing 已提交
655
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
656 657
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
658 659 660 661 662 663 664

    Returns:
        None.

    Examples:
        .. code-block:: python

665
            # required: distributed
666 667 668 669 670 671 672 673 674 675 676 677 678 679
            import numpy as np
            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data = np.array([[4, 5, 6], [4, 5, 6]])
            else:
                np_data = np.array([[1, 2, 3], [1, 2, 3]])
            data = paddle.to_tensor(np_data)
            paddle.distributed.reduce(data, 0)
            out = data.numpy()
            # [[5, 7, 9], [5, 7, 9]]
680
    """
K
kuizhiqing 已提交
681 682 683
    if group is not None and not group.is_member():
        return

684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701
    if framework._in_eager_mode_ and in_dygraph_mode():
        if op == ReduceOp.SUM:
            op_type = core.ReduceOp.SUM
        elif op == ReduceOp.MAX:
            op_type = core.ReduceOp.MAX
        elif op == ReduceOp.MIN:
            op_type = core.ReduceOp.MIN
        else:
            raise ValueError("Unknown reduce_op type for reduce.")
        group = _get_default_group() if group is None else group
        gdst = group.get_group_rank(dst)
        assert gdst >= 0, ("dst rank out of group, need global rank")
        task = group.process_group.reduce(tensor, gdst, op_type)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task
K
kuizhiqing 已提交
702 703 704

    ring_id = 0 if group is None else group.id
    gdst = dst if group is None else group.get_group_rank(dst)
K
kuizhiqing 已提交
705
    assert gdst >= 0, ("dst rank out of group, need global rank")
K
kuizhiqing 已提交
706

J
Jiabin Yang 已提交
707
    if _non_static_mode():
708
        if op == ReduceOp.SUM:
W
wanghuancoder 已提交
709 710 711
            return _C_ops.c_reduce_sum(tensor, tensor, 'use_calc_stream',
                                       use_calc_stream, 'ring_id', ring_id,
                                       'root_id', gdst)
712
        elif op == ReduceOp.MAX:
W
wanghuancoder 已提交
713 714 715
            return _C_ops.c_reduce_max(tensor, tensor, 'use_calc_stream',
                                       use_calc_stream, 'ring_id', ring_id,
                                       'root_id', gdst)
716
        elif op == ReduceOp.MIN:
W
wanghuancoder 已提交
717 718 719
            return _C_ops.c_reduce_min(tensor, tensor, 'use_calc_stream',
                                       use_calc_stream, 'ring_id', ring_id,
                                       'root_id', gdst)
720
        elif op == ReduceOp.PROD:
W
wanghuancoder 已提交
721 722 723
            return _C_ops.c_reduce_prod(tensor, tensor, 'use_calc_stream',
                                        use_calc_stream, 'ring_id', ring_id,
                                        'root_id', gdst)
724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746
        else:
            raise ValueError("Unknown parameter: {}.".format(op))

    op_type = 'c_reduce'
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'all_reduce')

    if op == ReduceOp.SUM:
        op_type = 'c_reduce_sum'
    elif op == ReduceOp.MAX:
        op_type = 'c_reduce_max'
    elif op == ReduceOp.MIN:
        op_type = 'c_reduce_min'
    elif op == ReduceOp.PROD:
        op_type = 'c_reduce_prod'

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
        attrs={
K
kuizhiqing 已提交
747 748 749
            'ring_id': ring_id,
            'use_calc_stream': use_calc_stream,
            'root_id': gdst,
750 751 752
        })


K
kuizhiqing 已提交
753
def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
754 755
    """

756 757 758 759 760 761 762 763 764
    Gather tensors from all participators and all get the result. As shown
    below, 4 GPUs each start 4 processes and the data on each GPU is represnted
    by the GPU number. Through the all_gather operator, each GPU will have data
    from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png
        :width: 800
        :alt: all_gather
        :align: center
765 766 767 768 769 770

    Args:
        tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
            should be float16, float32, float64, int32 or int64.
        tensor (Tensor): The Tensor to send. Its data type
            should be float16, float32, float64, int32 or int64.
K
kuizhiqing 已提交
771
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
772 773
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
774 775 776 777 778 779 780

    Returns:
        None.

    Examples:
        .. code-block:: python

781
            # required: distributed
782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800
            import numpy as np
            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            tensor_list = []
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data1 = np.array([[4, 5, 6], [4, 5, 6]])
                np_data2 = np.array([[4, 5, 6], [4, 5, 6]])
                data1 = paddle.to_tensor(np_data1)
                data2 = paddle.to_tensor(np_data2)
                paddle.distributed.all_gather(tensor_list, data1)
            else:
                np_data1 = np.array([[1, 2, 3], [1, 2, 3]])
                np_data2 = np.array([[1, 2, 3], [1, 2, 3]])
                data1 = paddle.to_tensor(np_data1)
                data2 = paddle.to_tensor(np_data2)
                paddle.distributed.all_gather(tensor_list, data2)
801
    """
K
kuizhiqing 已提交
802 803 804
    if group is not None and not group.is_member():
        return

805 806 807 808 809 810 811 812 813
    if framework._in_eager_mode_ and in_dygraph_mode():
        group = _get_default_group() if group is None else group
        out = paddle.concat(tensor_list)
        task = group.process_group.all_gather(tensor, out)
        task.wait()
        tensor_list.clear()
        tensor_list.extend(paddle.split(out, group.nranks, 0))
        return

K
kuizhiqing 已提交
814 815 816
    ring_id = 0 if group is None else group.id
    nranks = _get_global_group().nranks if group is None else group.nranks

J
Jiabin Yang 已提交
817
    if _non_static_mode():
818 819
        out = _C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream,
                                 'ring_id', ring_id, 'nranks', nranks)
820
    else:
821 822 823
        op_type = 'c_allgather'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839
        if not isinstance(tensor_list, list):
            raise ValueError("The type of 'tensor_list' for all_gather "
                             "should be list.")
        for elem in tensor_list:
            check_variable_and_dtype(
                elem, 'tensor_list',
                ['float16', 'float32', 'float64', 'int32', 'int64'],
                'all_gather')
        check_variable_and_dtype(
            tensor, 'tensor',
            ['float16', 'float32', 'float64', 'int32', 'int64'], 'all_gather')
        helper.append_op(
            type=op_type,
            inputs={'X': [tensor]},
            outputs={'Out': [out]},
            attrs={
K
kuizhiqing 已提交
840 841 842
                'ring_id': ring_id,
                'use_calc_stream': use_calc_stream,
                'nranks': nranks
843 844
            })

K
kuizhiqing 已提交
845
    tensor_list.extend(paddle.split(out, nranks, 0))
846 847


K
kuizhiqing 已提交
848
def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
849 850
    """

851 852 853 854 855 856 857
    Scatter a tensor to all participators. As shown below, 4 GPUs each start 4 processes and the source of the scatter
    is GPU0. Through scatter operator, the data in GPU0 will be sent to all GPUs averagely.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/scatter.png
        :width: 800
        :alt: scatter
        :align: center
858 859 860 861

    Args:
        tensor (Tensor): The output Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
862
        tensor_list (list|tuple): A list/tuple of Tensors to scatter. Every element in the list must be a Tensor whose data type
K
kuizhiqing 已提交
863 864
            should be float16, float32, float64, int32 or int64. Default value is None.
        src (int): The source rank id. Default value is 0.
K
kuizhiqing 已提交
865
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
866 867
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
868 869 870 871 872 873 874

    Returns:
        None.

    Examples:
        .. code-block:: python

875
            # required: distributed
876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894
            import numpy as np
            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data1 = np.array([7, 8, 9])
                np_data2 = np.array([10, 11, 12])
            else:
                np_data1 = np.array([1, 2, 3])
                np_data2 = np.array([4, 5, 6])
            data1 = paddle.to_tensor(np_data1)
            data2 = paddle.to_tensor(np_data2)
            if paddle.distributed.ParallelEnv().local_rank == 0:
                paddle.distributed.scatter(data1, src=1)
            else:
                paddle.distributed.scatter(data1, tensor_list=[data1, data2], src=1)
            out = data1.numpy()
895
    """
K
kuizhiqing 已提交
896 897 898 899 900 901
    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

902 903 904 905 906 907 908 909 910 911
    if framework._in_eager_mode_ and in_dygraph_mode():
        group = _get_default_group() if group is None else group
        gsrc = group.get_group_rank(src)
        rank = group.rank
        nranks = group.nranks
    else:
        ring_id = 0 if group is None else group.id
        gsrc = src if group is None else group.get_group_rank(src)
        rank = _get_global_group().rank if group is None else group.rank
        nranks = _get_global_group().nranks if group is None else group.nranks
K
kuizhiqing 已提交
912
    assert gsrc >= 0, ("src rank out of group, need global rank")
K
kuizhiqing 已提交
913 914

    if rank != gsrc:
915 916 917 918
        tensor_list = []
        for _ in range(nranks):
            tensor_list.append(tensor)
    temp = paddle.concat(tensor_list, axis=0)
919 920 921 922 923 924 925 926 927
    if framework._in_eager_mode_ and in_dygraph_mode():
        task = group.process_group.scatter(temp, tensor, gsrc)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

    if in_dygraph_mode():
W
wanghuancoder 已提交
928 929 930
        return _C_ops.c_scatter(temp, tensor, 'use_calc_stream',
                                use_calc_stream, 'ring_id', ring_id, 'nranks',
                                nranks, 'root', gsrc)
W
wanghuancoder 已提交
931
    op_type = 'c_scatter'
932 933 934 935 936 937 938 939 940
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'scatter')
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [temp]},
        outputs={'Out': [tensor]},
        attrs={
K
kuizhiqing 已提交
941 942 943
            'ring_id': ring_id,
            'root': gsrc,
            'use_calc_stream': use_calc_stream,
944 945 946 947
            'nranks': nranks,
        })


948
def _c_identity(tensor, group=None):
L
lilong12 已提交
949 950 951 952 953 954 955 956 957 958 959
    """
    Return a copy of the tensor, mainly used with model parallel.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
        group (int): The id of the process group to work on.

    Returns:
        Tensor.
    """
960 961 962 963
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
964
    if _non_static_mode():
W
wanghuancoder 已提交
965 966
        return _C_ops.c_identity(tensor, 'use_calc_stream', True, 'ring_id',
                                 ring_id, 'use_model_parallel', True)
L
lilong12 已提交
967 968 969
    op_type = 'c_identity'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
970

L
lilong12 已提交
971 972 973
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        '_c_identity')
974

L
lilong12 已提交
975 976 977 978 979
    helper.append_op(
        type=op_type,
        inputs={'X': tensor},
        outputs={'Out': out},
        attrs={
980
            'ring_id': ring_id,
L
lilong12 已提交
981 982 983 984 985 986
            'use_calc_stream': True,
            'use_model_parallel': True,
        })
    return out


987
def _c_concat(tensor, group=None):
988 989 990 991 992 993 994 995 996 997 998 999 1000
    """
    Return allgather of the tensor, mainly used with model parallel.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
        group (int): The id of the process group to work on.

    Returns:
        Tensor.
    """
    if group is not None and not group.is_member():
        return
1001 1002
    group = _get_default_group() if group is None else group
    ring_id = group.id
1003

1004
    global_rank = _get_global_env().rank
1005 1006
    rank = group.rank
    nranks = group.nranks
1007

J
Jiabin Yang 已提交
1008
    if _non_static_mode():
W
wanghuancoder 已提交
1009 1010 1011
        return _C_ops.c_concat(tensor, 'ring_id', ring_id, 'use_calc_stream',
                               True, 'rank', rank, 'nranks', nranks,
                               'use_model_parallel', True)
1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028

    op_type = 'c_concat'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)

    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        '_c_concat')

    helper.append_op(
        type=op_type,
        inputs={'X': tensor},
        outputs={'Out': out},
        attrs={
            'ring_id': ring_id,
            'use_calc_stream': True,
            'use_model_parallel': True,
1029 1030
            'nranks': nranks,
            'rank': rank
1031 1032 1033 1034
        })
    return out


1035
def _c_split(tensor, group=None):
L
lilong12 已提交
1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047
    """
    Split tensor evenly among all members, mainly used with model parallel.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
        rank (int): The rank of the current process.
        group (int): The id of the process group to work on.

    Returns:
        Tensor.
    """
1048 1049 1050 1051
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

1052 1053 1054 1055
    global_rank = _get_global_env().rank
    rank = global_rank if group is None else group.get_group_rank(global_rank)
    nranks = _get_global_env().world_size if group is None else group.nranks

J
Jiabin Yang 已提交
1056
    if _non_static_mode():
W
wanghuancoder 已提交
1057 1058 1059
        return _C_ops.c_split(tensor, 'use_calc_stream', True, 'ring_id',
                              ring_id, 'rank', rank, 'nranks', nranks,
                              'use_model_parallel', True)
1060

L
lilong12 已提交
1061 1062 1063
    op_type = 'c_split'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
1064

L
lilong12 已提交
1065 1066 1067
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        '_c_split')
1068

L
lilong12 已提交
1069 1070 1071 1072 1073
    helper.append_op(
        type=op_type,
        inputs={'X': tensor},
        outputs={'Out': out},
        attrs={
1074
            'ring_id': ring_id,
L
lilong12 已提交
1075 1076 1077 1078 1079 1080 1081 1082
            'use_calc_stream': True,
            'rank': rank,
            'nranks': nranks,
            'use_model_parallel': True,
        })
    return out


1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093
def _mp_allreduce(tensor,
                  op=ReduceOp.SUM,
                  group=None,
                  use_calc_stream=True,
                  use_model_parallel=True):
    """[it is same as allreduce above, but it suuports model parallel. And it support inplace startegy]
    """
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
1094
    if _non_static_mode():
1095
        if op == ReduceOp.SUM:
W
wanghuancoder 已提交
1096
            return _C_ops.c_allreduce_sum_(
1097 1098 1099 1100
                tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id,
                "use_model_parallel", use_model_parallel)
        else:
            raise ValueError("Unknown parameter: {}.".format(op))
1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119

    op_type = 'c_allreduce_sum'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)

    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        op_type)

    helper.append_op(
        type=op_type,
        inputs={'X': tensor},
        outputs={'Out': out},
        attrs={
            'ring_id': ring_id,
            'use_calc_stream': use_calc_stream,
            'use_model_parallel': use_model_parallel,
        })
    return out
1120 1121


1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135
def _c_lookup_table(table, index, start_index=0, name=None):
    """
    Lookup table according to index.

    Args:
        table (Tensor): The input Tensor. Its data type
            should be float16, float32, float64.
        index (Tensor): The index to lookup table.
        start_index (int): The initial index for table range.
        name (string): The name of the api

    Returns:
        Tensor.
    """
J
Jiabin Yang 已提交
1136
    if _non_static_mode():
W
wanghuancoder 已提交
1137
        return _C_ops.c_embedding(table, index, "start_index", start_index)
1138

1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151
    op_type = 'c_embedding'
    helper = LayerHelper(op_type, **locals())
    dtype = helper.input_dtype(input_param_name='table')
    check_variable_and_dtype(index, 'input', ['int32', 'int64'], op_type)
    tmp = helper.create_variable_for_type_inference(dtype)
    helper.append_op(
        type='c_embedding',
        inputs={'Ids': index,
                'W': table},
        outputs={'Out': tmp},
        attrs={"start_index": start_index})
    return tmp

1152

B
Baibaifan 已提交
1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190
class _Linear(layers.Layer):
    """
    Linear
    """

    def __init__(self,
                 in_features,
                 out_features,
                 weight_attr=None,
                 bias_attr=None,
                 name=None):
        super(_Linear, self).__init__()
        self._dtype = self._helper.get_default_dtype()
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
        self.weight = self.create_parameter(
            shape=[in_features, out_features],
            attr=self._weight_attr,
            dtype=self._dtype,
            is_bias=False)
        self.bias = self.create_parameter(
            shape=[out_features],
            attr=self._bias_attr,
            dtype=self._dtype,
            is_bias=True)
        self.name = name

    def forward(self, input):
        out = _linear(
            x=input, weight=self.weight, bias=self.bias, name=self.name)
        return out

    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'in_features={}, out_features={}, dtype={}{}'.format(
            self.weight.shape[0], self.weight.shape[1], self._dtype, name_str)


1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210
def _c_softmax_with_cross_entropy(logits,
                                  label,
                                  group=None,
                                  return_softmax=False):
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id
    global_rank = _get_global_env().rank
    rank = global_rank if group is None else group.get_group_rank(global_rank)
    nranks = _get_global_env().world_size if group is None else group.nranks

    input_dims = len(list(logits.shape))
    label_dims = len(list(label.shape))
    if input_dims - 1 != label_dims and input_dims != label_dims:
        raise ValueError(
            'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
             (got nput_dims{}, label_dims{})'.format(input_dims, label_dims))
    if input_dims - 1 == label_dims:
        label = paddle.unsqueeze(label, axis=-1)

J
Jiabin Yang 已提交
1211
    if _non_static_mode():
W
wanghuancoder 已提交
1212
        softmax, loss = _C_ops.c_softmax_with_cross_entropy(
1213 1214 1215 1216 1217 1218
            logits, label, 'ring_id', ring_id, 'rank', rank, 'nranks', nranks)
        if not return_softmax:
            return loss
        else:
            return loss, softmax

W
WangXi 已提交
1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239
    attrs = {
        'ring_id': ring_id,
        'rank': rank,
        'nranks': nranks,
    }
    helper = LayerHelper('c_softmax_with_cross_entropy', **locals())
    softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
    loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
    helper.append_op(
        type='c_softmax_with_cross_entropy',
        inputs={'Logits': logits,
                'Label': label},
        outputs={'Softmax': softmax,
                 'Loss': loss},
        attrs=attrs)

    if return_softmax:
        return loss, softmax

    return loss

1240

B
Baibaifan 已提交
1241 1242 1243 1244
def _linear(x, weight, bias=None, name=None):
    """
    Fuction Linear
    """
J
Jiabin Yang 已提交
1245
    if _non_static_mode():
B
Baibaifan 已提交
1246
        pre_bias = _varbase_creator(dtype=x.dtype)
W
wanghuancoder 已提交
1247 1248
        _C_ops.matmul(x, weight, pre_bias, 'transpose_X', False, 'transpose_Y',
                      False, "alpha", 1)
B
Baibaifan 已提交
1249 1250 1251 1252 1253
        return dygraph_utils._append_bias_in_dygraph(
            pre_bias, bias, axis=len(x.shape) - 1)
    else:
        helper = LayerHelper('linear', **locals())
        dtype = x.dtype
B
Baibaifan 已提交
1254 1255
        assert len(
            x.shape) < 4, "X latitude is not supported greater than 3 now."
B
Baibaifan 已提交
1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282

        check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
                                 'linear')
        check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear')

        inputs = {'X': [x], 'Y': [weight]}
        attrs = {
            'transpose_X': False,
            'transpose_Y': False,
            'alpha': 1,
        }
        tmp = helper.create_variable_for_type_inference(dtype)
        helper.append_op(
            type='matmul_v2', inputs=inputs, outputs={'Out': tmp}, attrs=attrs)
        if bias is not None:
            res = helper.create_variable_for_type_inference(dtype)
            helper.append_op(
                type='elementwise_add',
                inputs={'X': [tmp],
                        'Y': [bias]},
                outputs={'Out': [res]},
                attrs={'axis': len(x.shape) - 1})
        else:
            res = tmp
        return res


1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295
def _set_var_distributed(var):
    if var is None:
        return

    var.is_distributed = True

    # NOTE: use current_block and find_var_recursive to support while_loop
    startup_block = paddle.static.default_startup_program().current_block()
    main_block = paddle.static.default_main_program().current_block()
    startup_block._find_var_recursive(var.name).is_distributed = True
    main_block._find_var_recursive(var.name).is_distributed = True


L
lilong12 已提交
1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306
def _parallel_linear(x,
                     num_rows,
                     num_cols,
                     axis,
                     param_attr,
                     bias_attr,
                     gather_out,
                     inner_rank,
                     nranks,
                     split_tensor,
                     name,
1307
                     group=None):
1308 1309
    """
    Parallel Linear
1310 1311 1312

    axis the dimension of the parameter of linear layer. 
    axis = 0: the row dimension
1313
    axis = 1: the col dimension
1314
    
1315
    """
1316 1317 1318 1319
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

L
lilong12 已提交
1320 1321
    if axis == 0:
        if split_tensor:
1322
            x = _c_split(x, group=group)
1323
    else:
L
lilong12 已提交
1324 1325
        x = _c_identity(x, group=group)

1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343
    linear = paddle.nn.Linear(
        num_rows,
        num_cols,
        weight_attr=param_attr,
        bias_attr=bias_attr,
        name=name)

    # NOTE: npu linear function use matmul_v2 but linear use matmul
    linear_function = _linear if core.is_compiled_with_npu()\
        else paddle.nn.functional.linear
    linear_out = linear_function(
        x,
        linear.weight,
        # NOTE(wangxi): row split, bias need add after allreduce
        None if axis == 0 else linear.bias,
        linear.name)

    _set_var_distributed(linear.weight)
1344 1345 1346 1347
    # set is_distributed for splited bias
    # if a linear layer is splited by row, each rank would hold a complete bias and they should be the same in each rank.
    # if a linear layer is splited by col, the bias would also be split into each rank as its weight
    if axis == 1 and linear._bias_attr != False:
1348
        _set_var_distributed(linear.bias)
L
lilong12 已提交
1349 1350 1351 1352 1353

    if not gather_out: return linear_out

    out_shape = list(linear_out.shape)
    out_shape[0] *= 1 if axis == 0 else nranks
1354
    main_block = paddle.static.default_main_program().current_block()
L
lilong12 已提交
1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368
    out = main_block.create_var(
        shape=out_shape,
        dtype=linear_out.dtype,
        type=linear_out.type,
        lod_level=linear_out.lod_level,
        persistable=False,
        is_data=False,
        need_check_feed=linear_out.desc.need_check_feed())
    if axis == 0:
        main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': linear_out},
            outputs={'Out': out},
            attrs={
1369
                'ring_id': ring_id,
L
lilong12 已提交
1370 1371 1372
                'use_calc_stream': True,
                'use_model_parallel': True
            })
1373 1374
        if linear.bias is not None:
            out = out + linear.bias
L
lilong12 已提交
1375 1376 1377 1378 1379 1380
    else:
        main_block.append_op(
            type='c_concat',
            inputs={'X': linear_out},
            outputs={'Out': out},
            attrs={
1381
                'rank': inner_rank,
1382
                'ring_id': ring_id,
L
lilong12 已提交
1383 1384 1385 1386 1387
                'nranks': nranks,
                'use_calc_stream': True,
                'use_model_parallel': True
            })
    return out
1388 1389


L
lilong12 已提交
1390 1391 1392 1393 1394 1395 1396
def _parallel_embedding(x,
                        per_part_embeddings,
                        origin_size,
                        param_attr,
                        inner_rank,
                        num_partitions,
                        name,
1397
                        group=None):
1398 1399 1400
    """
    Parallel Embedding
    """
1401 1402 1403 1404
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420
    helper = LayerHelper("_parallel_embedding", **locals())

    per_part_size = per_part_embeddings
    rank = inner_rank

    vocab_start_index = rank * per_part_size
    dtype = helper.get_default_dtype()
    size = [per_part_size, origin_size[1]]

    weight = helper.create_parameter(
        attr=param_attr, shape=size, dtype=dtype, is_bias=False)

    if num_partitions == 1:
        return paddle.nn.functional.embedding(
            x, weight=weight, padding_idx=None, sparse=False, name=name)

1421 1422
    startup_block = paddle.static.default_startup_program().global_block()
    main_block = paddle.static.default_main_program().global_block()
1423 1424 1425 1426 1427 1428 1429 1430 1431 1432
    startup_block.vars[weight.name].is_distributed = True
    main_block.vars[weight.name].is_distributed = True

    output_parallel = paddle.distributed.collective._c_lookup_table(
        weight, x, start_index=vocab_start_index, name=name)
    out = paddle.distributed.collective._mp_allreduce(
        output_parallel,
        group=group,
        use_calc_stream=True,
        use_model_parallel=True)
L
lilong12 已提交
1433
    return out
1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456


def split(x,
          size,
          operation,
          axis=0,
          num_partitions=1,
          gather_out=True,
          weight_attr=None,
          bias_attr=None,
          name=None):
    """

    Split the weight of the specified operation into multiple devices
    and do the computation in parallel.

    Now the following three cases are supported.

    Case 1: Parallel Embedding
        The weight of the embedding operation is a NxM matrix with N rows and M columns.
        With parallel embedding, the weight is split into num_partitions partitions, each
        of which is a matrix with (N/num_partitions + 1) rows and M column where the last
        row as the padding idx.
K
kuizhiqing 已提交
1457

1458 1459 1460 1461 1462 1463 1464 1465 1466
        Suppose we split the NxM weight into two partitons on device_0 and device_1
        respectively. Then, one each device, the final weight has (N/2 + 1) rows with the
        index range from 0 to N/2. On device_0, all values in the input within [0, N/2 -1]
        keep unchanged and all other values are changed to N/2 which is the padding index and
        are mapped to all zeros after embedding. In the same way, on device_1, the value V in the
        input within [N/2, N-1] will be changed to (V - N/2), and all other values are changed
        to N/2 and are mapped to all zeros after embedding. Finally, the results on the two
        devices are sum-reduced.

1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481
        The Embedding put on single card is as shown below:

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_single.png
            :width: 800
            :height: 350
            :alt: single_embedding
            :align: center

        Parallel Embedding is shown as below:

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_split.png
            :width: 800
            :alt: split_embedding
            :align: center

1482 1483 1484 1485 1486
    Case 2: Row Parallel Linear
        The weight of the linear operation is a NxM matrix with N rows and M columns.
        With row parallel linear, the weight is split into num_partitions partitions, each
        of which is a matrix with N/num_partitions rows and M column.

1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504
        The linear layer put on single card is shown as below, the input variable is represented by X,
        the weight matrix is represented by W and the output vaiable is O. The linear layer on single card is 
        simple matrix multiplication operation, O = X * W.

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_single.png
            :width: 800
            :alt: single_linear
            :align: center

        Row Parallel Linear is shown as below. As the name suggests, Row Parallel Linear splits the weight matrix W into
        [[W_row1], [W_row2]] along the row. And accordingly the input is splitted along the column into [X_col1, X_col2] and multiply their
        respective weight matrices. Finally apply AllReduce on the output from each card to get the final output.

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_row.png
            :width: 800
            :alt: split_row
            :align: center

1505 1506 1507 1508 1509
    Case 3: Column Parallel Linear
        The weight of the linear operation is a NxM matrix with N rows and M columns.
        With column parallel linear, the weight is split into num_paratitions partitions, each
        of which is a matrix with N rows and M/num_partitions column.

1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526
        The linear layer put on single card has been illustrated on case 2 and Column Parallel Linear
        is shown as below. The Column Parallel Linear splits the weight matrix W into [W_col1, W_col2] along the column and 
        these splitted matrices respectively multiply the input. Finally apply AllGather on the output from each card to get the final output. 

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col.png
            :width: 800
            :alt: split_col
            :align: center
    
    As observed, the column parallel linear and row parallel linear can be combined to skip one ALLGATHER communication
    operator. Furthermore the Attention and MLP can be combined to imporve the performance as shown below.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col_row.png
            :width: 800
            :alt: split_col_row
            :align: center

1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546
    Args:
        x (Tensor): Input tensor. It's data type should be float16, float32, float64, int32 or int64.
        size (list|tuple): A list or tuple with two elements indicating the shape of the weight.
        operation (str): The name of the operation. The supported operations are 'linear' and 'embedding'.
        axis (int, Optional): Indicate along which axis to split the weight. Default: 0.
        num_partitions (int, Optional): How many parts the weight is partitioned. Default: 1.
        gather_out (bool, Optional): Whether to gather the output after computation. By default, the output
            on each partitions will be gathered after computation. Default: True.
        weight_attr (ParamAttr, Optional): The parameter attribute for the learnable
            weights(Parameter) of the specified operation. Default: None.
        bias_attr (ParamAttr, Optional): The parameter attribute for the bias
            of the specified operation. Default: None.
        name (str, Optional): The default value is None. Normally there is no need for user to set this
            property. Default: None. For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor.

    Examples:
        .. code-block:: python
1547

1548
            # required: distributed
1549
            import paddle
1550
            import paddle.distributed.fleet as fleet
1551

1552
            paddle.enable_static()
1553
            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
1554
            fleet.init(is_collective=True)
1555
            data = paddle.randint(0, 8, shape=[10,4])
1556
            emb_out = paddle.distributed.split(
1557 1558 1559 1560
                data,
                (8, 8),
                operation="embedding",
                num_partitions=2)
1561

1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577
    """
    assert isinstance(size, (list, tuple)), (
        "The type of size for "
        "paddle.distributed.split must be list or tuple.")
    assert len(size) == 2, ("Number of elements in size of "
                            "paddle.distributed.split must be two.")
    assert isinstance(operation, str), ("The type of operation for "
                                        "paddle.distributed.split must be str.")
    supported_operations = [
        'linear',
        'embedding',
    ]
    assert operation in supported_operations, (
        "The operation for "
        "paddle.distributed.split must be one of {}.".format(
            supported_operations))
J
Jiabin Yang 已提交
1578
    if _non_static_mode():
L
lilong12 已提交
1579 1580 1581 1582
        raise ValueError(
            "paddle.distributed.split cannot be used in dynamic "
            "graph mode, plese use ParallelEmbedding, ParallelRowLinear, "
            "ParallelColumnLinear instead.")
1583
    else:
1584
        from .fleet import fleet
1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595
        assert fleet._role_maker, ("To use paddle.distributed.split, "
                                   "you must call fleet.init() firstly.")
        rank = fleet.worker_index()
        nranks = fleet.worker_num()

    # rank within a model parallel group
    inner_rank = rank % num_partitions

    if operation == "embedding":
        assert axis == 0, ("We only support to split the weight of embedding "
                           "along the first axis now.")
1596 1597 1598
        assert size[0] % num_partitions == 0, \
            "The length of the vocabulary must be divisible by num_partitions " \
            "but received vocabulary={} num_partitions={}".format(size[0], num_partitions)
1599

1600
        per_part_size = size[0] // num_partitions
B
Baibaifan 已提交
1601 1602 1603 1604 1605 1606 1607 1608 1609 1610
        emb_out = _parallel_embedding(
            x,
            per_part_size,
            size,
            weight_attr,
            inner_rank,
            num_partitions,
            name,
            group=None)
        return emb_out
1611
    else:
L
lilong12 已提交
1612
        should_split = False
1613 1614 1615 1616 1617 1618 1619
        if axis == 0:
            assert size[0] % num_partitions == 0, (
                "Number of rows of the weight for linear ({}) must be"
                " divisible by num_partitions ({})".format(size[0],
                                                           num_partitions))
            per_part_size = size[0] // num_partitions
            linear_size = (per_part_size, size[1])
L
lilong12 已提交
1620
            if x.shape[-1] == size[0]: should_split = True
1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641

        elif axis == 1:
            assert size[1] % num_partitions == 0, (
                "Number of column of the weight for linear ({}) must be"
                " divisible by num_partitions ({})".format(size[1],
                                                           num_partitions))
            per_part_size = size[1] // num_partitions
            linear_size = (size[0], per_part_size)
        else:
            raise ValueError("The value of axis must be 0 or 1, but the value "
                             "given is {}.".format(axis))

        linear_out = _parallel_linear(
            x,
            linear_size[0],
            linear_size[1],
            axis,
            weight_attr,
            bias_attr,
            gather_out,
            inner_rank,
L
lilong12 已提交
1642 1643 1644
            num_partitions,
            should_split,
            name=name,
1645
            group=None)
1646
        return linear_out
L
lilong12 已提交
1647 1648


L
lilong12 已提交
1649 1650
def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
    """
1651 1652 1653 1654 1655 1656 1657 1658 1659 1660
    Scatter tensors in in_tensor_list to all participators averagely and gather the result tensors in out_tensor_list.
    As shown below, the in_tensor_list in GPU0 includes 0_0 and 0_1, and GPU1 includes 1_0 and 1_1.
    Through alltoall operator, the 0_0 in GPU0 will be sent to GPU0 and 0_1 to GPU1, 1_0 in GPU1 sent to GPU0 and 1_1 to GPU1.
    Finally the out_tensor_list in GPU0 includes 0_0 and 1_0, and GPU1 includes 0_1 and 1_1.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/alltoall.png
        :width: 800
        :alt: alltoall
        :align: center

L
lilong12 已提交
1661 1662 1663 1664 1665 1666 1667
    Args:
        in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type
            should be float16, float32, float64, int32 or int64.
        out_tensor_list (Tensor): A list of output Tensors. The data type of its elements should be the same as the
            data type of the input Tensors.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
        use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True.
1668
    
L
lilong12 已提交
1669 1670
    Returns:
        None.
1671
    
L
lilong12 已提交
1672 1673
    Examples:
        .. code-block:: python
1674

L
lilong12 已提交
1675 1676 1677 1678
            # required: distributed
            import numpy as np
            import paddle
            from paddle.distributed import init_parallel_env
1679
            
L
lilong12 已提交
1680 1681 1682 1683 1684 1685 1686 1687 1688 1689
            init_parallel_env()
            out_tensor_list = []
            if paddle.distributed.ParallelEnv().rank == 0:
                np_data1 = np.array([[1, 2, 3], [4, 5, 6]])
                np_data2 = np.array([[7, 8, 9], [10, 11, 12]])
            else:
                np_data1 = np.array([[13, 14, 15], [16, 17, 18]])
                np_data2 = np.array([[19, 20, 21], [22, 23, 24]])
            data1 = paddle.to_tensor(np_data1)
            data2 = paddle.to_tensor(np_data2)
李季 已提交
1690
            paddle.distributed.alltoall([data1, data2], out_tensor_list)
L
lilong12 已提交
1691 1692 1693 1694 1695 1696
            # out for rank 0: [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]]
            # out for rank 1: [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]]
    """
    if group is not None and not group.is_member():
        return

1697 1698 1699 1700 1701
    if framework._in_eager_mode_ and in_dygraph_mode():
        group = _get_default_group() if group is None else group
    else:
        ring_id = 0 if group is None else group.id

L
lilong12 已提交
1702
    temp = paddle.concat(in_tensor_list, axis=0)
李季 已提交
1703
    nranks = len(in_tensor_list)
1704 1705 1706 1707 1708 1709 1710 1711
    if framework._in_eager_mode_ and in_dygraph_mode():
        out = paddle.concat(out_tensor_list, axis=0)
        task = group.process_group.alltoall(temp, out)
        task.wait()
        out_tensor_list.clear()
        out_tensor_list.extend(paddle.split(out, nranks, 0))
        return

J
Jiabin Yang 已提交
1712
    if _non_static_mode():
李季 已提交
1713 1714
        out = _C_ops.alltoall(temp, 'use_calc_stream', use_calc_stream,
                              'ring_id', ring_id)
L
lilong12 已提交
1715
    else:
W
wanghuancoder 已提交
1716 1717 1718 1719 1720
        op_type = 'alltoall'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(
            dtype=in_tensor_list[0].dtype)

L
lilong12 已提交
1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739
        if not isinstance(in_tensor_list, list):
            raise ValueError("The type of 'in_tensor_list' for all_to_all "
                             "should be list.")
        for elem in in_tensor_list:
            check_variable_and_dtype(
                elem, 'in_tensor_list',
                ['float16', 'float32', 'float64', 'int32', 'int64'],
                'all_to_all')
        if not isinstance(out_tensor_list, list):
            raise ValueError("The type of 'out_tensor_list' for all_to_all "
                             "should be list.")
        if len(out_tensor_list) != 0:
            raise ValueError("The 'out_tensor_list' for all_to_all "
                             "must be an empty list.")
        helper.append_op(
            type=op_type,
            inputs={'X': [temp]},
            outputs={'Out': [out]},
            attrs={
L
lilong12 已提交
1740
                'ring_id': ring_id,
L
lilong12 已提交
1741 1742 1743 1744 1745
                'use_calc_stream': use_calc_stream,
            })
    out_tensor_list.extend(paddle.split(out, nranks, 0))


L
lilong12 已提交
1746 1747 1748 1749 1750 1751 1752 1753
def send(tensor, dst=0, group=None, use_calc_stream=True):
    """
    Send a tensor to the receiver.

    Args:
        tensor (Tensor): The Tensor to send. Its data type
            should be float16, float32, float64, int32 or int64.
        dst (int): The destination rank id.
L
lilong12 已提交
1754 1755
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
        use_calc_stream (bool, optional): Whether to use calculate stream or communication stream. Default: True.
1756
    
L
lilong12 已提交
1757 1758 1759 1760 1761
    Returns:
        None.

    Examples:
        .. code-block:: python
1762

L
lilong12 已提交
1763
            # required: distributed
L
lilong12 已提交
1764
            import paddle
L
lilong12 已提交
1765
            from paddle.distributed import init_parallel_env
1766

L
lilong12 已提交
1767 1768 1769 1770 1771 1772 1773 1774
            init_parallel_env()
            if paddle.distributed.ParallelEnv().rank == 0:
                data = paddle.to_tensor([7, 8, 9])
                paddle.distributed.send(data, dst=1)
            else:
                data = paddle.to_tensor([1,2,3])
                paddle.distributed.recv(data, src=0)
            out = data.numpy()
L
lilong12 已提交
1775 1776 1777
    """
    if group is not None and not group.is_member():
        return
1778 1779 1780 1781 1782 1783 1784 1785 1786 1787

    if framework._in_eager_mode_ and in_dygraph_mode():
        group = _get_default_group() if group is None else group
        task = group.process_group.send(tensor, dst)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

L
lilong12 已提交
1788 1789
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
1790
    if _non_static_mode():
W
wanghuancoder 已提交
1791 1792
        return _C_ops.send_v2(tensor, 'use_calc_stream', use_calc_stream,
                              'ring_id', ring_id, 'peer', dst)
W
wanghuancoder 已提交
1793
    op_type = 'send_v2'
L
lilong12 已提交
1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'send')

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        attrs={
            'ring_id': ring_id,
            'peer': dst,
            'use_calc_stream': use_calc_stream,
        })


def recv(tensor, src=0, group=None, use_calc_stream=True):
    """
    Receive a tensor to the sender.

    Args:
        tensor (Tensor): The Tensor to receive. Its data type
            should be float16, float32, float64, int32 or int64.
        src (int): The source rank id.
L
lilong12 已提交
1817 1818
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
        use_calc_stream (bool, optional): Whether to use calculate stream or communication stream. Default: True.
1819
    
L
lilong12 已提交
1820 1821 1822 1823 1824
    Returns:
        None.

    Examples:
        .. code-block:: python
1825

L
lilong12 已提交
1826
            # required: distributed
L
lilong12 已提交
1827
            import paddle
L
lilong12 已提交
1828
            from paddle.distributed import init_parallel_env
1829

L
lilong12 已提交
1830 1831 1832 1833 1834 1835 1836 1837
            init_parallel_env()
            if paddle.distributed.ParallelEnv().rank == 0:
                data = paddle.to_tensor([7, 8, 9])
                paddle.distributed.send(data, dst=1)
            else:
                data = paddle.to_tensor([1,2,3])
                paddle.distributed.recv(data, src=0)
            out = data.numpy()
L
lilong12 已提交
1838 1839 1840
    """
    if group is not None and not group.is_member():
        return
1841 1842 1843 1844 1845 1846 1847 1848 1849 1850

    if framework._in_eager_mode_ and in_dygraph_mode():
        group = _get_default_group() if group is None else group
        task = group.process_group.recv(tensor, src)
        if use_calc_stream:
            task.wait()
            return None
        else:
            return task

L
lilong12 已提交
1851 1852
    ring_id = 0 if group is None else group.id

J
Jiabin Yang 已提交
1853
    if _non_static_mode():
W
wanghuancoder 已提交
1854 1855 1856
        return _C_ops.recv_v2(tensor, 'use_calc_stream', use_calc_stream,
                              'ring_id', ring_id, 'peer', src, 'dtype',
                              tensor.dtype, 'out_shape', tensor.shape)
W
wanghuancoder 已提交
1857
    op_type = 'recv_v2'
L
lilong12 已提交
1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'recv')
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        outputs={'Out': [tensor]},
        attrs={
            'ring_id': ring_id,
            'peer': src,
            'out_shape': tensor.shape,
            'dtype': tensor.dtype,
            'use_calc_stream': use_calc_stream,
        })