collective.py 67.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import os
17
from datetime import timedelta
18
from ..fluid.layer_helper import LayerHelper
19 20 21 22
from ..fluid.framework import Variable
from ..fluid.framework import OpProtoHolder
from ..fluid.framework import in_dygraph_mode
from ..fluid.framework import convert_np_dtype_to_dtype_
J
Jiangxinz 已提交
23
from ..fluid.framework import _varbase_creator
24 25 26 27
from ..fluid.data_feeder import convert_dtype
from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.data_feeder import check_type
from ..fluid.data_feeder import check_dtype
28 29
from ..fluid.layers.tensor import fill_constant
from ..fluid.layers import utils
B
Baibaifan 已提交
30
from ..fluid.dygraph import layers
31 32 33 34
from ..fluid.dygraph.parallel import prepare_context
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
W
wanghuancoder 已提交
35
from paddle import _C_ops
J
Jiangxinz 已提交
36
import paddle.fluid.dygraph_utils as dygraph_utils
37

38
__all__ = []
39 40 41


class ReduceOp:
L
lilong12 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
    """
    Specify the type of operation used for element-wise reductions.
    It should be one of the following values:

        ReduceOp.SUM

        ReduceOp.MAX

        ReduceOp.MIN

        ReduceOp.PROD

    Examples:
        .. code-block:: python

            import numpy as np
            import paddle
            from paddle.distributed import ReduceOp
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data = np.array([[4, 5, 6], [4, 5, 6]])
            else:
                np_data = np.array([[1, 2, 3], [1, 2, 3]])
            data = paddle.to_tensor(np_data)
            paddle.distributed.all_reduce(data, op=ReduceOp.SUM)
            out = data.numpy()
            # [[5, 7, 9], [5, 7, 9]]
    """
73 74 75 76
    SUM = 0
    MAX = 1
    MIN = 2
    PROD = 3
77
    AVG = 4
78 79


K
kuizhiqing 已提交
80 81 82 83
class Group():
    """
    The abstract representation of group.
    """
84

85
    def __init__(self, rank, rank_num, id=0, ranks=[], pg=None, name=None):
86 87
        self.rank = rank
        self.nranks = rank_num
K
kuizhiqing 已提交
88 89
        self.id = id
        self.ranks = ranks
90 91
        self.pg = pg
        self.name = name
K
kuizhiqing 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105

    def is_member(self):
        if self.rank < 0:
            return False
        if self.nranks < 2:
            return False
        return True

    def get_group_rank(self, rank):
        if self.is_member() and rank in self.ranks:
            return self.ranks.index(rank)
        else:
            return -1

106 107 108 109
    @property
    def process_group(self):
        return self.pg

110 111 112 113
    def __repr__(self):
        debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format(
            self.rank, self.nranks, self.id)
        debug_str += ", ".join(map(str, self.ranks))
114 115
        debug_str += "; name: "
        debug_str += self.name if self.name else "None"
116 117
        return debug_str

K
kuizhiqing 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132

_global_env = None


def _get_global_env():
    global _global_env
    if not _global_env:
        _global_env = paddle.distributed.ParallelEnv()
    return _global_env


# group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group]
_group_map = {}

133 134 135 136 137 138 139 140 141 142 143
# group map by name : the map of all groups from their names
# Dict[name, Group]
_group_map_by_name = {}

# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"

_valid_backend_list = ['nccl', 'gloo', 'hccl']
_default_store = None  # the default tcp store
_default_backend = None

K
kuizhiqing 已提交
144 145 146 147 148

def _get_group_map():
    global _group_map
    if not _group_map:
        genv = _get_global_env()
149 150
        _group_map[0] = Group(
            genv.rank, genv.world_size, ranks=list(range(genv.world_size)))
K
kuizhiqing 已提交
151 152 153 154 155 156 157
    return _group_map


def _get_global_group():
    return _get_group_map()[0]


158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
def _get_group_map_by_name():
    global _group_map_by_name
    assert _default_group_name in _group_map_by_name, (
        "Call paddle.distributed.init_parallel_env first "
        "to initialize the distributed environment.")
    return _group_map_by_name


def _get_default_group():
    assert _default_group_name in _group_map_by_name, (
        "Call paddle.distributed.init_parallel_env first "
        "to initialize the distributed environment.")
    return _get_group_map_by_name()[_default_group_name]


K
kuizhiqing 已提交
173 174 175 176
def _new_ring_id():
    return len(_get_group_map()) + max(_get_global_env().nrings, 9)


177 178 179 180
def _new_group_name_id():
    return len(_get_group_map_by_name()) + max(_get_global_env().nrings, 9)


K
kuizhiqing 已提交
181 182 183 184 185 186
def get_group(id=0):
    """

    Get group instance by group id.

    Args:
K
kuizhiqing 已提交
187
        id (int): the group id. Default value is 0.
K
kuizhiqing 已提交
188 189 190 191 192 193 194 195 196 197 198 199 200 201

    Returns:
        Group: the group instance.

    Examples:
        .. code-block:: python

            ...
            gid = paddle.distributed.new_group([2,4,6])
            paddle.distributed.get_group(gid.id)

    """

    gm = _get_group_map()
J
Jiangxinz 已提交
202
    return gm[id] if id in gm else None
K
kuizhiqing 已提交
203 204


205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392
def _new_process_group_impl(backend, store, rank, world_size, group_name,
                            pg_options):
    if backend == "gloo":
        gloo_store = core.GlooStore(store)

    pg = None
    if backend == "gloo":
        pg = core.ProcessGroupGloo(gloo_store, rank, world_size)
    elif backend == "nccl":
        pg = core.ProcessGroupNCCL(store, rank, world_size)
    elif backend == "hccl":
        pg = core.ProcessGroupHCCL(store, rank, world_size)

    return pg


def _init_parallel_env(rank=None,
                       world_size=None,
                       backend="nccl",
                       timeout=timedelta(0),
                       pg_options=None):
    """

    Initializes the default distributed environment.
    
    Args:
        rank (int, optional): the rank of the current process or device from 0 to world_size (exclusive).
            If you launch your training with paddle.distributed.run or 
            paddle.distributed.launch module, None can be given. Default: None.
        world_size (int, optional): total number of processes or devices.
            If you launch your training with paddle.distributed.run or 
            paddle.distributed.launch module, None can be given. Default: None.
        backend (str, optional): the name of the backend used to initialize
            the distributed environment. The value can be one of 'nccl' for
            GPU, 'gloo' for CPU or 'hccl' for NPU. Default: 'nccl'.
        timeout (datetime.timedelta, optional): timeout used for operations of
            the group. Default: datetime.timedelta(0) which means no timeout.
        pg_options (dict, optional): options for the group. Default: None.

    Returns:
        Group: a group.

    Examples:

        .. code-block:: python

            # filename: train.py
            import paddle
            paddle.distributed.init_parallel_env(0, 1)
            
            # how to start
            # python paddle.distributed.run --gpus="0,1" train.py

    """

    global _group_map_by_name
    global _default_group_name
    assert _default_group_name not in _group_map_by_name, (
        "The default distributed environment has been initialized.")

    assert backend in _valid_backend_list, (
        "Backend must be one of {}, but the given one is: {}".format(
            _valid_backend_list, backend))
    _default_backend = backend

    assert isinstance(timeout, timedelta), (
        "timeout must be of the type datetime.timedelta.")

    if rank is None or world_size is None:
        assert rank is None and world_size is None, (
            "rank and world_size should be unset at the same time.")
        trainer_id = os.getenv("PADDLE_TRAINER_ID", None)
        trainer_num = os.getenv("PADDLE_TRAINERS_NUM", None)
        if trainer_id is None or trainer_num is None:
            warnings.warn("If rank and world_size are both None, please start "
                          "your training with paddle.distributed.run or "
                          "paddle.distributed.launch module. Otherwise, "
                          "init_parallel_env will do nothing.")
            return None
        rank = int(trainer_id)
        world_size = int(trainer_num)

    assert rank >= 0 and world_size > rank and world_size > 1, (
        "rank must be non-negative and world_size must be the "
        "maximum rank plus one. Moreover, at least two processes are "
        "required to create a process group.")

    master_addr = os.getenv("MASTER_ADDR", None)
    master_port = os.getenv("MASTER_PORT", None)
    if not master_addr or not master_port:
        endpoints = os.getenv("PADDLE_MASTER", None)
        if endpoints is None:
            endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", None)
        if not endpoints:
            raise ValueError(
                "The environment variable 'MASTER_ADDR' and 'MASTER_PORT' "
                "must be specified, for example 'export MASTER_ADDR=127.0.0.1' "
                "and 'export MASTER_ADDR=54612'. Or you can start your training"
                "with paddle.distributed.run or "
                "paddle.distributed.luanch module.")
        if ',' in endpoints:
            endpoints = endpoints.split(',')[0]
        master_addr, master_port = endpoints.split(":")

    master_port = int(master_port)

    is_master = rank == 0
    global _default_store
    _default_store = core.TCPStore(master_addr, master_port, is_master,
                                   world_size, timeout)

    pg = _new_process_group_impl(backend, _default_store, rank, world_size,
                                 _default_group_name, pg_options)
    ranks = list(range(world_size))
    group = Group(
        rank, world_size, id=0, ranks=ranks, pg=pg, name=_default_group_name)

    paddle.fluid.dygraph.parallel_helper._set_parallel_ctx(True)
    _group_map_by_name[_default_group_name] = group
    return group


def _new_group(ranks=None,
               backend=None,
               group_name=None,
               timeout=timedelta(0),
               pg_options=None):
    """
    Create a new process group.

    Args:
        ranks (list, optional): list of ranks for the new group. If None is given, 
            all processes is used. Default: None.
        backend (str, optional): the name of the backend used to initialize
            the distributed environment. Default: the one for init_parallel_env.
        timeout (datetime.timedelta, optional): timeout used for operations of
            the group. Default: datetime.timedelta(0).
        pg_options (dict, optional): options for the group. Default: None.

    Examples:

        .. code-block:: python

            import paddle
            paddle.distributed.init_parallel_env(0, 1)
            paddle.distributed.new_group([0, 1])

            # how to start
            # python paddle.distributed.run --gpus="0,1" train.py

    """
    global _default_group_name
    if group_name is None:
        group_name = _default_group_name + str(_new_group_name_id())
    if group_name == _default_group_name:
        raise ValueError("group_name must be specified and it cannot be '{}' "
                         "which is used for the default process group created "
                         "by init_parallel_env.".format(_default_group_name))
    global_group = _get_default_group()
    global_rank = global_group.rank
    global_ranks = global_group.ranks
    if ranks is None:
        ranks = global_ranks
    assert len(ranks) <= len(global_ranks), (
        "Size of new group must be less than or "
        "equal to that of the default global group.")
    size = len(ranks)
    assert size > 1, "A group must have at least two memebers."
    ranks = sorted(ranks)
    if global_rank in ranks:
        rank = ranks.index(global_rank)
        pg = _new_process_group_impl(backend, _default_store, rank, size,
                                     group_name, pg_options)
    else:
        rank = -1
        pg = None
    group = Group(
        rank,
        size,
        id=_new_group_name_id(),
        ranks=ranks,
        pg=pg,
        name=group_name)
    _group_map_by_name[group_name] = group

    return group


S
ShenLiang 已提交
393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420
def barrier(group=None):
    """

    Barrier among all participators in the group.

    Args:
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            paddle.distributed.barrier()
    """
    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id

    temp = fill_constant([1], dtype="int32", value="1")
    if in_dygraph_mode():
W
wanghuancoder 已提交
421
        return _C_ops.barrier(temp, temp, 'ring_id', ring_id)
W
wanghuancoder 已提交
422 423 424

    op_type = 'barrier'

S
ShenLiang 已提交
425 426 427 428 429 430 431 432 433 434
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'group' for barrier must be int.")
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [temp]},
        outputs={'Out': [temp]},
        attrs={'ring_id': ring_id})


K
kuizhiqing 已提交
435 436 437
def new_group(ranks=None, backend=None):
    """

K
kuizhiqing 已提交
438
    Creates a new distributed communication group.
K
kuizhiqing 已提交
439 440

    Args:
K
kuizhiqing 已提交
441
        ranks (list): The global ranks of group members.
K
kuizhiqing 已提交
442 443 444
        backend (str): The backend used to create group, only nccl is supported now.

    Returns:
K
kuizhiqing 已提交
445
        Group: The group instance.
K
kuizhiqing 已提交
446 447 448 449 450 451 452

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
453 454 455
            tindata = paddle.randn(shape=[2, 3])
            gp = paddle.distributed.new_group([2,4,6])
            paddle.distributed.all_reduce(tindata, group=gp, use_calc_stream=False)
K
kuizhiqing 已提交
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472

    """

    if not backend:
        backend = 'nccl'
    assert backend == 'nccl', ("backend other than nccl is not supported yet")

    genv = _get_global_env()
    global_rank = genv.rank

    ring_id = _new_ring_id()

    global _group_map
    if global_rank not in ranks:
        gp = Group(-1, -1, ring_id, ranks)
        _group_map[ring_id] = gp
    else:
473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492
        ranks = sorted(ranks)
        group_rank = ranks.index(global_rank)
        group_size = len(ranks)
        gp = Group(group_rank, group_size, ring_id, ranks)
        _group_map[ring_id] = gp

        if group_size >= 2:
            strategy = core.ParallelStrategy()
            strategy.nranks = group_size
            strategy.local_rank = group_rank
            strategy.trainer_endpoints = [
                genv.trainer_endpoints[i] for i in ranks
            ]
            strategy.current_endpoint = genv.current_endpoint
            strategy.nrings = 1

            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(genv.device_id)
                core.NCCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
493 494 495 496
            elif core.is_compiled_with_npu():
                place = core.NPUPlace(genv.device_id)
                core.HCCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
497 498 499 500
            elif core.is_compiled_with_mlu():
                place = core.MLUPlace(genv.device_id)
                core.CNCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
501 502 503 504 505 506 507
            else:
                assert False, ("no cuda device found")
        else:
            return gp

    # TODO(shenliang03): This is a temporary solution to solve the problem of 
    # hang caused by cross-creation of new_group
508 509 510
    tmp = paddle.to_tensor(
        [1], dtype="int32") if in_dygraph_mode() else fill_constant(
            [0], dtype="int32", value="1")
511 512
    paddle.distributed.all_reduce(tmp, use_calc_stream=True)
    paddle.distributed.wait(tmp)
K
kuizhiqing 已提交
513 514
    return gp

515

K
kuizhiqing 已提交
516 517 518 519 520 521 522 523
def wait(tensor, group=None, use_calc_stream=True):
    """

    wait to sync stream for group.

    Args:
        tensor (Tensor): The Tensor used before sync.
        group (Group): The Group instance to perform sync.
K
kuizhiqing 已提交
524 525
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
K
kuizhiqing 已提交
526 527 528 529 530 531 532 533 534 535

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
536
            tindata = paddle.randn(shape=[2, 3])
K
kuizhiqing 已提交
537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555
            paddle.distributed.all_reduce(tindata, use_calc_stream=True)
            paddle.distributed.wait(tindata)

    """

    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id

    if use_calc_stream:
        _sync_calc_stream(tensor)
    else:
        _sync_comm_stream(tensor, ring_id)


def _sync_calc_stream(tensor):

    if in_dygraph_mode():
W
wanghuancoder 已提交
556
        return _C_ops.c_sync_calc_stream(tensor, tensor)
K
kuizhiqing 已提交
557 558 559 560 561 562 563 564

    op_type = 'c_sync_calc_stream'

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]}, )
565

566

K
kuizhiqing 已提交
567
def _sync_comm_stream(tensor, ring_id=0):
568

K
kuizhiqing 已提交
569
    if in_dygraph_mode():
W
wanghuancoder 已提交
570
        return _C_ops.c_sync_comm_stream([tensor], [tensor], 'ring_id', ring_id)
571

K
kuizhiqing 已提交
572
    op_type = 'c_sync_comm_stream'
573

K
kuizhiqing 已提交
574 575 576 577 578 579 580 581 582
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
        attrs={'ring_id': ring_id}, )


def broadcast(tensor, src, group=None, use_calc_stream=True):
583 584 585
    """

    Broadcast a tensor from the source to all others.
586 587 588 589 590 591 592
    As shown below, 4 GPUs each start 4 processes and GPU0 owns data 0. Through broadcast operator,
    the data 0 will be sent to all GPUs from GPU0.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/broadcast.png
        :width: 800
        :alt: broadcast
        :align: center
593 594 595 596 597

    Args:
        tensor (Tensor): The Tensor to send if current rank is the source, or the tensor to receive otherwise. Its data type
            should be float16, float32, float64, int32 or int64.
        src (int): The source rank.
K
kuizhiqing 已提交
598
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
599 600
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
601 602 603 604 605 606 607

    Returns:
        None.

    Examples:
        .. code-block:: python

608
            # required: distributed
609 610 611 612 613 614 615 616 617 618 619 620 621 622
            import numpy as np
            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data = np.array([[4, 5, 6], [4, 5, 6]])
            else:
                np_data = np.array([[1, 2, 3], [1, 2, 3]])
            data = paddle.to_tensor(np_data)
            paddle.distributed.broadcast(data, 1)
            out = data.numpy()
            # [[1, 2, 3], [1, 2, 3]]
623
    """
K
kuizhiqing 已提交
624 625 626 627 628 629 630 631 632

    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

    ring_id = 0 if group is None else group.id
    gsrc = src if group is None else group.get_group_rank(src)
K
kuizhiqing 已提交
633
    assert gsrc >= 0, ("src rank out of group, need global rank")
K
kuizhiqing 已提交
634

635
    if in_dygraph_mode():
W
wanghuancoder 已提交
636 637 638
        return _C_ops.c_broadcast(tensor, tensor, 'root', gsrc,
                                  'use_calc_stream', use_calc_stream, 'ring_id',
                                  ring_id)
639 640 641 642 643 644 645 646 647 648 649 650

    op_type = 'c_broadcast'
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'broadcast')

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
        attrs={
K
kuizhiqing 已提交
651 652 653
            'root': gsrc,
            'use_calc_stream': use_calc_stream,
            'ring_id': ring_id,
654 655 656
        })


K
kuizhiqing 已提交
657
def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
658 659 660
    """

    Reduce a tensor over all ranks so that all get the result.
661 662 663 664 665 666 667 668
    As shown below, 4 GPUs each start 4 processes and the data on each GPU is represnted
    by the GPU number. The reduce operator is sum. Through all_reduce operator, 
    each GPU will have the sum of the data from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allreduce.png
        :width: 800
        :alt: all_reduce
        :align: center
669 670 671 672

    Args:
        tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
K
kuizhiqing 已提交
673
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
K
kuizhiqing 已提交
674
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
675 676
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
677 678 679 680 681 682 683

    Returns:
        None.

    Examples:
        .. code-block:: python

684
            # required: distributed
685 686 687 688 689 690 691 692 693 694 695 696 697 698 699
            import numpy as np
            import paddle
            from paddle.distributed import ReduceOp
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data = np.array([[4, 5, 6], [4, 5, 6]])
            else:
                np_data = np.array([[1, 2, 3], [1, 2, 3]])
            data = paddle.to_tensor(np_data)
            paddle.distributed.all_reduce(data)
            out = data.numpy()
            # [[5, 7, 9], [5, 7, 9]]
700
    """
K
kuizhiqing 已提交
701 702 703 704
    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id
705 706
    if in_dygraph_mode():
        if op == ReduceOp.SUM:
W
wanghuancoder 已提交
707 708
            return _C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
                                           use_calc_stream, 'ring_id', ring_id)
709
        elif op == ReduceOp.MAX:
W
wanghuancoder 已提交
710 711
            return _C_ops.c_allreduce_max_(tensor, 'use_calc_stream',
                                           use_calc_stream, 'ring_id', ring_id)
712
        elif op == ReduceOp.MIN:
W
wanghuancoder 已提交
713 714
            return _C_ops.c_allreduce_min_(tensor, 'use_calc_stream',
                                           use_calc_stream, 'ring_id', ring_id)
715
        elif op == ReduceOp.PROD:
W
wanghuancoder 已提交
716 717
            return _C_ops.c_allreduce_prod_(tensor, 'use_calc_stream',
                                            use_calc_stream, 'ring_id', ring_id)
718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734
        else:
            raise ValueError("Unknown parameter: {}.".format(op))

    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'all_reduce')
    if not op in [ReduceOp.SUM, ReduceOp.MAX, ReduceOp.MIN, ReduceOp.PROD]:
        raise ValueError("The op for all_reduce must be one of educeOp.PROD, "
                         "ReduceOp.SUM, ReduceOp.MAX, ReduceOp.MIN.")
    if op == ReduceOp.SUM:
        op_type = 'c_allreduce_sum'
    elif op == ReduceOp.MAX:
        op_type = 'c_allreduce_max'
    elif op == ReduceOp.MIN:
        op_type = 'c_allreduce_min'
    elif op == ReduceOp.PROD:
        op_type = 'c_allreduce_prod'
K
kuizhiqing 已提交
735 736
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'ring_id' for all_reduce should be int.")
737 738 739 740 741
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
K
kuizhiqing 已提交
742 743
        attrs={'ring_id': ring_id,
               'use_calc_stream': use_calc_stream})
744 745


K
kuizhiqing 已提交
746
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
747 748
    """

749 750 751 752 753 754 755 756
    Reduce a tensor to the destination from all others. As shown below, 4 GPUs each start 4 processes and the data on each GPU is respresnted
    by the GPU number. The destination of the reduce operator is GPU0 and the process is sum. Through reduce operator,
    the GPU0 will owns the sum of all data from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/reduce.png
        :width: 800
        :alt: reduce
        :align: center
757 758 759 760 761

    Args:
        tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
            should be float16, float32, float64, int32 or int64.
        dst (int): The destination rank id.
K
kuizhiqing 已提交
762
        op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
K
kuizhiqing 已提交
763
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
764 765
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
766 767 768 769 770 771 772

    Returns:
        None.

    Examples:
        .. code-block:: python

773
            # required: distributed
774 775 776 777 778 779 780 781 782 783 784 785 786 787
            import numpy as np
            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data = np.array([[4, 5, 6], [4, 5, 6]])
            else:
                np_data = np.array([[1, 2, 3], [1, 2, 3]])
            data = paddle.to_tensor(np_data)
            paddle.distributed.reduce(data, 0)
            out = data.numpy()
            # [[5, 7, 9], [5, 7, 9]]
788
    """
K
kuizhiqing 已提交
789 790 791 792 793 794 795 796
    if group is not None and not group.is_member():
        return

    if not isinstance(dst, int):
        raise ValueError("dst should be int.")

    ring_id = 0 if group is None else group.id
    gdst = dst if group is None else group.get_group_rank(dst)
K
kuizhiqing 已提交
797
    assert gdst >= 0, ("dst rank out of group, need global rank")
K
kuizhiqing 已提交
798

799 800
    if in_dygraph_mode():
        if op == ReduceOp.SUM:
W
wanghuancoder 已提交
801 802 803
            return _C_ops.c_reduce_sum(tensor, tensor, 'use_calc_stream',
                                       use_calc_stream, 'ring_id', ring_id,
                                       'root_id', gdst)
804
        elif op == ReduceOp.MAX:
W
wanghuancoder 已提交
805 806 807
            return _C_ops.c_reduce_max(tensor, tensor, 'use_calc_stream',
                                       use_calc_stream, 'ring_id', ring_id,
                                       'root_id', gdst)
808
        elif op == ReduceOp.MIN:
W
wanghuancoder 已提交
809 810 811
            return _C_ops.c_reduce_min(tensor, tensor, 'use_calc_stream',
                                       use_calc_stream, 'ring_id', ring_id,
                                       'root_id', gdst)
812
        elif op == ReduceOp.PROD:
W
wanghuancoder 已提交
813 814 815
            return _C_ops.c_reduce_prod(tensor, tensor, 'use_calc_stream',
                                        use_calc_stream, 'ring_id', ring_id,
                                        'root_id', gdst)
816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841
        else:
            raise ValueError("Unknown parameter: {}.".format(op))

    op_type = 'c_reduce'
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'all_reduce')
    if not op in [ReduceOp.SUM, ReduceOp.MAX, ReduceOp.MIN, ReduceOp.PROD]:
        raise ValueError("The op for reduce must be one of educeOp.PROD, "
                         "ReduceOp.SUM, ReduceOp.MAX, ReduceOp.MIN.")

    if op == ReduceOp.SUM:
        op_type = 'c_reduce_sum'
    elif op == ReduceOp.MAX:
        op_type = 'c_reduce_max'
    elif op == ReduceOp.MIN:
        op_type = 'c_reduce_min'
    elif op == ReduceOp.PROD:
        op_type = 'c_reduce_prod'

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
        attrs={
K
kuizhiqing 已提交
842 843 844
            'ring_id': ring_id,
            'use_calc_stream': use_calc_stream,
            'root_id': gdst,
845 846 847
        })


K
kuizhiqing 已提交
848
def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
849 850
    """

851 852 853 854 855 856 857 858 859
    Gather tensors from all participators and all get the result. As shown
    below, 4 GPUs each start 4 processes and the data on each GPU is represnted
    by the GPU number. Through the all_gather operator, each GPU will have data
    from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png
        :width: 800
        :alt: all_gather
        :align: center
860 861 862 863 864 865

    Args:
        tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
            should be float16, float32, float64, int32 or int64.
        tensor (Tensor): The Tensor to send. Its data type
            should be float16, float32, float64, int32 or int64.
K
kuizhiqing 已提交
866
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
867 868
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
869 870 871 872 873 874 875

    Returns:
        None.

    Examples:
        .. code-block:: python

876
            # required: distributed
877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895
            import numpy as np
            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            tensor_list = []
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data1 = np.array([[4, 5, 6], [4, 5, 6]])
                np_data2 = np.array([[4, 5, 6], [4, 5, 6]])
                data1 = paddle.to_tensor(np_data1)
                data2 = paddle.to_tensor(np_data2)
                paddle.distributed.all_gather(tensor_list, data1)
            else:
                np_data1 = np.array([[1, 2, 3], [1, 2, 3]])
                np_data2 = np.array([[1, 2, 3], [1, 2, 3]])
                data1 = paddle.to_tensor(np_data1)
                data2 = paddle.to_tensor(np_data2)
                paddle.distributed.all_gather(tensor_list, data2)
896
    """
K
kuizhiqing 已提交
897 898 899 900 901 902
    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id
    nranks = _get_global_group().nranks if group is None else group.nranks

903
    if in_dygraph_mode():
904 905
        out = _C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream,
                                 'ring_id', ring_id, 'nranks', nranks)
906
    else:
907 908 909
        op_type = 'c_allgather'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925
        if not isinstance(tensor_list, list):
            raise ValueError("The type of 'tensor_list' for all_gather "
                             "should be list.")
        for elem in tensor_list:
            check_variable_and_dtype(
                elem, 'tensor_list',
                ['float16', 'float32', 'float64', 'int32', 'int64'],
                'all_gather')
        check_variable_and_dtype(
            tensor, 'tensor',
            ['float16', 'float32', 'float64', 'int32', 'int64'], 'all_gather')
        helper.append_op(
            type=op_type,
            inputs={'X': [tensor]},
            outputs={'Out': [out]},
            attrs={
K
kuizhiqing 已提交
926 927 928
                'ring_id': ring_id,
                'use_calc_stream': use_calc_stream,
                'nranks': nranks
929 930
            })

K
kuizhiqing 已提交
931
    tensor_list.extend(paddle.split(out, nranks, 0))
932 933


K
kuizhiqing 已提交
934
def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
935 936
    """

937 938 939 940 941 942 943
    Scatter a tensor to all participators. As shown below, 4 GPUs each start 4 processes and the source of the scatter
    is GPU0. Through scatter operator, the data in GPU0 will be sent to all GPUs averagely.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/scatter.png
        :width: 800
        :alt: scatter
        :align: center
944 945 946 947

    Args:
        tensor (Tensor): The output Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
948
        tensor_list (list|tuple): A list/tuple of Tensors to scatter. Every element in the list must be a Tensor whose data type
K
kuizhiqing 已提交
949 950
            should be float16, float32, float64, int32 or int64. Default value is None.
        src (int): The source rank id. Default value is 0.
K
kuizhiqing 已提交
951
        group (Group): The group instance return by new_group or None for global default group.
K
kuizhiqing 已提交
952 953
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
954 955 956 957 958 959 960

    Returns:
        None.

    Examples:
        .. code-block:: python

961
            # required: distributed
962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980
            import numpy as np
            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            if paddle.distributed.ParallelEnv().local_rank == 0:
                np_data1 = np.array([7, 8, 9])
                np_data2 = np.array([10, 11, 12])
            else:
                np_data1 = np.array([1, 2, 3])
                np_data2 = np.array([4, 5, 6])
            data1 = paddle.to_tensor(np_data1)
            data2 = paddle.to_tensor(np_data2)
            if paddle.distributed.ParallelEnv().local_rank == 0:
                paddle.distributed.scatter(data1, src=1)
            else:
                paddle.distributed.scatter(data1, tensor_list=[data1, data2], src=1)
            out = data1.numpy()
981
    """
K
kuizhiqing 已提交
982 983 984 985 986 987 988 989
    if group is not None and not group.is_member():
        return

    if not isinstance(src, int):
        raise ValueError("src should be int.")

    ring_id = 0 if group is None else group.id
    gsrc = src if group is None else group.get_group_rank(src)
K
kuizhiqing 已提交
990
    assert gsrc >= 0, ("src rank out of group, need global rank")
K
kuizhiqing 已提交
991 992 993 994
    rank = _get_global_group().rank if group is None else group.rank
    nranks = _get_global_group().nranks if group is None else group.nranks

    if rank != gsrc:
995 996 997 998 999
        tensor_list = []
        for _ in range(nranks):
            tensor_list.append(tensor)
    temp = paddle.concat(tensor_list, axis=0)
    if in_dygraph_mode():
W
wanghuancoder 已提交
1000 1001 1002
        return _C_ops.c_scatter(temp, tensor, 'use_calc_stream',
                                use_calc_stream, 'ring_id', ring_id, 'nranks',
                                nranks, 'root', gsrc)
W
wanghuancoder 已提交
1003
    op_type = 'c_scatter'
1004 1005 1006 1007 1008 1009 1010 1011 1012
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'scatter')
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [temp]},
        outputs={'Out': [tensor]},
        attrs={
K
kuizhiqing 已提交
1013 1014 1015
            'ring_id': ring_id,
            'root': gsrc,
            'use_calc_stream': use_calc_stream,
1016 1017 1018 1019
            'nranks': nranks,
        })


1020
def _c_identity(tensor, group=None):
L
lilong12 已提交
1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031
    """
    Return a copy of the tensor, mainly used with model parallel.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
        group (int): The id of the process group to work on.

    Returns:
        Tensor.
    """
1032 1033 1034 1035 1036
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

    if in_dygraph_mode():
W
wanghuancoder 已提交
1037 1038
        return _C_ops.c_identity(tensor, 'use_calc_stream', True, 'ring_id',
                                 ring_id, 'use_model_parallel', True)
L
lilong12 已提交
1039 1040 1041
    op_type = 'c_identity'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
1042

L
lilong12 已提交
1043 1044 1045
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        '_c_identity')
1046

L
lilong12 已提交
1047 1048 1049 1050 1051
    helper.append_op(
        type=op_type,
        inputs={'X': tensor},
        outputs={'Out': out},
        attrs={
1052
            'ring_id': ring_id,
L
lilong12 已提交
1053 1054 1055 1056 1057 1058
            'use_calc_stream': True,
            'use_model_parallel': True,
        })
    return out


1059
def _c_concat(tensor, group=None):
1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074
    """
    Return allgather of the tensor, mainly used with model parallel.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
        group (int): The id of the process group to work on.

    Returns:
        Tensor.
    """
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

1075 1076 1077 1078
    global_rank = _get_global_env().rank
    rank = global_rank if group is None else group.get_group_rank(global_rank)
    nranks = _get_global_env().world_size if group is None else group.nranks

1079
    if in_dygraph_mode():
W
wanghuancoder 已提交
1080 1081 1082
        return _C_ops.c_concat(tensor, 'ring_id', ring_id, 'use_calc_stream',
                               True, 'rank', rank, 'nranks', nranks,
                               'use_model_parallel', True)
1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099

    op_type = 'c_concat'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)

    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        '_c_concat')

    helper.append_op(
        type=op_type,
        inputs={'X': tensor},
        outputs={'Out': out},
        attrs={
            'ring_id': ring_id,
            'use_calc_stream': True,
            'use_model_parallel': True,
1100 1101
            'nranks': nranks,
            'rank': rank
1102 1103 1104 1105
        })
    return out


1106
def _c_split(tensor, group=None):
L
lilong12 已提交
1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118
    """
    Split tensor evenly among all members, mainly used with model parallel.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
        rank (int): The rank of the current process.
        group (int): The id of the process group to work on.

    Returns:
        Tensor.
    """
1119 1120 1121 1122
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

1123 1124 1125 1126
    global_rank = _get_global_env().rank
    rank = global_rank if group is None else group.get_group_rank(global_rank)
    nranks = _get_global_env().world_size if group is None else group.nranks

1127
    if in_dygraph_mode():
W
wanghuancoder 已提交
1128 1129 1130
        return _C_ops.c_split(tensor, 'use_calc_stream', True, 'ring_id',
                              ring_id, 'rank', rank, 'nranks', nranks,
                              'use_model_parallel', True)
1131

L
lilong12 已提交
1132 1133 1134
    op_type = 'c_split'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
1135

L
lilong12 已提交
1136 1137 1138
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        '_c_split')
1139

L
lilong12 已提交
1140 1141 1142 1143 1144
    helper.append_op(
        type=op_type,
        inputs={'X': tensor},
        outputs={'Out': out},
        attrs={
1145
            'ring_id': ring_id,
L
lilong12 已提交
1146 1147 1148 1149 1150 1151 1152 1153
            'use_calc_stream': True,
            'rank': rank,
            'nranks': nranks,
            'use_model_parallel': True,
        })
    return out


1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166
def _mp_allreduce(tensor,
                  op=ReduceOp.SUM,
                  group=None,
                  use_calc_stream=True,
                  use_model_parallel=True):
    """[it is same as allreduce above, but it suuports model parallel. And it support inplace startegy]
    """
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

    if in_dygraph_mode():
        if op == ReduceOp.SUM:
W
wanghuancoder 已提交
1167
            return _C_ops.c_allreduce_sum_(
1168 1169 1170 1171
                tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id,
                "use_model_parallel", use_model_parallel)
        else:
            raise ValueError("Unknown parameter: {}.".format(op))
1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190

    op_type = 'c_allreduce_sum'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)

    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        op_type)

    helper.append_op(
        type=op_type,
        inputs={'X': tensor},
        outputs={'Out': out},
        attrs={
            'ring_id': ring_id,
            'use_calc_stream': use_calc_stream,
            'use_model_parallel': use_model_parallel,
        })
    return out
1191 1192


1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207
def _c_lookup_table(table, index, start_index=0, name=None):
    """
    Lookup table according to index.

    Args:
        table (Tensor): The input Tensor. Its data type
            should be float16, float32, float64.
        index (Tensor): The index to lookup table.
        start_index (int): The initial index for table range.
        name (string): The name of the api

    Returns:
        Tensor.
    """
    if in_dygraph_mode():
W
wanghuancoder 已提交
1208
        return _C_ops.c_embedding(table, index, "start_index", start_index)
1209

1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222
    op_type = 'c_embedding'
    helper = LayerHelper(op_type, **locals())
    dtype = helper.input_dtype(input_param_name='table')
    check_variable_and_dtype(index, 'input', ['int32', 'int64'], op_type)
    tmp = helper.create_variable_for_type_inference(dtype)
    helper.append_op(
        type='c_embedding',
        inputs={'Ids': index,
                'W': table},
        outputs={'Out': tmp},
        attrs={"start_index": start_index})
    return tmp

1223

B
Baibaifan 已提交
1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261
class _Linear(layers.Layer):
    """
    Linear
    """

    def __init__(self,
                 in_features,
                 out_features,
                 weight_attr=None,
                 bias_attr=None,
                 name=None):
        super(_Linear, self).__init__()
        self._dtype = self._helper.get_default_dtype()
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
        self.weight = self.create_parameter(
            shape=[in_features, out_features],
            attr=self._weight_attr,
            dtype=self._dtype,
            is_bias=False)
        self.bias = self.create_parameter(
            shape=[out_features],
            attr=self._bias_attr,
            dtype=self._dtype,
            is_bias=True)
        self.name = name

    def forward(self, input):
        out = _linear(
            x=input, weight=self.weight, bias=self.bias, name=self.name)
        return out

    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'in_features={}, out_features={}, dtype={}{}'.format(
            self.weight.shape[0], self.weight.shape[1], self._dtype, name_str)


1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282
def _c_softmax_with_cross_entropy(logits,
                                  label,
                                  group=None,
                                  return_softmax=False):
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id
    global_rank = _get_global_env().rank
    rank = global_rank if group is None else group.get_group_rank(global_rank)
    nranks = _get_global_env().world_size if group is None else group.nranks

    input_dims = len(list(logits.shape))
    label_dims = len(list(label.shape))
    if input_dims - 1 != label_dims and input_dims != label_dims:
        raise ValueError(
            'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
             (got nput_dims{}, label_dims{})'.format(input_dims, label_dims))
    if input_dims - 1 == label_dims:
        label = paddle.unsqueeze(label, axis=-1)

    if in_dygraph_mode():
W
wanghuancoder 已提交
1283
        softmax, loss = _C_ops.c_softmax_with_cross_entropy(
1284 1285 1286 1287 1288 1289
            logits, label, 'ring_id', ring_id, 'rank', rank, 'nranks', nranks)
        if not return_softmax:
            return loss
        else:
            return loss, softmax

W
WangXi 已提交
1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310
    attrs = {
        'ring_id': ring_id,
        'rank': rank,
        'nranks': nranks,
    }
    helper = LayerHelper('c_softmax_with_cross_entropy', **locals())
    softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
    loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
    helper.append_op(
        type='c_softmax_with_cross_entropy',
        inputs={'Logits': logits,
                'Label': label},
        outputs={'Softmax': softmax,
                 'Loss': loss},
        attrs=attrs)

    if return_softmax:
        return loss, softmax

    return loss

1311

B
Baibaifan 已提交
1312 1313 1314 1315 1316 1317
def _linear(x, weight, bias=None, name=None):
    """
    Fuction Linear
    """
    if in_dygraph_mode():
        pre_bias = _varbase_creator(dtype=x.dtype)
W
wanghuancoder 已提交
1318 1319
        _C_ops.matmul(x, weight, pre_bias, 'transpose_X', False, 'transpose_Y',
                      False, "alpha", 1)
B
Baibaifan 已提交
1320 1321 1322 1323 1324
        return dygraph_utils._append_bias_in_dygraph(
            pre_bias, bias, axis=len(x.shape) - 1)
    else:
        helper = LayerHelper('linear', **locals())
        dtype = x.dtype
B
Baibaifan 已提交
1325 1326
        assert len(
            x.shape) < 4, "X latitude is not supported greater than 3 now."
B
Baibaifan 已提交
1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353

        check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
                                 'linear')
        check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear')

        inputs = {'X': [x], 'Y': [weight]}
        attrs = {
            'transpose_X': False,
            'transpose_Y': False,
            'alpha': 1,
        }
        tmp = helper.create_variable_for_type_inference(dtype)
        helper.append_op(
            type='matmul_v2', inputs=inputs, outputs={'Out': tmp}, attrs=attrs)
        if bias is not None:
            res = helper.create_variable_for_type_inference(dtype)
            helper.append_op(
                type='elementwise_add',
                inputs={'X': [tmp],
                        'Y': [bias]},
                outputs={'Out': [res]},
                attrs={'axis': len(x.shape) - 1})
        else:
            res = tmp
        return res


1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366
def _set_var_distributed(var):
    if var is None:
        return

    var.is_distributed = True

    # NOTE: use current_block and find_var_recursive to support while_loop
    startup_block = paddle.static.default_startup_program().current_block()
    main_block = paddle.static.default_main_program().current_block()
    startup_block._find_var_recursive(var.name).is_distributed = True
    main_block._find_var_recursive(var.name).is_distributed = True


L
lilong12 已提交
1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377
def _parallel_linear(x,
                     num_rows,
                     num_cols,
                     axis,
                     param_attr,
                     bias_attr,
                     gather_out,
                     inner_rank,
                     nranks,
                     split_tensor,
                     name,
1378
                     group=None):
1379 1380
    """
    Parallel Linear
1381 1382 1383

    axis the dimension of the parameter of linear layer. 
    axis = 0: the row dimension
1384
    axis = 1: the col dimension
1385
    
1386
    """
1387 1388 1389 1390
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

L
lilong12 已提交
1391 1392
    if axis == 0:
        if split_tensor:
1393
            x = _c_split(x, group=group)
1394
    else:
L
lilong12 已提交
1395 1396
        x = _c_identity(x, group=group)

1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414
    linear = paddle.nn.Linear(
        num_rows,
        num_cols,
        weight_attr=param_attr,
        bias_attr=bias_attr,
        name=name)

    # NOTE: npu linear function use matmul_v2 but linear use matmul
    linear_function = _linear if core.is_compiled_with_npu()\
        else paddle.nn.functional.linear
    linear_out = linear_function(
        x,
        linear.weight,
        # NOTE(wangxi): row split, bias need add after allreduce
        None if axis == 0 else linear.bias,
        linear.name)

    _set_var_distributed(linear.weight)
1415 1416 1417 1418
    # set is_distributed for splited bias
    # if a linear layer is splited by row, each rank would hold a complete bias and they should be the same in each rank.
    # if a linear layer is splited by col, the bias would also be split into each rank as its weight
    if axis == 1 and linear._bias_attr != False:
1419
        _set_var_distributed(linear.bias)
L
lilong12 已提交
1420 1421 1422 1423 1424

    if not gather_out: return linear_out

    out_shape = list(linear_out.shape)
    out_shape[0] *= 1 if axis == 0 else nranks
1425
    main_block = paddle.static.default_main_program().current_block()
L
lilong12 已提交
1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439
    out = main_block.create_var(
        shape=out_shape,
        dtype=linear_out.dtype,
        type=linear_out.type,
        lod_level=linear_out.lod_level,
        persistable=False,
        is_data=False,
        need_check_feed=linear_out.desc.need_check_feed())
    if axis == 0:
        main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': linear_out},
            outputs={'Out': out},
            attrs={
1440
                'ring_id': ring_id,
L
lilong12 已提交
1441 1442 1443
                'use_calc_stream': True,
                'use_model_parallel': True
            })
1444 1445
        if linear.bias is not None:
            out = out + linear.bias
L
lilong12 已提交
1446 1447 1448 1449 1450 1451
    else:
        main_block.append_op(
            type='c_concat',
            inputs={'X': linear_out},
            outputs={'Out': out},
            attrs={
1452
                'rank': inner_rank,
1453
                'ring_id': ring_id,
L
lilong12 已提交
1454 1455 1456 1457 1458
                'nranks': nranks,
                'use_calc_stream': True,
                'use_model_parallel': True
            })
    return out
1459 1460


L
lilong12 已提交
1461 1462 1463 1464 1465 1466 1467
def _parallel_embedding(x,
                        per_part_embeddings,
                        origin_size,
                        param_attr,
                        inner_rank,
                        num_partitions,
                        name,
1468
                        group=None):
1469 1470 1471
    """
    Parallel Embedding
    """
1472 1473 1474 1475
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491
    helper = LayerHelper("_parallel_embedding", **locals())

    per_part_size = per_part_embeddings
    rank = inner_rank

    vocab_start_index = rank * per_part_size
    dtype = helper.get_default_dtype()
    size = [per_part_size, origin_size[1]]

    weight = helper.create_parameter(
        attr=param_attr, shape=size, dtype=dtype, is_bias=False)

    if num_partitions == 1:
        return paddle.nn.functional.embedding(
            x, weight=weight, padding_idx=None, sparse=False, name=name)

1492 1493
    startup_block = paddle.static.default_startup_program().global_block()
    main_block = paddle.static.default_main_program().global_block()
1494 1495 1496 1497 1498 1499 1500 1501 1502 1503
    startup_block.vars[weight.name].is_distributed = True
    main_block.vars[weight.name].is_distributed = True

    output_parallel = paddle.distributed.collective._c_lookup_table(
        weight, x, start_index=vocab_start_index, name=name)
    out = paddle.distributed.collective._mp_allreduce(
        output_parallel,
        group=group,
        use_calc_stream=True,
        use_model_parallel=True)
L
lilong12 已提交
1504
    return out
1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527


def split(x,
          size,
          operation,
          axis=0,
          num_partitions=1,
          gather_out=True,
          weight_attr=None,
          bias_attr=None,
          name=None):
    """

    Split the weight of the specified operation into multiple devices
    and do the computation in parallel.

    Now the following three cases are supported.

    Case 1: Parallel Embedding
        The weight of the embedding operation is a NxM matrix with N rows and M columns.
        With parallel embedding, the weight is split into num_partitions partitions, each
        of which is a matrix with (N/num_partitions + 1) rows and M column where the last
        row as the padding idx.
K
kuizhiqing 已提交
1528

1529 1530 1531 1532 1533 1534 1535 1536 1537
        Suppose we split the NxM weight into two partitons on device_0 and device_1
        respectively. Then, one each device, the final weight has (N/2 + 1) rows with the
        index range from 0 to N/2. On device_0, all values in the input within [0, N/2 -1]
        keep unchanged and all other values are changed to N/2 which is the padding index and
        are mapped to all zeros after embedding. In the same way, on device_1, the value V in the
        input within [N/2, N-1] will be changed to (V - N/2), and all other values are changed
        to N/2 and are mapped to all zeros after embedding. Finally, the results on the two
        devices are sum-reduced.

1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552
        The Embedding put on single card is as shown below:

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_single.png
            :width: 800
            :height: 350
            :alt: single_embedding
            :align: center

        Parallel Embedding is shown as below:

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_split.png
            :width: 800
            :alt: split_embedding
            :align: center

1553 1554 1555 1556 1557
    Case 2: Row Parallel Linear
        The weight of the linear operation is a NxM matrix with N rows and M columns.
        With row parallel linear, the weight is split into num_partitions partitions, each
        of which is a matrix with N/num_partitions rows and M column.

1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575
        The linear layer put on single card is shown as below, the input variable is represented by X,
        the weight matrix is represented by W and the output vaiable is O. The linear layer on single card is 
        simple matrix multiplication operation, O = X * W.

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_single.png
            :width: 800
            :alt: single_linear
            :align: center

        Row Parallel Linear is shown as below. As the name suggests, Row Parallel Linear splits the weight matrix W into
        [[W_row1], [W_row2]] along the row. And accordingly the input is splitted along the column into [X_col1, X_col2] and multiply their
        respective weight matrices. Finally apply AllReduce on the output from each card to get the final output.

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_row.png
            :width: 800
            :alt: split_row
            :align: center

1576 1577 1578 1579 1580
    Case 3: Column Parallel Linear
        The weight of the linear operation is a NxM matrix with N rows and M columns.
        With column parallel linear, the weight is split into num_paratitions partitions, each
        of which is a matrix with N rows and M/num_partitions column.

1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597
        The linear layer put on single card has been illustrated on case 2 and Column Parallel Linear
        is shown as below. The Column Parallel Linear splits the weight matrix W into [W_col1, W_col2] along the column and 
        these splitted matrices respectively multiply the input. Finally apply AllGather on the output from each card to get the final output. 

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col.png
            :width: 800
            :alt: split_col
            :align: center
    
    As observed, the column parallel linear and row parallel linear can be combined to skip one ALLGATHER communication
    operator. Furthermore the Attention and MLP can be combined to imporve the performance as shown below.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col_row.png
            :width: 800
            :alt: split_col_row
            :align: center

1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617
    Args:
        x (Tensor): Input tensor. It's data type should be float16, float32, float64, int32 or int64.
        size (list|tuple): A list or tuple with two elements indicating the shape of the weight.
        operation (str): The name of the operation. The supported operations are 'linear' and 'embedding'.
        axis (int, Optional): Indicate along which axis to split the weight. Default: 0.
        num_partitions (int, Optional): How many parts the weight is partitioned. Default: 1.
        gather_out (bool, Optional): Whether to gather the output after computation. By default, the output
            on each partitions will be gathered after computation. Default: True.
        weight_attr (ParamAttr, Optional): The parameter attribute for the learnable
            weights(Parameter) of the specified operation. Default: None.
        bias_attr (ParamAttr, Optional): The parameter attribute for the bias
            of the specified operation. Default: None.
        name (str, Optional): The default value is None. Normally there is no need for user to set this
            property. Default: None. For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor.

    Examples:
        .. code-block:: python
1618

1619
            # required: distributed
1620
            import paddle
1621
            import paddle.distributed.fleet as fleet
1622

1623
            paddle.enable_static()
1624
            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
1625
            fleet.init(is_collective=True)
1626
            data = paddle.randint(0, 8, shape=[10,4])
1627
            emb_out = paddle.distributed.split(
1628 1629 1630 1631
                data,
                (8, 8),
                operation="embedding",
                num_partitions=2)
1632

1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649
    """
    assert isinstance(size, (list, tuple)), (
        "The type of size for "
        "paddle.distributed.split must be list or tuple.")
    assert len(size) == 2, ("Number of elements in size of "
                            "paddle.distributed.split must be two.")
    assert isinstance(operation, str), ("The type of operation for "
                                        "paddle.distributed.split must be str.")
    supported_operations = [
        'linear',
        'embedding',
    ]
    assert operation in supported_operations, (
        "The operation for "
        "paddle.distributed.split must be one of {}.".format(
            supported_operations))
    if in_dygraph_mode():
L
lilong12 已提交
1650 1651 1652 1653
        raise ValueError(
            "paddle.distributed.split cannot be used in dynamic "
            "graph mode, plese use ParallelEmbedding, ParallelRowLinear, "
            "ParallelColumnLinear instead.")
1654
    else:
1655
        from .fleet import fleet
1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666
        assert fleet._role_maker, ("To use paddle.distributed.split, "
                                   "you must call fleet.init() firstly.")
        rank = fleet.worker_index()
        nranks = fleet.worker_num()

    # rank within a model parallel group
    inner_rank = rank % num_partitions

    if operation == "embedding":
        assert axis == 0, ("We only support to split the weight of embedding "
                           "along the first axis now.")
1667 1668 1669
        assert size[0] % num_partitions == 0, \
            "The length of the vocabulary must be divisible by num_partitions " \
            "but received vocabulary={} num_partitions={}".format(size[0], num_partitions)
1670

1671
        per_part_size = size[0] // num_partitions
B
Baibaifan 已提交
1672 1673 1674 1675 1676 1677 1678 1679 1680 1681
        emb_out = _parallel_embedding(
            x,
            per_part_size,
            size,
            weight_attr,
            inner_rank,
            num_partitions,
            name,
            group=None)
        return emb_out
1682
    else:
L
lilong12 已提交
1683
        should_split = False
1684 1685 1686 1687 1688 1689 1690
        if axis == 0:
            assert size[0] % num_partitions == 0, (
                "Number of rows of the weight for linear ({}) must be"
                " divisible by num_partitions ({})".format(size[0],
                                                           num_partitions))
            per_part_size = size[0] // num_partitions
            linear_size = (per_part_size, size[1])
L
lilong12 已提交
1691
            if x.shape[-1] == size[0]: should_split = True
1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712

        elif axis == 1:
            assert size[1] % num_partitions == 0, (
                "Number of column of the weight for linear ({}) must be"
                " divisible by num_partitions ({})".format(size[1],
                                                           num_partitions))
            per_part_size = size[1] // num_partitions
            linear_size = (size[0], per_part_size)
        else:
            raise ValueError("The value of axis must be 0 or 1, but the value "
                             "given is {}.".format(axis))

        linear_out = _parallel_linear(
            x,
            linear_size[0],
            linear_size[1],
            axis,
            weight_attr,
            bias_attr,
            gather_out,
            inner_rank,
L
lilong12 已提交
1713 1714 1715
            num_partitions,
            should_split,
            name=name,
1716
            group=None)
1717
        return linear_out
L
lilong12 已提交
1718 1719


L
lilong12 已提交
1720 1721
def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
    """
1722 1723 1724 1725 1726 1727 1728 1729 1730 1731
    Scatter tensors in in_tensor_list to all participators averagely and gather the result tensors in out_tensor_list.
    As shown below, the in_tensor_list in GPU0 includes 0_0 and 0_1, and GPU1 includes 1_0 and 1_1.
    Through alltoall operator, the 0_0 in GPU0 will be sent to GPU0 and 0_1 to GPU1, 1_0 in GPU1 sent to GPU0 and 1_1 to GPU1.
    Finally the out_tensor_list in GPU0 includes 0_0 and 1_0, and GPU1 includes 0_1 and 1_1.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/alltoall.png
        :width: 800
        :alt: alltoall
        :align: center

L
lilong12 已提交
1732 1733 1734 1735 1736 1737 1738
    Args:
        in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type
            should be float16, float32, float64, int32 or int64.
        out_tensor_list (Tensor): A list of output Tensors. The data type of its elements should be the same as the
            data type of the input Tensors.
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
        use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True.
1739
    
L
lilong12 已提交
1740 1741
    Returns:
        None.
1742
    
L
lilong12 已提交
1743 1744
    Examples:
        .. code-block:: python
1745

L
lilong12 已提交
1746 1747 1748 1749
            # required: distributed
            import numpy as np
            import paddle
            from paddle.distributed import init_parallel_env
1750
            
L
lilong12 已提交
1751 1752 1753 1754 1755 1756 1757 1758 1759 1760
            init_parallel_env()
            out_tensor_list = []
            if paddle.distributed.ParallelEnv().rank == 0:
                np_data1 = np.array([[1, 2, 3], [4, 5, 6]])
                np_data2 = np.array([[7, 8, 9], [10, 11, 12]])
            else:
                np_data1 = np.array([[13, 14, 15], [16, 17, 18]])
                np_data2 = np.array([[19, 20, 21], [22, 23, 24]])
            data1 = paddle.to_tensor(np_data1)
            data2 = paddle.to_tensor(np_data2)
李季 已提交
1761
            paddle.distributed.alltoall([data1, data2], out_tensor_list)
L
lilong12 已提交
1762 1763 1764 1765 1766 1767 1768 1769
            # out for rank 0: [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]]
            # out for rank 1: [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]]
    """
    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id
    temp = paddle.concat(in_tensor_list, axis=0)
李季 已提交
1770
    nranks = len(in_tensor_list)
L
lilong12 已提交
1771
    if in_dygraph_mode():
李季 已提交
1772 1773
        out = _C_ops.alltoall(temp, 'use_calc_stream', use_calc_stream,
                              'ring_id', ring_id)
L
lilong12 已提交
1774
    else:
W
wanghuancoder 已提交
1775 1776 1777 1778 1779
        op_type = 'alltoall'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(
            dtype=in_tensor_list[0].dtype)

L
lilong12 已提交
1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798
        if not isinstance(in_tensor_list, list):
            raise ValueError("The type of 'in_tensor_list' for all_to_all "
                             "should be list.")
        for elem in in_tensor_list:
            check_variable_and_dtype(
                elem, 'in_tensor_list',
                ['float16', 'float32', 'float64', 'int32', 'int64'],
                'all_to_all')
        if not isinstance(out_tensor_list, list):
            raise ValueError("The type of 'out_tensor_list' for all_to_all "
                             "should be list.")
        if len(out_tensor_list) != 0:
            raise ValueError("The 'out_tensor_list' for all_to_all "
                             "must be an empty list.")
        helper.append_op(
            type=op_type,
            inputs={'X': [temp]},
            outputs={'Out': [out]},
            attrs={
L
lilong12 已提交
1799
                'ring_id': ring_id,
L
lilong12 已提交
1800 1801 1802 1803 1804
                'use_calc_stream': use_calc_stream,
            })
    out_tensor_list.extend(paddle.split(out, nranks, 0))


L
lilong12 已提交
1805 1806 1807 1808 1809 1810 1811 1812
def send(tensor, dst=0, group=None, use_calc_stream=True):
    """
    Send a tensor to the receiver.

    Args:
        tensor (Tensor): The Tensor to send. Its data type
            should be float16, float32, float64, int32 or int64.
        dst (int): The destination rank id.
L
lilong12 已提交
1813 1814
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
        use_calc_stream (bool, optional): Whether to use calculate stream or communication stream. Default: True.
1815
    
L
lilong12 已提交
1816 1817 1818 1819 1820
    Returns:
        None.

    Examples:
        .. code-block:: python
1821

L
lilong12 已提交
1822
            # required: distributed
L
lilong12 已提交
1823
            import paddle
L
lilong12 已提交
1824
            from paddle.distributed import init_parallel_env
1825

L
lilong12 已提交
1826 1827 1828 1829 1830 1831 1832 1833
            init_parallel_env()
            if paddle.distributed.ParallelEnv().rank == 0:
                data = paddle.to_tensor([7, 8, 9])
                paddle.distributed.send(data, dst=1)
            else:
                data = paddle.to_tensor([1,2,3])
                paddle.distributed.recv(data, src=0)
            out = data.numpy()
L
lilong12 已提交
1834 1835 1836 1837 1838 1839
    """
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

    if in_dygraph_mode():
W
wanghuancoder 已提交
1840 1841
        return _C_ops.send_v2(tensor, 'use_calc_stream', use_calc_stream,
                              'ring_id', ring_id, 'peer', dst)
W
wanghuancoder 已提交
1842
    op_type = 'send_v2'
L
lilong12 已提交
1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'send')

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        attrs={
            'ring_id': ring_id,
            'peer': dst,
            'use_calc_stream': use_calc_stream,
        })


def recv(tensor, src=0, group=None, use_calc_stream=True):
    """
    Receive a tensor to the sender.

    Args:
        tensor (Tensor): The Tensor to receive. Its data type
            should be float16, float32, float64, int32 or int64.
        src (int): The source rank id.
L
lilong12 已提交
1866 1867
        group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
        use_calc_stream (bool, optional): Whether to use calculate stream or communication stream. Default: True.
1868
    
L
lilong12 已提交
1869 1870 1871 1872 1873
    Returns:
        None.

    Examples:
        .. code-block:: python
1874

L
lilong12 已提交
1875
            # required: distributed
L
lilong12 已提交
1876
            import paddle
L
lilong12 已提交
1877
            from paddle.distributed import init_parallel_env
1878

L
lilong12 已提交
1879 1880 1881 1882 1883 1884 1885 1886
            init_parallel_env()
            if paddle.distributed.ParallelEnv().rank == 0:
                data = paddle.to_tensor([7, 8, 9])
                paddle.distributed.send(data, dst=1)
            else:
                data = paddle.to_tensor([1,2,3])
                paddle.distributed.recv(data, src=0)
            out = data.numpy()
L
lilong12 已提交
1887 1888 1889 1890 1891 1892
    """
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

    if in_dygraph_mode():
W
wanghuancoder 已提交
1893 1894 1895
        return _C_ops.recv_v2(tensor, 'use_calc_stream', use_calc_stream,
                              'ring_id', ring_id, 'peer', src, 'dtype',
                              tensor.dtype, 'out_shape', tensor.shape)
W
wanghuancoder 已提交
1896
    op_type = 'recv_v2'
L
lilong12 已提交
1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910
    check_variable_and_dtype(
        tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
        'recv')
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        outputs={'Out': [tensor]},
        attrs={
            'ring_id': ring_id,
            'peer': src,
            'out_shape': tensor.shape,
            'dtype': tensor.dtype,
            'use_calc_stream': use_calc_stream,
        })