client.py 15.1 KB
Newer Older
F
fuyw 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cloudpickle
F
fuyw 已提交
16
import datetime
F
fuyw 已提交
17
import os
F
fuyw 已提交
18 19
import socket
import sys
F
fuyw 已提交
20 21 22 23
import threading
import zmq
from parl.utils import to_str, to_byte, get_ip_address, logger
from parl.remote import remote_constants
B
Bo Zhou 已提交
24
import time
F
fuyw 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38


class Client(object):
    """Base class for the remote client.

    For each training task, there is a global client in the cluster which
    submits jobs to the master node. Different `@parl.remote_class` objects
    connect to the same global client in a training task.

    Attributes:
        submit_job_socket (zmq.Context.socket): A socket which submits job to
                                                the master node.
        pyfiles (bytes): A serialized dictionary containing the code of python
                         files in local working directory.
F
fuyw 已提交
39 40
        executable_path (str): File path of the executable python script.
        start_time (time): A timestamp to record the start time of the program.
F
fuyw 已提交
41 42 43

    """

44
    def __init__(self, master_address, process_id, distributed_files=[]):
F
fuyw 已提交
45 46 47
        """
        Args:
            master_addr (str): ip address of the master node.
48 49
            process_id (str): id of the process that created the Client. 
                              Should use os.getpid() to get the process id.
50 51 52 53
            distributed_files (list): A list of files to be distributed at all
                                      remote instances(e,g. the configuration
                                      file for initialization) .

F
fuyw 已提交
54
        """
55 56
        self.master_address = master_address
        self.process_id = process_id
F
fuyw 已提交
57 58 59 60 61 62
        self.ctx = zmq.Context()
        self.lock = threading.Lock()
        self.heartbeat_socket_initialized = threading.Event()
        self.master_is_alive = True
        self.client_is_alive = True

F
fuyw 已提交
63 64 65 66
        self.executable_path = self.get_executable_path()

        self.actor_num = 0

F
fuyw 已提交
67
        self._create_sockets(master_address)
68
        self.pyfiles = self.read_local_files(distributed_files)
F
fuyw 已提交
69

F
fuyw 已提交
70 71 72 73 74 75 76 77 78 79
    def get_executable_path(self):
        """Return current executable path."""
        mod = sys.modules['__main__']
        if hasattr(mod, '__file__'):
            executable_path = os.path.abspath(mod.__file__)
        else:
            executable_path = os.getcwd()
        executable_path = executable_path[:executable_path.rfind('/')]
        return executable_path

80
    def read_local_files(self, distributed_files=[]):
F
fuyw 已提交
81 82 83
        """Read local python code and store them in a dictionary, which will
        then be sent to the job.

84 85 86 87 88
        Args:
            distributed_files (list): A list of files to be distributed at all
                                      remote instances(e,g. the configuration
                                      file for initialization) .

F
fuyw 已提交
89 90 91 92 93
        Returns:
            A cloudpickled dictionary containing the python code in current
            working directory.
        """
        pyfiles = dict()
94 95
        pyfiles['python_files'] = {}
        pyfiles['other_files'] = {}
96 97 98

        code_files = filter(lambda x: x.endswith('.py'), os.listdir('./'))

99 100
        try:
            for file in code_files:
F
fuyw 已提交
101 102 103
                assert os.path.exists(file)
                with open(file, 'rb') as code_file:
                    code = code_file.read()
104 105 106 107 108 109 110
                    pyfiles['python_files'][file] = code

            for file in distributed_files:
                assert os.path.exists(file)
                with open(file, 'rb') as f:
                    content = f.read()
                    pyfiles['other_files'][file] = content
111 112 113 114 115 116 117
            # append entry file to code list
            main_file = sys.argv[0]
            with open(main_file, 'rb') as code_file:
                code = code_file.read()
                # parl/remote/remote_decorator.py -> remote_decorator.py
                file_name = main_file.split(os.sep)[-1]
                pyfiles['python_files'][file_name] = code
118 119 120 121
        except AssertionError as e:
            raise Exception(
                'Failed to create the client, the file {} does not exist.'.
                format(file))
F
fuyw 已提交
122 123 124 125 126 127 128 129 130 131 132 133 134 135
        return cloudpickle.dumps(pyfiles)

    def _create_sockets(self, master_address):
        """ Each client has 1 sockets as start:

        (1) submit_job_socket: submits jobs to master node.
        """

        # submit_job_socket: submits job to master
        self.submit_job_socket = self.ctx.socket(zmq.REQ)
        self.submit_job_socket.linger = 0
        self.submit_job_socket.setsockopt(
            zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
        self.submit_job_socket.connect("tcp://{}".format(master_address))
F
fuyw 已提交
136
        self.start_time = time.time()
B
Bo Zhou 已提交
137 138
        thread = threading.Thread(target=self._reply_heartbeat)
        thread.setDaemon(True)
F
fuyw 已提交
139 140 141 142 143 144 145
        thread.start()
        self.heartbeat_socket_initialized.wait()

        # check if the master is connected properly
        try:
            self.submit_job_socket.send_multipart([
                remote_constants.CLIENT_CONNECT_TAG,
F
fuyw 已提交
146 147
                to_byte(self.heartbeat_master_address),
                to_byte(socket.gethostname())
F
fuyw 已提交
148 149 150 151 152 153 154 155 156 157 158 159
            ])
            _ = self.submit_job_socket.recv_multipart()
        except zmq.error.Again as e:
            logger.warning("[Client] Can not connect to the master, please "
                           "check if master is started and ensure the input "
                           "address {} is correct.".format(master_address))
            self.master_is_alive = False
            raise Exception("Client can not connect to the master, please "
                            "check if master is started and ensure the input "
                            "address {} is correct.".format(master_address))

    def _reply_heartbeat(self):
B
Bo Zhou 已提交
160
        """Reply heartbeat signals to the specific node."""
F
fuyw 已提交
161 162 163 164 165 166 167 168 169 170

        socket = self.ctx.socket(zmq.REP)
        socket.linger = 0
        socket.setsockopt(zmq.RCVTIMEO,
                          remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
        heartbeat_master_port =\
            socket.bind_to_random_port(addr="tcp://*")
        self.heartbeat_master_address = "{}:{}".format(get_ip_address(),
                                                       heartbeat_master_port)
        self.heartbeat_socket_initialized.set()
171
        connected = False
F
fuyw 已提交
172 173 174
        while self.client_is_alive and self.master_is_alive:
            try:
                message = socket.recv_multipart()
F
fuyw 已提交
175 176 177 178 179 180 181 182
                elapsed_time = datetime.timedelta(
                    seconds=int(time.time() - self.start_time))
                socket.send_multipart([
                    remote_constants.HEARTBEAT_TAG,
                    to_byte(self.executable_path),
                    to_byte(str(self.actor_num)),
                    to_byte(str(elapsed_time))
                ])
183
                connected = True
F
fuyw 已提交
184
            except zmq.error.Again as e:
185 186 187 188 189 190 191 192
                if connected:
                    logger.warning("[Client] Cannot connect to the master."
                                   "Please check if it is still alive.")
                else:
                    logger.warning(
                        "[Client] Cannot connect to the master."
                        "Please check the firewall between client and master.(e.g., ping the master IP)"
                    )
F
fuyw 已提交
193 194 195 196
                self.master_is_alive = False
        socket.close(0)
        logger.warning("Client exit replying heartbeat for master.")

B
Bo Zhou 已提交
197
    def _check_and_monitor_job(self, job_heartbeat_address,
198
                               ping_heartbeat_address, max_memory):
B
Bo Zhou 已提交
199 200 201 202 203 204 205 206 207 208
        """ Sometimes the client may receive a job that is dead, thus 
        we have to check if this job is still alive before sending it to the actor.
        """
        # job_heartbeat_socket: sends heartbeat signal to job
        job_heartbeat_socket = self.ctx.socket(zmq.REQ)
        job_heartbeat_socket.linger = 0
        job_heartbeat_socket.setsockopt(zmq.RCVTIMEO, int(0.9 * 1000))
        job_heartbeat_socket.connect("tcp://" + ping_heartbeat_address)
        try:
            job_heartbeat_socket.send_multipart(
209 210
                [remote_constants.HEARTBEAT_TAG,
                 to_byte(str(max_memory))])
B
Bo Zhou 已提交
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
            job_heartbeat_socket.recv_multipart()
        except zmq.error.Again:
            job_heartbeat_socket.close(0)
            logger.error(
                "[Client] connects to a finished job, will try again, ping_heartbeat_address:{}"
                .format(ping_heartbeat_address))
            return False
        job_heartbeat_socket.disconnect("tcp://" + ping_heartbeat_address)
        job_heartbeat_socket.connect("tcp://" + job_heartbeat_address)
        job_heartbeat_socket.setsockopt(
            zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)

        # a thread for sending heartbeat signals to job
        thread = threading.Thread(
            target=self._create_job_monitor, args=(job_heartbeat_socket, ))
        thread.setDaemon(True)
        thread.start()
        return True

    def _create_job_monitor(self, job_heartbeat_socket):
        """Send heartbeat signals to check target's status"""

        job_is_alive = True
        while job_is_alive and self.client_is_alive:
            try:
                job_heartbeat_socket.send_multipart(
                    [remote_constants.HEARTBEAT_TAG])
238 239 240 241 242 243 244 245
                job_message = job_heartbeat_socket.recv_multipart()
                stop_job = to_str(job_message[1])
                job_address = to_str(job_message[2])

                if stop_job == 'True':
                    logger.error(
                        'Job {} exceeds max memory usage, will stop this job.'.
                        format(job_address))
246
                    self.lock.acquire()
247
                    self.actor_num -= 1
248
                    self.lock.release()
249 250 251
                    job_is_alive = False
                else:
                    time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)
B
Bo Zhou 已提交
252 253 254

            except zmq.error.Again as e:
                job_is_alive = False
255
                self.lock.acquire()
F
fuyw 已提交
256
                self.actor_num -= 1
257 258 259
                logger.error(
                    '[xparl] lost connection with a job, current actor num: {}'
                    .format(self.actor_num))
260
                self.lock.release()
B
Bo Zhou 已提交
261 262 263 264 265 266

            except zmq.error.ZMQError as e:
                break

        job_heartbeat_socket.close(0)

267
    def submit_job(self, max_memory):
F
fuyw 已提交
268 269 270 271
        """Send a job to the Master node.

        When a `@parl.remote_class` object is created, the global client
        sends a job to the master node. Then the master node will allocate
B
Bo Zhou 已提交
272
        a vacant job from its job pool to the remote object.
F
fuyw 已提交
273

274 275 276 277 278
        Args:
            max_memory (float): Maximum memory (MB) can be used by each remote
                                instance, the unit is in MB and default value is
                                none(unlimited).

F
fuyw 已提交
279
        Returns:
B
Bo Zhou 已提交
280
            job_address(str): IP address of the job. None if there is no available CPU in the cluster.
F
fuyw 已提交
281 282 283
        """
        if self.master_is_alive:

B
Bo Zhou 已提交
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
            while True:
                # A lock to prevent multiple actors from submitting job at the same time.
                self.lock.acquire()
                self.submit_job_socket.send_multipart([
                    remote_constants.CLIENT_SUBMIT_TAG,
                    to_byte(self.heartbeat_master_address)
                ])
                message = self.submit_job_socket.recv_multipart()
                self.lock.release()

                tag = message[0]

                if tag == remote_constants.NORMAL_TAG:
                    job_address = to_str(message[1])
                    job_heartbeat_address = to_str(message[2])
                    ping_heartbeat_address = to_str(message[3])

                    check_result = self._check_and_monitor_job(
302 303
                        job_heartbeat_address, ping_heartbeat_address,
                        max_memory)
B
Bo Zhou 已提交
304
                    if check_result:
305
                        self.lock.acquire()
F
fuyw 已提交
306
                        self.actor_num += 1
307
                        self.lock.release()
B
Bo Zhou 已提交
308 309 310 311 312 313 314 315 316 317
                        return job_address

                # no vacant CPU resources, cannot submit a new job
                elif tag == remote_constants.CPU_TAG:
                    job_address = None
                    # wait 1 second to avoid requesting in a high frequency.
                    time.sleep(1)
                    return job_address
                else:
                    raise NotImplementedError
F
fuyw 已提交
318 319 320
        else:
            raise Exception("Client can not submit job to the master, "
                            "please check if master is connected.")
B
Bo Zhou 已提交
321
        return None
F
fuyw 已提交
322 323 324 325 326


GLOBAL_CLIENT = None


327
def connect(master_address, distributed_files=[]):
F
fuyw 已提交
328 329 330 331 332 333 334 335
    """Create a global client which connects to the master node.

    .. code-block:: python

        parl.connect(master_address='localhost:1234')

    Args:
        master_address (str): The address of the Master node to connect to.
336 337 338
        distributed_files (list): A list of files to be distributed at all 
                                  remote instances(e,g. the configuration
                                  file for initialization) .
F
fuyw 已提交
339 340 341 342 343

    Raises:
        Exception: An exception is raised if the master node is not started.
    """

344
    assert len(master_address.split(":")) == 2, "Please input address in " +\
F
fuyw 已提交
345 346
        "{ip}:{port} format"
    global GLOBAL_CLIENT
347
    addr = master_address.split(":")[0]
348
    cur_process_id = os.getpid()
F
fuyw 已提交
349
    if GLOBAL_CLIENT is None:
350 351
        GLOBAL_CLIENT = Client(master_address, cur_process_id,
                               distributed_files)
352 353
    else:
        if GLOBAL_CLIENT.process_id != cur_process_id:
354 355
            GLOBAL_CLIENT = Client(master_address, cur_process_id,
                                   distributed_files)
F
fuyw 已提交
356 357 358 359 360


def get_global_client():
    """Get the global client.

361 362
    To support process-based programming, we will create a new global client in the new process.

F
fuyw 已提交
363 364 365 366 367 368 369
    Returns:
        The global client.
    """
    global GLOBAL_CLIENT
    assert GLOBAL_CLIENT is not None, "Cannot get the client to submit the" +\
        " job, have you connected to the cluster by calling " +\
        "parl.connect(master_ip, master_port)?"
370 371 372 373

    cur_process_id = os.getpid()
    if GLOBAL_CLIENT.process_id != cur_process_id:
        GLOBAL_CLIENT = Client(GLOBAL_CLIENT.master_address, cur_process_id)
F
fuyw 已提交
374 375 376 377 378 379
    return GLOBAL_CLIENT


def disconnect():
    """Disconnect the global client from the master node."""
    global GLOBAL_CLIENT
B
Bo Zhou 已提交
380 381 382 383 384
    if GLOBAL_CLIENT is not None:
        GLOBAL_CLIENT.client_is_alive = False
        GLOBAL_CLIENT = None
    else:
        logger.info(
385
            "No client to be released. Please make sure that you have called `parl.connect`"
B
Bo Zhou 已提交
386
        )