job.py 14.2 KB
Newer Older
F
fuyw 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

B
Bo Zhou 已提交
15 16 17
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['XPARL'] = 'True'
F
fuyw 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30 31
import argparse
import cloudpickle
import pickle
import sys
import tempfile
import threading
import time
import traceback
import zmq
from parl.utils import to_str, to_byte, get_ip_address, logger
from parl.utils.communication import loads_argument, loads_return,\
    dumps_argument, dumps_return
from parl.remote import remote_constants
from parl.utils.exceptions import SerializeError, DeserializeError
B
Bo Zhou 已提交
32
from parl.remote.message import InitializedJob
F
fuyw 已提交
33 34 35 36 37 38 39 40


class Job(object):
    """Base class for the job.

    After establishing connection with the remote object, the job will
    create a remote class instance locally and enter an infinite loop,
    waiting for commands from the remote object.
B
Bo Zhou 已提交
41

F
fuyw 已提交
42 43 44
    """

    def __init__(self, worker_address):
B
Bo Zhou 已提交
45 46 47 48
        """
        Args:
            worker_address(str): worker_address for sending job information(e.g, pid)
        """
F
fuyw 已提交
49 50
        self.job_is_alive = True
        self.worker_address = worker_address
B
Bo Zhou 已提交
51
        self.lock = threading.Lock()
F
fuyw 已提交
52 53 54
        self._create_sockets()

    def _create_sockets(self):
B
Bo Zhou 已提交
55
        """Create three sockets for each job.
F
fuyw 已提交
56

B
Bo Zhou 已提交
57 58
        (1) reply_socket(main socket): receives the command(i.e, the function name and args) 
            from the actual class instance, completes the computation, and returns the result of
F
fuyw 已提交
59
            the function.
B
Bo Zhou 已提交
60 61
        (2) job_socket(functional socket): sends job_address and heartbeat_address to worker.
        (3) kill_job_socket: sends a command to the corresponding worker to kill the job.
F
fuyw 已提交
62 63 64 65 66

        """

        self.ctx = zmq.Context()

B
Bo Zhou 已提交
67
        # create the reply_socket
F
fuyw 已提交
68 69
        self.reply_socket = self.ctx.socket(zmq.REP)
        job_port = self.reply_socket.bind_to_random_port(addr="tcp://*")
B
Bo Zhou 已提交
70
        self.reply_socket.linger = 0
F
fuyw 已提交
71 72 73
        self.job_ip = get_ip_address()
        self.job_address = "{}:{}".format(self.job_ip, job_port)

B
Bo Zhou 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
        # create the job_socket
        self.job_socket = self.ctx.socket(zmq.REQ)
        self.job_socket.connect("tcp://{}".format(self.worker_address))

        # a thread that reply ping signals from the client
        ping_heartbeat_socket, ping_heartbeat_address = self._create_heartbeat_server(
            timeout=False)
        ping_thread = threading.Thread(
            target=self._reply_ping, args=(ping_heartbeat_socket, ))
        ping_thread.setDaemon(True)
        ping_thread.start()
        self.ping_heartbeat_address = ping_heartbeat_address

        # a thread that reply heartbeat signals from the worker
        worker_heartbeat_socket, worker_heartbeat_address = self._create_heartbeat_server(
        )
        worker_thread = threading.Thread(
            target=self._reply_worker_heartbeat,
            args=(worker_heartbeat_socket, ))
        worker_thread.setDaemon(True)
        worker_thread.start()

        # a thread that reply heartbeat signals from the client
        client_heartbeat_socket, client_heartbeat_address = self._create_heartbeat_server(
        )
        self.client_thread = threading.Thread(
            target=self._reply_client_heartbeat,
            args=(client_heartbeat_socket, ))
        self.client_thread.setDaemon(True)

        # sends job information to the worker
        initialized_job = InitializedJob(
            self.job_address, worker_heartbeat_address,
            client_heartbeat_address, self.ping_heartbeat_address, None,
            os.getpid())
        self.job_socket.send_multipart(
            [remote_constants.NORMAL_TAG,
             cloudpickle.dumps(initialized_job)])
        message = self.job_socket.recv_multipart()

        assert message[0] == remote_constants.NORMAL_TAG
        # create the kill_job_socket
        kill_job_address = to_str(message[1])
        self.kill_job_socket = self.ctx.socket(zmq.REQ)
        self.kill_job_socket.setsockopt(
            zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
        self.kill_job_socket.connect("tcp://{}".format(kill_job_address))

    def _reply_ping(self, socket):
        """Create a socket server that reply the ping signal from client.
        This signal is used to make sure that the job is still alive.
        """
        while self.job_is_alive:
            message = socket.recv_multipart()
            socket.send_multipart([remote_constants.HEARTBEAT_TAG])
        socket.close(0)

    def _create_heartbeat_server(self, timeout=True):
        """Create a socket server that will raises timeout exception.
        """
        heartbeat_socket = self.ctx.socket(zmq.REP)
        if timeout:
            heartbeat_socket.setsockopt(
                zmq.RCVTIMEO, remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
        heartbeat_socket.linger = 0
        heartbeat_port = heartbeat_socket.bind_to_random_port(addr="tcp://*")
        heartbeat_address = "{}:{}".format(self.job_ip, heartbeat_port)
        return heartbeat_socket, heartbeat_address

    def _reply_client_heartbeat(self, socket):
        """Create a socket that replies heartbeat signals from the client.
        If the job losts connection with the client, it will exit too.
        """
        self.client_is_alive = True
        while self.client_is_alive:
            try:
                message = socket.recv_multipart()
                socket.send_multipart([remote_constants.HEARTBEAT_TAG])

            except zmq.error.Again as e:
                logger.warning(
                    "[Job] Cannot connect to the client. This job will exit and inform the worker."
                )
                self.client_is_alive = False
        socket.close(0)
        with self.lock:
            self.kill_job_socket.send_multipart(
                [remote_constants.KILLJOB_TAG,
                 to_byte(self.job_address)])
            _ = self.kill_job_socket.recv_multipart()
        logger.warning("[Job]lost connection with the client, will exit")
        os._exit(1)

    def _reply_worker_heartbeat(self, socket):
        """create a socket that replies heartbeat signals from the worker.
        If the worker has exited, the job will exit automatically.
        """
B
Bo Zhou 已提交
171

F
fuyw 已提交
172
        self.worker_is_alive = True
B
Bo Zhou 已提交
173
        # a flag to decide when to exit heartbeat loop
F
fuyw 已提交
174 175 176 177 178 179
        while self.worker_is_alive and self.job_is_alive:
            try:
                message = socket.recv_multipart()
                socket.send_multipart([remote_constants.HEARTBEAT_TAG])

            except zmq.error.Again as e:
B
Bo Zhou 已提交
180 181
                logger.warning("[Job] Cannot connect to the worker{}. ".format(
                    self.worker_address) + "Job will quit.")
F
fuyw 已提交
182 183
                self.worker_is_alive = False
                self.job_is_alive = False
B
Bo Zhou 已提交
184 185
        socket.close(0)
        os._exit(1)
F
fuyw 已提交
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212

    def wait_for_files(self):
        """Wait for python files from remote object.

        When a remote object receives the allocated job address, it will send
        the python files to the job. Later, the job will save these files to a
        temporary directory and add the temporary diretory to Python's working
        directory.

        Returns:
            A temporary directory containing the python files.
        """

        while True:
            message = self.reply_socket.recv_multipart()
            tag = message[0]
            if tag == remote_constants.SEND_FILE_TAG:
                pyfiles = pickle.loads(message[1])
                envdir = tempfile.mkdtemp()
                for file in pyfiles:
                    code = pyfiles[file]
                    file = os.path.join(envdir, file)
                    with open(file, 'wb') as code_file:
                        code_file.write(code)
                self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
                return envdir
            else:
B
Bo Zhou 已提交
213 214
                logger.error("NotImplementedError:{}, received tag:{}".format(
                    self.job_address, ))
F
fuyw 已提交
215 216 217 218 219 220 221 222 223 224 225 226 227
                raise NotImplementedError

    def wait_for_connection(self):
        """Wait for connection from the remote object.

        The remote object will send its class information and initialization
        arguments to the job, these parameters are then used to create a
        local instance in the job process.

        Returns:
            A local instance of the remote class object.
        """

B
Bo Zhou 已提交
228 229 230 231 232 233 234 235
        message = self.reply_socket.recv_multipart()
        tag = message[0]
        obj = None
        if tag == remote_constants.INIT_OBJECT_TAG:
            cls = cloudpickle.loads(message[1])
            args, kwargs = cloudpickle.loads(message[2])

            try:
F
fuyw 已提交
236
                obj = cls(*args, **kwargs)
B
Bo Zhou 已提交
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
            except Exception as e:
                traceback_str = str(traceback.format_exc())
                error_str = str(e)
                logger.error("traceback:\n{}".format(traceback_str))
                self.reply_socket.send_multipart([
                    remote_constants.EXCEPTION_TAG,
                    to_byte(error_str + "\ntraceback:\n" + traceback_str)
                ])
                self.client_is_alive = False
                return None

            self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
        else:
            logger.error("Message from job {}".format(message))
            self.reply_socket.send_multipart([
                remote_constants.EXCEPTION_TAG,
                b"[job]Unkonwn tag when tried to receive the class definition"
            ])
            raise NotImplementedError

        return obj
F
fuyw 已提交
258 259

    def run(self):
B
Bo Zhou 已提交
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
        """An infinite loop waiting for a new task.
        """
        # receive source code from the actor and append them to the environment variables.
        envdir = self.wait_for_files()
        sys.path.append(envdir)
        self.client_thread.start()

        try:
            obj = self.wait_for_connection()
            assert obj is not None
            self.single_task(obj)
        except Exception as e:
            logger.error(
                "Error occurs when running a single task. We will reset this job. Reason:{}"
                .format(e))
            traceback_str = str(traceback.format_exc())
            logger.error("traceback:\n{}".format(traceback_str))
        with self.lock:
            self.kill_job_socket.send_multipart(
                [remote_constants.KILLJOB_TAG,
                 to_byte(self.job_address)])
            _ = self.kill_job_socket.recv_multipart()

    def single_task(self, obj):
F
fuyw 已提交
284 285 286 287 288 289 290 291 292 293 294
        """An infinite loop waiting for commands from the remote object.

        Each job will receive two kinds of message from the remote object:

        1. When the remote object calls a function, job will run the
           function on the local instance and return the results to the
           remote object.
        2. When the remote object is deleted, the job will quit and release
           related computation resources.
        """

B
Bo Zhou 已提交
295
        while self.job_is_alive and self.client_is_alive:
F
fuyw 已提交
296
            message = self.reply_socket.recv_multipart()
B
Bo Zhou 已提交
297

F
fuyw 已提交
298 299 300 301 302 303 304 305 306 307 308 309 310 311
            tag = message[0]

            if tag == remote_constants.CALL_TAG:
                try:
                    function_name = to_str(message[1])
                    data = message[2]
                    args, kwargs = loads_argument(data)
                    ret = getattr(obj, function_name)(*args, **kwargs)
                    ret = dumps_return(ret)

                    self.reply_socket.send_multipart(
                        [remote_constants.NORMAL_TAG, ret])

                except Exception as e:
B
Bo Zhou 已提交
312 313 314
                    # reset the job
                    self.client_is_alive = False

F
fuyw 已提交
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
                    error_str = str(e)
                    logger.error(error_str)

                    if type(e) == AttributeError:
                        self.reply_socket.send_multipart([
                            remote_constants.ATTRIBUTE_EXCEPTION_TAG,
                            to_byte(error_str)
                        ])
                        raise AttributeError

                    elif type(e) == SerializeError:
                        self.reply_socket.send_multipart([
                            remote_constants.SERIALIZE_EXCEPTION_TAG,
                            to_byte(error_str)
                        ])
                        raise SerializeError

                    elif type(e) == DeserializeError:
                        self.reply_socket.send_multipart([
                            remote_constants.DESERIALIZE_EXCEPTION_TAG,
                            to_byte(error_str)
                        ])
B
Bo Zhou 已提交
337
                        raise DeserializeError
F
fuyw 已提交
338 339 340 341 342 343 344 345 346

                    else:
                        traceback_str = str(traceback.format_exc())
                        logger.error("traceback:\n{}".format(traceback_str))
                        self.reply_socket.send_multipart([
                            remote_constants.EXCEPTION_TAG,
                            to_byte(error_str + "\ntraceback:\n" +
                                    traceback_str)
                        ])
B
Bo Zhou 已提交
347
                        break
F
fuyw 已提交
348 349 350 351

            # receive DELETE_TAG from actor, and stop replying worker heartbeat
            elif tag == remote_constants.KILLJOB_TAG:
                self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
B
Bo Zhou 已提交
352 353 354 355 356
                self.client_is_alive = False
                logger.warning(
                    "An actor exits and this job {} will exit.".format(
                        self.job_address))
                break
F
fuyw 已提交
357
            else:
B
Bo Zhou 已提交
358 359
                logger.error(
                    "The job receives an unknown message: {}".format(message))
F
fuyw 已提交
360 361 362 363 364 365 366 367 368 369
                raise NotImplementedError


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--worker_address", required=True, type=str, help="worker_address")
    args = parser.parse_args()
    job = Job(args.worker_address)
    job.run()