collective.py 20.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
16 17
import pickle
import io
18
import datetime
19
from ..fluid.layer_helper import LayerHelper
20
from ..fluid.framework import in_dygraph_mode
J
Jiabin Yang 已提交
21
from ..fluid.framework import _non_static_mode
22
from ..fluid.data_feeder import check_variable_and_dtype
23 24 25
from ..fluid.layers.tensor import fill_constant
import paddle
import paddle.fluid.core as core
26 27 28 29 30 31 32 33 34 35 36 37 38
from paddle import _legacy_C_ops
from .fleet.layers.mpu.mp_ops import split  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_identity  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_concat  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_split  # noqa: F401
from .fleet.layers.mpu.mp_ops import _mp_allreduce  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_lookup_table  # noqa: F401
from .fleet.layers.mpu.mp_ops import _Linear  # noqa: F401
from .fleet.layers.mpu.mp_ops import _set_var_distributed  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_softmax_with_cross_entropy  # noqa: F401
from .fleet.layers.mpu.mp_ops import _linear  # noqa: F401
from .fleet.layers.mpu.mp_ops import _parallel_linear  # noqa: F401
from .fleet.layers.mpu.mp_ops import _parallel_embedding  # noqa: F401
39
from .communication.group import Group, _add_new_group, is_initialized
40

41
__all__ = []
42

K
kuizhiqing 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55
_global_env = None


def _get_global_env():
    global _global_env
    if not _global_env:
        _global_env = paddle.distributed.ParallelEnv()
    return _global_env


# group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group]
_group_map = {}
56
_global_env_gid = 0
K
kuizhiqing 已提交
57

58 59 60 61
# group map by name : the map of all groups from their names
# Dict[name, Group]
_group_map_by_name = {}

62 63 64 65
# backend map by group : the map of all backend from their groups
# Dict[group, backend]
_group_map_backend = {}

66 67 68
# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"

J
james 已提交
69
_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter', 'xccl', 'bkcl']
70 71
_default_store = None  # the default tcp store
_default_backend = None
72 73
_default_timeout = datetime.timedelta(seconds=1800)
_start_ring_id = 0
74

K
kuizhiqing 已提交
75

L
lilong12 已提交
76 77 78 79 80 81 82 83 84 85
def _set_default_backend(backend):
    global _default_backend
    _default_backend = backend


def _set_default_store(store):
    global _default_store
    _default_store = store


K
kuizhiqing 已提交
86 87
def _get_group_map():
    global _group_map
88
    if _global_env_gid not in _group_map:
K
kuizhiqing 已提交
89
        genv = _get_global_env()
90 91 92
        _group_map[_global_env_gid] = Group(
            genv.rank, 0, list(range(genv.world_size))
        )
K
kuizhiqing 已提交
93 94 95 96
    return _group_map


def _get_global_group():
97
    return _get_group_map()[_global_env_gid]
K
kuizhiqing 已提交
98 99


100 101 102 103 104 105
def _get_group_map_by_name():
    global _group_map_by_name
    return _group_map_by_name


def _get_default_group():
L
lilong12 已提交
106
    global _group_map_by_name
107 108 109 110
    assert is_initialized(), (
        "Call paddle.distributed.init_parallel_env first "
        "to initialize the distributed environment."
    )
111 112 113
    return _get_group_map_by_name()[_default_group_name]


L
lilong12 已提交
114 115 116 117 118 119 120 121 122 123 124 125
def _set_group_map(gid, group):
    global _group_map
    assert gid not in _group_map
    _group_map[gid] = group


def _set_group_map_by_name(name, group):
    global _group_map_by_name
    assert name not in _group_map_by_name
    _group_map_by_name[name] = group


126 127 128 129 130 131
def _set_group_map_backend(group, backend):
    global _group_map_backend
    assert group not in _group_map_backend
    _group_map_backend[group] = backend


K
kuizhiqing 已提交
132
def _new_ring_id():
133 134 135 136 137 138 139
    # NOTE(liyurui): For compatible reason, auto parallel and eager mode relay on previous syntax.
    if in_dygraph_mode():
        global _start_ring_id
        _start_ring_id += 1
        return _start_ring_id + max(_get_global_env().nrings, 9)
    else:
        return len(_get_group_map()) + max(_get_global_env().nrings, 9)
K
kuizhiqing 已提交
140 141


142 143 144 145 146 147 148 149 150
def _new_process_group_impl(
    backend,
    store,
    rank,
    world_size,
    group_name,
    pg_options,
    group_id=0,
):
151
    pg = None
152
    genv = _get_global_env()
L
lilong12 已提交
153
    assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
154
    if backend == "gloo":
155 156
        place = core.CPUPlace()
        pg = core.ProcessGroupGloo(store, rank, world_size, place, group_id)
157
    elif backend == "nccl":
158 159
        place = core.CUDAPlace(genv.device_id)
        pg = core.ProcessGroupNCCL(store, rank, world_size, place, group_id)
160 161 162
    elif backend == "xccl":
        place = core.CustomPlace(genv.device_type, genv.device_id)
        pg = core.ProcessGroupCustom(store, rank, world_size, place, group_id)
J
james 已提交
163 164 165
    elif backend == "bkcl":
        place = core.XPUPlace(genv.device_id)
        pg = core.ProcessGroupBKCL(store, rank, world_size, place, group_id)
166 167 168
    return pg


S
ShenLiang 已提交
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
def barrier(group=None):
    """

    Barrier among all participators in the group.

    Args:
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            paddle.distributed.barrier()
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
193
    if in_dygraph_mode():
194 195 196 197 198
        group = _get_default_group() if group is None else group
        task = group.process_group.barrier()
        task.wait()
        return

S
ShenLiang 已提交
199 200 201
    ring_id = 0 if group is None else group.id

    temp = fill_constant([1], dtype="int32", value="1")
J
Jiabin Yang 已提交
202
    if _non_static_mode():
203
        return _legacy_C_ops.barrier(temp, temp, 'ring_id', ring_id)
W
wanghuancoder 已提交
204 205 206

    op_type = 'barrier'

S
ShenLiang 已提交
207 208 209
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'group' for barrier must be int.")
    helper = LayerHelper(op_type, **locals())
210 211 212 213 214 215
    helper.append_op(
        type=op_type,
        inputs={'X': [temp]},
        outputs={'Out': [temp]},
        attrs={'ring_id': ring_id},
    )
S
ShenLiang 已提交
216 217


L
lilong12 已提交
218 219 220 221 222 223 224
# _custom_gid provides a way for users to
# set the group id, which is usually useful
# to be compatible with the static mode.
_custom_gid = None


def _set_custom_gid(gid):
225
    global _custom_gid
L
lilong12 已提交
226 227 228
    _custom_gid = gid


229
def new_group(ranks=None, backend=None, timeout=_default_timeout):
K
kuizhiqing 已提交
230 231
    """

K
kuizhiqing 已提交
232
    Creates a new distributed communication group.
K
kuizhiqing 已提交
233 234

    Args:
K
kuizhiqing 已提交
235
        ranks (list): The global ranks of group members.
K
kuizhiqing 已提交
236
        backend (str): The backend used to create group, only nccl is supported now.
237
        timeout (datetime.timedelta, optional): The waiting timeout for store relevant options, default is 30 minutes.
K
kuizhiqing 已提交
238 239

    Returns:
K
kuizhiqing 已提交
240
        Group: The group instance.
K
kuizhiqing 已提交
241 242 243 244 245 246 247

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
248 249
            tindata = paddle.randn(shape=[2, 3])
            gp = paddle.distributed.new_group([2,4,6])
250
            paddle.distributed.all_reduce(tindata, group=gp, sync_op=False)
K
kuizhiqing 已提交
251 252

    """
253
    global _custom_gid
254
    global _group_map
L
lilong12 已提交
255
    if in_dygraph_mode():
256
        global _default_group_name
L
lilong12 已提交
257
        gid = _custom_gid if _custom_gid else _new_ring_id()
258
        group_name = _default_group_name + str(gid)
L
lilong12 已提交
259
        if backend != 'heter' and (ranks is None or len(ranks) > 1):
260 261 262 263 264 265 266 267
            global_group = _get_default_group()
            global_rank = global_group.rank
            global_ranks = global_group.ranks
            backend = _default_backend if backend is None else backend
            if ranks is None:
                ranks = global_ranks
            assert len(ranks) <= len(global_ranks), (
                "Size of new group must be less than or "
268 269
                "equal to that of the default global group."
            )
270 271
        size = len(ranks)
        ranks = sorted(ranks)
L
LiYuRio 已提交
272
        if size > 1 and global_rank in ranks:
L
lilong12 已提交
273
            rank = 0 if backend == 'heter' else ranks.index(global_rank)
274 275 276 277 278 279 280 281 282
            pg = _new_process_group_impl(
                backend,
                _default_store,
                rank,
                size,
                group_name,
                pg_options=None,
                group_id=gid,
            )
283 284 285
        else:
            rank = -1
            pg = None
286
        group = Group(rank, gid, ranks, pg=pg, name=group_name)
287 288
        _group_map_by_name[group_name] = group
        _group_map[gid] = group
289
        _group_map_backend[group] = backend
290
        # TODO: The method below is a new method for group management, will replace the previous
291 292
        # three in the future.
        _add_new_group(group)
293

294
        # TODO(shenliang03): This is a temporary solution to solve the problem of
295
        # hang caused by tcp
296
        paddle.distributed.barrier(group=group)
L
LiYuRio 已提交
297 298
        if paddle.distributed.get_world_size() > 1:
            paddle.distributed.barrier()
299
        return group
K
kuizhiqing 已提交
300 301 302

    if not backend:
        backend = 'nccl'
303
    assert backend == 'nccl', "backend other than nccl is not supported yet"
K
kuizhiqing 已提交
304 305 306 307 308 309 310

    genv = _get_global_env()
    global_rank = genv.rank

    ring_id = _new_ring_id()

    if global_rank not in ranks:
311
        gp = Group(-1, ring_id, ranks)
K
kuizhiqing 已提交
312 313
        _group_map[ring_id] = gp
    else:
314 315 316
        ranks = sorted(ranks)
        group_rank = ranks.index(global_rank)
        group_size = len(ranks)
317
        gp = Group(group_rank, ring_id, ranks)
318 319 320 321 322 323 324 325 326 327 328 329 330 331
        _group_map[ring_id] = gp

        if group_size >= 2:
            strategy = core.ParallelStrategy()
            strategy.nranks = group_size
            strategy.local_rank = group_rank
            strategy.trainer_endpoints = [
                genv.trainer_endpoints[i] for i in ranks
            ]
            strategy.current_endpoint = genv.current_endpoint
            strategy.nrings = 1

            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(genv.device_id)
332 333 334
                core.NCCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
335 336
            elif core.is_compiled_with_npu():
                place = core.NPUPlace(genv.device_id)
337 338 339
                core.HCCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
340 341
            elif core.is_compiled_with_mlu():
                place = core.MLUPlace(genv.device_id)
342 343 344
                core.CNCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
345 346
            elif core.is_compiled_with_xpu():
                place = core.XPUPlace(genv.device_id)
347 348 349
                core.BKCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
350
            else:
351
                assert False, "no cuda device found"
352 353 354
        else:
            return gp

355
    # TODO(shenliang03): This is a temporary solution to solve the problem of
356
    # hang caused by cross-creation of new_group
357 358 359 360 361
    tmp = (
        paddle.to_tensor([1], dtype="int32")
        if _non_static_mode()
        else fill_constant([0], dtype="int32", value="1")
    )
362
    paddle.distributed.all_reduce(tmp, sync_op=True)
363
    paddle.distributed.wait(tmp)
K
kuizhiqing 已提交
364 365
    return gp

366

K
kuizhiqing 已提交
367 368 369 370 371 372 373 374
def wait(tensor, group=None, use_calc_stream=True):
    """

    wait to sync stream for group.

    Args:
        tensor (Tensor): The Tensor used before sync.
        group (Group): The Group instance to perform sync.
K
kuizhiqing 已提交
375 376
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
K
kuizhiqing 已提交
377 378 379 380 381 382 383 384 385 386

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
387
            tindata = paddle.randn(shape=[2, 3])
388
            paddle.distributed.all_reduce(tindata, sync_op=True)
K
kuizhiqing 已提交
389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405
            paddle.distributed.wait(tindata)

    """

    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id

    if use_calc_stream:
        _sync_calc_stream(tensor)
    else:
        _sync_comm_stream(tensor, ring_id)


def _sync_calc_stream(tensor):

J
Jiabin Yang 已提交
406
    if _non_static_mode():
407
        return _legacy_C_ops.c_sync_calc_stream(tensor, tensor)
K
kuizhiqing 已提交
408 409 410 411 412 413 414

    op_type = 'c_sync_calc_stream'

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
415 416
        outputs={'Out': [tensor]},
    )
417

418

K
kuizhiqing 已提交
419
def _sync_comm_stream(tensor, ring_id=0):
420

J
Jiabin Yang 已提交
421
    if _non_static_mode():
422 423 424
        return _legacy_C_ops.c_sync_comm_stream(
            [tensor], [tensor], 'ring_id', ring_id
        )
425

K
kuizhiqing 已提交
426
    op_type = 'c_sync_comm_stream'
427

K
kuizhiqing 已提交
428 429 430 431 432
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
433 434
        attrs={'ring_id': ring_id},
    )
K
kuizhiqing 已提交
435 436


437
def all_gather(tensor_list, tensor, group=None, sync_op=True):
438 439
    """

440
    Gather tensors from all participators and all get the result. As shown
441 442
    below, one process is started with a GPU and the data of this process is represented
    by its group rank. Through the all_gather operator, each GPU will have data
443 444 445 446 447 448
    from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png
        :width: 800
        :alt: all_gather
        :align: center
449 450 451

    Args:
        tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
452
            should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128.
453
        tensor (Tensor): The Tensor to send. Its data type
454
            should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
455 456
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
457 458 459 460 461 462 463

    Returns:
        None.

    Examples:
        .. code-block:: python

464
            # required: distributed
465
            import paddle
466
            import paddle.distributed as dist
467

468
            dist.init_parallel_env()
469
            tensor_list = []
470 471
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
472
            else:
473 474 475 476
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.all_gather(tensor_list, data)
            print(tensor_list)
            # [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
477
    """
K
kuizhiqing 已提交
478 479 480
    if group is not None and not group.is_member():
        return

481 482 483 484 485 486
    def convert_to_complex(list_of_tensor):
        list_of_complex = []
        for tensor in list_of_tensor:
            list_of_complex.append(paddle.as_complex(tensor))
        return list_of_complex

487 488 489
    is_input_complex = (
        tensor.dtype == paddle.complex64 or tensor.dtype == paddle.complex128
    )
490 491 492
    if is_input_complex:
        tensor = paddle.as_real(tensor)

L
lilong12 已提交
493
    if in_dygraph_mode():
494
        group = _get_default_group() if group is None else group
495 496 497 498 499 500
        if len(tensor_list) == 0:
            tensor_shape = list(tensor.shape)
            tensor_shape[0] *= group.nranks
            out = paddle.empty(tensor_shape, tensor.dtype)
        else:
            out = paddle.concat(tensor_list, axis=0)
L
LiYuRio 已提交
501
        task = group.process_group.all_gather_into_tensor(out, tensor, sync_op)
502 503
        task.wait()
        tensor_list.clear()
504 505 506 507 508
        list_of_tensor = paddle.split(out, group.nranks, 0)
        if is_input_complex:
            tensor_list.extend(convert_to_complex(list_of_tensor))
        else:
            tensor_list.extend(list_of_tensor)
509 510
        return

511
    use_calc_stream = sync_op
K
kuizhiqing 已提交
512 513 514
    ring_id = 0 if group is None else group.id
    nranks = _get_global_group().nranks if group is None else group.nranks

J
Jiabin Yang 已提交
515
    if _non_static_mode():
516 517 518 519 520 521 522 523 524
        out = _legacy_C_ops.c_allgather(
            tensor,
            'use_calc_stream',
            use_calc_stream,
            'ring_id',
            ring_id,
            'nranks',
            nranks,
        )
525
    else:
526 527 528
        op_type = 'c_allgather'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
529
        if not isinstance(tensor_list, list):
530 531 532
            raise ValueError(
                "The type of 'tensor_list' for all_gather " "should be list."
            )
533
        for elem in tensor_list:
534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577
            check_variable_and_dtype(
                elem,
                'tensor_list',
                [
                    'float16',
                    'float32',
                    'float64',
                    'int32',
                    'int64',
                    'bool',
                    'int8',
                    'uint8',
                    'complex64',
                    'complex128',
                ],
                'all_gather',
            )
        check_variable_and_dtype(
            tensor,
            'tensor',
            [
                'float16',
                'float32',
                'float64',
                'int32',
                'int64',
                'bool',
                'int8',
                'uint8',
                'complex64',
                'complex128',
            ],
            'all_gather',
        )
        helper.append_op(
            type=op_type,
            inputs={'X': [tensor]},
            outputs={'Out': [out]},
            attrs={
                'ring_id': ring_id,
                'use_calc_stream': use_calc_stream,
                'nranks': nranks,
            },
        )
578

579 580 581 582 583 584 585 586 587 588 589 590 591
    list_of_tensor = paddle.split(out, nranks, 0)
    if is_input_complex:
        tensor_list.extend(convert_to_complex(list_of_tensor))
    else:
        tensor_list.extend(list_of_tensor)


def _convert_object_to_tensor(obj):
    _pickler = pickle.Pickler
    f = io.BytesIO()
    _pickler(f).dump(obj)
    data = np.frombuffer(f.getvalue(), dtype=np.uint8)
    tensor = paddle.to_tensor(data)
592
    return tensor, tensor.numel()
593 594


595
def _convert_tensor_to_object(tensor, len_of_tensor):
596
    _unpickler = pickle.Unpickler
597
    return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()
598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624


def all_gather_object(object_list, obj, group=None):
    """

    Gather picklable objects from all participators and all get the result. Similiar to all_gather(), but python object can be passed in.

    Args:
        object_list (list): A list of output object. The datatype of every element in the list is same as the input obj.
        obj (Any): The picklable object to send.
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Warning:
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            object_list = []
625
            if dist.get_rank() == 0:
626 627 628
                obj = {"foo": [1, 2, 3]}
            else:
                obj = {"bar": [4, 5, 6]}
629 630 631
            dist.all_gather_object(object_list, obj)
            print(object_list)
            # [{'foo': [1, 2, 3]}, {'bar': [4, 5, 6]}] (2 GPUs)
632
    """
633 634
    assert (
        in_dygraph_mode()
635 636
    ), "all_gather_object doesn't support static graph mode."

637 638 639 640 641 642 643 644 645 646 647 648 649
    tensor, len_of_tensor = _convert_object_to_tensor(obj)

    # gather len_of_tensor from all ranks
    list_len_of_tensor = []
    all_gather(list_len_of_tensor, len_of_tensor, group)
    # get the max length from list
    max_len_of_tensor = int(max(list_len_of_tensor).item())
    # resize the input tensor to max length avoid hang in all gather
    # Note(liyurui): Maybe we should support various length all_gather?
    # Now this operation is efficient for we don't support resize in python.
    numpy_data = tensor.numpy()
    numpy_data = np.resize(numpy_data, [max_len_of_tensor])
    input_tensor = paddle.to_tensor(numpy_data)
650 651

    tensor_list = []
652 653 654
    all_gather(tensor_list, input_tensor, group)
    for i, tensor in enumerate(tensor_list):
        object_list.append(
655 656
            _convert_tensor_to_object(tensor, list_len_of_tensor[i])
        )