role_maker.py 12.4 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

T
tangwei12 已提交
15
from __future__ import print_function
16

T
tangwei12 已提交
17
__all__ = [
18
    'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker',
T
tangwei12 已提交
19
    'UserDefinedCollectiveRoleMaker', 'PaddleCloudRoleMaker'
T
tangwei12 已提交
20 21
]

22

T
tangwei12 已提交
23 24
class Role:
    WORKER = 1
25 26
    SERVER = 2

D
dongdaxiang 已提交
27 28

class RoleMakerBase(object):
29 30 31 32 33 34 35
    """
    RoleMakerBase is a base class for assigning a role to current process
    in distributed training.
    A paddle developer can implement RoleMakerBase to design a role maker
    for worker or pserver assignment.
    """

D
dongdaxiang 已提交
36
    def __init__(self):
T
tangwei12 已提交
37 38
        self._worker_endpoints = []
        self._server_endpoints = []
D
dongdaxiang 已提交
39
        self._role_is_generated = False
T
tangwei12 已提交
40 41
        self._role = None
        self._current_id = -1
D
dongdaxiang 已提交
42

T
tangwei12 已提交
43
    def is_worker(self):
44 45 46
        """
        return is_worker() of current process
        """
D
dongdaxiang 已提交
47 48
        raise NotImplementedError("Please implement this method in child class")

T
tangwei12 已提交
49
    def is_server(self):
50 51 52
        """
        return is_server() of current process
        """
D
dongdaxiang 已提交
53 54
        raise NotImplementedError("Please implement this method in child class")

T
tangwei12 已提交
55
    def is_first_worker(self):
56
        """
T
tangwei12 已提交
57 58 59 60
        Check whether the node is the first instance of worker.
        Returns:
            bool: True if this is the first node of worker,
                  False if not.
61
        """
T
tangwei12 已提交
62
        raise NotImplementedError("Please implement this method in child class")
D
dongdaxiang 已提交
63

64 65 66 67 68 69 70 71 72
    def worker_num(self):
        """
        Get current total worker number.

        Returns:
            int: worker number
        """
        raise NotImplementedError("Please implement this method in child class")

T
tangwei12 已提交
73
    def worker_index(self):
74
        """
T
tangwei12 已提交
75 76 77 78
        Get current worker id.

        Returns:
            int: node id
79
        """
T
tangwei12 已提交
80
        raise NotImplementedError("Please implement this method in child class")
D
dongdaxiang 已提交
81

T
tangwei12 已提交
82
    def server_index(self):
83
        """
T
tangwei12 已提交
84 85 86 87
        Get current server id.

        Returns:
            int: node id
88
        """
T
tangwei12 已提交
89
        raise NotImplementedError("Please implement this method in child class")
D
dongdaxiang 已提交
90

T
tangwei12 已提交
91
    def get_trainer_endpoints(self):
92
        """
T
tangwei12 已提交
93
        return trainer endpoints
94
        """
T
tangwei12 已提交
95 96 97 98 99 100 101
        return self._worker_endpoints

    def get_pserver_endpoints(self):
        """
        return pserver endpoints
        """
        return self._server_endpoints
D
dongdaxiang 已提交
102 103 104


class MPIRoleMaker(RoleMakerBase):
105 106 107 108 109
    """
    MPIRoleMaker is a MPI-API based role maker which is a counter-part of K8SRoleMaker
    mpi4py will be used if a developer inherits MPIRoleMaker
    """

D
dongdaxiang 已提交
110
    def __init__(self):
X
xujiaqi01 已提交
111
        super(MPIRoleMaker, self).__init__()
D
dongdaxiang 已提交
112 113
        from mpi4py import MPI
        self.MPI = MPI
T
tangwei12 已提交
114 115
        self._comm = MPI.COMM_WORLD
        self._node_type_comm = None
D
dongdaxiang 已提交
116
        self._ips = None
T
tangwei12 已提交
117
        self._ip = None
D
dongdaxiang 已提交
118

119
    def _get_rank(self):
120 121 122
        """
        return rank
        """
D
dongdaxiang 已提交
123 124
        self._rank = self._comm.Get_rank()
        return self._rank
D
dongdaxiang 已提交
125

126
    def _get_size(self):
127 128 129
        """
        return size
        """
D
dongdaxiang 已提交
130 131
        self._size = self._comm.Get_size()
        return self._size
D
dongdaxiang 已提交
132

133
    def _all_gather(self, obj):
134 135 136
        """
        all_gather(obj) will call MPI's allgather function
        """
X
xjqbest 已提交
137
        self._barrier_all()
D
dongdaxiang 已提交
138
        return self._comm.allgather(obj)
D
dongdaxiang 已提交
139

X
xjqbest 已提交
140 141 142 143
    def _worker_gather(self, obj):
        """
        worker_gather(obj) will call MPI's allgather function
        """
T
tangwei12 已提交
144
        if self.is_worker():
D
dongdaxiang 已提交
145 146
            self._node_type_comm.barrier()
            return self._node_type_comm.allgather(obj)
X
xjqbest 已提交
147 148
        return None

149
    def _barrier_all(self):
150 151 152
        """
        barrier_all() will call MPI's barrier_all function
        """
D
dongdaxiang 已提交
153
        self._comm.barrier()
D
dongdaxiang 已提交
154

T
tangwei12 已提交
155 156 157 158 159 160
    def _finalize(self):
        """
        finalize the current MPI instance.
        """
        pass

161
    def _get_ips(self):
162 163 164
        """
        collect current distributed job's ip list
        """
T
tangwei12 已提交
165 166
        if not self._ips:
            self._ips = self._comm.allgather(self.get_local_ip())
D
dongdaxiang 已提交
167
        return self._ips
D
dongdaxiang 已提交
168

T
tangwei12 已提交
169
    def get_local_ip(self):
170
        """
T
tangwei12 已提交
171
        return get local ip
172
        """
T
tangwei12 已提交
173 174 175 176 177 178 179 180 181
        import socket
        self._ip = socket.gethostbyname(socket.gethostname())
        return self._ip

    def generate_role(self):
        """
        generate_role() should be called to identify current process's role
        """
        raise NotImplementedError("Please implement this method in child class")
D
dongdaxiang 已提交
182 183 184


class MPISymetricRoleMaker(MPIRoleMaker):
185 186 187 188 189 190
    """
    MPISymetricRoleMaker is designed for worker and server assignment
    under MPI. Typically, a worker and a server node will be appointed
    on each physical node. This role maker can be only used under MPI.
    """

D
dongdaxiang 已提交
191 192
    def __init__(self):
        super(MPISymetricRoleMaker, self).__init__()
D
dongdaxiang 已提交
193 194
        self._node_type = None
        self._proc_per_node = 2
D
dongdaxiang 已提交
195

196
    def _check_role_generation(self):
D
dongdaxiang 已提交
197
        if not self._role_is_generated:
T
tangwei12 已提交
198
            raise NameError("generate_role() should be called first")
199 200
        return True

T
tangwei12 已提交
201
    def is_first_worker(self):
202 203 204 205
        """
        return whether current process is the first worker assigned by role maker
        """
        if self._check_role_generation():
T
tangwei12 已提交
206
            return self.is_worker() and 0 == self.worker_index()
207
        return False
D
dongdaxiang 已提交
208

209 210 211
    def worker_num(self):
        return self._worker_num()

T
tangwei12 已提交
212
    def is_worker(self):
213 214 215 216
        """
        return whether current process is worker assigned by role maker
        """
        if self._check_role_generation():
D
dongdaxiang 已提交
217
            return self._node_type == 1
218
        return False
D
dongdaxiang 已提交
219

T
tangwei12 已提交
220
    def is_server(self):
221 222 223 224
        """
        return whether current process is server assigned by role maker
        """
        if self._check_role_generation():
D
dongdaxiang 已提交
225
            return self._node_type == 0
226
        return False
D
dongdaxiang 已提交
227

228
    def _worker_num(self):
229 230 231 232
        """
        return the current number of worker
        """
        if self._check_role_generation():
T
tangwei12 已提交
233
            if self.is_worker():
X
xjqbest 已提交
234
                return self._get_size() / 2
235
        return 0
D
dongdaxiang 已提交
236

237
    def _server_num(self):
238 239 240 241
        """
        return the current number of server
        """
        if self._check_role_generation():
T
tangwei12 已提交
242
            if self.is_server():
X
xjqbest 已提交
243
                return self._get_size() / 2
244
        return 0
D
dongdaxiang 已提交
245

T
tangwei12 已提交
246
    def worker_index(self):
247 248 249 250
        """
        return the index of worker
        """
        if self._check_role_generation():
D
dongdaxiang 已提交
251
            return self._rank / self._proc_per_node
252
        return 0
D
dongdaxiang 已提交
253

T
tangwei12 已提交
254
    def server_index(self):
255 256 257 258
        """
        return the index of server
        """
        if self._check_role_generation():
D
dongdaxiang 已提交
259
            return self._rank / self._proc_per_node
260
        return 0
D
dongdaxiang 已提交
261

262
    def _barrier_worker(self):
263 264 265 266
        """
        barrier all workers in current distributed job
        """
        if self._check_role_generation():
T
tangwei12 已提交
267
            if self.is_worker():
D
dongdaxiang 已提交
268
                self._node_type_comm.barrier()
D
dongdaxiang 已提交
269

270
    def _barrier_server(self):
271 272 273 274
        """
        barrier all servers in current distributed job
        """
        if self._check_role_generation():
T
tangwei12 已提交
275
            if self.is_server():
D
dongdaxiang 已提交
276
                self._node_type_comm.barrier()
D
dongdaxiang 已提交
277

T
tangwei12 已提交
278
    def generate_role(self):
279 280 281
        """
        generate currently process's role
        """
D
dongdaxiang 已提交
282
        if not self._role_is_generated:
283
            # TODO(guru4elephant): only allow to be called once
284 285
            self._worker_endpoints = self._get_ips()[1::2]
            self._server_endpoints = self._get_ips()[::2]
286

D
dongdaxiang 已提交
287 288
            if 0 == self._get_rank() % self._proc_per_node % 2:
                self._node_type = 0
289
            else:
D
dongdaxiang 已提交
290 291 292
                self._node_type = 1
            self._node_type_comm = self._comm.Split(self._node_type)
            self._role_is_generated = True
293 294


T
tangwei12 已提交
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
class PaddleCloudRoleMaker(RoleMakerBase):
    def __init__(self):
        super(PaddleCloudRoleMaker, self).__init__()

    def generate_role(self):
        if not self._role_is_generated:
            self.port = os.getenv("PADDLE_PORT", "6174")
            self.pserver_ips = os.getenv("PADDLE_PSERVERS", "")
            eplist = []
            for ip in pserver_ips.split(","):
                eplist.append(':'.join([ip, port]))
                self.endpoints = ",".join(eplist)
                self.trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
                self.current_endpoint = os.getenv("POD_IP",
                                                  "localhost") + ":" + port
                self.role = os.getenv("TRAINING_ROLE", "TRAINER")
                self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
            self.eplist = eplist
            self.endpoints = self.endpoints.split(",")
            if self.role.upper() == "PSERVER":
                self.current_id = self.endpoints.index(self.current_endpoint)
            else:
                self.current_id = self.trainer_id
            self._role_is_generated = True

    def is_wokrer(self):
        return self._role == Role.WORKER

    def is_server(self):
        return self._role == Role.SERVER

    def is_first_worker(self):
        return self._role == Role.WORKER and self._current_id == 0

    def worker_index(self):
        return self._current_id

    def server_index(self):
        return self._current_id

    def worker_num(self):
        return self._worker_num


339 340 341
class UserDefinedRoleMaker(RoleMakerBase):
    def __init__(self,
                 current_id=0,
T
tangwei12 已提交
342 343 344
                 role=Role.WORKER,
                 worker_num=0,
                 server_endpoints=None):
345 346 347 348 349 350 351
        """
        UserDefinedRoleMaker is designed for worker and server assignment
        under manual. Typically, a worker and a server node will be appointed
        on each physical node, It can be assign by user.
        """
        super(UserDefinedRoleMaker, self).__init__()

352 353 354 355 356 357 358
        if not isinstance(current_id, int):
            raise TypeError("current_id must be as int")
        else:
            if current_id < 0:
                raise ValueError("current_id must be gather or equal 0")
            self._current_id = current_id

T
tangwei12 已提交
359
        if role != Role.WORKER and role != Role.SERVER:
360 361 362 363 364 365 366 367 368 369 370 371 372 373 374
            raise TypeError("role must be as Role")
        else:
            self._role = role

        if not isinstance(worker_num, int):
            raise TypeError("worker_num must be as int")
        else:
            if worker_num < 0:
                raise ValueError("worker_num must be gather or equal 0")
            self._worker_num = worker_num

        if not isinstance(server_endpoints, list):
            raise TypeError("server_endpoints must be as string list")
        else:
            self._server_endpoints = server_endpoints
T
tangwei12 已提交
375

T
tangwei12 已提交
376 377 378
    def generate_role(self):
        self._role_is_generated = True

T
tangwei12 已提交
379 380 381 382 383
    def is_worker(self):
        return self._role == Role.WORKER

    def is_server(self):
        return self._role == Role.SERVER
384

T
tangwei12 已提交
385 386
    def is_first_worker(self):
        return self._role == Role.WORKER and self._current_id == 0
387

T
tangwei12 已提交
388 389
    def worker_index(self):
        return self._current_id
390

T
tangwei12 已提交
391 392
    def server_index(self):
        return self._current_id
393 394 395

    def worker_num(self):
        return self._worker_num
396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418


class UserDefinedCollectiveRoleMaker(RoleMakerBase):
    def __init__(self, current_id=0, worker_endpoints=None):
        """
        UserDefinedCollectiveRoleMaker is designed for worker assignment
        under manual for collective mode.
        """
        super(UserDefinedCollectiveRoleMaker, self).__init__()

        if not isinstance(current_id, int):
            raise TypeError("current_id must be as int")
        else:
            if current_id < 0:
                raise ValueError("current_id must be greater or equal 0")
            self._current_id = current_id

        if not isinstance(worker_endpoints, list):
            raise TypeError("worker_endpoints must be as string list")
        else:
            self._worker_endpoints = worker_endpoints
        self._worker_num = len(self._worker_endpoints)

T
tangwei12 已提交
419 420 421
    def generate_role(self):
        self._role_is_generated = True

422 423 424 425 426 427 428 429 430 431 432
    def is_worker(self):
        return True

    def is_first_worker(self):
        return self._current_id == 0

    def worker_index(self):
        return self._current_id

    def worker_num(self):
        return self._worker_num