topology.py 13.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
from functools import reduce
17 18 19
from itertools import product

import paddle
20
from paddle.distributed.utils.nccl_utils import check_nccl_version_for_p2p
21

22 23
from ..utils.log_util import logger

24 25
__all__ = ['CommunicateTopology', 'HybridCommunicateGroup']

26 27
_HYBRID_PARALLEL_GROUP = None

28

29
class ParallelMode:
Y
Yanxing Shi 已提交
30
    """
31

Y
Yanxing Shi 已提交
32
    There are all the parallel modes currently supported:
33 34 35 36 37

        - DATA_PARALLEL: Distribute input data to different devices.
        - TENSOR_PARALLEL: Shards tensors in the network to different devices.
        - PIPELINE_PARALLEL: Place different layers of the network on different devices.
        - SHARDING_PARALLEL: Segment the model parameters, parameter gradients and optimizer states corresponding to the parameters to each device.
Y
Yanxing Shi 已提交
38 39 40 41 42 43 44 45 46

    Examples:
        .. code-block:: python

            import paddle
            parallel_mode = paddle.distributed.ParallelMode
            print(parallel_mode.DATA_PARALLEL)  # 0

    """
47

48
    DATA_PARALLEL = 0
49
    TENSOR_PARALLEL = 1
50
    PIPELINE_PARALLEL = 2
J
JZ-LIANG 已提交
51
    SHARDING_PARALLEL = 3
52 53


54
class CommunicateTopology:
55 56 57 58 59
    def __init__(
        self,
        hybrid_group_names=["data", "pipe", "sharding", "model"],
        dims=[1, 1, 1, 1],
    ):
60 61
        self._parallel_names = hybrid_group_names
        self._dims = dims
62 63 64
        self.coordinate = collections.namedtuple(
            'Coordinate', self._parallel_names
        )
65 66 67 68 69 70 71
        self._world_size = reduce(lambda x, y: x * y, self._dims)

        ranges = [range(d) for d in self._dims]
        all_coordinate = [self.coordinate(*x) for x in product(*ranges)]

        self._coord2rank = dict(zip(all_coordinate, range(len(all_coordinate))))
        self._rank2coord = dict(
72 73
            zip(self._coord2rank.values(), self._coord2rank.keys())
        )
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97

    def get_hybrid_group_names(self):
        return self._parallel_names

    def get_dim(self, axis_name):
        return self._dims[self._parallel_names.index(axis_name)]

    def world_size(self):
        return self._world_size

    def get_rank(self, **args):
        assert len(args) == len(self._dims)
        key = self.coordinate(**args)
        assert key in self._coord2rank.keys()
        return self._coord2rank[key]

    def get_coord(self, rank):
        assert rank < self._world_size
        assert rank in self._rank2coord.keys()
        return self._rank2coord[rank]

    def get_axis_list(self, axis_name, index):
        axis = self._parallel_names.index(axis_name)
        ranks = [
98 99
            self._coord2rank[coord]
            for coord in self._coord2rank.keys()
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
            if coord[axis] == index
        ]
        ranks.sort()
        return ranks

    def get_dim_size(self, axis_name):
        assert axis_name in self._parallel_names
        return self._dims[self._parallel_names.index(axis_name)]

    def get_comm_list(self, axis_name):
        assert axis_name in self._parallel_names
        other_axis_names = [
            name for name in self._parallel_names if name != axis_name
        ]

        ranges = []
        for name in other_axis_names:
            dim_num = self.get_dim_size(name)
            ranges.append(range(dim_num))

        all_result = []
        for x in product(*ranges):
            key_coord = {}
            for other_name in other_axis_names:
                key_coord[other_name] = x[other_axis_names.index(other_name)]

            result = []
            for i in range(0, self.get_dim_size(axis_name)):
                key_coord[axis_name] = i
                result.append(self._coord2rank[self.coordinate(**key_coord)])
            all_result.append(result)

        return all_result

134 135 136 137 138
    def get_rank_from_stage(self, global_rank, **kwargs):
        coord = self.get_coord(global_rank)
        tf = coord._replace(**kwargs)._asdict()
        return self.get_rank(**tf)

139

140
class HybridCommunicateGroup:
141 142 143 144 145
    def __init__(self, topology):
        self.nranks = paddle.distributed.get_world_size()
        self.global_rank = paddle.distributed.get_rank()
        self._topo = topology

146 147 148
        self._dp_degree = self._topo.get_dim('data')
        self._mp_degree = self._topo.get_dim('model')
        self._pp_degree = self._topo.get_dim('pipe')
J
JZ-LIANG 已提交
149
        self._sharding_degree = self._topo.get_dim('sharding')
150 151 152

        self._data_parallel_id = self._get_data_parallel_id()
        self._model_parallel_id = self._get_model_parallel_id()
J
JZ-LIANG 已提交
153
        self._sharding_parallel_id = self._get_sharding_parallel_id()
154
        self.stage_id = self._get_pipe_parallel_id()
155

156 157 158 159 160 161 162 163 164 165
        assert self._check_vaild_topo(), (
            "Here is an unreasonable topogy setting. world_size: {}, but"
            "mp_num: {}, sharding_num: {}, pp_num: {}, dp_num: {}".format(
                self.nranks,
                self._mp_degree,
                self._sharding_degree,
                self._pp_degree,
                self._dp_degree,
            )
        )
166 167 168 169 170 171

        # create comm group for data parallel
        self._dp_group, self._dp_comm_group = self._set_comm_group("data")

        # create comm group for model parallel
        self._mp_group, self._mp_comm_group = self._set_comm_group("model")
172

173 174 175
        # create comm group for pipe parallel
        self._pp_group, self._pp_comm_group = self._set_comm_group("pipe")

J
JZ-LIANG 已提交
176 177
        # create comm group for sharding parallel
        self._sharding_group, self._sharding_comm_group = self._set_comm_group(
178 179
            "sharding"
        )
J
JZ-LIANG 已提交
180

181 182
        # create global group for check inf_nan / clip global norm
        self._check_group, self._check_comm_group = self._set_check_group(
183 184
            "data"
        )
185

P
pangengzheng 已提交
186 187 188 189 190
        (
            self.sharding_check_group,
            self.sharding_check_comm_group,
        ) = self._set_check_group("sharding")

191
        # create p2p group
192 193
        self.is_first_stage = self.stage_id == 0
        self.is_last_stage = self.stage_id == (self._pp_degree - 1)
194

195 196
        # create p2p_groups
        if self._pp_degree > 1:
197
            check_nccl_version_for_p2p()
198 199
            self._set_p2p_group()

200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
        debug_str = (
            "HybridParallelInfo: rank_id: %d, mp_degree: %d, "
            "sharding_degree: %d, pp_degree: %d, dp_degree: %d"
            % (
                self.global_rank,
                self._mp_degree,
                self._sharding_degree,
                self._pp_degree,
                self._dp_degree,
            )
        )
        debug_str += (
            ", mp_group: %s,  sharding_group: %s, pp_group: %s, dp_group: %s, check/clip group: %s"
            % (
                self._mp_group,
                self._sharding_group,
                self._pp_group,
                self._dp_group,
                self._check_group,
            )
        )
221
        logger.info(debug_str)
222 223 224

        global _HYBRID_PARALLEL_GROUP
        _HYBRID_PARALLEL_GROUP = self
225

226
    def get_parallel_mode(self):
J
JZ-LIANG 已提交
227
        # there are four modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel
228
        # NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and
J
JZ-LIANG 已提交
229 230 231
        # adding its parallel logic within that parallelism
        # when use sharding alone, it should have its own parallelism for its parallel logic
        # TODO modify 3 others parallel to support sharding
232 233 234 235 236 237
        if (
            self._mp_degree == 1
            and self._pp_degree == 1
            and self._dp_degree == 1
            and self._sharding_degree > 1
        ):
J
JZ-LIANG 已提交
238 239
            return ParallelMode.SHARDING_PARALLEL
        elif self._mp_degree == 1 and self._pp_degree == 1:
240 241 242
            return ParallelMode.DATA_PARALLEL
        elif self._mp_degree > 1 and self._pp_degree == 1:
            # initialize the seed
243
            return ParallelMode.TENSOR_PARALLEL
244 245 246
        elif self._pp_degree > 1:
            return ParallelMode.PIPELINE_PARALLEL

247
    def _check_vaild_topo(self):
248 249 250 251 252 253 254
        return (
            self._dp_degree
            * self._mp_degree
            * self._pp_degree
            * self._sharding_degree
            == self.nranks
        )
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269

    def _set_comm_group(self, parallel_method="data"):
        parallel_group = []
        parallel_comm_group = None
        parallel_groups = self._topo.get_comm_list(parallel_method)

        for group in parallel_groups:
            comm_group = paddle.distributed.new_group(ranks=group)
            if self.global_rank in group:
                parallel_group = group
                parallel_comm_group = comm_group

        assert len(parallel_group) > 0
        assert parallel_comm_group is not None

270 271 272 273 274
        logger.info(
            "Total {} {} comm group(s) create successfully!".format(
                len(parallel_groups), parallel_method
            )
        )
275 276
        return parallel_group, parallel_comm_group

277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
    def _set_check_group(self, parallel_method="data"):
        parallel_group = []
        parallel_comm_group = None
        parallel_size = self._topo.get_dim(parallel_method)
        for idx in range(parallel_size):
            parallel_groups = self._topo.get_axis_list(parallel_method, idx)
            comm_group = paddle.distributed.new_group(ranks=parallel_groups)
            if self.global_rank in parallel_groups:
                parallel_group = parallel_groups
                parallel_comm_group = comm_group

        assert len(parallel_group) > 0
        assert parallel_comm_group is not None

        return parallel_group, parallel_comm_group

293 294 295 296 297 298 299 300
    def _get_p2p_next_rank(self):
        assert hasattr(self, 'next_rank'), "next_rank has not been inited"
        return self.next_rank

    def _get_p2p_prev_rank(self):
        assert hasattr(self, 'prev_rank'), "prev_rank has not been inited"
        return self.prev_rank

301 302 303 304 305 306 307 308 309 310
    def _set_p2p_group(self):
        comm_lists = self._topo.get_comm_list('pipe')

        for comm_ranks in comm_lists:
            assert len(comm_ranks) == self._pp_degree
            for idx, rank in enumerate(comm_ranks):
                curr_rank = rank
                next_rank = comm_ranks[(idx + 1) % self._pp_degree]
                prev_rank = comm_ranks[(idx - 1) % self._pp_degree]

311 312 313 314
                if self.global_rank == curr_rank:
                    self.next_rank = next_rank
                    self.prev_rank = prev_rank

315 316 317 318 319 320 321 322 323 324 325 326 327 328
    def topology(self):
        return self._topo

    def get_global_rank(self):
        return self.global_rank

    # data parallel message:
    def _get_data_parallel_id(self):
        return self._topo.get_coord(self.global_rank).data

    def get_data_parallel_rank(self):
        return self._data_parallel_id

    def get_data_parallel_world_size(self):
329
        return self._dp_degree
330 331 332 333 334 335 336 337 338 339 340 341 342 343 344

    def get_data_parallel_group(self):
        return self._dp_comm_group

    def get_data_parallel_group_src_rank(self):
        return self._dp_comm_group.ranks[0]

    # model parallel message:
    def _get_model_parallel_id(self):
        return self._topo.get_coord(self.global_rank).model

    def get_model_parallel_rank(self):
        return self._model_parallel_id

    def get_model_parallel_world_size(self):
345
        return self._mp_degree
346 347 348 349 350 351

    def get_model_parallel_group(self):
        return self._mp_comm_group

    def get_model_parallel_group_src_rank(self):
        return self._mp_comm_group.ranks[0]
352

353 354 355 356 357 358 359 360 361 362 363 364 365
    # pipeline parallel message
    def _get_pipe_parallel_id(self):
        return self._topo.get_coord(self.global_rank).pipe

    def get_stage_id(self):
        return self.stage_id

    def get_pipe_parallel_world_size(self):
        return self._pp_degree

    def get_pipe_parallel_group(self):
        return self._pp_comm_group

J
JZ-LIANG 已提交
366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382
    # sharding parallel message:
    def _get_sharding_parallel_id(self):
        return self._topo.get_coord(self.global_rank).sharding

    def get_sharding_parallel_rank(self):
        return self._sharding_parallel_id

    def get_sharding_parallel_world_size(self):
        return self._sharding_degree

    def get_sharding_parallel_group(self):
        return self._sharding_comm_group

    def get_sharding_parallel_group_src_rank(self):
        # TODO should the src rank related to the shard rank for each parameter ?
        return self._sharding_comm_group.ranks[0]

383
    # check parallel group
P
pangengzheng 已提交
384 385 386 387 388
    def get_check_parallel_group(self, sharding=False):
        if sharding:
            return self.sharding_check_comm_group
        else:
            return self._check_comm_group
389

390
    def get_rank_from_stage(self, stage_id, **kwargs):
391 392 393
        return self._topo.get_rank_from_stage(
            self.global_rank, pipe=stage_id, **kwargs
        )
W
WangXi 已提交
394 395


396
class _CommunicateGroup:
397
    """tmp for static"""
W
WangXi 已提交
398 399 400 401

    def __init__(self):
        global _HYBRID_PARALLEL_GROUP
        _HYBRID_PARALLEL_GROUP = self
402
        self.groups = {}
W
WangXi 已提交
403

404 405 406 407 408 409
    def set_comm_group(
        self, group_name, group_rank, group_size, ring_id, group_ranks
    ):
        group = paddle.distributed.collective.Group(
            group_rank, ring_id, group_ranks
        )
W
WangXi 已提交
410 411 412 413 414 415 416 417 418 419 420 421 422 423
        self.groups[group_name] = group

    def get_group(self, group_name):
        assert group_name in self.groups
        return self.groups[group_name]

    def get_model_parallel_group(self):
        return self.get_group('model')

    def get_model_parallel_world_size(self):
        return self.get_group('model').nranks

    def get_model_parallel_rank(self):
        return self.get_group('model').rank