role_maker.py 10.7 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

T
tangwei12 已提交
15
from __future__ import print_function
16 17
from enum import Enum

T
tangwei12 已提交
18
__all__ = [
19 20
    'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker',
    'UserDefinedCollectiveRoleMaker'
T
tangwei12 已提交
21 22
]

23 24 25 26 27

class Role(Enum):
    WORKER = 1,
    SERVER = 2

D
dongdaxiang 已提交
28 29

class RoleMakerBase(object):
30 31 32 33 34 35 36
    """
    RoleMakerBase is a base class for assigning a role to current process
    in distributed training.
    A paddle developer can implement RoleMakerBase to design a role maker
    for worker or pserver assignment.
    """

D
dongdaxiang 已提交
37
    def __init__(self):
T
tangwei12 已提交
38 39
        self._worker_endpoints = []
        self._server_endpoints = []
D
dongdaxiang 已提交
40
        self._role_is_generated = False
T
tangwei12 已提交
41 42
        self._role = None
        self._current_id = -1
D
dongdaxiang 已提交
43

T
tangwei12 已提交
44
    def is_worker(self):
45 46 47
        """
        return is_worker() of current process
        """
D
dongdaxiang 已提交
48 49
        raise NotImplementedError("Please implement this method in child class")

T
tangwei12 已提交
50
    def is_server(self):
51 52 53
        """
        return is_server() of current process
        """
D
dongdaxiang 已提交
54 55
        raise NotImplementedError("Please implement this method in child class")

T
tangwei12 已提交
56
    def is_first_worker(self):
57
        """
T
tangwei12 已提交
58 59 60 61
        Check whether the node is the first instance of worker.
        Returns:
            bool: True if this is the first node of worker,
                  False if not.
62
        """
T
tangwei12 已提交
63
        raise NotImplementedError("Please implement this method in child class")
D
dongdaxiang 已提交
64

65 66 67 68 69 70 71 72 73
    def worker_num(self):
        """
        Get current total worker number.

        Returns:
            int: worker number
        """
        raise NotImplementedError("Please implement this method in child class")

T
tangwei12 已提交
74
    def worker_index(self):
75
        """
T
tangwei12 已提交
76 77 78 79
        Get current worker id.

        Returns:
            int: node id
80
        """
T
tangwei12 已提交
81
        raise NotImplementedError("Please implement this method in child class")
D
dongdaxiang 已提交
82

T
tangwei12 已提交
83
    def server_index(self):
84
        """
T
tangwei12 已提交
85 86 87 88
        Get current server id.

        Returns:
            int: node id
89
        """
T
tangwei12 已提交
90
        raise NotImplementedError("Please implement this method in child class")
D
dongdaxiang 已提交
91

T
tangwei12 已提交
92
    def get_trainer_endpoints(self):
93
        """
T
tangwei12 已提交
94
        return trainer endpoints
95
        """
T
tangwei12 已提交
96 97 98 99 100 101 102
        return self._worker_endpoints

    def get_pserver_endpoints(self):
        """
        return pserver endpoints
        """
        return self._server_endpoints
D
dongdaxiang 已提交
103 104 105


class MPIRoleMaker(RoleMakerBase):
106 107 108 109 110
    """
    MPIRoleMaker is a MPI-API based role maker which is a counter-part of K8SRoleMaker
    mpi4py will be used if a developer inherits MPIRoleMaker
    """

D
dongdaxiang 已提交
111
    def __init__(self):
X
xujiaqi01 已提交
112
        super(MPIRoleMaker, self).__init__()
D
dongdaxiang 已提交
113 114
        from mpi4py import MPI
        self.MPI = MPI
T
tangwei12 已提交
115 116
        self._comm = MPI.COMM_WORLD
        self._node_type_comm = None
D
dongdaxiang 已提交
117
        self._ips = None
T
tangwei12 已提交
118
        self._ip = None
D
dongdaxiang 已提交
119

120
    def _get_rank(self):
121 122 123
        """
        return rank
        """
D
dongdaxiang 已提交
124 125
        self._rank = self._comm.Get_rank()
        return self._rank
D
dongdaxiang 已提交
126

127
    def _get_size(self):
128 129 130
        """
        return size
        """
D
dongdaxiang 已提交
131 132
        self._size = self._comm.Get_size()
        return self._size
D
dongdaxiang 已提交
133

134
    def _all_gather(self, obj):
135 136 137
        """
        all_gather(obj) will call MPI's allgather function
        """
X
xjqbest 已提交
138
        self._barrier_all()
D
dongdaxiang 已提交
139
        return self._comm.allgather(obj)
D
dongdaxiang 已提交
140

X
xjqbest 已提交
141 142 143 144
    def _worker_gather(self, obj):
        """
        worker_gather(obj) will call MPI's allgather function
        """
T
tangwei12 已提交
145
        if self.is_worker():
D
dongdaxiang 已提交
146 147
            self._node_type_comm.barrier()
            return self._node_type_comm.allgather(obj)
X
xjqbest 已提交
148 149
        return None

150
    def _barrier_all(self):
151 152 153
        """
        barrier_all() will call MPI's barrier_all function
        """
D
dongdaxiang 已提交
154
        self._comm.barrier()
D
dongdaxiang 已提交
155

T
tangwei12 已提交
156 157 158 159 160 161
    def _finalize(self):
        """
        finalize the current MPI instance.
        """
        pass

162
    def _get_ips(self):
163 164 165
        """
        collect current distributed job's ip list
        """
T
tangwei12 已提交
166 167
        if not self._ips:
            self._ips = self._comm.allgather(self.get_local_ip())
D
dongdaxiang 已提交
168
        return self._ips
D
dongdaxiang 已提交
169

T
tangwei12 已提交
170
    def get_local_ip(self):
171
        """
T
tangwei12 已提交
172
        return get local ip
173
        """
T
tangwei12 已提交
174 175 176 177 178 179 180 181 182
        import socket
        self._ip = socket.gethostbyname(socket.gethostname())
        return self._ip

    def generate_role(self):
        """
        generate_role() should be called to identify current process's role
        """
        raise NotImplementedError("Please implement this method in child class")
D
dongdaxiang 已提交
183 184 185


class MPISymetricRoleMaker(MPIRoleMaker):
186 187 188 189 190 191
    """
    MPISymetricRoleMaker is designed for worker and server assignment
    under MPI. Typically, a worker and a server node will be appointed
    on each physical node. This role maker can be only used under MPI.
    """

D
dongdaxiang 已提交
192 193
    def __init__(self):
        super(MPISymetricRoleMaker, self).__init__()
D
dongdaxiang 已提交
194 195
        self._node_type = None
        self._proc_per_node = 2
D
dongdaxiang 已提交
196

197
    def _check_role_generation(self):
D
dongdaxiang 已提交
198
        if not self._role_is_generated:
T
tangwei12 已提交
199
            raise NameError("generate_role() should be called first")
200 201
        return True

T
tangwei12 已提交
202
    def is_first_worker(self):
203 204 205 206
        """
        return whether current process is the first worker assigned by role maker
        """
        if self._check_role_generation():
T
tangwei12 已提交
207
            return self.is_worker() and 0 == self.worker_index()
208
        return False
D
dongdaxiang 已提交
209

210 211 212
    def worker_num(self):
        return self._worker_num()

T
tangwei12 已提交
213
    def is_worker(self):
214 215 216 217
        """
        return whether current process is worker assigned by role maker
        """
        if self._check_role_generation():
D
dongdaxiang 已提交
218
            return self._node_type == 1
219
        return False
D
dongdaxiang 已提交
220

T
tangwei12 已提交
221
    def is_server(self):
222 223 224 225
        """
        return whether current process is server assigned by role maker
        """
        if self._check_role_generation():
D
dongdaxiang 已提交
226
            return self._node_type == 0
227
        return False
D
dongdaxiang 已提交
228

229
    def _worker_num(self):
230 231 232 233
        """
        return the current number of worker
        """
        if self._check_role_generation():
T
tangwei12 已提交
234
            if self.is_worker():
X
xjqbest 已提交
235
                return self._get_size() / 2
236
        return 0
D
dongdaxiang 已提交
237

238
    def _server_num(self):
239 240 241 242
        """
        return the current number of server
        """
        if self._check_role_generation():
T
tangwei12 已提交
243
            if self.is_server():
X
xjqbest 已提交
244
                return self._get_size() / 2
245
        return 0
D
dongdaxiang 已提交
246

T
tangwei12 已提交
247
    def worker_index(self):
248 249 250 251
        """
        return the index of worker
        """
        if self._check_role_generation():
D
dongdaxiang 已提交
252
            return self._rank / self._proc_per_node
253
        return 0
D
dongdaxiang 已提交
254

T
tangwei12 已提交
255
    def server_index(self):
256 257 258 259
        """
        return the index of server
        """
        if self._check_role_generation():
D
dongdaxiang 已提交
260
            return self._rank / self._proc_per_node
261
        return 0
D
dongdaxiang 已提交
262

263
    def _barrier_worker(self):
264 265 266 267
        """
        barrier all workers in current distributed job
        """
        if self._check_role_generation():
T
tangwei12 已提交
268
            if self.is_worker():
D
dongdaxiang 已提交
269
                self._node_type_comm.barrier()
D
dongdaxiang 已提交
270

271
    def _barrier_server(self):
272 273 274 275
        """
        barrier all servers in current distributed job
        """
        if self._check_role_generation():
T
tangwei12 已提交
276
            if self.is_server():
D
dongdaxiang 已提交
277
                self._node_type_comm.barrier()
D
dongdaxiang 已提交
278

T
tangwei12 已提交
279
    def generate_role(self):
280 281 282
        """
        generate currently process's role
        """
D
dongdaxiang 已提交
283
        if not self._role_is_generated:
284
            # TODO(guru4elephant): only allow to be called once
285 286
            self._worker_endpoints = self._get_ips()[1::2]
            self._server_endpoints = self._get_ips()[::2]
287

D
dongdaxiang 已提交
288 289
            if 0 == self._get_rank() % self._proc_per_node % 2:
                self._node_type = 0
290
            else:
D
dongdaxiang 已提交
291 292 293
                self._node_type = 1
            self._node_type_comm = self._comm.Split(self._node_type)
            self._role_is_generated = True
294 295 296 297 298


class UserDefinedRoleMaker(RoleMakerBase):
    def __init__(self,
                 current_id=0,
T
tangwei12 已提交
299 300 301
                 role=Role.WORKER,
                 worker_num=0,
                 server_endpoints=None):
302 303 304 305 306 307 308
        """
        UserDefinedRoleMaker is designed for worker and server assignment
        under manual. Typically, a worker and a server node will be appointed
        on each physical node, It can be assign by user.
        """
        super(UserDefinedRoleMaker, self).__init__()

309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
        if not isinstance(current_id, int):
            raise TypeError("current_id must be as int")
        else:
            if current_id < 0:
                raise ValueError("current_id must be gather or equal 0")
            self._current_id = current_id

        if not isinstance(role, Role):
            raise TypeError("role must be as Role")
        else:
            self._role = role

        if not isinstance(worker_num, int):
            raise TypeError("worker_num must be as int")
        else:
            if worker_num < 0:
                raise ValueError("worker_num must be gather or equal 0")
            self._worker_num = worker_num

        if not isinstance(server_endpoints, list):
            raise TypeError("server_endpoints must be as string list")
        else:
            self._server_endpoints = server_endpoints
T
tangwei12 已提交
332 333 334 335 336 337

    def is_worker(self):
        return self._role == Role.WORKER

    def is_server(self):
        return self._role == Role.SERVER
338

T
tangwei12 已提交
339 340
    def is_first_worker(self):
        return self._role == Role.WORKER and self._current_id == 0
341

T
tangwei12 已提交
342 343
    def worker_index(self):
        return self._current_id
344

T
tangwei12 已提交
345 346
    def server_index(self):
        return self._current_id
347 348 349

    def worker_num(self):
        return self._worker_num
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383


class UserDefinedCollectiveRoleMaker(RoleMakerBase):
    def __init__(self, current_id=0, worker_endpoints=None):
        """
        UserDefinedCollectiveRoleMaker is designed for worker assignment
        under manual for collective mode.
        """
        super(UserDefinedCollectiveRoleMaker, self).__init__()

        if not isinstance(current_id, int):
            raise TypeError("current_id must be as int")
        else:
            if current_id < 0:
                raise ValueError("current_id must be greater or equal 0")
            self._current_id = current_id

        if not isinstance(worker_endpoints, list):
            raise TypeError("worker_endpoints must be as string list")
        else:
            self._worker_endpoints = worker_endpoints
        self._worker_num = len(self._worker_endpoints)

    def is_worker(self):
        return True

    def is_first_worker(self):
        return self._current_id == 0

    def worker_index(self):
        return self._current_id

    def worker_num(self):
        return self._worker_num