client.py 17.3 KB
Newer Older
F
fuyw 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cloudpickle
F
fuyw 已提交
16
import datetime
F
fuyw 已提交
17
import os
F
fuyw 已提交
18 19
import socket
import sys
F
fuyw 已提交
20 21
import threading
import zmq
22
import parl
B
Bo Zhou 已提交
23
from parl.utils import to_str, to_byte, get_ip_address, logger, isnotebook
F
fuyw 已提交
24
from parl.remote import remote_constants
B
Bo Zhou 已提交
25
import time
26
import glob
F
fuyw 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40


class Client(object):
    """Base class for the remote client.

    For each training task, there is a global client in the cluster which
    submits jobs to the master node. Different `@parl.remote_class` objects
    connect to the same global client in a training task.

    Attributes:
        submit_job_socket (zmq.Context.socket): A socket which submits job to
                                                the master node.
        pyfiles (bytes): A serialized dictionary containing the code of python
                         files in local working directory.
F
fuyw 已提交
41 42
        executable_path (str): File path of the executable python script.
        start_time (time): A timestamp to record the start time of the program.
F
fuyw 已提交
43 44 45

    """

46
    def __init__(self, master_address, process_id, distributed_files=[]):
F
fuyw 已提交
47 48 49
        """
        Args:
            master_addr (str): ip address of the master node.
50 51
            process_id (str): id of the process that created the Client. 
                              Should use os.getpid() to get the process id.
52 53 54
            distributed_files (list): A list of files to be distributed at all
                                      remote instances(e,g. the configuration
                                      file for initialization) .
F
fuyw 已提交
55
        """
56 57
        self.master_address = master_address
        self.process_id = process_id
F
fuyw 已提交
58 59 60 61 62
        self.ctx = zmq.Context()
        self.lock = threading.Lock()
        self.heartbeat_socket_initialized = threading.Event()
        self.master_is_alive = True
        self.client_is_alive = True
63
        self.log_monitor_url = None
F
fuyw 已提交
64

F
fuyw 已提交
65 66 67 68
        self.executable_path = self.get_executable_path()

        self.actor_num = 0

F
fuyw 已提交
69
        self._create_sockets(master_address)
70
        self.check_version()
71
        self.pyfiles = self.read_local_files(distributed_files)
F
fuyw 已提交
72

F
fuyw 已提交
73 74 75 76 77 78 79 80 81 82
    def get_executable_path(self):
        """Return current executable path."""
        mod = sys.modules['__main__']
        if hasattr(mod, '__file__'):
            executable_path = os.path.abspath(mod.__file__)
        else:
            executable_path = os.getcwd()
        executable_path = executable_path[:executable_path.rfind('/')]
        return executable_path

83
    def read_local_files(self, distributed_files=[]):
F
fuyw 已提交
84 85 86
        """Read local python code and store them in a dictionary, which will
        then be sent to the job.

87 88 89
        Args:
            distributed_files (list): A list of files to be distributed at all
                                      remote instances(e,g. the configuration
90 91 92 93 94
                                      file for initialization) . RegExp of file
                                      names is supported. 
                                      e.g. 
                                          distributed_files = ['./*.npy', './test*']
                                                                             
F
fuyw 已提交
95 96 97 98
        Returns:
            A cloudpickled dictionary containing the python code in current
            working directory.
        """
99 100 101 102 103 104 105 106 107 108 109 110 111

        parsed_distributed_files = set()
        for distributed_file in distributed_files:
            parsed_list = glob.glob(distributed_file)
            if not parsed_list:
                raise ValueError(
                    "no local file is matched with '{}', please check your input"
                    .format(distributed_file))
            # exclude the directiories
            for pathname in parsed_list:
                if not os.path.isdir(pathname):
                    parsed_distributed_files.add(pathname)

F
fuyw 已提交
112
        pyfiles = dict()
113 114
        pyfiles['python_files'] = {}
        pyfiles['other_files'] = {}
115

B
Bo Zhou 已提交
116 117 118 119 120 121 122 123
        if isnotebook():
            main_folder = './'
        else:
            main_file = sys.argv[0]
            main_folder = './'
            sep = os.sep
            if sep in main_file:
                main_folder = sep.join(main_file.split(sep)[:-1])
B
Bo Zhou 已提交
124 125 126 127 128 129 130
        code_files = filter(lambda x: x.endswith('.py'),
                            os.listdir(main_folder))

        for file_name in code_files:
            file_path = os.path.join(main_folder, file_name)
            assert os.path.exists(file_path)
            with open(file_path, 'rb') as code_file:
131 132
                code = code_file.read()
                pyfiles['python_files'][file_name] = code
B
Bo Zhou 已提交
133

134
        for file_name in parsed_distributed_files:
B
Bo Zhou 已提交
135 136 137 138 139 140 141
            assert os.path.exists(file_name)
            assert not os.path.isabs(
                file_name
            ), "[XPARL] Please do not distribute a file with absolute path."
            with open(file_name, 'rb') as f:
                content = f.read()
                pyfiles['other_files'][file_name] = content
F
fuyw 已提交
142 143 144 145 146 147 148 149 150 151 152 153 154 155
        return cloudpickle.dumps(pyfiles)

    def _create_sockets(self, master_address):
        """ Each client has 1 sockets as start:

        (1) submit_job_socket: submits jobs to master node.
        """

        # submit_job_socket: submits job to master
        self.submit_job_socket = self.ctx.socket(zmq.REQ)
        self.submit_job_socket.linger = 0
        self.submit_job_socket.setsockopt(
            zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
        self.submit_job_socket.connect("tcp://{}".format(master_address))
F
fuyw 已提交
156
        self.start_time = time.time()
B
Bo Zhou 已提交
157 158
        thread = threading.Thread(target=self._reply_heartbeat)
        thread.setDaemon(True)
F
fuyw 已提交
159 160 161
        thread.start()
        self.heartbeat_socket_initialized.wait()

162 163 164
        self.client_id = self.reply_master_heartbeat_address.replace(':', '_') + \
                            '_' + str(int(time.time()))

F
fuyw 已提交
165 166 167 168
        # check if the master is connected properly
        try:
            self.submit_job_socket.send_multipart([
                remote_constants.CLIENT_CONNECT_TAG,
169 170 171
                to_byte(self.reply_master_heartbeat_address),
                to_byte(socket.gethostname()),
                to_byte(self.client_id),
F
fuyw 已提交
172
            ])
173 174
            message = self.submit_job_socket.recv_multipart()
            self.log_monitor_url = to_str(message[1])
F
fuyw 已提交
175 176 177 178 179 180 181 182 183
        except zmq.error.Again as e:
            logger.warning("[Client] Can not connect to the master, please "
                           "check if master is started and ensure the input "
                           "address {} is correct.".format(master_address))
            self.master_is_alive = False
            raise Exception("Client can not connect to the master, please "
                            "check if master is started and ensure the input "
                            "address {} is correct.".format(master_address))

184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
    def check_version(self):
        '''Verify that the parl & python version in 'client' process matches that of the 'master' process'''
        self.submit_job_socket.send_multipart(
            [remote_constants.CHECK_VERSION_TAG])
        message = self.submit_job_socket.recv_multipart()
        tag = message[0]
        if tag == remote_constants.NORMAL_TAG:
            client_parl_version = parl.__version__
            client_python_version = str(sys.version_info.major)
            assert client_parl_version == to_str(message[1]) and client_python_version == to_str(message[2]),\
                '''Version mismatch: the 'master' is of version 'parl={}, python={}'. However, 
                'parl={}, python={}'is provided in your environment.'''.format(
                        to_str(message[1]), to_str(message[2]),
                        client_parl_version, client_python_version
                    )
        else:
            raise NotImplementedError

F
fuyw 已提交
202
    def _reply_heartbeat(self):
203
        """Reply heartbeat signals to the master node."""
F
fuyw 已提交
204 205 206 207 208

        socket = self.ctx.socket(zmq.REP)
        socket.linger = 0
        socket.setsockopt(zmq.RCVTIMEO,
                          remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
209
        reply_master_heartbeat_port =\
F
fuyw 已提交
210
            socket.bind_to_random_port(addr="tcp://*")
211 212
        self.reply_master_heartbeat_address = "{}:{}".format(
            get_ip_address(), reply_master_heartbeat_port)
F
fuyw 已提交
213
        self.heartbeat_socket_initialized.set()
214
        connected = False
F
fuyw 已提交
215 216 217
        while self.client_is_alive and self.master_is_alive:
            try:
                message = socket.recv_multipart()
F
fuyw 已提交
218 219 220 221 222 223
                elapsed_time = datetime.timedelta(
                    seconds=int(time.time() - self.start_time))
                socket.send_multipart([
                    remote_constants.HEARTBEAT_TAG,
                    to_byte(self.executable_path),
                    to_byte(str(self.actor_num)),
224 225 226
                    to_byte(str(elapsed_time)),
                    to_byte(str(self.log_monitor_url)),
                ])  # TODO: remove additional information
F
fuyw 已提交
227
            except zmq.error.Again as e:
228 229 230 231 232 233 234 235
                if connected:
                    logger.warning("[Client] Cannot connect to the master."
                                   "Please check if it is still alive.")
                else:
                    logger.warning(
                        "[Client] Cannot connect to the master."
                        "Please check the firewall between client and master.(e.g., ping the master IP)"
                    )
F
fuyw 已提交
236 237 238 239
                self.master_is_alive = False
        socket.close(0)
        logger.warning("Client exit replying heartbeat for master.")

B
Bo Zhou 已提交
240
    def _check_and_monitor_job(self, job_heartbeat_address,
241
                               ping_heartbeat_address, max_memory):
B
Bo Zhou 已提交
242
        """ Sometimes the client may receive a job that is dead, thus 
243
        we have to check if this job is still alive before adding it to the `actor_num`.
B
Bo Zhou 已提交
244 245 246 247 248 249 250 251
        """
        # job_heartbeat_socket: sends heartbeat signal to job
        job_heartbeat_socket = self.ctx.socket(zmq.REQ)
        job_heartbeat_socket.linger = 0
        job_heartbeat_socket.setsockopt(zmq.RCVTIMEO, int(0.9 * 1000))
        job_heartbeat_socket.connect("tcp://" + ping_heartbeat_address)
        try:
            job_heartbeat_socket.send_multipart(
252 253
                [remote_constants.HEARTBEAT_TAG,
                 to_byte(str(max_memory))])
B
Bo Zhou 已提交
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
            job_heartbeat_socket.recv_multipart()
        except zmq.error.Again:
            job_heartbeat_socket.close(0)
            logger.error(
                "[Client] connects to a finished job, will try again, ping_heartbeat_address:{}"
                .format(ping_heartbeat_address))
            return False
        job_heartbeat_socket.disconnect("tcp://" + ping_heartbeat_address)
        job_heartbeat_socket.connect("tcp://" + job_heartbeat_address)
        job_heartbeat_socket.setsockopt(
            zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)

        # a thread for sending heartbeat signals to job
        thread = threading.Thread(
            target=self._create_job_monitor, args=(job_heartbeat_socket, ))
        thread.setDaemon(True)
        thread.start()
        return True

    def _create_job_monitor(self, job_heartbeat_socket):
        """Send heartbeat signals to check target's status"""

        job_is_alive = True
        while job_is_alive and self.client_is_alive:
            try:
                job_heartbeat_socket.send_multipart(
                    [remote_constants.HEARTBEAT_TAG])
281 282 283 284 285 286 287 288
                job_message = job_heartbeat_socket.recv_multipart()
                stop_job = to_str(job_message[1])
                job_address = to_str(job_message[2])

                if stop_job == 'True':
                    logger.error(
                        'Job {} exceeds max memory usage, will stop this job.'.
                        format(job_address))
289
                    self.lock.acquire()
290
                    self.actor_num -= 1
291
                    self.lock.release()
292 293 294
                    job_is_alive = False
                else:
                    time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)
B
Bo Zhou 已提交
295 296 297

            except zmq.error.Again as e:
                job_is_alive = False
298
                self.lock.acquire()
F
fuyw 已提交
299
                self.actor_num -= 1
300 301 302
                logger.error(
                    '[xparl] lost connection with a job, current actor num: {}'
                    .format(self.actor_num))
303
                self.lock.release()
B
Bo Zhou 已提交
304 305 306 307 308 309

            except zmq.error.ZMQError as e:
                break

        job_heartbeat_socket.close(0)

310
    def submit_job(self, max_memory):
F
fuyw 已提交
311 312 313 314
        """Send a job to the Master node.

        When a `@parl.remote_class` object is created, the global client
        sends a job to the master node. Then the master node will allocate
B
Bo Zhou 已提交
315
        a vacant job from its job pool to the remote object.
F
fuyw 已提交
316

317 318 319 320 321
        Args:
            max_memory (float): Maximum memory (MB) can be used by each remote
                                instance, the unit is in MB and default value is
                                none(unlimited).

F
fuyw 已提交
322
        Returns:
B
Bo Zhou 已提交
323
            job_address(str): IP address of the job. None if there is no available CPU in the cluster.
F
fuyw 已提交
324 325 326
        """
        if self.master_is_alive:

B
Bo Zhou 已提交
327 328 329 330 331
            while True:
                # A lock to prevent multiple actors from submitting job at the same time.
                self.lock.acquire()
                self.submit_job_socket.send_multipart([
                    remote_constants.CLIENT_SUBMIT_TAG,
332 333
                    to_byte(self.reply_master_heartbeat_address),
                    to_byte(self.client_id),
B
Bo Zhou 已提交
334 335 336 337 338 339 340 341 342 343 344 345
                ])
                message = self.submit_job_socket.recv_multipart()
                self.lock.release()

                tag = message[0]

                if tag == remote_constants.NORMAL_TAG:
                    job_address = to_str(message[1])
                    job_heartbeat_address = to_str(message[2])
                    ping_heartbeat_address = to_str(message[3])

                    check_result = self._check_and_monitor_job(
346 347
                        job_heartbeat_address, ping_heartbeat_address,
                        max_memory)
B
Bo Zhou 已提交
348
                    if check_result:
349
                        self.lock.acquire()
F
fuyw 已提交
350
                        self.actor_num += 1
351
                        self.lock.release()
B
Bo Zhou 已提交
352 353 354 355 356 357 358 359 360 361
                        return job_address

                # no vacant CPU resources, cannot submit a new job
                elif tag == remote_constants.CPU_TAG:
                    job_address = None
                    # wait 1 second to avoid requesting in a high frequency.
                    time.sleep(1)
                    return job_address
                else:
                    raise NotImplementedError
F
fuyw 已提交
362 363 364
        else:
            raise Exception("Client can not submit job to the master, "
                            "please check if master is connected.")
B
Bo Zhou 已提交
365
        return None
F
fuyw 已提交
366 367 368 369 370


GLOBAL_CLIENT = None


371
def connect(master_address, distributed_files=[]):
F
fuyw 已提交
372 373 374 375 376 377 378 379
    """Create a global client which connects to the master node.

    .. code-block:: python

        parl.connect(master_address='localhost:1234')

    Args:
        master_address (str): The address of the Master node to connect to.
380 381 382
        distributed_files (list): A list of files to be distributed at all 
                                  remote instances(e,g. the configuration
                                  file for initialization) .
F
fuyw 已提交
383 384 385 386 387

    Raises:
        Exception: An exception is raised if the master node is not started.
    """

388
    assert len(master_address.split(":")) == 2, "Please input address in " +\
F
fuyw 已提交
389 390
        "{ip}:{port} format"
    global GLOBAL_CLIENT
391
    addr = master_address.split(":")[0]
392
    cur_process_id = os.getpid()
F
fuyw 已提交
393
    if GLOBAL_CLIENT is None:
394 395
        GLOBAL_CLIENT = Client(master_address, cur_process_id,
                               distributed_files)
396 397
    else:
        if GLOBAL_CLIENT.process_id != cur_process_id:
398 399
            GLOBAL_CLIENT = Client(master_address, cur_process_id,
                                   distributed_files)
400 401
    logger.info("Remote actors log url: {}".format(
        GLOBAL_CLIENT.log_monitor_url))
F
fuyw 已提交
402 403 404 405 406


def get_global_client():
    """Get the global client.

407 408
    To support process-based programming, we will create a new global client in the new process.

F
fuyw 已提交
409 410 411 412 413 414 415
    Returns:
        The global client.
    """
    global GLOBAL_CLIENT
    assert GLOBAL_CLIENT is not None, "Cannot get the client to submit the" +\
        " job, have you connected to the cluster by calling " +\
        "parl.connect(master_ip, master_port)?"
416 417 418 419

    cur_process_id = os.getpid()
    if GLOBAL_CLIENT.process_id != cur_process_id:
        GLOBAL_CLIENT = Client(GLOBAL_CLIENT.master_address, cur_process_id)
F
fuyw 已提交
420 421 422 423 424 425
    return GLOBAL_CLIENT


def disconnect():
    """Disconnect the global client from the master node."""
    global GLOBAL_CLIENT
B
Bo Zhou 已提交
426 427 428 429 430
    if GLOBAL_CLIENT is not None:
        GLOBAL_CLIENT.client_is_alive = False
        GLOBAL_CLIENT = None
    else:
        logger.info(
431
            "No client to be released. Please make sure that you have called `parl.connect`"
B
Bo Zhou 已提交
432
        )