collective.py 22.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import os
17 18
import pickle
import io
19
import datetime
20
from ..fluid.layer_helper import LayerHelper
21
from ..fluid.framework import in_dygraph_mode
J
Jiabin Yang 已提交
22
from ..fluid.framework import _non_static_mode
23
from ..fluid.data_feeder import check_variable_and_dtype
24 25 26
from ..fluid.layers.tensor import fill_constant
import paddle
import paddle.fluid.core as core
27 28 29 30 31 32 33 34 35 36 37 38 39
from paddle import _legacy_C_ops
from .fleet.layers.mpu.mp_ops import split  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_identity  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_concat  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_split  # noqa: F401
from .fleet.layers.mpu.mp_ops import _mp_allreduce  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_lookup_table  # noqa: F401
from .fleet.layers.mpu.mp_ops import _Linear  # noqa: F401
from .fleet.layers.mpu.mp_ops import _set_var_distributed  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_softmax_with_cross_entropy  # noqa: F401
from .fleet.layers.mpu.mp_ops import _linear  # noqa: F401
from .fleet.layers.mpu.mp_ops import _parallel_linear  # noqa: F401
from .fleet.layers.mpu.mp_ops import _parallel_embedding  # noqa: F401
40
from .communication.group import Group, _add_new_group, is_initialized
41

42
__all__ = []
43

K
kuizhiqing 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56
_global_env = None


def _get_global_env():
    global _global_env
    if not _global_env:
        _global_env = paddle.distributed.ParallelEnv()
    return _global_env


# group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group]
_group_map = {}
57
_global_env_gid = 0
K
kuizhiqing 已提交
58

59 60 61 62
# group map by name : the map of all groups from their names
# Dict[name, Group]
_group_map_by_name = {}

63 64 65 66
# backend map by group : the map of all backend from their groups
# Dict[group, backend]
_group_map_backend = {}

67 68 69
# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"

70
_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter', 'xccl']
71 72
_default_store = None  # the default tcp store
_default_backend = None
73 74
_default_timeout = datetime.timedelta(seconds=1800)
_start_ring_id = 0
75

K
kuizhiqing 已提交
76

L
lilong12 已提交
77 78 79 80 81 82 83 84 85 86
def _set_default_backend(backend):
    global _default_backend
    _default_backend = backend


def _set_default_store(store):
    global _default_store
    _default_store = store


K
kuizhiqing 已提交
87 88
def _get_group_map():
    global _group_map
89
    if _global_env_gid not in _group_map:
K
kuizhiqing 已提交
90
        genv = _get_global_env()
91 92 93
        _group_map[_global_env_gid] = Group(
            genv.rank, 0, list(range(genv.world_size))
        )
K
kuizhiqing 已提交
94 95 96 97
    return _group_map


def _get_global_group():
98
    return _get_group_map()[_global_env_gid]
K
kuizhiqing 已提交
99 100


101 102 103 104 105 106
def _get_group_map_by_name():
    global _group_map_by_name
    return _group_map_by_name


def _get_default_group():
L
lilong12 已提交
107
    global _group_map_by_name
108 109 110 111
    assert is_initialized(), (
        "Call paddle.distributed.init_parallel_env first "
        "to initialize the distributed environment."
    )
112 113 114
    return _get_group_map_by_name()[_default_group_name]


L
lilong12 已提交
115 116 117 118 119 120 121 122 123 124 125 126
def _set_group_map(gid, group):
    global _group_map
    assert gid not in _group_map
    _group_map[gid] = group


def _set_group_map_by_name(name, group):
    global _group_map_by_name
    assert name not in _group_map_by_name
    _group_map_by_name[name] = group


127 128 129 130 131 132
def _set_group_map_backend(group, backend):
    global _group_map_backend
    assert group not in _group_map_backend
    _group_map_backend[group] = backend


K
kuizhiqing 已提交
133
def _new_ring_id():
134 135 136 137 138 139 140
    # NOTE(liyurui): For compatible reason, auto parallel and eager mode relay on previous syntax.
    if in_dygraph_mode():
        global _start_ring_id
        _start_ring_id += 1
        return _start_ring_id + max(_get_global_env().nrings, 9)
    else:
        return len(_get_group_map()) + max(_get_global_env().nrings, 9)
K
kuizhiqing 已提交
141 142


143 144 145 146 147 148 149 150 151 152 153
def _new_process_group_impl(
    backend,
    store,
    rank,
    world_size,
    group_name,
    pg_options,
    group_id=0,
    src_rank=None,
    dst_rank=None,
):
154
    pg = None
155
    genv = _get_global_env()
L
lilong12 已提交
156 157
    if backend != 'heter':
        assert src_rank is None and dst_rank is None, (
158 159
            "src_rank and dst_rank " "can only be set for heter backend."
        )
L
lilong12 已提交
160
    assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
161
    if backend == "gloo":
162 163
        place = core.CPUPlace()
        pg = core.ProcessGroupGloo(store, rank, world_size, place, group_id)
164
    elif backend == "nccl":
165 166
        place = core.CUDAPlace(genv.device_id)
        pg = core.ProcessGroupNCCL(store, rank, world_size, place, group_id)
167
    elif backend == "hccl":
168 169
        place = core.NPUPlace(genv.device_id)
        pg = core.ProcessGroupHCCL(store, rank, world_size, place, group_id)
170 171 172
    elif backend == "xccl":
        place = core.CustomPlace(genv.device_type, genv.device_id)
        pg = core.ProcessGroupCustom(store, rank, world_size, place, group_id)
173
    elif backend == "heter":
174 175 176 177 178
        place = None
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(genv.device_id)
        elif core.is_compiled_with_npu():
            place = core.NPUPlace(genv.device_id)
179 180 181 182 183 184 185 186 187
        cluster_id = int(os.getenv("CLUSTER_ID", "-1"))
        assert cluster_id >= 0, "please set the CLUSTER_ID variable."
        cluster_size = os.getenv("CLUSTER_SIZE", None)
        assert cluster_size, "please set the CLUSTER_SIZE variable."
        cluster_size = cluster_size.split(",")
        cluster_size = [int(s) for s in cluster_size]
        switch_ep = os.getenv("CLUSTER_SWITCH", None)
        assert switch_ep, "please set the CLUSTER_SWITCH variable."
        cluster_size_cumsum = np.cumsum(cluster_size)
188 189 190
        cluster_offset = (
            0 if cluster_id == 0 else cluster_size_cumsum[cluster_id - 1]
        )
191 192
        global_rank = cluster_offset + rank
        global_world_size = cluster_size_cumsum[-1]
193
        global_rank, global_world_size = _get_global_config(backend, rank)
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
        pg = core.ProcessGroupHeter(
            store,
            rank=global_rank,
            world_size=global_world_size,
            place=place,
            gid=group_id,
            local_rank=rank,
            local_size=world_size,
            gloo_rank=cluster_id,
            gloo_size=len(cluster_size),
            with_switch=True,
            switch_endpoint=switch_ep,
            src_rank=src_rank,
            dst_rank=dst_rank,
        )
209 210 211 212

    return pg


S
ShenLiang 已提交
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
def barrier(group=None):
    """

    Barrier among all participators in the group.

    Args:
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle
            from paddle.distributed import init_parallel_env

            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            init_parallel_env()
            paddle.distributed.barrier()
    """
    if group is not None and not group.is_member():
        return

L
lilong12 已提交
237
    if in_dygraph_mode():
238 239 240 241 242
        group = _get_default_group() if group is None else group
        task = group.process_group.barrier()
        task.wait()
        return

S
ShenLiang 已提交
243 244 245
    ring_id = 0 if group is None else group.id

    temp = fill_constant([1], dtype="int32", value="1")
J
Jiabin Yang 已提交
246
    if _non_static_mode():
247
        return _legacy_C_ops.barrier(temp, temp, 'ring_id', ring_id)
W
wanghuancoder 已提交
248 249 250

    op_type = 'barrier'

S
ShenLiang 已提交
251 252 253
    if not isinstance(ring_id, int):
        raise ValueError("The type of 'group' for barrier must be int.")
    helper = LayerHelper(op_type, **locals())
254 255 256 257 258 259
    helper.append_op(
        type=op_type,
        inputs={'X': [temp]},
        outputs={'Out': [temp]},
        attrs={'ring_id': ring_id},
    )
S
ShenLiang 已提交
260 261


L
lilong12 已提交
262 263 264 265 266 267 268
# _custom_gid provides a way for users to
# set the group id, which is usually useful
# to be compatible with the static mode.
_custom_gid = None


def _set_custom_gid(gid):
269
    global _custom_gid
L
lilong12 已提交
270 271 272
    _custom_gid = gid


273
def new_group(ranks=None, backend=None, timeout=_default_timeout):
K
kuizhiqing 已提交
274 275
    """

K
kuizhiqing 已提交
276
    Creates a new distributed communication group.
K
kuizhiqing 已提交
277 278

    Args:
K
kuizhiqing 已提交
279
        ranks (list): The global ranks of group members.
K
kuizhiqing 已提交
280
        backend (str): The backend used to create group, only nccl is supported now.
281
        timeout (datetime.timedelta, optional): The waiting timeout for store relevant options, default is 30 minutes.
K
kuizhiqing 已提交
282 283

    Returns:
K
kuizhiqing 已提交
284
        Group: The group instance.
K
kuizhiqing 已提交
285 286 287 288 289 290 291

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
292 293
            tindata = paddle.randn(shape=[2, 3])
            gp = paddle.distributed.new_group([2,4,6])
294
            paddle.distributed.all_reduce(tindata, group=gp, sync_op=False)
K
kuizhiqing 已提交
295 296

    """
297
    global _custom_gid
298
    global _group_map
L
lilong12 已提交
299
    if in_dygraph_mode():
300
        global _default_group_name
L
lilong12 已提交
301
        gid = _custom_gid if _custom_gid else _new_ring_id()
302
        group_name = _default_group_name + str(gid)
L
lilong12 已提交
303
        if backend != 'heter' and (ranks is None or len(ranks) > 1):
304 305 306 307 308 309 310 311
            global_group = _get_default_group()
            global_rank = global_group.rank
            global_ranks = global_group.ranks
            backend = _default_backend if backend is None else backend
            if ranks is None:
                ranks = global_ranks
            assert len(ranks) <= len(global_ranks), (
                "Size of new group must be less than or "
312 313
                "equal to that of the default global group."
            )
314 315
        size = len(ranks)
        ranks = sorted(ranks)
L
lilong12 已提交
316 317 318 319
        if backend == 'heter' or (size > 1 and global_rank in ranks):
            rank = 0 if backend == 'heter' else ranks.index(global_rank)
            src_rank = ranks[0] if backend == 'heter' else None
            dst_rank = ranks[1] if backend == 'heter' else None
320 321 322 323 324 325 326 327 328 329 330
            pg = _new_process_group_impl(
                backend,
                _default_store,
                rank,
                size,
                group_name,
                pg_options=None,
                group_id=gid,
                src_rank=src_rank,
                dst_rank=dst_rank,
            )
331 332 333
        else:
            rank = -1
            pg = None
334
        group = Group(rank, gid, ranks, pg=pg, name=group_name)
335 336
        _group_map_by_name[group_name] = group
        _group_map[gid] = group
337
        _group_map_backend[group] = backend
338
        # TODO: The method below is a new method for group management, will replace the previous
339 340
        # three in the future.
        _add_new_group(group)
341

342
        # TODO(shenliang03): This is a temporary solution to solve the problem of
343
        # hang caused by tcp
344
        paddle.distributed.barrier(group=group)
L
LiYuRio 已提交
345 346
        if paddle.distributed.get_world_size() > 1:
            paddle.distributed.barrier()
347
        return group
K
kuizhiqing 已提交
348 349 350

    if not backend:
        backend = 'nccl'
351
    assert backend == 'nccl', "backend other than nccl is not supported yet"
K
kuizhiqing 已提交
352 353 354 355 356 357 358

    genv = _get_global_env()
    global_rank = genv.rank

    ring_id = _new_ring_id()

    if global_rank not in ranks:
359
        gp = Group(-1, ring_id, ranks)
K
kuizhiqing 已提交
360 361
        _group_map[ring_id] = gp
    else:
362 363 364
        ranks = sorted(ranks)
        group_rank = ranks.index(global_rank)
        group_size = len(ranks)
365
        gp = Group(group_rank, ring_id, ranks)
366 367 368 369 370 371 372 373 374 375 376 377 378 379
        _group_map[ring_id] = gp

        if group_size >= 2:
            strategy = core.ParallelStrategy()
            strategy.nranks = group_size
            strategy.local_rank = group_rank
            strategy.trainer_endpoints = [
                genv.trainer_endpoints[i] for i in ranks
            ]
            strategy.current_endpoint = genv.current_endpoint
            strategy.nrings = 1

            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(genv.device_id)
380 381 382
                core.NCCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
383 384
            elif core.is_compiled_with_npu():
                place = core.NPUPlace(genv.device_id)
385 386 387
                core.HCCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
388 389
            elif core.is_compiled_with_mlu():
                place = core.MLUPlace(genv.device_id)
390 391 392
                core.CNCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
393 394
            elif core.is_compiled_with_xpu():
                place = core.XPUPlace(genv.device_id)
395 396 397
                core.BKCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
398
            else:
399
                assert False, "no cuda device found"
400 401 402
        else:
            return gp

403
    # TODO(shenliang03): This is a temporary solution to solve the problem of
404
    # hang caused by cross-creation of new_group
405 406 407 408 409
    tmp = (
        paddle.to_tensor([1], dtype="int32")
        if _non_static_mode()
        else fill_constant([0], dtype="int32", value="1")
    )
410
    paddle.distributed.all_reduce(tmp, sync_op=True)
411
    paddle.distributed.wait(tmp)
K
kuizhiqing 已提交
412 413
    return gp

414

K
kuizhiqing 已提交
415 416 417 418 419 420 421 422
def wait(tensor, group=None, use_calc_stream=True):
    """

    wait to sync stream for group.

    Args:
        tensor (Tensor): The Tensor used before sync.
        group (Group): The Group instance to perform sync.
K
kuizhiqing 已提交
423 424
        use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
            Default to True.
K
kuizhiqing 已提交
425 426 427 428 429 430 431 432 433 434

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
435
            tindata = paddle.randn(shape=[2, 3])
436
            paddle.distributed.all_reduce(tindata, sync_op=True)
K
kuizhiqing 已提交
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453
            paddle.distributed.wait(tindata)

    """

    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id

    if use_calc_stream:
        _sync_calc_stream(tensor)
    else:
        _sync_comm_stream(tensor, ring_id)


def _sync_calc_stream(tensor):

J
Jiabin Yang 已提交
454
    if _non_static_mode():
455
        return _legacy_C_ops.c_sync_calc_stream(tensor, tensor)
K
kuizhiqing 已提交
456 457 458 459 460 461 462

    op_type = 'c_sync_calc_stream'

    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
463 464
        outputs={'Out': [tensor]},
    )
465

466

K
kuizhiqing 已提交
467
def _sync_comm_stream(tensor, ring_id=0):
468

J
Jiabin Yang 已提交
469
    if _non_static_mode():
470 471 472
        return _legacy_C_ops.c_sync_comm_stream(
            [tensor], [tensor], 'ring_id', ring_id
        )
473

K
kuizhiqing 已提交
474
    op_type = 'c_sync_comm_stream'
475

K
kuizhiqing 已提交
476 477 478 479 480
    helper = LayerHelper(op_type, **locals())
    helper.append_op(
        type=op_type,
        inputs={'X': [tensor]},
        outputs={'Out': [tensor]},
481 482
        attrs={'ring_id': ring_id},
    )
K
kuizhiqing 已提交
483 484


485
def all_gather(tensor_list, tensor, group=None, sync_op=True):
486 487
    """

488
    Gather tensors from all participators and all get the result. As shown
489 490
    below, one process is started with a GPU and the data of this process is represented
    by its group rank. Through the all_gather operator, each GPU will have data
491 492 493 494 495 496
    from all GPUs.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png
        :width: 800
        :alt: all_gather
        :align: center
497 498 499

    Args:
        tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
500
            should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128.
501
        tensor (Tensor): The Tensor to send. Its data type
502
            should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
503 504
        group (Group, optional): The group instance return by new_group or None for global default group.
        sync_op (bool, optional): Whether this op is a sync op. The default value is True.
505 506 507 508 509 510 511

    Returns:
        None.

    Examples:
        .. code-block:: python

512
            # required: distributed
513
            import paddle
514
            import paddle.distributed as dist
515

516
            dist.init_parallel_env()
517
            tensor_list = []
518 519
            if dist.get_rank() == 0:
                data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
520
            else:
521 522 523 524
                data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
            dist.all_gather(tensor_list, data)
            print(tensor_list)
            # [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
525
    """
K
kuizhiqing 已提交
526 527 528
    if group is not None and not group.is_member():
        return

529 530 531 532 533 534
    def convert_to_complex(list_of_tensor):
        list_of_complex = []
        for tensor in list_of_tensor:
            list_of_complex.append(paddle.as_complex(tensor))
        return list_of_complex

535 536 537
    is_input_complex = (
        tensor.dtype == paddle.complex64 or tensor.dtype == paddle.complex128
    )
538 539 540
    if is_input_complex:
        tensor = paddle.as_real(tensor)

L
lilong12 已提交
541
    if in_dygraph_mode():
542
        group = _get_default_group() if group is None else group
543 544 545 546 547 548
        if len(tensor_list) == 0:
            tensor_shape = list(tensor.shape)
            tensor_shape[0] *= group.nranks
            out = paddle.empty(tensor_shape, tensor.dtype)
        else:
            out = paddle.concat(tensor_list, axis=0)
L
LiYuRio 已提交
549
        task = group.process_group.all_gather_into_tensor(out, tensor, sync_op)
550 551
        task.wait()
        tensor_list.clear()
552 553 554 555 556
        list_of_tensor = paddle.split(out, group.nranks, 0)
        if is_input_complex:
            tensor_list.extend(convert_to_complex(list_of_tensor))
        else:
            tensor_list.extend(list_of_tensor)
557 558
        return

559
    use_calc_stream = sync_op
K
kuizhiqing 已提交
560 561 562
    ring_id = 0 if group is None else group.id
    nranks = _get_global_group().nranks if group is None else group.nranks

J
Jiabin Yang 已提交
563
    if _non_static_mode():
564 565 566 567 568 569 570 571 572
        out = _legacy_C_ops.c_allgather(
            tensor,
            'use_calc_stream',
            use_calc_stream,
            'ring_id',
            ring_id,
            'nranks',
            nranks,
        )
573
    else:
574 575 576
        op_type = 'c_allgather'
        helper = LayerHelper(op_type, **locals())
        out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
577
        if not isinstance(tensor_list, list):
578 579 580
            raise ValueError(
                "The type of 'tensor_list' for all_gather " "should be list."
            )
581
        for elem in tensor_list:
582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625
            check_variable_and_dtype(
                elem,
                'tensor_list',
                [
                    'float16',
                    'float32',
                    'float64',
                    'int32',
                    'int64',
                    'bool',
                    'int8',
                    'uint8',
                    'complex64',
                    'complex128',
                ],
                'all_gather',
            )
        check_variable_and_dtype(
            tensor,
            'tensor',
            [
                'float16',
                'float32',
                'float64',
                'int32',
                'int64',
                'bool',
                'int8',
                'uint8',
                'complex64',
                'complex128',
            ],
            'all_gather',
        )
        helper.append_op(
            type=op_type,
            inputs={'X': [tensor]},
            outputs={'Out': [out]},
            attrs={
                'ring_id': ring_id,
                'use_calc_stream': use_calc_stream,
                'nranks': nranks,
            },
        )
626

627 628 629 630 631 632 633 634 635 636 637 638 639
    list_of_tensor = paddle.split(out, nranks, 0)
    if is_input_complex:
        tensor_list.extend(convert_to_complex(list_of_tensor))
    else:
        tensor_list.extend(list_of_tensor)


def _convert_object_to_tensor(obj):
    _pickler = pickle.Pickler
    f = io.BytesIO()
    _pickler(f).dump(obj)
    data = np.frombuffer(f.getvalue(), dtype=np.uint8)
    tensor = paddle.to_tensor(data)
640
    return tensor, tensor.numel()
641 642


643
def _convert_tensor_to_object(tensor, len_of_tensor):
644
    _unpickler = pickle.Unpickler
645
    return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()
646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672


def all_gather_object(object_list, obj, group=None):
    """

    Gather picklable objects from all participators and all get the result. Similiar to all_gather(), but python object can be passed in.

    Args:
        object_list (list): A list of output object. The datatype of every element in the list is same as the input obj.
        obj (Any): The picklable object to send.
        group (Group): The group instance return by new_group or None for global default group.

    Returns:
        None.

    Warning:
        This API only supports the dygraph mode.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed as dist

            dist.init_parallel_env()
            object_list = []
673
            if dist.get_rank() == 0:
674 675 676
                obj = {"foo": [1, 2, 3]}
            else:
                obj = {"bar": [4, 5, 6]}
677 678 679
            dist.all_gather_object(object_list, obj)
            print(object_list)
            # [{'foo': [1, 2, 3]}, {'bar': [4, 5, 6]}] (2 GPUs)
680
    """
681 682
    assert (
        in_dygraph_mode()
683 684
    ), "all_gather_object doesn't support static graph mode."

685 686 687 688 689 690 691 692 693 694 695 696 697
    tensor, len_of_tensor = _convert_object_to_tensor(obj)

    # gather len_of_tensor from all ranks
    list_len_of_tensor = []
    all_gather(list_len_of_tensor, len_of_tensor, group)
    # get the max length from list
    max_len_of_tensor = int(max(list_len_of_tensor).item())
    # resize the input tensor to max length avoid hang in all gather
    # Note(liyurui): Maybe we should support various length all_gather?
    # Now this operation is efficient for we don't support resize in python.
    numpy_data = tensor.numpy()
    numpy_data = np.resize(numpy_data, [max_len_of_tensor])
    input_tensor = paddle.to_tensor(numpy_data)
698 699

    tensor_list = []
700 701 702
    all_gather(tensor_list, input_tensor, group)
    for i, tensor in enumerate(tensor_list):
        object_list.append(
703 704
            _convert_tensor_to_object(tensor, list_len_of_tensor[i])
        )