collective.py 20.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
16 17
import pickle
import io
18
import datetime
19
from ..fluid.layer_helper import LayerHelper
20
from ..fluid.framework import in_dygraph_mode
J
Jiabin Yang 已提交
21
from ..fluid.framework import _non_static_mode
22
from ..fluid.data_feeder import check_variable_and_dtype
23 24 25
from ..fluid.layers.tensor import fill_constant
import paddle
import paddle.fluid.core as core
26 27 28 29 30 31 32 33 34 35 36 37 38
from paddle import _legacy_C_ops
from .fleet.layers.mpu.mp_ops import split  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_identity  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_concat  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_split  # noqa: F401
from .fleet.layers.mpu.mp_ops import _mp_allreduce  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_lookup_table  # noqa: F401
from .fleet.layers.mpu.mp_ops import _Linear  # noqa: F401
from .fleet.layers.mpu.mp_ops import _set_var_distributed  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_softmax_with_cross_entropy  # noqa: F401
from .fleet.layers.mpu.mp_ops import _linear  # noqa: F401
from .fleet.layers.mpu.mp_ops import _parallel_linear  # noqa: F401
from .fleet.layers.mpu.mp_ops import _parallel_embedding  # noqa: F401
39
from .communication.group import Group, _add_new_group, is_initialized
40

41
__all__ = []
42

K
kuizhiqing 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55
_global_env = None


def _get_global_env():
    global _global_env
    if not _global_env:
        _global_env = paddle.distributed.ParallelEnv()
    return _global_env


# group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group]
_group_map = {}
56
_global_env_gid = 0
K
kuizhiqing 已提交
57

58 59 60 61
# group map by name : the map of all groups from their names
# Dict[name, Group]
_group_map_by_name = {}

62 63 64 65
# backend map by group : the map of all backend from their groups
# Dict[group, backend]
_group_map_backend = {}

66 67 68
# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"

J
james 已提交
69
_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter', 'xccl', 'bkcl']
70 71
_default_store = None  # the default tcp store
_default_backend = None
72 73
_default_timeout = datetime.timedelta(seconds=1800)
_start_ring_id = 0
74

K
kuizhiqing 已提交
75

L
lilong12 已提交
76 77 78 79 80 81 82 83 84 85
def _set_default_backend(backend):
    global _default_backend
    _default_backend = backend


def _set_default_store(store):
    global _default_store
    _default_store = store


K
kuizhiqing 已提交
86 87
def _get_group_map():
    global _group_map
88
    if _global_env_gid not in _group_map:
K
kuizhiqing 已提交
89
        genv = _get_global_env()
90 91 92
        _group_map[_global_env_gid] = Group(
            genv.rank, 0, list(range(genv.world_size))
        )
K
kuizhiqing 已提交
93 94 95 96
    return _group_map


def _get_global_group():
97
    return _get_group_map()[_global_env_gid]
K
kuizhiqing 已提交
98 99


100 101 102 103 104 105
def _get_group_map_by_name():
    global _group_map_by_name
    return _group_map_by_name


def _get_default_group():
L
lilong12 已提交
106
    global _group_map_by_name
107 108 109 110
    assert is_initialized(), (
        "Call paddle.distributed.init_parallel_env first "
        "to initialize the distributed environment."
    )
111 112 113
    return _get_group_map_by_name()[_default_group_name]


L
lilong12 已提交
114 115 116 117 118 119 120 121 122 123 124 125
def _set_group_map(gid, group):
    global _group_map
    assert gid not in _group_map
    _group_map[gid] = group


def _set_group_map_by_name(name, group):
    global _group_map_by_name
    assert name not in _group_map_by_name
    _group_map_by_name[name] = group


126 127 128 129 130 131
def _set_group_map_backend(group, backend):
    global _group_map_backend
    assert group not in _group_map_backend
    _group_map_backend[group] = backend


K
kuizhiqing 已提交
132
def _new_ring_id():
133 134 135 136 137 138 139
    # NOTE(liyurui): For compatible reason, auto parallel and eager mode relay on previous syntax.
    if in_dygraph_mode():
        global _start_ring_id
        _start_ring_id += 1
        return _start_ring_id + max(_get_global_env().nrings, 9)
    else:
        return len(_get_group_map()) + max(_get_global_env().nrings, 9)
K
kuizhiqing 已提交
140 141


142 143 144 145 146 147 148 149 150
def _new_process_group_impl(
    backend,
    store,
    rank,
    world_size,
    group_name,
    pg_options,
    group_id=0,
):
151
    pg = None
152
    genv = _get_global_env()
L
lilong12 已提交
153
    assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
154
    if backend == "gloo":
155
        pg = core.ProcessGroupGloo(store, rank, world_size, group_id)
156
    elif backend == "nccl":
157
        pg = core.ProcessGroupNCCL(store, rank, world_size, group_id)
158
    elif backend == "xccl":
159 160 161
        pg = core.ProcessGroupCustom(
            store, genv.device_type, rank, world_size, group_id
        )
J
james 已提交
162
    elif backend == "bkcl":
163
        pg = core.ProcessGroupBKCL(store, rank, world_size, group_id)
164 165 166
    return pg


S
ShenLiang 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
def barrier(group=None):
    """

    Barrier among all participators in the group.

    Args:
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            paddle.distributed.barrier()
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
191
    if in_dygraph_mode():
192
        group = _get_default_group() if group is None else group
193 194 195 196 197 198
        place = paddle.fluid.framework._current_expected_place()
        if isinstance(place, paddle.fluid.core.CPUPlace):
            task = group.process_group.barrier()
        else:
            device_id = place.get_device_id()
            task = group.process_group.barrier(device_id)
199 200 201
        task.wait()
        return

S
ShenLiang 已提交
202 203 204
    ring_id = 0 if group is None else group.id

    temp = fill_constant([1], dtype="int32", value="1")
J
Jiabin Yang 已提交
205
    if _non_static_mode():
206
        return _legacy_C_ops.barrier(temp, temp, 'ring_id', ring_id)
W
wanghuancoder 已提交
207 208 209

    op_type = 'barrier'

S
ShenLiang 已提交
210 211 212
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'group' for barrier must be int.")
    helper = LayerHelper(op_type, **locals())
213 214 215 216 217 218
    helper.append_op(
        type=op_type,
        inputs={'X': [temp]},
        outputs={'Out': [temp]},
        attrs={'ring_id': ring_id},
    )
S
ShenLiang 已提交
219 220


L
lilong12 已提交
221 222 223 224 225 226 227
# _custom_gid provides a way for users to
# set the group id, which is usually useful
# to be compatible with the static mode.
_custom_gid = None


def _set_custom_gid(gid):
228
    global _custom_gid
L
lilong12 已提交
229 230 231
    _custom_gid = gid


232
def new_group(ranks=None, backend=None, timeout=_default_timeout):
K
kuizhiqing 已提交
233 234
    """

K
kuizhiqing 已提交
235
    Creates a new distributed communication group.
K
kuizhiqing 已提交
236 237

    Args:
K
kuizhiqing 已提交
238
        ranks (list): The global ranks of group members.
K
kuizhiqing 已提交
239
        backend (str): The backend used to create group, only nccl is supported now.
240
        timeout (datetime.timedelta, optional): The waiting timeout for store relevant options, default is 30 minutes.
K
kuizhiqing 已提交
241 242

    Returns:
K
kuizhiqing 已提交
243
        Group: The group instance.
K
kuizhiqing 已提交
244 245 246 247 248 249 250

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
251 252
            tindata = paddle.randn(shape=[2, 3])
            gp = paddle.distributed.new_group([2,4,6])
253
            paddle.distributed.all_reduce(tindata, group=gp, sync_op=False)
K
kuizhiqing 已提交
254 255

    """
256
    global _custom_gid
257
    global _group_map
L
lilong12 已提交
258
    if in_dygraph_mode():
259
        global _default_group_name
L
lilong12 已提交
260
        gid = _custom_gid if _custom_gid else _new_ring_id()
261
        group_name = _default_group_name + str(gid)
L
lilong12 已提交
262
        if backend != 'heter' and (ranks is None or len(ranks) > 1):
263 264 265 266 267 268 269 270
            global_group = _get_default_group()
            global_rank = global_group.rank
            global_ranks = global_group.ranks
            backend = _default_backend if backend is None else backend
            if ranks is None:
                ranks = global_ranks
            assert len(ranks) <= len(global_ranks), (
                "Size of new group must be less than or "
271 272
                "equal to that of the default global group."
            )
273 274
        size = len(ranks)
        ranks = sorted(ranks)
L
LiYuRio 已提交
275
        if size > 1 and global_rank in ranks:
L
lilong12 已提交
276
            rank = 0 if backend == 'heter' else ranks.index(global_rank)
277 278 279 280 281 282 283 284 285
            pg = _new_process_group_impl(
                backend,
                _default_store,
                rank,
                size,
                group_name,
                pg_options=None,
                group_id=gid,
            )
286 287 288
        else:
            rank = -1
            pg = None
289
        group = Group(rank, gid, ranks, pg=pg, name=group_name)
290 291
        _group_map_by_name[group_name] = group
        _group_map[gid] = group
292
        _group_map_backend[group] = backend
293
        # TODO: The method below is a new method for group management, will replace the previous
294 295
        # three in the future.
        _add_new_group(group)
296

297
        # TODO(shenliang03): This is a temporary solution to solve the problem of
298
        # hang caused by tcp
299
        paddle.distributed.barrier(group=group)
L
LiYuRio 已提交
300 301
        if paddle.distributed.get_world_size() > 1:
            paddle.distributed.barrier()
302
        return group
K
kuizhiqing 已提交
303 304 305

    if not backend:
        backend = 'nccl'
306
    assert backend == 'nccl', "backend other than nccl is not supported yet"
K
kuizhiqing 已提交
307 308 309 310 311 312 313

    genv = _get_global_env()
    global_rank = genv.rank

    ring_id = _new_ring_id()

    if global_rank not in ranks:
314
        gp = Group(-1, ring_id, ranks)
K
kuizhiqing 已提交
315 316
        _group_map[ring_id] = gp
    else:
317 318 319
        ranks = sorted(ranks)
        group_rank = ranks.index(global_rank)
        group_size = len(ranks)
320
        gp = Group(group_rank, ring_id, ranks)
321 322 323 324 325 326 327 328 329 330 331 332 333 334
        _group_map[ring_id] = gp

        if group_size >= 2:
            strategy = core.ParallelStrategy()
            strategy.nranks = group_size
            strategy.local_rank = group_rank
            strategy.trainer_endpoints = [
                genv.trainer_endpoints[i] for i in ranks
            ]
            strategy.current_endpoint = genv.current_endpoint
            strategy.nrings = 1

            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(genv.device_id)
335 336 337
                core.NCCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
338 339
            elif core.is_compiled_with_npu():
                place = core.NPUPlace(genv.device_id)
340 341 342
                core.HCCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
343 344
            elif core.is_compiled_with_mlu():
                place = core.MLUPlace(genv.device_id)
345 346 347
                core.CNCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
348 349
            elif core.is_compiled_with_xpu():
                place = core.XPUPlace(genv.device_id)
350 351 352
                core.BKCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
353
            else:
354
                assert False, "no cuda device found"
355 356 357
        else:
            return gp

358
    # TODO(shenliang03): This is a temporary solution to solve the problem of
359
    # hang caused by cross-creation of new_group
360 361 362 363 364
    tmp = (
        paddle.to_tensor([1], dtype="int32")
        if _non_static_mode()
        else fill_constant([0], dtype="int32", value="1")
    )
365
    paddle.distributed.all_reduce(tmp, sync_op=True)
366
    paddle.distributed.wait(tmp)
K
kuizhiqing 已提交
367 368
    return gp

369

K
kuizhiqing 已提交
370 371 372 373 374 375 376 377
def wait(tensor, group=None, use_calc_stream=True):
    """

    wait to sync stream for group.

    Args:
        tensor (Tensor): The Tensor used before sync.
        group (Group): The Group instance to perform sync.
K
kuizhiqing 已提交
378 379
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
K
kuizhiqing 已提交
380 381 382 383 384 385 386 387 388 389

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
390
            tindata = paddle.randn(shape=[2, 3])
391
            paddle.distributed.all_reduce(tindata, sync_op=True)
K
kuizhiqing 已提交
392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408
            paddle.distributed.wait(tindata)

    """

    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id

    if use_calc_stream:
        _sync_calc_stream(tensor)
    else:
        _sync_comm_stream(tensor, ring_id)


def _sync_calc_stream(tensor):

J
Jiabin Yang 已提交
409
    if _non_static_mode():
410
        return _legacy_C_ops.c_sync_calc_stream(tensor, tensor)
K
kuizhiqing 已提交
411 412 413 414 415 416 417

    op_type = 'c_sync_calc_stream'

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
418 419
        outputs={'Out': [tensor]},
    )
420

421

K
kuizhiqing 已提交
422
def _sync_comm_stream(tensor, ring_id=0):
423

J
Jiabin Yang 已提交
424
    if _non_static_mode():
425 426 427
        return _legacy_C_ops.c_sync_comm_stream(
            [tensor], [tensor], 'ring_id', ring_id
        )
428

K
kuizhiqing 已提交
429
    op_type = 'c_sync_comm_stream'
430

K
kuizhiqing 已提交
431 432 433 434 435
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
436 437
        attrs={'ring_id': ring_id},
    )
K
kuizhiqing 已提交
438 439


440
def all_gather(tensor_list, tensor, group=None, sync_op=True):
441 442
    """

443
    Gather tensors from all participators and all get the result. As shown
444 445
    below, one process is started with a GPU and the data of this process is represented
    by its group rank. Through the all_gather operator, each GPU will have data
446 447 448 449 450 451
    from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png
        :width: 800
        :alt: all_gather
        :align: center
452 453 454

    Args:
        tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
455
            should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128.
456
        tensor (Tensor): The Tensor to send. Its data type
457
            should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
458 459
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
460 461 462 463 464 465 466

    Returns:
        None.

    Examples:
        .. code-block:: python

467
            # required: distributed
468
            import paddle
469
            import paddle.distributed as dist
470

471
            dist.init_parallel_env()
472
            tensor_list = []
473 474
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
475
            else:
476 477 478 479
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.all_gather(tensor_list, data)
            print(tensor_list)
            # [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
480
    """
K
kuizhiqing 已提交
481 482 483
    if group is not None and not group.is_member():
        return

484 485 486 487 488 489
    def convert_to_complex(list_of_tensor):
        list_of_complex = []
        for tensor in list_of_tensor:
            list_of_complex.append(paddle.as_complex(tensor))
        return list_of_complex

490 491 492
    is_input_complex = (
        tensor.dtype == paddle.complex64 or tensor.dtype == paddle.complex128
    )
493 494 495
    if is_input_complex:
        tensor = paddle.as_real(tensor)

L
lilong12 已提交
496
    if in_dygraph_mode():
497
        group = _get_default_group() if group is None else group
498 499 500 501 502 503
        if len(tensor_list) == 0:
            tensor_shape = list(tensor.shape)
            tensor_shape[0] *= group.nranks
            out = paddle.empty(tensor_shape, tensor.dtype)
        else:
            out = paddle.concat(tensor_list, axis=0)
L
LiYuRio 已提交
504
        task = group.process_group.all_gather_into_tensor(out, tensor, sync_op)
505 506
        task.wait()
        tensor_list.clear()
507 508 509 510 511
        list_of_tensor = paddle.split(out, group.nranks, 0)
        if is_input_complex:
            tensor_list.extend(convert_to_complex(list_of_tensor))
        else:
            tensor_list.extend(list_of_tensor)
512 513
        return

514
    use_calc_stream = sync_op
K
kuizhiqing 已提交
515 516 517
    ring_id = 0 if group is None else group.id
    nranks = _get_global_group().nranks if group is None else group.nranks

J
Jiabin Yang 已提交
518
    if _non_static_mode():
519 520 521 522 523 524 525 526 527
        out = _legacy_C_ops.c_allgather(
            tensor,
            'use_calc_stream',
            use_calc_stream,
            'ring_id',
            ring_id,
            'nranks',
            nranks,
        )
528
    else:
529 530 531
        op_type = 'c_allgather'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
532
        if not isinstance(tensor_list, list):
533 534 535
            raise ValueError(
                "The type of 'tensor_list' for all_gather " "should be list."
            )
536
        for elem in tensor_list:
537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580
            check_variable_and_dtype(
                elem,
                'tensor_list',
                [
                    'float16',
                    'float32',
                    'float64',
                    'int32',
                    'int64',
                    'bool',
                    'int8',
                    'uint8',
                    'complex64',
                    'complex128',
                ],
                'all_gather',
            )
        check_variable_and_dtype(
            tensor,
            'tensor',
            [
                'float16',
                'float32',
                'float64',
                'int32',
                'int64',
                'bool',
                'int8',
                'uint8',
                'complex64',
                'complex128',
            ],
            'all_gather',
        )
        helper.append_op(
            type=op_type,
            inputs={'X': [tensor]},
            outputs={'Out': [out]},
            attrs={
                'ring_id': ring_id,
                'use_calc_stream': use_calc_stream,
                'nranks': nranks,
            },
        )
581

582 583 584 585 586 587 588 589 590 591 592 593 594
    list_of_tensor = paddle.split(out, nranks, 0)
    if is_input_complex:
        tensor_list.extend(convert_to_complex(list_of_tensor))
    else:
        tensor_list.extend(list_of_tensor)


def _convert_object_to_tensor(obj):
    _pickler = pickle.Pickler
    f = io.BytesIO()
    _pickler(f).dump(obj)
    data = np.frombuffer(f.getvalue(), dtype=np.uint8)
    tensor = paddle.to_tensor(data)
595
    return tensor, tensor.numel()
596 597


598
def _convert_tensor_to_object(tensor, len_of_tensor):
599
    _unpickler = pickle.Unpickler
600
    return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()
601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627


def all_gather_object(object_list, obj, group=None):
    """

    Gather picklable objects from all participators and all get the result. Similiar to all_gather(), but python object can be passed in.

    Args:
        object_list (list): A list of output object. The datatype of every element in the list is same as the input obj.
        obj (Any): The picklable object to send.
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Warning:
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            object_list = []
628
            if dist.get_rank() == 0:
629 630 631
                obj = {"foo": [1, 2, 3]}
            else:
                obj = {"bar": [4, 5, 6]}
632 633 634
            dist.all_gather_object(object_list, obj)
            print(object_list)
            # [{'foo': [1, 2, 3]}, {'bar': [4, 5, 6]}] (2 GPUs)
635
    """
636 637
    assert (
        in_dygraph_mode()
638 639
    ), "all_gather_object doesn't support static graph mode."

640 641 642 643 644 645 646 647 648 649 650 651 652
    tensor, len_of_tensor = _convert_object_to_tensor(obj)

    # gather len_of_tensor from all ranks
    list_len_of_tensor = []
    all_gather(list_len_of_tensor, len_of_tensor, group)
    # get the max length from list
    max_len_of_tensor = int(max(list_len_of_tensor).item())
    # resize the input tensor to max length avoid hang in all gather
    # Note(liyurui): Maybe we should support various length all_gather?
    # Now this operation is efficient for we don't support resize in python.
    numpy_data = tensor.numpy()
    numpy_data = np.resize(numpy_data, [max_len_of_tensor])
    input_tensor = paddle.to_tensor(numpy_data)
653 654

    tensor_list = []
655 656 657
    all_gather(tensor_list, input_tensor, group)
    for i, tensor in enumerate(tensor_list):
        object_list.append(
658 659
            _convert_tensor_to_object(tensor, list_len_of_tensor[i])
        )