worker.py 14.4 KB
Newer Older
F
fuyw 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cloudpickle
import multiprocessing
import os
F
fuyw 已提交
18
import psutil
B
Bo Zhou 已提交
19
import signal
F
fuyw 已提交
20
import socket
F
fuyw 已提交
21 22 23 24
import subprocess
import sys
import time
import threading
B
Bo Zhou 已提交
25
import warnings
F
fuyw 已提交
26
import zmq
F
fuyw 已提交
27
from datetime import datetime
F
fuyw 已提交
28 29 30

from parl.utils import get_ip_address, to_byte, to_str, logger
from parl.remote import remote_constants
B
Bo Zhou 已提交
31 32 33
from parl.remote.message import InitializedWorker
from parl.remote.status import WorkerStatus
from six.moves import queue
F
fuyw 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57


class Worker(object):
    """Worker provides the cpu computation resources for the cluster.

    A worker node is connected to the master node and will send its
    computation resources information to the master node. When a worker
    node is created, it will start `cpu_num` empty jobs and these jobs'
    ip addresses will be send to the master node. Further, when an old
    job is killed, worker will start a new job and send the new job ip
    address to the master node.

    To start a worker, we use the following xparl command line api:

    .. code-block:: python

        xparl connect --address localhost:1234 --cpu_num 8

    Attributes:
        master_address (str): Master's ip address.
        request_master_socket (zmq.Context.socket): A socket which sends job
                                                    address to the master node.
        reply_job_socket (zmq.Context.socket): A socket which receives
                                               job_address from the job.
B
Bo Zhou 已提交
58
        kill_job_socket (zmq.Context.socket): A socket that receives commands to kill the job from jobs.
F
fuyw 已提交
59 60 61
    Args:
        master_address (str): IP address of the master node.
        cpu_num (int): Number of cpu to be used on the worker.
B
Bo Zhou 已提交
62
        job_buffer (str): A buffer that stores initialized jobs for providing new jobs in a short time.
F
fuyw 已提交
63 64 65 66 67 68 69 70 71
    """

    def __init__(self, master_address, cpu_num=None):
        self.lock = threading.Lock()
        self.heartbeat_socket_initialized = threading.Event()
        self.ctx = zmq.Context.instance()
        self.master_address = master_address
        self.master_is_alive = True
        self.worker_is_alive = True
B
Bo Zhou 已提交
72 73
        self.worker_status = None  # initialized at `self._create_jobs`
        self.lock = threading.Lock()
F
fuyw 已提交
74
        self._set_cpu_num(cpu_num)
B
Bo Zhou 已提交
75
        self.job_buffer = queue.Queue(maxsize=self.cpu_num)
F
fuyw 已提交
76
        self._create_sockets()
B
Bo Zhou 已提交
77 78 79 80 81 82 83 84 85 86 87

        # create a thread that waits commands from the job to kill the job.
        self.kill_job_thread = threading.Thread(target=self._reply_kill_job)
        self.kill_job_thread.start()

        self._create_jobs()

        # create a thread that initializes jobs and adds them into the job_buffer
        job_thread = threading.Thread(target=self._fill_job_buffer)
        job_thread.setDaemon(True)
        job_thread.start()
F
fuyw 已提交
88 89 90 91 92 93 94 95 96 97 98 99

    def _set_cpu_num(self, cpu_num=None):
        """set useable cpu number for worker"""
        if cpu_num is not None:
            assert isinstance(
                cpu_num, int
            ), "cpu_num should be INT type, please check the input type."
            self.cpu_num = cpu_num
        else:
            self.cpu_num = multiprocessing.cpu_count()

    def _create_sockets(self):
F
fuyw 已提交
100
        """ Each worker has three sockets at start:
F
fuyw 已提交
101 102

        (1) request_master_socket: sends job address to master node.
F
fuyw 已提交
103 104
        (2) reply_job_socket: receives job_address from subprocess.
        (3) kill_job_socket : receives commands to kill the job from jobs.
F
fuyw 已提交
105

B
Bo Zhou 已提交
106 107
        When a job starts, a new heartbeat socket is created to receive
        heartbeat signals from the job.
F
fuyw 已提交
108 109

        """
F
fuyw 已提交
110
        self.worker_ip = get_ip_address()
F
fuyw 已提交
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125

        # request_master_socket: sends job address to master
        self.request_master_socket = self.ctx.socket(zmq.REQ)
        self.request_master_socket.linger = 0

        # wait for 0.5 second to check whether master is started
        self.request_master_socket.setsockopt(zmq.RCVTIMEO, 500)
        self.request_master_socket.connect("tcp://" + self.master_address)

        # reply_job_socket: receives job_address from subprocess
        self.reply_job_socket = self.ctx.socket(zmq.REP)
        self.reply_job_socket.linger = 0
        reply_job_port = self.reply_job_socket.bind_to_random_port("tcp://*")
        self.reply_job_address = "{}:{}".format(self.worker_ip, reply_job_port)

B
Bo Zhou 已提交
126 127 128 129 130 131 132 133
        # kill_job_socket
        self.kill_job_socket = self.ctx.socket(zmq.REP)
        self.kill_job_socket.linger = 0
        kill_job_port = self.kill_job_socket.bind_to_random_port("tcp://*")
        self.kill_job_address = "{}:{}".format(self.worker_ip, kill_job_port)

    def _create_jobs(self):
        """Create jobs and send a instance of ``InitializedWorker`` that contains the worker information to the master."""
F
fuyw 已提交
134 135 136 137 138 139 140 141 142 143
        try:
            self.request_master_socket.send_multipart(
                [remote_constants.WORKER_CONNECT_TAG])
            _ = self.request_master_socket.recv_multipart()
        except zmq.error.Again as e:
            logger.error("Can not connect to the master, "
                         "please check if master is started.")
            self.master_is_alive = False
            return

B
Bo Zhou 已提交
144
        initialized_jobs = self._init_jobs(job_num=self.cpu_num)
F
fuyw 已提交
145 146 147
        self.request_master_socket.setsockopt(
            zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)

B
Bo Zhou 已提交
148
        self.reply_master_hearbeat_thread = threading.Thread(
F
fuyw 已提交
149
            target=self._reply_heartbeat,
B
Bo Zhou 已提交
150
            args=("master {}".format(self.master_address), ))
B
Bo Zhou 已提交
151
        self.reply_master_hearbeat_thread.start()
F
fuyw 已提交
152 153
        self.heartbeat_socket_initialized.wait()

F
fuyw 已提交
154 155 156 157 158 159
        for job in initialized_jobs:
            job.worker_address = self.master_heartbeat_address

        initialized_worker = InitializedWorker(self.master_heartbeat_address,
                                               initialized_jobs, self.cpu_num,
                                               socket.gethostname())
F
fuyw 已提交
160 161
        self.request_master_socket.send_multipart([
            remote_constants.WORKER_INITIALIZED_TAG,
B
Bo Zhou 已提交
162
            cloudpickle.dumps(initialized_worker)
F
fuyw 已提交
163
        ])
F
fuyw 已提交
164

F
fuyw 已提交
165
        _ = self.request_master_socket.recv_multipart()
F
fuyw 已提交
166
        self.worker_status = WorkerStatus(self.master_heartbeat_address,
B
Bo Zhou 已提交
167 168 169 170 171
                                          initialized_jobs, self.cpu_num)

    def _fill_job_buffer(self):
        """An endless loop that adds initialized job into the job buffer"""
        while self.worker_is_alive:
172 173 174 175
            if self.job_buffer.full() is False:
                initialized_jobs = self._init_jobs(job_num=self.cpu_num)
                for job in initialized_jobs:
                    self.job_buffer.put(job)
B
Bo Zhou 已提交
176 177 178 179 180 181 182

        # release jobs if the worker is not alive
        for job in initialized_jobs:
            try:
                os.kill(job.pid, signal.SIGTERM)
            except OSError:
                pass
F
fuyw 已提交
183

B
Bo Zhou 已提交
184 185 186 187 188 189 190 191
    def _init_jobs(self, job_num):
        """Create jobs.

        Args:
            job_num(int): the number of jobs to create.
        """
        job_file = __file__.replace('worker.pyc', 'job.py')
        job_file = job_file.replace('worker.py', 'job.py')
F
fuyw 已提交
192
        command = [
193 194
            sys.executable, job_file, "--worker_address",
            self.reply_job_address
F
fuyw 已提交
195 196
        ]

F
fuyw 已提交
197 198 199
        if sys.version_info.major == 3:
            warnings.simplefilter("ignore", ResourceWarning)

B
Bo Zhou 已提交
200 201 202
        # avoid that many jobs are killed and restarted at the same time.
        self.lock.acquire()

B
Bo Zhou 已提交
203 204 205
        # Redirect the output to DEVNULL
        FNULL = open(os.devnull, 'w')
        for _ in range(job_num):
B
Bo Zhou 已提交
206
            subprocess.Popen(command, stdout=FNULL, stderr=subprocess.STDOUT)
B
Bo Zhou 已提交
207 208
        FNULL.close()

B
Bo Zhou 已提交
209
        new_jobs = []
B
Bo Zhou 已提交
210 211
        for _ in range(job_num):
            job_message = self.reply_job_socket.recv_multipart()
B
Bo Zhou 已提交
212 213 214 215 216
            self.reply_job_socket.send_multipart(
                [remote_constants.NORMAL_TAG,
                 to_byte(self.kill_job_address)])
            initialized_job = cloudpickle.loads(job_message[1])
            new_jobs.append(initialized_job)
B
Bo Zhou 已提交
217 218 219

            # a thread for sending heartbeat signals to job
            thread = threading.Thread(
B
Bo Zhou 已提交
220
                target=self._create_job_monitor, args=(initialized_job, ))
B
Bo Zhou 已提交
221
            thread.start()
B
Bo Zhou 已提交
222 223 224
        self.lock.release()
        assert len(new_jobs) > 0, "init jobs failed"
        return new_jobs
F
fuyw 已提交
225 226

    def _kill_job(self, job_address):
B
Bo Zhou 已提交
227 228 229 230 231
        """Kill a job process and update worker information"""
        success = self.worker_status.remove_job(job_address)
        if success:
            while True:
                initialized_job = self.job_buffer.get()
F
fuyw 已提交
232
                initialized_job.worker_address = self.master_heartbeat_address
B
Bo Zhou 已提交
233 234 235 236 237 238 239 240 241 242 243 244 245
                if initialized_job.is_alive:
                    self.worker_status.add_job(initialized_job)
                    if not initialized_job.is_alive:  # make sure that the job is still alive.
                        self.worker_status.remove_job(
                            initialized_job.job_address)
                        continue
                else:
                    logger.warning(
                        "[Worker] a dead job found. The job buffer will not accept this one."
                    )
                if initialized_job.is_alive:
                    break

B
Bo Zhou 已提交
246
            self.lock.acquire()
B
Bo Zhou 已提交
247 248 249 250 251 252
            self.request_master_socket.send_multipart([
                remote_constants.NEW_JOB_TAG,
                cloudpickle.dumps(initialized_job),
                to_byte(job_address)
            ])
            _ = self.request_master_socket.recv_multipart()
B
Bo Zhou 已提交
253
            self.lock.release()
F
fuyw 已提交
254

B
Bo Zhou 已提交
255 256
    def _create_job_monitor(self, job):
        """Send heartbeat signals to check target's status"""
F
fuyw 已提交
257 258 259 260 261 262

        # job_heartbeat_socket: sends heartbeat signal to job
        job_heartbeat_socket = self.ctx.socket(zmq.REQ)
        job_heartbeat_socket.linger = 0
        job_heartbeat_socket.setsockopt(
            zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
B
Bo Zhou 已提交
263
        job_heartbeat_socket.connect("tcp://" + job.worker_heartbeat_address)
F
fuyw 已提交
264

B
Bo Zhou 已提交
265 266
        job.is_alive = True
        while job.is_alive and self.master_is_alive and self.worker_is_alive:
F
fuyw 已提交
267 268 269 270 271 272
            try:
                job_heartbeat_socket.send_multipart(
                    [remote_constants.HEARTBEAT_TAG])
                _ = job_heartbeat_socket.recv_multipart()
                time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)
            except zmq.error.Again as e:
B
Bo Zhou 已提交
273 274 275 276 277 278
                job.is_alive = False
                logger.warning(
                    "[Worker] lost connection with the job:{}".format(
                        job.job_address))
                if self.master_is_alive and self.worker_is_alive:
                    self._kill_job(job.job_address)
F
fuyw 已提交
279 280 281 282 283 284

            except zmq.error.ZMQError as e:
                break

        job_heartbeat_socket.close(0)

B
Bo Zhou 已提交
285 286 287 288 289 290 291 292
    def _reply_kill_job(self):
        """Worker starts a thread to wait jobs' commands to kill the job"""
        self.kill_job_socket.linger = 0
        self.kill_job_socket.setsockopt(
            zmq.RCVTIMEO, remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
        while self.worker_is_alive and self.master_is_alive:
            try:
                message = self.kill_job_socket.recv_multipart()
293 294
                tag = message[0]
                assert tag == remote_constants.KILLJOB_TAG
B
Bo Zhou 已提交
295 296 297 298 299 300 301 302
                to_kill_job_address = to_str(message[1])
                self._kill_job(to_kill_job_address)
                self.kill_job_socket.send_multipart(
                    [remote_constants.NORMAL_TAG])
            except zmq.error.Again as e:
                #detect whether `self.worker_is_alive` is True periodically
                pass

F
fuyw 已提交
303 304 305 306 307 308
    def _get_worker_status(self):
        now = datetime.strftime(datetime.now(), '%H:%M:%S')
        virtual_memory = psutil.virtual_memory()
        total_memory = round(virtual_memory[0] / (1024**3), 2)
        used_memory = round(virtual_memory[3] / (1024**3), 2)
        vacant_memory = round(total_memory - used_memory, 2)
F
fuyw 已提交
309
        load_average = round(os.getloadavg()[0], 2)
F
fuyw 已提交
310 311
        return (vacant_memory, used_memory, now, load_average)

F
fuyw 已提交
312 313 314 315 316 317 318 319 320 321
    def _reply_heartbeat(self, target):
        """Worker will kill its jobs when it lost connection with the master.
        """

        socket = self.ctx.socket(zmq.REP)
        socket.linger = 0
        socket.setsockopt(zmq.RCVTIMEO,
                          remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
        heartbeat_master_port =\
            socket.bind_to_random_port("tcp://*")
B
Bo Zhou 已提交
322
        self.master_heartbeat_address = "{}:{}".format(self.worker_ip,
F
fuyw 已提交
323
                                                       heartbeat_master_port)
F
fuyw 已提交
324 325 326 327 328

        logger.set_dir(
            os.path.expanduser('~/.parl_data/worker/{}'.format(
                self.master_heartbeat_address)))

F
fuyw 已提交
329 330 331
        self.heartbeat_socket_initialized.set()
        logger.info("[Worker] Connect to the master node successfully. "
                    "({} CPUs)".format(self.cpu_num))
332
        while self.master_is_alive and self.worker_is_alive:
F
fuyw 已提交
333 334
            try:
                message = socket.recv_multipart()
F
fuyw 已提交
335 336 337 338 339 340 341 342
                worker_status = self._get_worker_status()
                socket.send_multipart([
                    remote_constants.HEARTBEAT_TAG,
                    to_byte(str(worker_status[0])),
                    to_byte(str(worker_status[1])),
                    to_byte(worker_status[2]),
                    to_byte(str(worker_status[3]))
                ])
F
fuyw 已提交
343 344 345 346 347
            except zmq.error.Again as e:
                self.master_is_alive = False
            except zmq.error.ContextTerminated as e:
                break
        socket.close(0)
348 349 350
        logger.warning(
            "[Worker] lost connection with the master, will exit replying heartbeat for master."
        )
B
Bo Zhou 已提交
351
        self.worker_status.clear()
352 353
        # exit the worker
        self.worker_is_alive = False
F
fuyw 已提交
354 355

    def exit(self):
356
        """close the worker"""
F
fuyw 已提交
357 358 359
        self.worker_is_alive = False

    def run(self):
B
Bo Zhou 已提交
360
        """Keep running until it lost connection with the master.
F
fuyw 已提交
361
        """
B
Bo Zhou 已提交
362
        self.reply_master_hearbeat_thread.join()