role_maker.py 14.4 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

T
tangwei12 已提交
15
from __future__ import print_function
16

T
tangwei12 已提交
17
__all__ = [
18
    'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker',
19
    'UserDefinedCollectiveRoleMaker', 'PaddleCloudRoleMaker'
T
tangwei12 已提交
20 21
]

22 23
import os

24

T
tangwei12 已提交
25 26
class Role:
    WORKER = 1
27 28
    SERVER = 2

D
dongdaxiang 已提交
29 30

class RoleMakerBase(object):
31 32 33 34 35 36 37
    """
    RoleMakerBase is a base class for assigning a role to current process
    in distributed training.
    A paddle developer can implement RoleMakerBase to design a role maker
    for worker or pserver assignment.
    """

D
dongdaxiang 已提交
38
    def __init__(self):
T
tangwei12 已提交
39 40
        self._worker_endpoints = []
        self._server_endpoints = []
D
dongdaxiang 已提交
41
        self._role_is_generated = False
T
tangwei12 已提交
42 43
        self._role = None
        self._current_id = -1
D
dongdaxiang 已提交
44

T
tangwei12 已提交
45
    def is_worker(self):
46 47 48
        """
        return is_worker() of current process
        """
D
dongdaxiang 已提交
49 50
        raise NotImplementedError("Please implement this method in child class")

T
tangwei12 已提交
51
    def is_server(self):
52 53 54
        """
        return is_server() of current process
        """
D
dongdaxiang 已提交
55 56
        raise NotImplementedError("Please implement this method in child class")

T
tangwei12 已提交
57
    def is_first_worker(self):
58
        """
T
tangwei12 已提交
59 60 61 62
        Check whether the node is the first instance of worker.
        Returns:
            bool: True if this is the first node of worker,
                  False if not.
63
        """
T
tangwei12 已提交
64
        raise NotImplementedError("Please implement this method in child class")
D
dongdaxiang 已提交
65

66 67 68 69 70 71 72 73 74
    def worker_num(self):
        """
        Get current total worker number.

        Returns:
            int: worker number
        """
        raise NotImplementedError("Please implement this method in child class")

T
tangwei12 已提交
75
    def worker_index(self):
76
        """
T
tangwei12 已提交
77 78 79 80
        Get current worker id.

        Returns:
            int: node id
81
        """
T
tangwei12 已提交
82
        raise NotImplementedError("Please implement this method in child class")
D
dongdaxiang 已提交
83

T
tangwei12 已提交
84
    def server_index(self):
85
        """
T
tangwei12 已提交
86 87 88 89
        Get current server id.

        Returns:
            int: node id
90
        """
T
tangwei12 已提交
91
        raise NotImplementedError("Please implement this method in child class")
D
dongdaxiang 已提交
92

T
tangwei12 已提交
93
    def get_trainer_endpoints(self):
94
        """
T
tangwei12 已提交
95
        return trainer endpoints
96
        """
T
tangwei12 已提交
97 98 99 100 101 102 103
        return self._worker_endpoints

    def get_pserver_endpoints(self):
        """
        return pserver endpoints
        """
        return self._server_endpoints
D
dongdaxiang 已提交
104 105


106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
class MultiProcessRoleMaker(RoleMakerBase):
    """
    MultiProcessRoleMaker is a default role maker for multi-process
    GPU training. It works with paddle.distributed.lanuch.py by-design
    """

    def __init__(self):
        super(MultiProcessRoleMaker, self).__init__()
        self._role_is_generated = False

    def generate_role(self):
        import os
        if not self._role_is_generated:
            self._current_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
            self._num_trainers = 1
            self._training_role = os.getenv("PADDLE_TRAINING_ROLE", "TRAINER")
            assert (self._training_role == "TRAINER")
            self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS")
            self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
            if self._worker_endpoints:
                self._worker_endpoints = self._worker_endpoints.split(",")
                self._num_trainers = len(self._worker_endpoints)
            self._role_is_generated = True

    def is_worker(self):
        return True

    def is_server(self):
        return False

    def is_first_worker(self):
        return self._current_id == 0

    def worker_index(self):
        return self._current_id

    def worker_num(self):
        return self._worker_num


D
dongdaxiang 已提交
146
class MPIRoleMaker(RoleMakerBase):
147 148 149 150 151
    """
    MPIRoleMaker is a MPI-API based role maker which is a counter-part of K8SRoleMaker
    mpi4py will be used if a developer inherits MPIRoleMaker
    """

D
dongdaxiang 已提交
152
    def __init__(self):
X
xujiaqi01 已提交
153
        super(MPIRoleMaker, self).__init__()
D
dongdaxiang 已提交
154 155
        from mpi4py import MPI
        self.MPI = MPI
T
tangwei12 已提交
156 157
        self._comm = MPI.COMM_WORLD
        self._node_type_comm = None
D
dongdaxiang 已提交
158
        self._ips = None
T
tangwei12 已提交
159
        self._ip = None
D
dongdaxiang 已提交
160

161
    def _get_rank(self):
162 163 164
        """
        return rank
        """
D
dongdaxiang 已提交
165 166
        self._rank = self._comm.Get_rank()
        return self._rank
D
dongdaxiang 已提交
167

168
    def _get_size(self):
169 170 171
        """
        return size
        """
D
dongdaxiang 已提交
172 173
        self._size = self._comm.Get_size()
        return self._size
D
dongdaxiang 已提交
174

175
    def _all_gather(self, obj):
176 177 178
        """
        all_gather(obj) will call MPI's allgather function
        """
X
xjqbest 已提交
179
        self._barrier_all()
D
dongdaxiang 已提交
180
        return self._comm.allgather(obj)
D
dongdaxiang 已提交
181

X
xjqbest 已提交
182 183 184 185
    def _worker_gather(self, obj):
        """
        worker_gather(obj) will call MPI's allgather function
        """
T
tangwei12 已提交
186
        if self.is_worker():
D
dongdaxiang 已提交
187 188
            self._node_type_comm.barrier()
            return self._node_type_comm.allgather(obj)
X
xjqbest 已提交
189 190
        return None

191
    def _barrier_all(self):
192 193 194
        """
        barrier_all() will call MPI's barrier_all function
        """
D
dongdaxiang 已提交
195
        self._comm.barrier()
D
dongdaxiang 已提交
196

T
tangwei12 已提交
197 198 199 200
    def _finalize(self):
        """
        finalize the current MPI instance.
        """
201
        self.MPI.Finalize()
T
tangwei12 已提交
202

203
    def _get_ips(self):
204 205 206
        """
        collect current distributed job's ip list
        """
T
tangwei12 已提交
207 208
        if not self._ips:
            self._ips = self._comm.allgather(self.get_local_ip())
D
dongdaxiang 已提交
209
        return self._ips
D
dongdaxiang 已提交
210

T
tangwei12 已提交
211
    def get_local_ip(self):
212
        """
T
tangwei12 已提交
213
        return get local ip
214
        """
T
tangwei12 已提交
215 216 217 218 219 220 221 222 223
        import socket
        self._ip = socket.gethostbyname(socket.gethostname())
        return self._ip

    def generate_role(self):
        """
        generate_role() should be called to identify current process's role
        """
        raise NotImplementedError("Please implement this method in child class")
D
dongdaxiang 已提交
224 225 226


class MPISymetricRoleMaker(MPIRoleMaker):
227 228 229 230 231 232
    """
    MPISymetricRoleMaker is designed for worker and server assignment
    under MPI. Typically, a worker and a server node will be appointed
    on each physical node. This role maker can be only used under MPI.
    """

D
dongdaxiang 已提交
233 234
    def __init__(self):
        super(MPISymetricRoleMaker, self).__init__()
D
dongdaxiang 已提交
235 236
        self._node_type = None
        self._proc_per_node = 2
D
dongdaxiang 已提交
237

238
    def _check_role_generation(self):
D
dongdaxiang 已提交
239
        if not self._role_is_generated:
T
tangwei12 已提交
240
            raise NameError("generate_role() should be called first")
241 242
        return True

T
tangwei12 已提交
243
    def is_first_worker(self):
244 245 246 247
        """
        return whether current process is the first worker assigned by role maker
        """
        if self._check_role_generation():
T
tangwei12 已提交
248
            return self.is_worker() and 0 == self.worker_index()
249
        return False
D
dongdaxiang 已提交
250

251 252 253
    def worker_num(self):
        return self._worker_num()

T
tangwei12 已提交
254
    def is_worker(self):
255 256 257 258
        """
        return whether current process is worker assigned by role maker
        """
        if self._check_role_generation():
D
dongdaxiang 已提交
259
            return self._node_type == 1
260
        return False
D
dongdaxiang 已提交
261

T
tangwei12 已提交
262
    def is_server(self):
263 264 265 266
        """
        return whether current process is server assigned by role maker
        """
        if self._check_role_generation():
D
dongdaxiang 已提交
267
            return self._node_type == 0
268
        return False
D
dongdaxiang 已提交
269

270
    def _worker_num(self):
271 272 273 274
        """
        return the current number of worker
        """
        if self._check_role_generation():
T
tangwei12 已提交
275
            if self.is_worker():
X
xjqbest 已提交
276
                return self._get_size() / 2
277
        return 0
D
dongdaxiang 已提交
278

279
    def _server_num(self):
280 281 282 283
        """
        return the current number of server
        """
        if self._check_role_generation():
T
tangwei12 已提交
284
            if self.is_server():
X
xjqbest 已提交
285
                return self._get_size() / 2
286
        return 0
D
dongdaxiang 已提交
287

T
tangwei12 已提交
288
    def worker_index(self):
289 290 291 292
        """
        return the index of worker
        """
        if self._check_role_generation():
D
dongdaxiang 已提交
293
            return self._rank / self._proc_per_node
294
        return 0
D
dongdaxiang 已提交
295

T
tangwei12 已提交
296
    def server_index(self):
297 298 299 300
        """
        return the index of server
        """
        if self._check_role_generation():
D
dongdaxiang 已提交
301
            return self._rank / self._proc_per_node
302
        return 0
D
dongdaxiang 已提交
303

304
    def _barrier_worker(self):
305 306 307 308
        """
        barrier all workers in current distributed job
        """
        if self._check_role_generation():
T
tangwei12 已提交
309
            if self.is_worker():
D
dongdaxiang 已提交
310
                self._node_type_comm.barrier()
D
dongdaxiang 已提交
311

312
    def _barrier_server(self):
313 314 315 316
        """
        barrier all servers in current distributed job
        """
        if self._check_role_generation():
T
tangwei12 已提交
317
            if self.is_server():
D
dongdaxiang 已提交
318
                self._node_type_comm.barrier()
D
dongdaxiang 已提交
319

T
tangwei12 已提交
320
    def generate_role(self):
321 322 323
        """
        generate currently process's role
        """
D
dongdaxiang 已提交
324
        if not self._role_is_generated:
325
            # TODO(guru4elephant): only allow to be called once
326 327
            self._worker_endpoints = self._get_ips()[1::2]
            self._server_endpoints = self._get_ips()[::2]
328

D
dongdaxiang 已提交
329 330
            if 0 == self._get_rank() % self._proc_per_node % 2:
                self._node_type = 0
331
            else:
D
dongdaxiang 已提交
332 333 334
                self._node_type = 1
            self._node_type_comm = self._comm.Split(self._node_type)
            self._role_is_generated = True
335 336


337 338 339
class PaddleCloudRoleMaker(RoleMakerBase):
    def __init__(self):
        super(PaddleCloudRoleMaker, self).__init__()
340
        self._role_is_generated = False
341 342 343 344 345 346

    def generate_role(self):
        if not self._role_is_generated:
            self.port = os.getenv("PADDLE_PORT", "6174")
            self.pserver_ips = os.getenv("PADDLE_PSERVERS", "")
            eplist = []
347 348 349 350 351 352 353 354
            for ip in self.pserver_ips.split(","):
                eplist.append(':'.join([ip, self.port]))
            self.endpoints = ",".join(eplist)
            self._trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
            self.current_endpoint = os.getenv("POD_IP",
                                              "localhost") + ":" + self.port
            self.role = os.getenv("TRAINING_ROLE", "TRAINER")
            self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
355
            self.eplist = eplist
356
            print("PaddleCloudRoleMaker() endpoints: %s" % self.endpoints)
357
            self.endpoints = self.endpoints.split(",")
358
            self._server_endpoints = self.endpoints
359
            self._worker_endpoints = self.endpoints
360
            if self.role.upper() == "PSERVER":
361 362
                self._current_id = self.endpoints.index(self.current_endpoint)
                self._role = Role.SERVER
363
            else:
364 365
                self._current_id = self.trainer_id
                self._role = Role.WORKER
366 367
            self._role_is_generated = True

368 369 370
    def is_worker(self):
        if not self._role_is_generated:
            self.generate_role()
371 372 373
        return self._role == Role.WORKER

    def is_server(self):
374 375
        if not self._role_is_generated:
            self.generate_role()
376 377 378
        return self._role == Role.SERVER

    def is_first_worker(self):
379 380
        if not self._role_is_generated:
            self.generate_role()
381 382 383
        return self._role == Role.WORKER and self._current_id == 0

    def worker_index(self):
384 385
        if not self._role_is_generated:
            self.generate_role()
386 387 388
        return self._current_id

    def server_index(self):
389 390
        if not self._role_is_generated:
            self.generate_role()
391 392 393
        return self._current_id

    def worker_num(self):
394 395 396
        if not self._role_is_generated:
            self.generate_role()
        return self._trainers
397 398


399 400 401
class UserDefinedRoleMaker(RoleMakerBase):
    def __init__(self,
                 current_id=0,
T
tangwei12 已提交
402 403 404
                 role=Role.WORKER,
                 worker_num=0,
                 server_endpoints=None):
405 406 407 408 409 410 411
        """
        UserDefinedRoleMaker is designed for worker and server assignment
        under manual. Typically, a worker and a server node will be appointed
        on each physical node, It can be assign by user.
        """
        super(UserDefinedRoleMaker, self).__init__()

412 413 414 415 416 417 418
        if not isinstance(current_id, int):
            raise TypeError("current_id must be as int")
        else:
            if current_id < 0:
                raise ValueError("current_id must be gather or equal 0")
            self._current_id = current_id

T
tangwei12 已提交
419
        if role != Role.WORKER and role != Role.SERVER:
420 421 422 423 424 425 426 427 428 429 430 431 432 433 434
            raise TypeError("role must be as Role")
        else:
            self._role = role

        if not isinstance(worker_num, int):
            raise TypeError("worker_num must be as int")
        else:
            if worker_num < 0:
                raise ValueError("worker_num must be gather or equal 0")
            self._worker_num = worker_num

        if not isinstance(server_endpoints, list):
            raise TypeError("server_endpoints must be as string list")
        else:
            self._server_endpoints = server_endpoints
T
tangwei12 已提交
435

436 437 438
    def generate_role(self):
        self._role_is_generated = True

T
tangwei12 已提交
439 440 441 442 443
    def is_worker(self):
        return self._role == Role.WORKER

    def is_server(self):
        return self._role == Role.SERVER
444

T
tangwei12 已提交
445 446
    def is_first_worker(self):
        return self._role == Role.WORKER and self._current_id == 0
447

T
tangwei12 已提交
448 449
    def worker_index(self):
        return self._current_id
450

T
tangwei12 已提交
451 452
    def server_index(self):
        return self._current_id
453 454 455

    def worker_num(self):
        return self._worker_num
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478


class UserDefinedCollectiveRoleMaker(RoleMakerBase):
    def __init__(self, current_id=0, worker_endpoints=None):
        """
        UserDefinedCollectiveRoleMaker is designed for worker assignment
        under manual for collective mode.
        """
        super(UserDefinedCollectiveRoleMaker, self).__init__()

        if not isinstance(current_id, int):
            raise TypeError("current_id must be as int")
        else:
            if current_id < 0:
                raise ValueError("current_id must be greater or equal 0")
            self._current_id = current_id

        if not isinstance(worker_endpoints, list):
            raise TypeError("worker_endpoints must be as string list")
        else:
            self._worker_endpoints = worker_endpoints
        self._worker_num = len(self._worker_endpoints)

479 480 481
    def generate_role(self):
        self._role_is_generated = True

482 483 484 485 486 487 488 489 490 491 492
    def is_worker(self):
        return True

    def is_first_worker(self):
        return self._current_id == 0

    def worker_index(self):
        return self._current_id

    def worker_num(self):
        return self._worker_num