role_maker.py 17.7 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

T
tangwei12 已提交
15
from __future__ import print_function
16

T
tangwei12 已提交
17
__all__ = [
18
    'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker',
19
    'UserDefinedCollectiveRoleMaker', 'PaddleCloudRoleMaker'
T
tangwei12 已提交
20 21
]

22 23
import os

24

T
tangwei12 已提交
25 26
class Role:
    WORKER = 1
27 28
    SERVER = 2

D
dongdaxiang 已提交
29 30

class RoleMakerBase(object):
31 32 33 34 35 36 37
    """
    RoleMakerBase is a base class for assigning a role to current process
    in distributed training.
    A paddle developer can implement RoleMakerBase to design a role maker
    for worker or pserver assignment.
    """

D
dongdaxiang 已提交
38
    def __init__(self):
T
tangwei12 已提交
39 40
        self._worker_endpoints = []
        self._server_endpoints = []
D
dongdaxiang 已提交
41
        self._role_is_generated = False
T
tangwei12 已提交
42 43
        self._role = None
        self._current_id = -1
D
dongdaxiang 已提交
44

T
tangwei12 已提交
45
    def is_worker(self):
46 47 48
        """
        return is_worker() of current process
        """
D
dongdaxiang 已提交
49 50
        raise NotImplementedError("Please implement this method in child class")

T
tangwei12 已提交
51
    def is_server(self):
52 53 54
        """
        return is_server() of current process
        """
D
dongdaxiang 已提交
55 56
        raise NotImplementedError("Please implement this method in child class")

T
tangwei12 已提交
57
    def is_first_worker(self):
58
        """
T
tangwei12 已提交
59 60 61 62
        Check whether the node is the first instance of worker.
        Returns:
            bool: True if this is the first node of worker,
                  False if not.
63
        """
T
tangwei12 已提交
64
        raise NotImplementedError("Please implement this method in child class")
D
dongdaxiang 已提交
65

66 67 68 69 70 71 72 73 74
    def worker_num(self):
        """
        Get current total worker number.

        Returns:
            int: worker number
        """
        raise NotImplementedError("Please implement this method in child class")

T
tangwei12 已提交
75
    def worker_index(self):
76
        """
T
tangwei12 已提交
77 78 79 80
        Get current worker id.

        Returns:
            int: node id
81
        """
T
tangwei12 已提交
82
        raise NotImplementedError("Please implement this method in child class")
D
dongdaxiang 已提交
83

T
tangwei12 已提交
84
    def server_index(self):
85
        """
T
tangwei12 已提交
86 87 88 89
        Get current server id.

        Returns:
            int: node id
90
        """
T
tangwei12 已提交
91
        raise NotImplementedError("Please implement this method in child class")
D
dongdaxiang 已提交
92

T
tangwei12 已提交
93
    def get_trainer_endpoints(self):
94
        """
T
tangwei12 已提交
95
        return trainer endpoints
96
        """
T
tangwei12 已提交
97 98 99 100 101 102 103
        return self._worker_endpoints

    def get_pserver_endpoints(self):
        """
        return pserver endpoints
        """
        return self._server_endpoints
D
dongdaxiang 已提交
104

T
tangwei12 已提交
105 106 107 108 109
    def to_string(self):
        return "role: {}, current_id: {}, worker_endpoints: {}, server_endpoints: {}".format(
            self._role, self._current_id, self._worker_endpoints,
            self._server_endpoints)

D
dongdaxiang 已提交
110 111

class MPIRoleMaker(RoleMakerBase):
112 113 114 115 116
    """
    MPIRoleMaker is a MPI-API based role maker which is a counter-part of K8SRoleMaker
    mpi4py will be used if a developer inherits MPIRoleMaker
    """

D
dongdaxiang 已提交
117
    def __init__(self):
X
xujiaqi01 已提交
118
        super(MPIRoleMaker, self).__init__()
D
dongdaxiang 已提交
119 120
        from mpi4py import MPI
        self.MPI = MPI
T
tangwei12 已提交
121 122
        self._comm = MPI.COMM_WORLD
        self._node_type_comm = None
D
dongdaxiang 已提交
123
        self._ips = None
T
tangwei12 已提交
124
        self._ip = None
D
dongdaxiang 已提交
125

126
    def _get_rank(self):
127 128 129
        """
        return rank
        """
D
dongdaxiang 已提交
130 131
        self._rank = self._comm.Get_rank()
        return self._rank
D
dongdaxiang 已提交
132

133
    def _get_size(self):
134 135 136
        """
        return size
        """
D
dongdaxiang 已提交
137 138
        self._size = self._comm.Get_size()
        return self._size
D
dongdaxiang 已提交
139

140
    def _all_gather(self, obj):
141 142 143
        """
        all_gather(obj) will call MPI's allgather function
        """
X
xjqbest 已提交
144
        self._barrier_all()
D
dongdaxiang 已提交
145
        return self._comm.allgather(obj)
D
dongdaxiang 已提交
146

X
xjqbest 已提交
147 148 149 150
    def _worker_gather(self, obj):
        """
        worker_gather(obj) will call MPI's allgather function
        """
T
tangwei12 已提交
151
        if self.is_worker():
D
dongdaxiang 已提交
152 153
            self._node_type_comm.barrier()
            return self._node_type_comm.allgather(obj)
X
xjqbest 已提交
154 155
        return None

156
    def _barrier_all(self):
157 158 159
        """
        barrier_all() will call MPI's barrier_all function
        """
D
dongdaxiang 已提交
160
        self._comm.barrier()
D
dongdaxiang 已提交
161

T
tangwei12 已提交
162 163 164 165
    def _finalize(self):
        """
        finalize the current MPI instance.
        """
166
        self.MPI.Finalize()
T
tangwei12 已提交
167

168
    def _get_ips(self):
169 170 171
        """
        collect current distributed job's ip list
        """
T
tangwei12 已提交
172 173
        if not self._ips:
            self._ips = self._comm.allgather(self.get_local_ip())
D
dongdaxiang 已提交
174
        return self._ips
D
dongdaxiang 已提交
175

T
tangwei12 已提交
176
    def get_local_ip(self):
177
        """
T
tangwei12 已提交
178
        return get local ip
179
        """
T
tangwei12 已提交
180 181 182 183 184 185 186 187 188
        import socket
        self._ip = socket.gethostbyname(socket.gethostname())
        return self._ip

    def generate_role(self):
        """
        generate_role() should be called to identify current process's role
        """
        raise NotImplementedError("Please implement this method in child class")
D
dongdaxiang 已提交
189 190 191


class MPISymetricRoleMaker(MPIRoleMaker):
192 193 194 195 196 197
    """
    MPISymetricRoleMaker is designed for worker and server assignment
    under MPI. Typically, a worker and a server node will be appointed
    on each physical node. This role maker can be only used under MPI.
    """

D
dongdaxiang 已提交
198 199
    def __init__(self):
        super(MPISymetricRoleMaker, self).__init__()
D
dongdaxiang 已提交
200 201
        self._node_type = None
        self._proc_per_node = 2
G
guru4elephant 已提交
202
        self._pserver_rand_port = 0
D
dongdaxiang 已提交
203

204
    def _check_role_generation(self):
D
dongdaxiang 已提交
205
        if not self._role_is_generated:
T
tangwei12 已提交
206
            raise NameError("generate_role() should be called first")
207 208
        return True

T
tangwei12 已提交
209
    def is_first_worker(self):
210 211 212 213
        """
        return whether current process is the first worker assigned by role maker
        """
        if self._check_role_generation():
T
tangwei12 已提交
214
            return self.is_worker() and 0 == self.worker_index()
215
        return False
D
dongdaxiang 已提交
216

G
guru4elephant 已提交
217 218 219 220 221 222 223 224 225 226 227 228 229 230
    def get_pserver_endpoints(self):
        if self._pserver_rand_port <= 0:
            import random
            random.seed(self._server_num())
            # port will be randomly generated from 60001 to 63999
            # random seed is server num so that all nodes will get
            # the same port
            self._pserver_rand_port = random.randint(60001, 64000)
        endpoints = [
            x + ":" + str(self._pserver_rand_port)
            for x in self._server_endpoints
        ]
        return endpoints

231 232 233
    def worker_num(self):
        return self._worker_num()

T
tangwei12 已提交
234
    def is_worker(self):
235 236 237 238
        """
        return whether current process is worker assigned by role maker
        """
        if self._check_role_generation():
D
dongdaxiang 已提交
239
            return self._node_type == 1
240
        return False
D
dongdaxiang 已提交
241

T
tangwei12 已提交
242
    def is_server(self):
243 244 245 246
        """
        return whether current process is server assigned by role maker
        """
        if self._check_role_generation():
D
dongdaxiang 已提交
247
            return self._node_type == 0
248
        return False
D
dongdaxiang 已提交
249

250
    def _worker_num(self):
251 252 253 254
        """
        return the current number of worker
        """
        if self._check_role_generation():
T
tangwei12 已提交
255
            if self.is_worker():
G
guru4elephant 已提交
256
                return self._get_size() / self._proc_per_node
257
        return 0
D
dongdaxiang 已提交
258

259
    def _server_num(self):
260 261 262 263
        """
        return the current number of server
        """
        if self._check_role_generation():
G
guru4elephant 已提交
264 265 266 267
            return self._get_size() / self._proc_per_node
        else:
            self.generate_role()
            return self._get_size() / self._proc_per_node
D
dongdaxiang 已提交
268

T
tangwei12 已提交
269
    def worker_index(self):
270 271 272 273
        """
        return the index of worker
        """
        if self._check_role_generation():
D
dongdaxiang 已提交
274
            return self._rank / self._proc_per_node
G
guru4elephant 已提交
275 276 277
        else:
            self.generate_role()
            return self._get_size() / 2
D
dongdaxiang 已提交
278

T
tangwei12 已提交
279
    def server_index(self):
280 281 282 283
        """
        return the index of server
        """
        if self._check_role_generation():
D
dongdaxiang 已提交
284
            return self._rank / self._proc_per_node
G
guru4elephant 已提交
285 286 287
        else:
            self.generate_role()
            return self._get_size() / self._proc_per_node
D
dongdaxiang 已提交
288

289
    def _barrier_worker(self):
290 291 292 293
        """
        barrier all workers in current distributed job
        """
        if self._check_role_generation():
T
tangwei12 已提交
294
            if self.is_worker():
D
dongdaxiang 已提交
295
                self._node_type_comm.barrier()
G
guru4elephant 已提交
296 297
        else:
            raise Exception("You should check role generation first")
D
dongdaxiang 已提交
298

299
    def _barrier_server(self):
300 301 302 303
        """
        barrier all servers in current distributed job
        """
        if self._check_role_generation():
T
tangwei12 已提交
304
            if self.is_server():
D
dongdaxiang 已提交
305
                self._node_type_comm.barrier()
G
guru4elephant 已提交
306 307
        else:
            raise Exception("You should check role generation first")
D
dongdaxiang 已提交
308

T
tangwei12 已提交
309
    def generate_role(self):
310 311 312
        """
        generate currently process's role
        """
D
dongdaxiang 已提交
313
        if not self._role_is_generated:
314
            # TODO(guru4elephant): only allow to be called once
315 316
            self._worker_endpoints = self._get_ips()[1::2]
            self._server_endpoints = self._get_ips()[::2]
317

D
dongdaxiang 已提交
318 319
            if 0 == self._get_rank() % self._proc_per_node % 2:
                self._node_type = 0
320
            else:
D
dongdaxiang 已提交
321 322 323
                self._node_type = 1
            self._node_type_comm = self._comm.Split(self._node_type)
            self._role_is_generated = True
G
guru4elephant 已提交
324 325
        else:
            raise Exception("You should check role generation first")
326 327


328
class PaddleCloudRoleMaker(RoleMakerBase):
329
    def __init__(self, is_collective=False):
330
        super(PaddleCloudRoleMaker, self).__init__()
331
        self._role_is_generated = False
332
        self._is_collective = is_collective
333 334 335

    def generate_role(self):
        if not self._role_is_generated:
336
            if not self._is_collective:
T
tangwei12 已提交
337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378
                try:
                    port = os.environ["PADDLE_PORT"]
                    pserver_ips = os.environ["PADDLE_PSERVERS"].split(",")
                    if "," in port:
                        ports = port.split(",")
                    else:
                        ports = [port] * len(pserver_ips)
                    eplist = []
                    # note that, we usually assign the same port to different ips
                    # if we run parameter server training in local mode
                    # port should be different in environment variables
                    for i, ip in enumerate(pserver_ips):
                        eplist.append(':'.join([ip, ports[i]]))

                    trainers_num = int(os.environ["PADDLE_TRAINERS_NUM"])
                    training_role = os.environ["TRAINING_ROLE"]

                    if training_role not in ["TRAINER", "PSERVER"]:
                        raise ValueError(
                            "TRAINING_ROLE must be PSERVER or TRAINER")

                    if training_role == "TRAINER":
                        role = Role.WORKER
                        current_id = int(os.environ["PADDLE_TRAINER_ID"])
                    elif training_role == "PSERVER":
                        role = Role.SERVER
                        cur_ip = os.environ["POD_IP"]
                        cur_idx = pserver_ips.index(cur_ip)
                        current_id = eplist.index(":".join(
                            [cur_ip, ports[cur_idx]]))
                    else:
                        raise ValueError(
                            "TRAINING_ROLE must be PSERVER or TRAINER")
                except ValueError as ve:
                    raise ValueError(
                        "something wrong with PaddleCloud, please check environment"
                    )

                self._trainers_num = trainers_num
                self._server_endpoints = eplist
                self._role = role
                self._current_id = current_id
379
            else:
380 381 382 383 384 385
                self._current_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
                self._training_role = os.getenv("PADDLE_TRAINING_ROLE",
                                                "TRAINER")
                assert (self._training_role == "TRAINER")
                self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS")
                self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
386 387 388 389
                assert self._worker_endpoints is not None, "can't find PADDLE_TRAINER_ENDPOINTS"
                self._worker_endpoints = self._worker_endpoints.split(",")
                self._trainers_num = len(self._worker_endpoints)

390 391
            self._role_is_generated = True

392 393 394 395 396
    def get_pserver_endpoints(self):
        if not self._role_is_generated:
            self.generate_role()
        return self._server_endpoints

397 398 399
    def is_worker(self):
        if not self._role_is_generated:
            self.generate_role()
400 401 402
        return self._role == Role.WORKER

    def is_server(self):
403 404
        if not self._role_is_generated:
            self.generate_role()
405 406 407
        return self._role == Role.SERVER

    def is_first_worker(self):
408 409
        if not self._role_is_generated:
            self.generate_role()
410 411 412
        return self._role == Role.WORKER and self._current_id == 0

    def worker_index(self):
413 414
        if not self._role_is_generated:
            self.generate_role()
415 416 417
        return self._current_id

    def server_index(self):
418 419
        if not self._role_is_generated:
            self.generate_role()
420 421 422
        return self._current_id

    def worker_num(self):
423 424
        if not self._role_is_generated:
            self.generate_role()
425
        return self._trainers_num
426 427


428 429 430
class UserDefinedRoleMaker(RoleMakerBase):
    def __init__(self,
                 current_id=0,
T
tangwei12 已提交
431 432 433
                 role=Role.WORKER,
                 worker_num=0,
                 server_endpoints=None):
434 435 436 437 438 439 440
        """
        UserDefinedRoleMaker is designed for worker and server assignment
        under manual. Typically, a worker and a server node will be appointed
        on each physical node, It can be assign by user.
        """
        super(UserDefinedRoleMaker, self).__init__()

441 442 443 444 445 446 447
        if not isinstance(server_endpoints, list):
            raise TypeError("server_endpoints must be as string list")
        elif len(server_endpoints) <= 0:
            raise ValueError(
                "the length of server_endpoints list must be greater than 0")
        elif len(server_endpoints) != len(set(server_endpoints)):
            raise ValueError("server_endpoints can't have duplicate elements")
448
        else:
449 450 451 452 453 454
            for server_endpoint in server_endpoints:
                if not isinstance(server_endpoint, str):
                    raise TypeError(
                        "every element in server_endpoints list must be as string"
                    )
            self._server_endpoints = server_endpoints
455

T
tangwei12 已提交
456
        if role != Role.WORKER and role != Role.SERVER:
457 458 459 460
            raise TypeError("role must be as Role")
        else:
            self._role = role

461 462 463 464 465 466 467 468 469 470 471 472 473
        if not isinstance(current_id, int):
            raise TypeError("current_id must be as int")
        else:
            if current_id < 0:
                raise ValueError(
                    "current_id must be greater than or equal to 0")
            elif self._role == Role.SERVER and current_id >= len(
                    server_endpoints):
                raise ValueError(
                    "if role is Role.SERVER, current_id must be less than or equal to len(server_endpoints) - 1"
                )
            self._current_id = current_id

474 475 476
        if not isinstance(worker_num, int):
            raise TypeError("worker_num must be as int")
        else:
477 478
            if worker_num <= 0:
                raise ValueError("worker_num must be greater than 0")
479 480
            self._worker_num = worker_num

481 482 483
    def generate_role(self):
        self._role_is_generated = True

T
tangwei12 已提交
484 485 486 487 488
    def is_worker(self):
        return self._role == Role.WORKER

    def is_server(self):
        return self._role == Role.SERVER
489

T
tangwei12 已提交
490 491
    def is_first_worker(self):
        return self._role == Role.WORKER and self._current_id == 0
492

T
tangwei12 已提交
493 494
    def worker_index(self):
        return self._current_id
495

T
tangwei12 已提交
496 497
    def server_index(self):
        return self._current_id
498 499 500

    def worker_num(self):
        return self._worker_num
501 502 503 504 505 506 507 508 509 510


class UserDefinedCollectiveRoleMaker(RoleMakerBase):
    def __init__(self, current_id=0, worker_endpoints=None):
        """
        UserDefinedCollectiveRoleMaker is designed for worker assignment
        under manual for collective mode.
        """
        super(UserDefinedCollectiveRoleMaker, self).__init__()

511 512 513 514 515 516 517 518 519 520 521 522 523 524 525
        if not isinstance(worker_endpoints, list):
            raise TypeError("worker_endpoints must be as string list")
        elif len(worker_endpoints) <= 0:
            raise ValueError(
                "the length of worker_endpoints list must be greater than 0")
        elif len(worker_endpoints) != len(set(worker_endpoints)):
            raise ValueError("worker_endpoints can't have duplicate elements")
        else:
            for worker_endpoint in worker_endpoints:
                if not isinstance(worker_endpoint, str):
                    raise TypeError(
                        "every element in worker_endpoints list must be as string"
                    )
            self._worker_endpoints = worker_endpoints

526 527 528 529
        if not isinstance(current_id, int):
            raise TypeError("current_id must be as int")
        else:
            if current_id < 0:
530 531 532 533 534 535
                raise ValueError(
                    "current_id must be greater than or equal to 0")
            elif current_id >= len(worker_endpoints):
                raise ValueError(
                    "current_id must be less than or equal to len(worker_endpoints) - 1"
                )
536 537 538 539
            self._current_id = current_id

        self._worker_num = len(self._worker_endpoints)

540 541 542
    def generate_role(self):
        self._role_is_generated = True

543 544 545 546 547 548 549 550 551 552 553
    def is_worker(self):
        return True

    def is_first_worker(self):
        return self._current_id == 0

    def worker_index(self):
        return self._current_id

    def worker_num(self):
        return self._worker_num