collective.py 10.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import datetime
16

17
import paddle
18 19

# (TODO: GhostScreaming) It will be removed later.
20
from paddle.fluid import core
21
from paddle.framework import in_dynamic_mode
22 23

from .communication.group import Group, _add_new_group, is_initialized
24
from .fleet.layers.mpu.mp_ops import _c_concat  # noqa: F401
25
from .fleet.layers.mpu.mp_ops import _c_identity  # noqa: F401
26 27
from .fleet.layers.mpu.mp_ops import _c_lookup_table  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_softmax_with_cross_entropy  # noqa: F401
28 29
from .fleet.layers.mpu.mp_ops import _c_split  # noqa: F401
from .fleet.layers.mpu.mp_ops import _Linear  # noqa: F401
30
from .fleet.layers.mpu.mp_ops import _linear  # noqa: F401
31
from .fleet.layers.mpu.mp_ops import _mp_allreduce  # noqa: F401
32
from .fleet.layers.mpu.mp_ops import _parallel_embedding  # noqa: F401
33 34 35
from .fleet.layers.mpu.mp_ops import _parallel_linear  # noqa: F401
from .fleet.layers.mpu.mp_ops import _set_var_distributed  # noqa: F401
from .fleet.layers.mpu.mp_ops import split  # noqa: F401
36

37
__all__ = []
38

K
kuizhiqing 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51
_global_env = None


def _get_global_env():
    global _global_env
    if not _global_env:
        _global_env = paddle.distributed.ParallelEnv()
    return _global_env


# group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group]
_group_map = {}
52
_global_env_gid = 0
K
kuizhiqing 已提交
53

54 55 56 57
# group map by name : the map of all groups from their names
# Dict[name, Group]
_group_map_by_name = {}

58 59 60 61
# backend map by group : the map of all backend from their groups
# Dict[group, backend]
_group_map_backend = {}

62 63 64
# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"

张春乔 已提交
65
_valid_backend_list = ['nccl', 'gloo', 'heter', 'xccl', 'bkcl']
66 67
_default_store = None  # the default tcp store
_default_backend = None
68 69
_default_timeout = datetime.timedelta(seconds=1800)
_start_ring_id = 0
70

K
kuizhiqing 已提交
71

L
lilong12 已提交
72 73 74 75 76 77 78 79 80 81
def _set_default_backend(backend):
    global _default_backend
    _default_backend = backend


def _set_default_store(store):
    global _default_store
    _default_store = store


K
kuizhiqing 已提交
82 83
def _get_group_map():
    global _group_map
84
    if _global_env_gid not in _group_map:
K
kuizhiqing 已提交
85
        genv = _get_global_env()
86 87 88
        _group_map[_global_env_gid] = Group(
            genv.rank, 0, list(range(genv.world_size))
        )
K
kuizhiqing 已提交
89 90 91 92
    return _group_map


def _get_global_group():
93
    return _get_group_map()[_global_env_gid]
K
kuizhiqing 已提交
94 95


96 97 98 99 100 101
def _get_group_map_by_name():
    global _group_map_by_name
    return _group_map_by_name


def _get_default_group():
L
lilong12 已提交
102
    global _group_map_by_name
103 104 105 106
    assert is_initialized(), (
        "Call paddle.distributed.init_parallel_env first "
        "to initialize the distributed environment."
    )
107 108 109
    return _get_group_map_by_name()[_default_group_name]


L
lilong12 已提交
110 111 112 113 114 115 116 117 118 119 120 121
def _set_group_map(gid, group):
    global _group_map
    assert gid not in _group_map
    _group_map[gid] = group


def _set_group_map_by_name(name, group):
    global _group_map_by_name
    assert name not in _group_map_by_name
    _group_map_by_name[name] = group


122 123 124 125 126 127
def _set_group_map_backend(group, backend):
    global _group_map_backend
    assert group not in _group_map_backend
    _group_map_backend[group] = backend


K
kuizhiqing 已提交
128
def _new_ring_id():
129
    # NOTE(liyurui): For compatible reason, auto parallel and eager mode relay on previous syntax.
130
    if in_dynamic_mode():
131 132 133 134 135
        global _start_ring_id
        _start_ring_id += 1
        return _start_ring_id + max(_get_global_env().nrings, 9)
    else:
        return len(_get_group_map()) + max(_get_global_env().nrings, 9)
K
kuizhiqing 已提交
136 137


138 139 140 141 142 143 144 145 146
def _new_process_group_impl(
    backend,
    store,
    rank,
    world_size,
    group_name,
    pg_options,
    group_id=0,
):
147
    pg = None
148
    genv = _get_global_env()
L
lilong12 已提交
149
    assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
150
    if backend == "gloo":
L
LiYuRio 已提交
151
        pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id)
152
    elif backend == "nccl":
153
        pg = core.ProcessGroupNCCL.create(store, rank, world_size, group_id)
154

155
    elif backend == "xccl":
L
LiYuRio 已提交
156
        pg = core.ProcessGroupCustom.create(
157 158
            store, genv.device_type, rank, world_size, group_id
        )
J
james 已提交
159
    elif backend == "bkcl":
L
LiYuRio 已提交
160
        pg = core.ProcessGroupBKCL.create(store, rank, world_size, group_id)
161 162 163
    return pg


L
lilong12 已提交
164 165
# _custom_gid provides a way for users to
# set the group id, which is usually useful
166
# to be compatible with the static graph mode.
L
lilong12 已提交
167 168 169 170
_custom_gid = None


def _set_custom_gid(gid):
171
    global _custom_gid
L
lilong12 已提交
172 173 174
    _custom_gid = gid


175
def new_group(ranks=None, backend=None, timeout=_default_timeout):
K
kuizhiqing 已提交
176 177
    """

K
kuizhiqing 已提交
178
    Creates a new distributed communication group.
K
kuizhiqing 已提交
179 180

    Args:
K
kuizhiqing 已提交
181
        ranks (list): The global ranks of group members.
K
kuizhiqing 已提交
182
        backend (str): The backend used to create group, only nccl is supported now.
183
        timeout (datetime.timedelta, optional): The waiting timeout for store relevant options, default is 30 minutes.
K
kuizhiqing 已提交
184 185

    Returns:
K
kuizhiqing 已提交
186
        Group: The group instance.
K
kuizhiqing 已提交
187 188 189 190

    Examples:
        .. code-block:: python

191 192
            >>> # doctest: +REQUIRES(env: DISTRIBUTED)
            >>> import paddle
K
kuizhiqing 已提交
193

194 195 196 197
            >>> paddle.distributed.init_parallel_env()
            >>> tindata = paddle.randn(shape=[2, 3])
            >>> gp = paddle.distributed.new_group([2, 4, 6])
            >>> paddle.distributed.all_reduce(tindata, group=gp, sync_op=False)
K
kuizhiqing 已提交
198 199

    """
200
    global _custom_gid
201
    global _group_map
202
    if in_dynamic_mode():
203
        global _default_group_name
L
lilong12 已提交
204
        gid = _custom_gid if _custom_gid else _new_ring_id()
205
        group_name = _default_group_name + str(gid)
L
lilong12 已提交
206
        if backend != 'heter' and (ranks is None or len(ranks) > 1):
207 208 209 210 211 212 213 214
            global_group = _get_default_group()
            global_rank = global_group.rank
            global_ranks = global_group.ranks
            backend = _default_backend if backend is None else backend
            if ranks is None:
                ranks = global_ranks
            assert len(ranks) <= len(global_ranks), (
                "Size of new group must be less than or "
215 216
                "equal to that of the default global group."
            )
217 218
        size = len(ranks)
        ranks = sorted(ranks)
L
LiYuRio 已提交
219
        if size > 1 and global_rank in ranks:
L
lilong12 已提交
220
            rank = 0 if backend == 'heter' else ranks.index(global_rank)
221 222 223 224 225 226 227 228 229
            pg = _new_process_group_impl(
                backend,
                _default_store,
                rank,
                size,
                group_name,
                pg_options=None,
                group_id=gid,
            )
230 231 232
        else:
            rank = -1
            pg = None
233
        group = Group(rank, gid, ranks, pg=pg, name=group_name)
234 235
        _group_map_by_name[group_name] = group
        _group_map[gid] = group
236
        _group_map_backend[group] = backend
237
        # TODO: The method below is a new method for group management, will replace the previous
238 239
        # three in the future.
        _add_new_group(group)
240

241
        # TODO(shenliang03): This is a temporary solution to solve the problem of
242
        # hang caused by tcp
243
        paddle.distributed.barrier(group=group)
L
LiYuRio 已提交
244 245
        if paddle.distributed.get_world_size() > 1:
            paddle.distributed.barrier()
246
        return group
K
kuizhiqing 已提交
247 248 249

    if not backend:
        backend = 'nccl'
250
    assert backend == 'nccl', "backend other than nccl is not supported yet"
K
kuizhiqing 已提交
251 252 253 254 255 256 257

    genv = _get_global_env()
    global_rank = genv.rank

    ring_id = _new_ring_id()

    if global_rank not in ranks:
258
        gp = Group(-1, ring_id, ranks)
K
kuizhiqing 已提交
259 260
        _group_map[ring_id] = gp
    else:
261 262 263
        ranks = sorted(ranks)
        group_rank = ranks.index(global_rank)
        group_size = len(ranks)
264
        gp = Group(group_rank, ring_id, ranks)
265 266 267 268 269 270 271 272 273 274 275 276 277 278
        _group_map[ring_id] = gp

        if group_size >= 2:
            strategy = core.ParallelStrategy()
            strategy.nranks = group_size
            strategy.local_rank = group_rank
            strategy.trainer_endpoints = [
                genv.trainer_endpoints[i] for i in ranks
            ]
            strategy.current_endpoint = genv.current_endpoint
            strategy.nrings = 1

            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(genv.device_id)
279 280 281
                core.NCCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
282 283
            elif core.is_compiled_with_xpu():
                place = core.XPUPlace(genv.device_id)
284 285 286
                core.BKCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
287
            else:
288
                raise AssertionError("no cuda device found")
289 290 291
        else:
            return gp

292
    # TODO(shenliang03): This is a temporary solution to solve the problem of
293
    # hang caused by cross-creation of new_group
294 295
    tmp = (
        paddle.to_tensor([1], dtype="int32")
296
        if in_dynamic_mode()
297
        else paddle.full([0], 1, dtype="int32")
298
    )
299
    paddle.distributed.all_reduce(tmp, sync_op=True)
300
    paddle.distributed.wait(tmp)
K
kuizhiqing 已提交
301
    return gp
302 303 304 305 306 307 308 309 310 311 312 313


def is_available():
    """
    Check whether the distributed package is available.

    Returns:
        Returns True if the distributed package is available, otherwise False.

    Examples:
        .. code-block:: python

314 315
            >>> import paddle
            >>> print(paddle.distributed.is_available())
316 317 318

    """
    return core.is_compiled_with_dist()
319 320 321


def _init_parallel_env(backend):
322 323 324 325 326 327 328 329 330 331 332 333 334 335
    store = core.create_or_get_global_tcp_store()
    global_env = _get_global_env()
    rank = global_env.rank
    world_size = global_env.world_size
    dev_id = global_env.device_id

    if backend == "gloo":
        core.CommContextManager.create_gloo_comm_context(
            store, "0", rank, world_size
        )
    elif backend == "nccl":
        core.CommContextManager.set_cuda_device_id(dev_id)
        core.CommContextManager.create_nccl_comm_context(
            store, "0", rank, world_size
336
        )