collective.py 10.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import datetime
16
import os
17

18
import paddle
19 20

# (TODO: GhostScreaming) It will be removed later.
21
from paddle.fluid import core
22
from paddle.framework import in_dynamic_mode
23 24

from .communication.group import Group, _add_new_group, is_initialized
25
from .fleet.layers.mpu.mp_ops import _c_concat  # noqa: F401
26
from .fleet.layers.mpu.mp_ops import _c_identity  # noqa: F401
27 28
from .fleet.layers.mpu.mp_ops import _c_lookup_table  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_softmax_with_cross_entropy  # noqa: F401
29 30
from .fleet.layers.mpu.mp_ops import _c_split  # noqa: F401
from .fleet.layers.mpu.mp_ops import _Linear  # noqa: F401
31
from .fleet.layers.mpu.mp_ops import _linear  # noqa: F401
32
from .fleet.layers.mpu.mp_ops import _mp_allreduce  # noqa: F401
33
from .fleet.layers.mpu.mp_ops import _parallel_embedding  # noqa: F401
34 35 36
from .fleet.layers.mpu.mp_ops import _parallel_linear  # noqa: F401
from .fleet.layers.mpu.mp_ops import _set_var_distributed  # noqa: F401
from .fleet.layers.mpu.mp_ops import split  # noqa: F401
37

38
__all__ = []
39

K
kuizhiqing 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52
_global_env = None


def _get_global_env():
    global _global_env
    if not _global_env:
        _global_env = paddle.distributed.ParallelEnv()
    return _global_env


# group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group]
_group_map = {}
53
_global_env_gid = 0
K
kuizhiqing 已提交
54

55 56 57 58
# group map by name : the map of all groups from their names
# Dict[name, Group]
_group_map_by_name = {}

59 60 61 62
# backend map by group : the map of all backend from their groups
# Dict[group, backend]
_group_map_backend = {}

63 64 65
# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"

张春乔 已提交
66
_valid_backend_list = ['nccl', 'gloo', 'heter', 'xccl', 'bkcl']
67 68
_default_store = None  # the default tcp store
_default_backend = None
69 70
_default_timeout = datetime.timedelta(seconds=1800)
_start_ring_id = 0
71

K
kuizhiqing 已提交
72

L
lilong12 已提交
73 74 75 76 77 78 79 80 81 82
def _set_default_backend(backend):
    global _default_backend
    _default_backend = backend


def _set_default_store(store):
    global _default_store
    _default_store = store


K
kuizhiqing 已提交
83 84
def _get_group_map():
    global _group_map
85
    if _global_env_gid not in _group_map:
K
kuizhiqing 已提交
86
        genv = _get_global_env()
87 88 89
        _group_map[_global_env_gid] = Group(
            genv.rank, 0, list(range(genv.world_size))
        )
K
kuizhiqing 已提交
90 91 92 93
    return _group_map


def _get_global_group():
94
    return _get_group_map()[_global_env_gid]
K
kuizhiqing 已提交
95 96


97 98 99 100 101 102
def _get_group_map_by_name():
    global _group_map_by_name
    return _group_map_by_name


def _get_default_group():
L
lilong12 已提交
103
    global _group_map_by_name
104 105 106 107
    assert is_initialized(), (
        "Call paddle.distributed.init_parallel_env first "
        "to initialize the distributed environment."
    )
108 109 110
    return _get_group_map_by_name()[_default_group_name]


L
lilong12 已提交
111 112 113 114 115 116 117 118 119 120 121 122
def _set_group_map(gid, group):
    global _group_map
    assert gid not in _group_map
    _group_map[gid] = group


def _set_group_map_by_name(name, group):
    global _group_map_by_name
    assert name not in _group_map_by_name
    _group_map_by_name[name] = group


123 124 125 126 127 128
def _set_group_map_backend(group, backend):
    global _group_map_backend
    assert group not in _group_map_backend
    _group_map_backend[group] = backend


K
kuizhiqing 已提交
129
def _new_ring_id():
130
    # NOTE(liyurui): For compatible reason, auto parallel and eager mode relay on previous syntax.
131
    if in_dynamic_mode():
132 133 134 135 136
        global _start_ring_id
        _start_ring_id += 1
        return _start_ring_id + max(_get_global_env().nrings, 9)
    else:
        return len(_get_group_map()) + max(_get_global_env().nrings, 9)
K
kuizhiqing 已提交
137 138


139 140 141 142 143 144 145 146 147
def _new_process_group_impl(
    backend,
    store,
    rank,
    world_size,
    group_name,
    pg_options,
    group_id=0,
):
148
    pg = None
149
    genv = _get_global_env()
L
lilong12 已提交
150
    assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
151
    if backend == "gloo":
L
LiYuRio 已提交
152
        pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id)
153
    elif backend == "nccl":
154 155 156 157
        pg = core.ProcessGroupNCCL.create(
            store, genv.device_id, rank, world_size, group_id
        )

158
    elif backend == "xccl":
L
LiYuRio 已提交
159
        pg = core.ProcessGroupCustom.create(
160 161
            store, genv.device_type, rank, world_size, group_id
        )
J
james 已提交
162
    elif backend == "bkcl":
L
LiYuRio 已提交
163
        pg = core.ProcessGroupBKCL.create(store, rank, world_size, group_id)
164 165 166
    return pg


L
lilong12 已提交
167 168
# _custom_gid provides a way for users to
# set the group id, which is usually useful
169
# to be compatible with the static graph mode.
L
lilong12 已提交
170 171 172 173
_custom_gid = None


def _set_custom_gid(gid):
174
    global _custom_gid
L
lilong12 已提交
175 176 177
    _custom_gid = gid


178
def new_group(ranks=None, backend=None, timeout=_default_timeout):
K
kuizhiqing 已提交
179 180
    """

K
kuizhiqing 已提交
181
    Creates a new distributed communication group.
K
kuizhiqing 已提交
182 183

    Args:
K
kuizhiqing 已提交
184
        ranks (list): The global ranks of group members.
K
kuizhiqing 已提交
185
        backend (str): The backend used to create group, only nccl is supported now.
186
        timeout (datetime.timedelta, optional): The waiting timeout for store relevant options, default is 30 minutes.
K
kuizhiqing 已提交
187 188

    Returns:
K
kuizhiqing 已提交
189
        Group: The group instance.
K
kuizhiqing 已提交
190 191 192 193 194 195 196

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
197 198
            tindata = paddle.randn(shape=[2, 3])
            gp = paddle.distributed.new_group([2,4,6])
199
            paddle.distributed.all_reduce(tindata, group=gp, sync_op=False)
K
kuizhiqing 已提交
200 201

    """
202
    global _custom_gid
203
    global _group_map
204
    if in_dynamic_mode():
205
        global _default_group_name
L
lilong12 已提交
206
        gid = _custom_gid if _custom_gid else _new_ring_id()
207
        group_name = _default_group_name + str(gid)
L
lilong12 已提交
208
        if backend != 'heter' and (ranks is None or len(ranks) > 1):
209 210 211 212 213 214 215 216
            global_group = _get_default_group()
            global_rank = global_group.rank
            global_ranks = global_group.ranks
            backend = _default_backend if backend is None else backend
            if ranks is None:
                ranks = global_ranks
            assert len(ranks) <= len(global_ranks), (
                "Size of new group must be less than or "
217 218
                "equal to that of the default global group."
            )
219 220
        size = len(ranks)
        ranks = sorted(ranks)
L
LiYuRio 已提交
221
        if size > 1 and global_rank in ranks:
L
lilong12 已提交
222
            rank = 0 if backend == 'heter' else ranks.index(global_rank)
223 224 225 226 227 228 229 230 231
            pg = _new_process_group_impl(
                backend,
                _default_store,
                rank,
                size,
                group_name,
                pg_options=None,
                group_id=gid,
            )
232 233 234
        else:
            rank = -1
            pg = None
235
        group = Group(rank, gid, ranks, pg=pg, name=group_name)
236 237
        _group_map_by_name[group_name] = group
        _group_map[gid] = group
238
        _group_map_backend[group] = backend
239
        # TODO: The method below is a new method for group management, will replace the previous
240 241
        # three in the future.
        _add_new_group(group)
242

243
        # TODO(shenliang03): This is a temporary solution to solve the problem of
244
        # hang caused by tcp
245
        paddle.distributed.barrier(group=group)
L
LiYuRio 已提交
246 247
        if paddle.distributed.get_world_size() > 1:
            paddle.distributed.barrier()
248
        return group
K
kuizhiqing 已提交
249 250 251

    if not backend:
        backend = 'nccl'
252
    assert backend == 'nccl', "backend other than nccl is not supported yet"
K
kuizhiqing 已提交
253 254 255 256 257 258 259

    genv = _get_global_env()
    global_rank = genv.rank

    ring_id = _new_ring_id()

    if global_rank not in ranks:
260
        gp = Group(-1, ring_id, ranks)
K
kuizhiqing 已提交
261 262
        _group_map[ring_id] = gp
    else:
263 264 265
        ranks = sorted(ranks)
        group_rank = ranks.index(global_rank)
        group_size = len(ranks)
266
        gp = Group(group_rank, ring_id, ranks)
267 268 269 270 271 272 273 274 275 276 277 278 279 280
        _group_map[ring_id] = gp

        if group_size >= 2:
            strategy = core.ParallelStrategy()
            strategy.nranks = group_size
            strategy.local_rank = group_rank
            strategy.trainer_endpoints = [
                genv.trainer_endpoints[i] for i in ranks
            ]
            strategy.current_endpoint = genv.current_endpoint
            strategy.nrings = 1

            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(genv.device_id)
281 282 283
                core.NCCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
284 285
            elif core.is_compiled_with_xpu():
                place = core.XPUPlace(genv.device_id)
286 287 288
                core.BKCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
289
            else:
290
                raise AssertionError("no cuda device found")
291 292 293
        else:
            return gp

294
    # TODO(shenliang03): This is a temporary solution to solve the problem of
295
    # hang caused by cross-creation of new_group
296 297
    tmp = (
        paddle.to_tensor([1], dtype="int32")
298
        if in_dynamic_mode()
299
        else paddle.full([0], 1, dtype="int32")
300
    )
301
    paddle.distributed.all_reduce(tmp, sync_op=True)
302
    paddle.distributed.wait(tmp)
K
kuizhiqing 已提交
303
    return gp
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321


def is_available():
    """
    Check whether the distributed package is available.

    Returns:
        Returns True if the distributed package is available, otherwise False.

    Examples:
        .. code-block:: python

            import paddle

            print(paddle.distributed.is_available())

    """
    return core.is_compiled_with_dist()
322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339


def _init_parallel_env(backend):
    master_endpoint = os.getenv("PADDLE_MASTER", None)
    if master_endpoint:
        master_addr = master_endpoint.split(":")[0]
        master_port = int(master_endpoint.split(":")[1])
        global_env = _get_global_env()
        rank = global_env.rank
        world_size = global_env.world_size
        dev_id = global_env.device_id
        is_master = rank == 0
        store = core.TCPStore(
            master_addr,
            master_port,
            is_master,
            world_size,
        )
340 341 342 343 344
        if backend == "gloo":
            core.CommContextManager.create_gloo_comm_context(
                store, 0, rank, world_size
            )
        elif backend == "nccl":
345 346 347
            core.CommContextManager.create_nccl_comm_context(
                store, dev_id, 0, rank, world_size
            )