job.py 20.2 KB
Newer Older
F
fuyw 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

H
Hongsheng Zeng 已提交
15 16 17
# Fix cloudpickle compatible problem we known.
import compatible_trick

B
Bo Zhou 已提交
18 19
import os
os.environ['XPARL'] = 'True'
20
os.environ['CUDA_VISIBLE_DEVICES'] = ''
F
fuyw 已提交
21 22 23
import argparse
import cloudpickle
import pickle
24 25
import psutil
import re
F
fuyw 已提交
26 27 28 29 30 31
import sys
import tempfile
import threading
import time
import traceback
import zmq
32
from multiprocessing import Process, Pipe
F
fuyw 已提交
33 34 35 36 37
from parl.utils import to_str, to_byte, get_ip_address, logger
from parl.utils.communication import loads_argument, loads_return,\
    dumps_argument, dumps_return
from parl.remote import remote_constants
from parl.utils.exceptions import SerializeError, DeserializeError
B
Bo Zhou 已提交
38
from parl.remote.message import InitializedJob
39
from parl.remote.utils import load_remote_class, redirect_stdout_to_file
F
fuyw 已提交
40 41 42 43 44 45


class Job(object):
    """Base class for the job.

    After establishing connection with the remote object, the job will
46 47
    create a remote class instance locally and enter an infinite loop
    in a separate process, waiting for commands from the remote object.
B
Bo Zhou 已提交
48

F
fuyw 已提交
49 50
    """

51
    def __init__(self, worker_address, log_server_address):
B
Bo Zhou 已提交
52 53 54
        """
        Args:
            worker_address(str): worker_address for sending job information(e.g, pid)
55 56 57

        Attributes:
            pid (int): Job process ID.
F
fuyw 已提交
58
            max_memory (float): Maximum memory (MB) can be used by each remote instance.
B
Bo Zhou 已提交
59
        """
60 61 62
        self.max_memory = None

        self.job_address_receiver, job_address_sender = Pipe()
63
        self.job_id_receiver, job_id_sender = Pipe()
64

F
fuyw 已提交
65
        self.worker_address = worker_address
66
        self.log_server_address = log_server_address
67
        self.job_ip = get_ip_address()
68
        self.pid = os.getpid()
H
Hongsheng Zeng 已提交
69 70
        """
        NOTE:
71
            In Windows, it will raise errors when creating threading.Lock before starting multiprocess.Process.
H
Hongsheng Zeng 已提交
72 73
        """
        self.lock = threading.Lock()
74 75 76
        th = threading.Thread(target=self._create_sockets)
        th.setDaemon(True)
        th.start()
F
fuyw 已提交
77

F
fuyw 已提交
78 79 80
        process = psutil.Process(self.pid)
        self.init_memory = float(process.memory_info()[0]) / (1024**2)

81
        self.run(job_address_sender, job_id_sender)
82 83 84 85 86 87 88 89 90

        with self.lock:
            self.kill_job_socket.send_multipart(
                [remote_constants.KILLJOB_TAG,
                 to_byte(self.job_address)])
            try:
                _ = self.kill_job_socket.recv_multipart()
            except zmq.error.Again as e:
                pass
91
            os._exit(0)
92

F
fuyw 已提交
93
    def _create_sockets(self):
94
        """Create five sockets for each job in main process.
F
fuyw 已提交
95

96 97 98 99 100
        (1) job_socket(functional socket): sends job_address and heartbeat_address to worker.
        (2) ping_heartbeat_socket: replies ping message of client.
        (3) worker_heartbeat_socket: replies heartbeat message of worker.
        (4) client_heartbeat_socket: replies heartbeat message of client.
        (5) kill_job_socket: sends a command to the corresponding worker to kill the job.
F
fuyw 已提交
101 102

        """
103 104
        # wait for another process to create reply socket
        self.job_address = self.job_address_receiver.recv()
105
        self.job_id = self.job_id_receiver.recv()
F
fuyw 已提交
106 107

        self.ctx = zmq.Context()
B
Bo Zhou 已提交
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
        # create the job_socket
        self.job_socket = self.ctx.socket(zmq.REQ)
        self.job_socket.connect("tcp://{}".format(self.worker_address))

        # a thread that reply ping signals from the client
        ping_heartbeat_socket, ping_heartbeat_address = self._create_heartbeat_server(
            timeout=False)
        ping_thread = threading.Thread(
            target=self._reply_ping, args=(ping_heartbeat_socket, ))
        ping_thread.setDaemon(True)
        ping_thread.start()

        # a thread that reply heartbeat signals from the worker
        worker_heartbeat_socket, worker_heartbeat_address = self._create_heartbeat_server(
        )
        worker_thread = threading.Thread(
            target=self._reply_worker_heartbeat,
            args=(worker_heartbeat_socket, ))
        worker_thread.setDaemon(True)

        # a thread that reply heartbeat signals from the client
        client_heartbeat_socket, client_heartbeat_address = self._create_heartbeat_server(
        )
        self.client_thread = threading.Thread(
            target=self._reply_client_heartbeat,
            args=(client_heartbeat_socket, ))
        self.client_thread.setDaemon(True)

        # sends job information to the worker
        initialized_job = InitializedJob(
            self.job_address, worker_heartbeat_address,
139 140
            client_heartbeat_address, ping_heartbeat_address, None, self.pid,
            self.job_id, self.log_server_address)
B
Bo Zhou 已提交
141 142 143 144
        self.job_socket.send_multipart(
            [remote_constants.NORMAL_TAG,
             cloudpickle.dumps(initialized_job)])
        message = self.job_socket.recv_multipart()
B
Bo Zhou 已提交
145
        worker_thread.start()
B
Bo Zhou 已提交
146

147 148
        tag = message[0]
        assert tag == remote_constants.NORMAL_TAG
B
Bo Zhou 已提交
149 150 151 152 153 154 155
        # create the kill_job_socket
        kill_job_address = to_str(message[1])
        self.kill_job_socket = self.ctx.socket(zmq.REQ)
        self.kill_job_socket.setsockopt(
            zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
        self.kill_job_socket.connect("tcp://{}".format(kill_job_address))

156 157 158 159 160 161
    def _check_used_memory(self):
        """Check if the memory used by this job exceeds self.max_memory."""
        stop_job = False
        if self.max_memory is not None:
            process = psutil.Process(self.pid)
            used_memory = float(process.memory_info()[0]) / (1024**2)
F
fuyw 已提交
162
            if used_memory > self.max_memory + self.init_memory:
163 164 165
                stop_job = True
        return stop_job

B
Bo Zhou 已提交
166 167 168 169
    def _reply_ping(self, socket):
        """Create a socket server that reply the ping signal from client.
        This signal is used to make sure that the job is still alive.
        """
170 171 172 173 174 175
        message = socket.recv_multipart()
        max_memory = to_str(message[1])
        if max_memory != 'None':
            self.max_memory = float(max_memory)
        socket.send_multipart([remote_constants.HEARTBEAT_TAG])
        self.client_thread.start()
B
Bo Zhou 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
        socket.close(0)

    def _create_heartbeat_server(self, timeout=True):
        """Create a socket server that will raises timeout exception.
        """
        heartbeat_socket = self.ctx.socket(zmq.REP)
        if timeout:
            heartbeat_socket.setsockopt(
                zmq.RCVTIMEO, remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
        heartbeat_socket.linger = 0
        heartbeat_port = heartbeat_socket.bind_to_random_port(addr="tcp://*")
        heartbeat_address = "{}:{}".format(self.job_ip, heartbeat_port)
        return heartbeat_socket, heartbeat_address

    def _reply_client_heartbeat(self, socket):
        """Create a socket that replies heartbeat signals from the client.
        If the job losts connection with the client, it will exit too.
        """
194
        while True:
B
Bo Zhou 已提交
195 196
            try:
                message = socket.recv_multipart()
197 198 199 200 201 202 203
                stop_job = self._check_used_memory()
                socket.send_multipart([
                    remote_constants.HEARTBEAT_TAG,
                    to_byte(str(stop_job)),
                    to_byte(self.job_address)
                ])
                if stop_job == True:
F
fuyw 已提交
204 205 206
                    logger.error(
                        "Memory used by this job exceeds {}. This job will exist."
                        .format(self.max_memory))
F
fuyw 已提交
207
                    time.sleep(5)
208 209
                    socket.close(0)
                    os._exit(1)
B
Bo Zhou 已提交
210 211 212 213
            except zmq.error.Again as e:
                logger.warning(
                    "[Job] Cannot connect to the client. This job will exit and inform the worker."
                )
214
                break
B
Bo Zhou 已提交
215 216 217 218 219
        socket.close(0)
        with self.lock:
            self.kill_job_socket.send_multipart(
                [remote_constants.KILLJOB_TAG,
                 to_byte(self.job_address)])
B
Bo Zhou 已提交
220 221 222 223
            try:
                _ = self.kill_job_socket.recv_multipart()
            except zmq.error.Again as e:
                pass
B
Bo Zhou 已提交
224 225 226 227 228 229 230
        logger.warning("[Job]lost connection with the client, will exit")
        os._exit(1)

    def _reply_worker_heartbeat(self, socket):
        """create a socket that replies heartbeat signals from the worker.
        If the worker has exited, the job will exit automatically.
        """
231
        while True:
F
fuyw 已提交
232 233 234 235
            try:
                message = socket.recv_multipart()
                socket.send_multipart([remote_constants.HEARTBEAT_TAG])
            except zmq.error.Again as e:
B
Bo Zhou 已提交
236 237
                logger.warning("[Job] Cannot connect to the worker{}. ".format(
                    self.worker_address) + "Job will quit.")
238
                break
B
Bo Zhou 已提交
239 240
        socket.close(0)
        os._exit(1)
F
fuyw 已提交
241

242
    def wait_for_files(self, reply_socket, job_address):
F
fuyw 已提交
243 244 245 246 247 248
        """Wait for python files from remote object.

        When a remote object receives the allocated job address, it will send
        the python files to the job. Later, the job will save these files to a
        temporary directory and add the temporary diretory to Python's working
        directory.
249

250 251 252
        Args:
            reply_socket (sockert): main socket to accept commands of remote object.
            job_address (String): address of reply_socket.
F
fuyw 已提交
253 254 255 256 257

        Returns:
            A temporary directory containing the python files.
        """

258 259 260 261 262
        message = reply_socket.recv_multipart()
        tag = message[0]
        if tag == remote_constants.SEND_FILE_TAG:
            pyfiles = pickle.loads(message[1])
            envdir = tempfile.mkdtemp()
263 264 265 266 267 268 269

            for empty_subfolder in pyfiles['empty_subfolders']:
                empty_subfolder_path = os.path.join(envdir, empty_subfolder)
                if not os.path.exists(empty_subfolder_path):
                    os.makedirs(empty_subfolder_path)

            # save python files to temporary directory
270 271 272 273 274 275 276 277 278 279
            for file, code in pyfiles['python_files'].items():
                file = os.path.join(envdir, file)
                with open(file, 'wb') as code_file:
                    code_file.write(code)

            # save other files to current directory
            for file, content in pyfiles['other_files'].items():
                # create directory (i.e. ./rom_files/)
                if '/' in file:
                    try:
280 281 282 283
                        sep = os.sep
                        recursive_dirs = os.path.join(*(file.split(sep)[:-1]))
                        recursive_dirs = os.path.join(envdir, recursive_dirs)
                        os.makedirs(recursive_dirs)
284 285
                    except OSError as e:
                        pass
286
                file = os.path.join(envdir, file)
287 288 289 290 291 292 293 294
                with open(file, 'wb') as f:
                    f.write(content)
            reply_socket.send_multipart([remote_constants.NORMAL_TAG])
            return envdir
        else:
            logger.error("NotImplementedError:{}, received tag:{}".format(
                job_address, ))
            raise NotImplementedError
F
fuyw 已提交
295

296
    def wait_for_connection(self, reply_socket):
F
fuyw 已提交
297 298 299 300 301 302
        """Wait for connection from the remote object.

        The remote object will send its class information and initialization
        arguments to the job, these parameters are then used to create a
        local instance in the job process.

303 304 305
        Args:
            reply_socket (sockert): main socket to accept commands of remote object.

F
fuyw 已提交
306 307 308 309
        Returns:
            A local instance of the remote class object.
        """

310
        message = reply_socket.recv_multipart()
B
Bo Zhou 已提交
311 312 313
        tag = message[0]
        obj = None

314
        if tag == remote_constants.INIT_OBJECT_TAG:
B
Bo Zhou 已提交
315
            try:
316 317 318
                file_name, class_name, end_of_file = cloudpickle.loads(
                    message[1])
                cls = load_remote_class(file_name, class_name, end_of_file)
319
                args, kwargs = cloudpickle.loads(message[2])
320 321 322
                logfile_path = os.path.join(self.log_dir, 'stdout.log')
                with redirect_stdout_to_file(logfile_path):
                    obj = cls(*args, **kwargs)
B
Bo Zhou 已提交
323 324 325 326
            except Exception as e:
                traceback_str = str(traceback.format_exc())
                error_str = str(e)
                logger.error("traceback:\n{}".format(traceback_str))
327
                reply_socket.send_multipart([
B
Bo Zhou 已提交
328 329 330 331
                    remote_constants.EXCEPTION_TAG,
                    to_byte(error_str + "\ntraceback:\n" + traceback_str)
                ])
                return None
332 333 334 335
            reply_socket.send_multipart([
                remote_constants.NORMAL_TAG,
                dumps_return(set(obj.__dict__.keys()))
            ])
B
Bo Zhou 已提交
336 337
        else:
            logger.error("Message from job {}".format(message))
338
            reply_socket.send_multipart([
B
Bo Zhou 已提交
339 340 341 342 343 344
                remote_constants.EXCEPTION_TAG,
                b"[job]Unkonwn tag when tried to receive the class definition"
            ])
            raise NotImplementedError

        return obj
F
fuyw 已提交
345

346
    def run(self, job_address_sender, job_id_sender):
B
Bo Zhou 已提交
347
        """An infinite loop waiting for a new task.
348 349 350

        Args:
            job_address_sender(sending end of multiprocessing.Pipe): send job address of reply_socket to main process.
B
Bo Zhou 已提交
351
        """
352 353 354 355 356 357 358 359 360
        ctx = zmq.Context()

        # create the reply_socket
        reply_socket = ctx.socket(zmq.REP)
        job_port = reply_socket.bind_to_random_port(addr="tcp://*")
        reply_socket.linger = 0
        job_ip = get_ip_address()
        job_address = "{}:{}".format(job_ip, job_port)

361 362 363 364 365 366 367
        job_id = job_address.replace(':', '_') + '_' + str(int(time.time()))
        self.log_dir = os.path.expanduser('~/.parl_data/job/{}'.format(job_id))
        logger.set_dir(self.log_dir)
        logger.info(
            "[Job] Job {} initialized. Reply heartbeat socket Address: {}.".
            format(job_id, job_address))

368
        job_address_sender.send(job_address)
369
        job_id_sender.send(job_id)
B
Bo Zhou 已提交
370 371

        try:
372 373
            # receive source code from the actor and append them to the environment variables.
            envdir = self.wait_for_files(reply_socket, job_address)
B
Bo Zhou 已提交
374
            sys.path.insert(0, envdir)
375
            os.chdir(envdir)
376 377

            obj = self.wait_for_connection(reply_socket)
B
Bo Zhou 已提交
378
            assert obj is not None
379
            self.single_task(obj, reply_socket, job_address)
B
Bo Zhou 已提交
380 381
        except Exception as e:
            logger.error(
382
                "Error occurs when running a single task. We will reset this job. \nReason:{}"
B
Bo Zhou 已提交
383 384 385 386
                .format(e))
            traceback_str = str(traceback.format_exc())
            logger.error("traceback:\n{}".format(traceback_str))

387
    def single_task(self, obj, reply_socket, job_address):
F
fuyw 已提交
388 389 390 391 392 393 394 395 396
        """An infinite loop waiting for commands from the remote object.

        Each job will receive two kinds of message from the remote object:

        1. When the remote object calls a function, job will run the
           function on the local instance and return the results to the
           remote object.
        2. When the remote object is deleted, the job will quit and release
           related computation resources.
397 398 399 400

        Args:
            reply_socket (sockert): main socket to accept commands of remote object.
            job_address (String): address of reply_socket.
F
fuyw 已提交
401 402
        """

403 404
        while True:
            message = reply_socket.recv_multipart()
F
fuyw 已提交
405
            tag = message[0]
406
            if tag in [
407 408 409
                    remote_constants.CALL_TAG,
                    remote_constants.GET_ATTRIBUTE_TAG,
                    remote_constants.SET_ATTRIBUTE_TAG,
410
            ]:
F
fuyw 已提交
411
                try:
412
                    if tag == remote_constants.CALL_TAG:
413 414 415 416 417 418 419 420 421 422
                        function_name = to_str(message[1])
                        data = message[2]
                        args, kwargs = loads_argument(data)

                        # Redirect stdout to stdout.log temporarily
                        logfile_path = os.path.join(self.log_dir, 'stdout.log')
                        with redirect_stdout_to_file(logfile_path):
                            ret = getattr(obj, function_name)(*args, **kwargs)

                        ret = dumps_return(ret)
423 424 425 426
                        reply_socket.send_multipart([
                            remote_constants.NORMAL_TAG, ret,
                            dumps_return(set(obj.__dict__.keys()))
                        ])
427

428
                    elif tag == remote_constants.GET_ATTRIBUTE_TAG:
429 430 431 432 433 434 435
                        attribute_name = to_str(message[1])
                        logfile_path = os.path.join(self.log_dir, 'stdout.log')
                        with redirect_stdout_to_file(logfile_path):
                            ret = getattr(obj, attribute_name)
                        ret = dumps_return(ret)
                        reply_socket.send_multipart(
                            [remote_constants.NORMAL_TAG, ret])
436
                    elif tag == remote_constants.SET_ATTRIBUTE_TAG:
437 438 439 440 441
                        attribute_name = to_str(message[1])
                        attribute_value = loads_return(message[2])
                        logfile_path = os.path.join(self.log_dir, 'stdout.log')
                        with redirect_stdout_to_file(logfile_path):
                            setattr(obj, attribute_name, attribute_value)
442 443 444 445
                        reply_socket.send_multipart([
                            remote_constants.NORMAL_TAG,
                            dumps_return(set(obj.__dict__.keys()))
                        ])
446 447
                    else:
                        pass
F
fuyw 已提交
448 449

                except Exception as e:
B
Bo Zhou 已提交
450 451
                    # reset the job

F
fuyw 已提交
452 453 454 455
                    error_str = str(e)
                    logger.error(error_str)

                    if type(e) == AttributeError:
456
                        reply_socket.send_multipart([
F
fuyw 已提交
457 458 459 460 461 462
                            remote_constants.ATTRIBUTE_EXCEPTION_TAG,
                            to_byte(error_str)
                        ])
                        raise AttributeError

                    elif type(e) == SerializeError:
463
                        reply_socket.send_multipart([
F
fuyw 已提交
464 465 466 467 468 469
                            remote_constants.SERIALIZE_EXCEPTION_TAG,
                            to_byte(error_str)
                        ])
                        raise SerializeError

                    elif type(e) == DeserializeError:
470
                        reply_socket.send_multipart([
F
fuyw 已提交
471 472 473
                            remote_constants.DESERIALIZE_EXCEPTION_TAG,
                            to_byte(error_str)
                        ])
B
Bo Zhou 已提交
474
                        raise DeserializeError
F
fuyw 已提交
475 476 477 478

                    else:
                        traceback_str = str(traceback.format_exc())
                        logger.error("traceback:\n{}".format(traceback_str))
479
                        reply_socket.send_multipart([
F
fuyw 已提交
480 481 482 483
                            remote_constants.EXCEPTION_TAG,
                            to_byte(error_str + "\ntraceback:\n" +
                                    traceback_str)
                        ])
B
Bo Zhou 已提交
484
                        break
F
fuyw 已提交
485 486 487

            # receive DELETE_TAG from actor, and stop replying worker heartbeat
            elif tag == remote_constants.KILLJOB_TAG:
488 489 490
                reply_socket.send_multipart([remote_constants.NORMAL_TAG])
                logger.warning("An actor exits and this job {} will exit.".
                               format(job_address))
B
Bo Zhou 已提交
491
                break
F
fuyw 已提交
492
            else:
B
Bo Zhou 已提交
493 494
                logger.error(
                    "The job receives an unknown message: {}".format(message))
F
fuyw 已提交
495 496 497 498 499 500 501
                raise NotImplementedError


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--worker_address", required=True, type=str, help="worker_address")
502 503 504 505 506
    parser.add_argument(
        "--log_server_address",
        required=True,
        type=str,
        help="log_server_address, address of the log web server on worker")
F
fuyw 已提交
507
    args = parser.parse_args()
508
    job = Job(args.worker_address, args.log_server_address)