role_maker.py 7.9 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
import sys
D
dongdaxiang 已提交
15

16 17 18 19 20 21 22
from enum import Enum


class Role(Enum):
    WORKER = 1,
    SERVER = 2

D
dongdaxiang 已提交
23 24

class RoleMakerBase(object):
25 26 27 28 29 30 31
    """
    RoleMakerBase is a base class for assigning a role to current process
    in distributed training.
    A paddle developer can implement RoleMakerBase to design a role maker
    for worker or pserver assignment.
    """

D
dongdaxiang 已提交
32
    def __init__(self):
D
dongdaxiang 已提交
33 34 35
        self._trainer_endpoints = []
        self._pserver_endpoints = []
        self._role_is_generated = False
D
dongdaxiang 已提交
36

37
    def _is_worker(self):
38 39 40
        """
        return is_worker() of current process
        """
D
dongdaxiang 已提交
41 42
        raise NotImplementedError("Please implement this method in child class")

43
    def _is_server(self):
44 45 46
        """
        return is_server() of current process
        """
D
dongdaxiang 已提交
47 48
        raise NotImplementedError("Please implement this method in child class")

49
    def _get_local_ip(self):
50 51 52
        """
        return get local ip
        """
D
dongdaxiang 已提交
53
        import socket
D
dongdaxiang 已提交
54 55
        self._ip = socket.gethostbyname(socket.gethostname())
        return self._ip
D
dongdaxiang 已提交
56

57
    def _get_trainer_endpoints(self):
58 59 60
        """
        return trainer endpoints
        """
D
dongdaxiang 已提交
61
        return self._trainer_endpoints
D
dongdaxiang 已提交
62

63
    def _get_pserver_endpoints(self):
64 65 66
        """
        return pserver endpoints
        """
D
dongdaxiang 已提交
67
        return self._pserver_endpoints
D
dongdaxiang 已提交
68

69
    def _generate_role(self):
70 71 72
        """
        generate_role() should be called to identify current process's role
        """
D
dongdaxiang 已提交
73 74 75 76
        raise NotImplementedError("Please implement this method in child class")


class MPIRoleMaker(RoleMakerBase):
77 78 79 80 81
    """
    MPIRoleMaker is a MPI-API based role maker which is a counter-part of K8SRoleMaker
    mpi4py will be used if a developer inherits MPIRoleMaker
    """

D
dongdaxiang 已提交
82
    def __init__(self):
X
xujiaqi01 已提交
83
        super(MPIRoleMaker, self).__init__()
D
dongdaxiang 已提交
84
        from mpi4py import MPI
D
dongdaxiang 已提交
85
        self._comm = MPI.COMM_WORLD
D
dongdaxiang 已提交
86
        self.MPI = MPI
D
dongdaxiang 已提交
87
        self._ips = None
D
dongdaxiang 已提交
88

89
    def _get_rank(self):
90 91 92
        """
        return rank
        """
D
dongdaxiang 已提交
93 94
        self._rank = self._comm.Get_rank()
        return self._rank
D
dongdaxiang 已提交
95

96
    def _get_size(self):
97 98 99
        """
        return size
        """
D
dongdaxiang 已提交
100 101
        self._size = self._comm.Get_size()
        return self._size
D
dongdaxiang 已提交
102

103
    def _all_gather(self, obj):
104 105 106
        """
        all_gather(obj) will call MPI's allgather function
        """
X
xjqbest 已提交
107
        self._barrier_all()
D
dongdaxiang 已提交
108
        return self._comm.allgather(obj)
D
dongdaxiang 已提交
109

X
xjqbest 已提交
110 111 112 113 114
    def _worker_gather(self, obj):
        """
        worker_gather(obj) will call MPI's allgather function
        """
        if self._is_worker():
D
dongdaxiang 已提交
115 116
            self._node_type_comm.barrier()
            return self._node_type_comm.allgather(obj)
X
xjqbest 已提交
117 118
        return None

119
    def _barrier_all(self):
120 121 122
        """
        barrier_all() will call MPI's barrier_all function
        """
D
dongdaxiang 已提交
123
        self._comm.barrier()
D
dongdaxiang 已提交
124

125
    def _get_ips(self):
126 127 128
        """
        collect current distributed job's ip list
        """
D
dongdaxiang 已提交
129 130 131
        if self._ips == None:
            self._ips = self._comm.allgather(self._get_local_ip())
        return self._ips
D
dongdaxiang 已提交
132

133
    def _finalize(self):
134 135 136
        """
        finalize the current MPI instance.
        """
137
        pass
D
dongdaxiang 已提交
138 139 140


class MPISymetricRoleMaker(MPIRoleMaker):
141 142 143 144 145 146
    """
    MPISymetricRoleMaker is designed for worker and server assignment
    under MPI. Typically, a worker and a server node will be appointed
    on each physical node. This role maker can be only used under MPI.
    """

D
dongdaxiang 已提交
147 148
    def __init__(self):
        super(MPISymetricRoleMaker, self).__init__()
D
dongdaxiang 已提交
149 150
        self._node_type = None
        self._proc_per_node = 2
D
dongdaxiang 已提交
151

152
    def _check_role_generation(self):
D
dongdaxiang 已提交
153
        if not self._role_is_generated:
154 155 156 157 158
            sys.stderr.write("generate_role() should be called first")
            sys.exit(-1)
            return False
        return True

159
    def _is_first_worker(self):
160 161 162 163
        """
        return whether current process is the first worker assigned by role maker
        """
        if self._check_role_generation():
X
xjqbest 已提交
164
            return self._is_worker() and 0 == self._worker_index()
165
        return False
D
dongdaxiang 已提交
166

167
    def _is_worker(self):
168 169 170 171
        """
        return whether current process is worker assigned by role maker
        """
        if self._check_role_generation():
D
dongdaxiang 已提交
172
            return self._node_type == 1
173
        return False
D
dongdaxiang 已提交
174

175
    def _is_server(self):
176 177 178 179
        """
        return whether current process is server assigned by role maker
        """
        if self._check_role_generation():
D
dongdaxiang 已提交
180
            return self._node_type == 0
181
        return False
D
dongdaxiang 已提交
182

183
    def _worker_num(self):
184 185 186 187
        """
        return the current number of worker
        """
        if self._check_role_generation():
X
xjqbest 已提交
188 189
            if self._is_worker():
                return self._get_size() / 2
190
        return 0
D
dongdaxiang 已提交
191

192
    def _server_num(self):
193 194 195 196
        """
        return the current number of server
        """
        if self._check_role_generation():
X
xjqbest 已提交
197 198
            if self._is_server():
                return self._get_size() / 2
199
        return 0
D
dongdaxiang 已提交
200

201
    def _worker_index(self):
202 203 204 205
        """
        return the index of worker
        """
        if self._check_role_generation():
D
dongdaxiang 已提交
206
            return self._rank / self._proc_per_node
207
        return 0
D
dongdaxiang 已提交
208

209
    def _server_index(self):
210 211 212 213
        """
        return the index of server
        """
        if self._check_role_generation():
D
dongdaxiang 已提交
214
            return self._rank / self._proc_per_node
215
        return 0
D
dongdaxiang 已提交
216

217
    def _barrier_worker(self):
218 219 220 221
        """
        barrier all workers in current distributed job
        """
        if self._check_role_generation():
X
xjqbest 已提交
222
            if self._is_worker():
D
dongdaxiang 已提交
223
                self._node_type_comm.barrier()
D
dongdaxiang 已提交
224

225
    def _barrier_server(self):
226 227 228 229
        """
        barrier all servers in current distributed job
        """
        if self._check_role_generation():
X
xjqbest 已提交
230
            if self._is_server():
D
dongdaxiang 已提交
231
                self._node_type_comm.barrier()
D
dongdaxiang 已提交
232

233
    def _generate_role(self):
234 235 236
        """
        generate currently process's role
        """
D
dongdaxiang 已提交
237
        if not self._role_is_generated:
238
            # TODO(guru4elephant): only allow to be called once
D
dongdaxiang 已提交
239 240
            self._trainer_endpoints = self._get_ips()
            self._pserver_endpoints = self._get_ips()
241

D
dongdaxiang 已提交
242 243
            if 0 == self._get_rank() % self._proc_per_node % 2:
                self._node_type = 0
244
            else:
D
dongdaxiang 已提交
245 246 247
                self._node_type = 1
            self._node_type_comm = self._comm.Split(self._node_type)
            self._role_is_generated = True
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281


class UserDefinedRoleMaker(RoleMakerBase):
    def __init__(self,
                 current_id=0,
                 current_endpoint=None,
                 workers=0,
                 worker_endpoints=None,
                 servers=0,
                 server_endpoints=None,
                 role=Role.WORKER):
        """
        UserDefinedRoleMaker is designed for worker and server assignment
        under manual. Typically, a worker and a server node will be appointed
        on each physical node, It can be assign by user.
        """
        super(UserDefinedRoleMaker, self).__init__()

        self.current_id = current_id
        self.current_endpoint = current_endpoint
        self.workers = workers
        self.worker_endpoints = worker_endpoints
        self.servers = servers
        self.server_endpoints = server_endpoints
        self.role = role

    def _is_worker(self):
        return self.role == Role.WORKER

    def _is_server(self):
        return self.role == Role.SERVER

    def _generate_role(self):
        self.role_is_generated_ = True