topology.py 14.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function
import sys
17 18 19 20 21
import paddle
import collections
import numpy as np
from itertools import product
from functools import reduce
22 23
from ..utils.log_util import logger

24 25
__all__ = ['CommunicateTopology', 'HybridCommunicateGroup']

26 27
_HYBRID_PARALLEL_GROUP = None

28

29
class ParallelMode(object):
Y
Yanxing Shi 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
    """
    There are all the parallel modes currently supported:
    - DATA_PARALLEL: Distribute input data to different devices.
    - TENSOR_PARALLEL: Shards tensors in the network to different devices.
    - PIPELINE_PARALLEL: Place different layers of the network on different devices.
    - SHARDING_PARALLEL: Segment the model parameters, parameter gradients and optimizer states 
                         corresponding to the parameters to each device.

    Examples:
        .. code-block:: python

            import paddle
            parallel_mode = paddle.distributed.ParallelMode
            print(parallel_mode.DATA_PARALLEL)  # 0

    """
46
    DATA_PARALLEL = 0
47
    TENSOR_PARALLEL = 1
48
    PIPELINE_PARALLEL = 2
J
JZ-LIANG 已提交
49
    SHARDING_PARALLEL = 3
50 51


52
class CommunicateTopology(object):
53

54
    def __init__(self,
J
JZ-LIANG 已提交
55 56
                 hybrid_group_names=["data", "pipe", "sharding", "model"],
                 dims=[1, 1, 1, 1]):
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
        self._parallel_names = hybrid_group_names
        self._dims = dims
        self.coordinate = collections.namedtuple('Coordinate',
                                                 self._parallel_names)
        self._world_size = reduce(lambda x, y: x * y, self._dims)

        ranges = [range(d) for d in self._dims]
        all_coordinate = [self.coordinate(*x) for x in product(*ranges)]

        self._coord2rank = dict(zip(all_coordinate, range(len(all_coordinate))))
        self._rank2coord = dict(
            zip(self._coord2rank.values(), self._coord2rank.keys()))

    def get_hybrid_group_names(self):
        return self._parallel_names

    def get_dim(self, axis_name):
        return self._dims[self._parallel_names.index(axis_name)]

    def world_size(self):
        return self._world_size

    def get_rank(self, **args):
        assert len(args) == len(self._dims)
        key = self.coordinate(**args)
        assert key in self._coord2rank.keys()
        return self._coord2rank[key]

    def get_coord(self, rank):
        assert rank < self._world_size
        assert rank in self._rank2coord.keys()
        return self._rank2coord[rank]

    def get_axis_list(self, axis_name, index):
        axis = self._parallel_names.index(axis_name)
        ranks = [
            self._coord2rank[coord] for coord in self._coord2rank.keys()
            if coord[axis] == index
        ]
        ranks.sort()
        return ranks

    def get_dim_size(self, axis_name):
        assert axis_name in self._parallel_names
        return self._dims[self._parallel_names.index(axis_name)]

    def get_comm_list(self, axis_name):
        assert axis_name in self._parallel_names
        other_axis_names = [
            name for name in self._parallel_names if name != axis_name
        ]

        ranges = []
        for name in other_axis_names:
            dim_num = self.get_dim_size(name)
            ranges.append(range(dim_num))

        all_result = []
        for x in product(*ranges):
            key_coord = {}
            for other_name in other_axis_names:
                key_coord[other_name] = x[other_axis_names.index(other_name)]

            result = []
            for i in range(0, self.get_dim_size(axis_name)):
                key_coord[axis_name] = i
                result.append(self._coord2rank[self.coordinate(**key_coord)])
            all_result.append(result)

        return all_result

128 129 130 131 132
    def get_rank_from_stage(self, global_rank, **kwargs):
        coord = self.get_coord(global_rank)
        tf = coord._replace(**kwargs)._asdict()
        return self.get_rank(**tf)

133 134

class HybridCommunicateGroup(object):
135

136 137 138 139 140
    def __init__(self, topology):
        self.nranks = paddle.distributed.get_world_size()
        self.global_rank = paddle.distributed.get_rank()
        self._topo = topology

141 142 143
        self._dp_degree = self._topo.get_dim('data')
        self._mp_degree = self._topo.get_dim('model')
        self._pp_degree = self._topo.get_dim('pipe')
J
JZ-LIANG 已提交
144
        self._sharding_degree = self._topo.get_dim('sharding')
145 146 147

        self._data_parallel_id = self._get_data_parallel_id()
        self._model_parallel_id = self._get_model_parallel_id()
J
JZ-LIANG 已提交
148
        self._sharding_parallel_id = self._get_sharding_parallel_id()
149
        self.stage_id = self._get_pipe_parallel_id()
150 151

        assert self._check_vaild_topo(
152
        ), "Here is an unreasonable topogy setting. world_size: {}, but" \
J
JZ-LIANG 已提交
153 154
            "mp_num: {}, sharding_num: {}, pp_num: {}, dp_num: {}".format(self.nranks,
            self._mp_degree, self._sharding_degree, self._pp_degree, self._dp_degree)
155 156 157 158 159 160

        # create comm group for data parallel
        self._dp_group, self._dp_comm_group = self._set_comm_group("data")

        # create comm group for model parallel
        self._mp_group, self._mp_comm_group = self._set_comm_group("model")
161

162 163 164
        # create comm group for pipe parallel
        self._pp_group, self._pp_comm_group = self._set_comm_group("pipe")

J
JZ-LIANG 已提交
165 166 167 168
        # create comm group for sharding parallel
        self._sharding_group, self._sharding_comm_group = self._set_comm_group(
            "sharding")

169 170 171 172
        # create global group for check inf_nan / clip global norm
        self._check_group, self._check_comm_group = self._set_check_group(
            "data")

173 174 175 176
        # create p2p group
        self.is_first_stage = (self.stage_id == 0)
        self.is_last_stage = (self.stage_id == (self._pp_degree - 1))

177 178 179 180
        # create p2p_groups
        if self._pp_degree > 1:
            self._set_p2p_group()

J
JZ-LIANG 已提交
181 182 183 184 185 186
        debug_str = "HybridParallelInfo: rank_id: %d, mp_degree: %d, " \
                    "sharding_degree: %d, pp_degree: %d, dp_degree: %d" % (self.global_rank, self._mp_degree,
                    self._sharding_degree, self._pp_degree, self._dp_degree)
        debug_str += ", mp_group: %s,  sharding_group: %s, pp_group: %s, dp_group: %s, check/clip group: %s" % (
            self._mp_group, self._sharding_group, self._pp_group,
            self._dp_group, self._check_group)
187
        logger.info(debug_str)
188 189 190

        global _HYBRID_PARALLEL_GROUP
        _HYBRID_PARALLEL_GROUP = self
191

192
    def get_parallel_mode(self):
J
JZ-LIANG 已提交
193
        # there are four modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel
194
        # NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and
J
JZ-LIANG 已提交
195 196 197 198 199 200
        # adding its parallel logic within that parallelism
        # when use sharding alone, it should have its own parallelism for its parallel logic
        # TODO modify 3 others parallel to support sharding
        if self._mp_degree == 1 and self._pp_degree == 1 and self._dp_degree == 1 and self._sharding_degree > 1:
            return ParallelMode.SHARDING_PARALLEL
        elif self._mp_degree == 1 and self._pp_degree == 1:
201 202 203
            return ParallelMode.DATA_PARALLEL
        elif self._mp_degree > 1 and self._pp_degree == 1:
            # initialize the seed
204
            return ParallelMode.TENSOR_PARALLEL
205 206 207
        elif self._pp_degree > 1:
            return ParallelMode.PIPELINE_PARALLEL

208
    def _check_vaild_topo(self):
J
JZ-LIANG 已提交
209
        return self._dp_degree * self._mp_degree * self._pp_degree * self._sharding_degree == self.nranks
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226

    def _set_comm_group(self, parallel_method="data"):
        parallel_group = []
        parallel_comm_group = None
        parallel_groups = self._topo.get_comm_list(parallel_method)

        for group in parallel_groups:
            comm_group = paddle.distributed.new_group(ranks=group)
            if self.global_rank in group:
                parallel_group = group
                parallel_comm_group = comm_group

        assert len(parallel_group) > 0
        assert parallel_comm_group is not None

        return parallel_group, parallel_comm_group

227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
    def _set_check_group(self, parallel_method="data"):
        parallel_group = []
        parallel_comm_group = None
        parallel_size = self._topo.get_dim(parallel_method)
        for idx in range(parallel_size):
            parallel_groups = self._topo.get_axis_list(parallel_method, idx)
            comm_group = paddle.distributed.new_group(ranks=parallel_groups)
            if self.global_rank in parallel_groups:
                parallel_group = parallel_groups
                parallel_comm_group = comm_group

        assert len(parallel_group) > 0
        assert parallel_comm_group is not None

        return parallel_group, parallel_comm_group

243 244 245 246 247 248 249 250
    def _get_p2p_next_rank(self):
        assert hasattr(self, 'next_rank'), "next_rank has not been inited"
        return self.next_rank

    def _get_p2p_prev_rank(self):
        assert hasattr(self, 'prev_rank'), "prev_rank has not been inited"
        return self.prev_rank

251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
    def _set_p2p_group(self):
        comm_lists = self._topo.get_comm_list('pipe')

        self.send_next_group = None
        self.send_prev_group = None
        self.recv_next_group = None
        self.recv_prev_group = None

        for comm_ranks in comm_lists:
            assert len(comm_ranks) == self._pp_degree
            for idx, rank in enumerate(comm_ranks):
                curr_rank = rank
                next_rank = comm_ranks[(idx + 1) % self._pp_degree]
                prev_rank = comm_ranks[(idx - 1) % self._pp_degree]

266 267 268 269
                if self.global_rank == curr_rank:
                    self.next_rank = next_rank
                    self.prev_rank = prev_rank

270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
                next_group = paddle.distributed.new_group(
                    ranks=[curr_rank, next_rank])
                if self.global_rank == curr_rank:
                    self.send_next_group = next_group
                elif self.global_rank == next_rank:
                    self.recv_prev_group = next_group

                prev_group = paddle.distributed.new_group(
                    ranks=[prev_rank, curr_rank])

                if self.global_rank == curr_rank:
                    self.send_prev_group = prev_group
                elif self.global_rank == prev_rank:
                    self.recv_next_group = prev_group

        assert self.send_next_group is not None
        assert self.send_prev_group is not None
        assert self.recv_next_group is not None
        assert self.recv_prev_group is not None

290 291 292 293 294 295 296 297 298 299 300 301 302 303
    def topology(self):
        return self._topo

    def get_global_rank(self):
        return self.global_rank

    # data parallel message:
    def _get_data_parallel_id(self):
        return self._topo.get_coord(self.global_rank).data

    def get_data_parallel_rank(self):
        return self._data_parallel_id

    def get_data_parallel_world_size(self):
304
        return self._dp_degree
305 306 307 308 309 310 311 312 313 314 315 316 317 318 319

    def get_data_parallel_group(self):
        return self._dp_comm_group

    def get_data_parallel_group_src_rank(self):
        return self._dp_comm_group.ranks[0]

    # model parallel message:
    def _get_model_parallel_id(self):
        return self._topo.get_coord(self.global_rank).model

    def get_model_parallel_rank(self):
        return self._model_parallel_id

    def get_model_parallel_world_size(self):
320
        return self._mp_degree
321 322 323 324 325 326

    def get_model_parallel_group(self):
        return self._mp_comm_group

    def get_model_parallel_group_src_rank(self):
        return self._mp_comm_group.ranks[0]
327

328 329 330 331 332 333 334 335 336 337 338 339 340
    # pipeline parallel message
    def _get_pipe_parallel_id(self):
        return self._topo.get_coord(self.global_rank).pipe

    def get_stage_id(self):
        return self.stage_id

    def get_pipe_parallel_world_size(self):
        return self._pp_degree

    def get_pipe_parallel_group(self):
        return self._pp_comm_group

341 342 343
    def get_p2p_groups(self):
        return self.send_next_group, self.send_prev_group, self.recv_next_group, self.recv_prev_group

J
JZ-LIANG 已提交
344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
    # sharding parallel message:
    def _get_sharding_parallel_id(self):
        return self._topo.get_coord(self.global_rank).sharding

    def get_sharding_parallel_rank(self):
        return self._sharding_parallel_id

    def get_sharding_parallel_world_size(self):
        return self._sharding_degree

    def get_sharding_parallel_group(self):
        return self._sharding_comm_group

    def get_sharding_parallel_group_src_rank(self):
        # TODO should the src rank related to the shard rank for each parameter ?
        return self._sharding_comm_group.ranks[0]

361 362 363
    # check parallel group
    def get_check_parallel_group(self):
        return self._check_comm_group
364

365
    def get_rank_from_stage(self, stage_id, **kwargs):
366 367 368
        return self._topo.get_rank_from_stage(self.global_rank,
                                              pipe=stage_id,
                                              **kwargs)
W
WangXi 已提交
369 370 371 372 373 374 375 376 377 378 379 380


class _CommunicateGroup(object):
    """ tmp for static """

    def __init__(self):
        global _HYBRID_PARALLEL_GROUP
        _HYBRID_PARALLEL_GROUP = self
        self.groups = dict()

    def set_comm_group(self, group_name, group_rank, group_size, ring_id,
                       group_ranks):
381 382
        group = paddle.distributed.collective.Group(group_rank, ring_id,
                                                    group_ranks)
W
WangXi 已提交
383 384 385 386 387 388 389 390 391 392 393 394 395 396
        self.groups[group_name] = group

    def get_group(self, group_name):
        assert group_name in self.groups
        return self.groups[group_name]

    def get_model_parallel_group(self):
        return self.get_group('model')

    def get_model_parallel_world_size(self):
        return self.get_group('model').nranks

    def get_model_parallel_rank(self):
        return self.get_group('model').rank