提交 bc2c3ad3 编写于 作者: B Bo Zhou 提交者: Hongsheng Zeng

fix bug: removing the job from the cluster twice; add test scripts (#121)

* fix bug: removing the job from the cluster twice; add test scripts

* fix attribute error

* fix comment

* rename

* fix the bug in unit test

* remove queue
上级 c1646351
...@@ -47,11 +47,12 @@ class Master(object): ...@@ -47,11 +47,12 @@ class Master(object):
client_job_dict (dict): A dict of list to record the job submitted by client_job_dict (dict): A dict of list to record the job submitted by
each client. each client.
job_worker_dict (dict): A dict to record the job and related worker. job_worker_dict (dict): A dict to record the job and related worker.
client_socket (zmq.Context.socket): A socket which receives submitted client_socket (zmq.Context.socket): A socket that receives submitted
job from the client, and later sends job from the client, and later sends
job_address back to the client. job_address back to the client.
worker_socket (zmq.Context.socket): A socket which receives job worker_socket (zmq.Context.socket): A socket that receives job
addresses from the worker node. addresses from the worker node.
cpu_num(int): the number of available CPUs in the cluster.
Args: Args:
port: the ip port that the master node binds to. port: the ip port that the master node binds to.
...@@ -101,6 +102,7 @@ class Master(object): ...@@ -101,6 +102,7 @@ class Master(object):
self.job_worker_dict.pop(job) self.job_worker_dict.pop(job)
self.worker_job_dict.pop(worker_address) self.worker_job_dict.pop(worker_address)
self.worker_pool.pop(worker_address) self.worker_pool.pop(worker_address)
self.worker_locks.pop(worker_address)
logger.warning("\n[Master] Cannot connect to the worker " + logger.warning("\n[Master] Cannot connect to the worker " +
"{}. ".format(worker_address) + "{}. ".format(worker_address) +
"Worker_pool will drop this worker.") "Worker_pool will drop this worker.")
...@@ -151,8 +153,12 @@ class Master(object): ...@@ -151,8 +153,12 @@ class Master(object):
for job_address in jobs: for job_address in jobs:
if job_address in self.job_worker_dict: if job_address in self.job_worker_dict:
worker_address = self.job_worker_dict[job_address] worker_address = self.job_worker_dict[job_address]
# ignore this worker if it has been deleted
if worker_address not in self.worker_pool:
continue
worker_socket = self.worker_pool[worker_address].worker_socket worker_socket = self.worker_pool[worker_address].worker_socket
self.worker_locks[worker_address].acquire() lock = self.worker_locks[worker_address]
lock.acquire()
worker_socket.send_multipart( worker_socket.send_multipart(
[remote_constants.KILLJOB_TAG, [remote_constants.KILLJOB_TAG,
to_byte(job_address)]) to_byte(job_address)])
...@@ -160,7 +166,7 @@ class Master(object): ...@@ -160,7 +166,7 @@ class Master(object):
_ = worker_socket.recv_multipart() _ = worker_socket.recv_multipart()
except zmq.error.Again as e: except zmq.error.Again as e:
logger.warning("Error in recv kill_client_job") logger.warning("Error in recv kill_client_job")
self.worker_locks[worker_address].release() lock.release()
self.job_worker_dict.pop(job_address) self.job_worker_dict.pop(job_address)
self.client_job_dict.pop(client_address) self.client_job_dict.pop(client_address)
...@@ -170,6 +176,10 @@ class Master(object): ...@@ -170,6 +176,10 @@ class Master(object):
"Master connects to {} workers and have {} vacant CPUs.\n".format( "Master connects to {} workers and have {} vacant CPUs.\n".format(
len(self.worker_pool), len(self.job_pool))) len(self.worker_pool), len(self.job_pool)))
@property
def cpu_num(self):
return len(self.job_pool)
def _receive_message(self): def _receive_message(self):
"""master node will receive four types of message: (1) worker """master node will receive four types of message: (1) worker
connection; (2) worker update; (3) client connection; (4) job connection; (2) worker update; (3) client connection; (4) job
...@@ -212,7 +222,6 @@ class Master(object): ...@@ -212,7 +222,6 @@ class Master(object):
worker_heartbeat_address, worker_heartbeat_address,
worker.address, worker.address,
)) ))
thread.setDaemon(True)
thread.start() thread.start()
self.client_socket.send_multipart([remote_constants.NORMAL_TAG]) self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
...@@ -226,7 +235,6 @@ class Master(object): ...@@ -226,7 +235,6 @@ class Master(object):
thread = threading.Thread( thread = threading.Thread(
target=self._create_client_monitor, target=self._create_client_monitor,
args=(client_heartbeat_address, )) args=(client_heartbeat_address, ))
thread.setDaemon(True)
thread.start() thread.start()
self.client_socket.send_multipart([remote_constants.NORMAL_TAG]) self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
...@@ -281,7 +289,6 @@ class Master(object): ...@@ -281,7 +289,6 @@ class Master(object):
def exit(self): def exit(self):
self.master_is_alive = False self.master_is_alive = False
self.ctx.destroy()
def run(self): def run(self):
"""An infinite loop waiting for messages from the workers and """An infinite loop waiting for messages from the workers and
...@@ -295,10 +302,21 @@ class Master(object): ...@@ -295,10 +302,21 @@ class Master(object):
3. A new client connects to the master node. 3. A new client connects to the master node.
4. A connected client submits a job after a remote object is created. 4. A connected client submits a job after a remote object is created.
""" """
self.client_socket.linger = 0
self.client_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
while self.master_is_alive: while self.master_is_alive:
try: try:
self._receive_message() self._receive_message()
except zmq.error.ContextTerminated as e:
pass pass
except zmq.error.Again as e:
#detect whether `self.master_is_alive` is True periodically
pass
for worker_address, worker in self.worker_pool.items():
lock = self.worker_locks[worker_address]
lock.acquire()
worker.worker_socket.close(0)
lock.release()
logger.warning("[Master] Exit master.") logger.warning("[Master] Exit master.")
...@@ -113,9 +113,12 @@ def remote_class(cls): ...@@ -113,9 +113,12 @@ def remote_class(cls):
def __del__(self): def __del__(self):
"""Delete the remote class object and release remote resources.""" """Delete the remote class object and release remote resources."""
self.job_socket.send_multipart([remote_constants.KILLJOB_TAG]) try:
_ = self.job_socket.recv_multipart() self.job_socket.send_multipart([remote_constants.KILLJOB_TAG])
self.job_socket.close(0) _ = self.job_socket.recv_multipart()
self.job_socket.close(0)
except AttributeError:
pass
def send_file(self, socket): def send_file(self, socket):
try: try:
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
@parl.remote_class
class Actor(object):
pass
class TestClient(unittest.TestCase):
def test_not_init(self):
"""client is expected to raise an error and say that the master has not been started"""
def create_actor():
actor = Actor()
self.assertRaises(AssertionError, create_actor)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
from parl.remote.master import Master
from parl.remote.worker import Worker
import time
import threading
from parl.remote.client import disconnect
@parl.remote_class
class Actor(object):
def __init__(self, arg1=None, arg2=None):
self.arg1 = arg1
self.arg2 = arg2
def get_arg1(self):
return self.arg1
def get_arg2(self):
return self.arg2
def set_arg1(self, value):
self.arg1 = value
def set_arg2(self, value):
self.arg2 = value
def get_unable_serialize_object(self):
return UnableSerializeObject()
def add_one(self, value):
value += 1
return value
def add(self, x, y):
time.sleep(3)
return x + y
def will_raise_exception_func(self):
x = 1 / 0
class TestExit(unittest.TestCase):
def test_delete_worker(self):
# start the master
master = Master(port=1235)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:1235', 4)
parl.connect('localhost:1235')
for i in range(4):
actor = Actor()
ret = actor.add_one(1)
self.assertEqual(ret, 2)
worker1.exit()
time.sleep(30)
disconnect()
time.sleep(30)
master.exit()
def test_add_worker(self):
master = Master(port=1234)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:1234', 4)
self.assertEqual(master.cpu_num, 4)
worker2 = Worker('localhost:1234', 4)
self.assertEqual(master.cpu_num, 8)
worker2.exit()
time.sleep(30)
self.assertEqual(master.cpu_num, 4)
master.exit()
if __name__ == '__main__':
unittest.main()
...@@ -152,7 +152,6 @@ class Worker(object): ...@@ -152,7 +152,6 @@ class Worker(object):
reply_thread = threading.Thread( reply_thread = threading.Thread(
target=self._reply_heartbeat, target=self._reply_heartbeat,
args=("master {}".format(self.master_address), )) args=("master {}".format(self.master_address), ))
reply_thread.setDaemon(True)
reply_thread.start() reply_thread.start()
self.heartbeat_socket_initialized.wait() self.heartbeat_socket_initialized.wait()
...@@ -172,7 +171,8 @@ class Worker(object): ...@@ -172,7 +171,8 @@ class Worker(object):
job_file = __file__.replace('worker.pyc', 'job.py') job_file = __file__.replace('worker.pyc', 'job.py')
job_file = job_file.replace('worker.py', 'job.py') job_file = job_file.replace('worker.py', 'job.py')
command = [ command = [
"python", job_file, "--worker_address", self.reply_job_address sys.executable, job_file, "--worker_address",
self.reply_job_address
] ]
# Redirect the output to DEVNULL # Redirect the output to DEVNULL
...@@ -199,7 +199,6 @@ class Worker(object): ...@@ -199,7 +199,6 @@ class Worker(object):
job_address, job_address,
heartbeat_job_address, heartbeat_job_address,
)) ))
thread.setDaemon(True)
thread.start() thread.start()
assert len(new_job_address) > 0, "init jobs failed" assert len(new_job_address) > 0, "init jobs failed"
if len(new_job_address) > 1: if len(new_job_address) > 1:
...@@ -245,7 +244,7 @@ class Worker(object): ...@@ -245,7 +244,7 @@ class Worker(object):
job_heartbeat_socket.connect("tcp://" + heartbeat_job_address) job_heartbeat_socket.connect("tcp://" + heartbeat_job_address)
job_is_alive = True job_is_alive = True
while job_is_alive and self.master_is_alive: while job_is_alive and self.master_is_alive and self.worker_is_alive:
try: try:
job_heartbeat_socket.send_multipart( job_heartbeat_socket.send_multipart(
[remote_constants.HEARTBEAT_TAG]) [remote_constants.HEARTBEAT_TAG])
...@@ -279,7 +278,7 @@ class Worker(object): ...@@ -279,7 +278,7 @@ class Worker(object):
self.heartbeat_socket_initialized.set() self.heartbeat_socket_initialized.set()
logger.info("[Worker] Connect to the master node successfully. " logger.info("[Worker] Connect to the master node successfully. "
"({} CPUs)".format(self.cpu_num)) "({} CPUs)".format(self.cpu_num))
while self.master_is_alive: while self.master_is_alive and self.worker_is_alive:
try: try:
message = socket.recv_multipart() message = socket.recv_multipart()
socket.send_multipart([remote_constants.HEARTBEAT_TAG]) socket.send_multipart([remote_constants.HEARTBEAT_TAG])
...@@ -290,14 +289,15 @@ class Worker(object): ...@@ -290,14 +289,15 @@ class Worker(object):
except zmq.error.ContextTerminated as e: except zmq.error.ContextTerminated as e:
break break
socket.close(0) socket.close(0)
logger.warning("Worker exit replying heartbeat for master.") logger.warning(
if self.worker_is_alive: "[Worker] lost connection with the master, will exit replying heartbeat for master."
self.exit() )
# exit the worker
self.worker_is_alive = False
def exit(self): def exit(self):
"""Exit all zmq sockets related to the worker.""" """close the worker"""
self.worker_is_alive = False self.worker_is_alive = False
self.ctx.destroy()
def run(self): def run(self):
"""An infinite loop waiting for killing job commands from """An infinite loop waiting for killing job commands from
...@@ -310,6 +310,10 @@ class Worker(object): ...@@ -310,6 +310,10 @@ class Worker(object):
new jobs and update job addresses to the master node. new jobs and update job addresses to the master node.
""" """
self.reply_master_socket.linger = 0
self.reply_master_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
while self.master_is_alive and self.worker_is_alive: while self.master_is_alive and self.worker_is_alive:
try: try:
message = self.reply_master_socket.recv_multipart() message = self.reply_master_socket.recv_multipart()
...@@ -323,8 +327,13 @@ class Worker(object): ...@@ -323,8 +327,13 @@ class Worker(object):
else: else:
raise NotImplementedError raise NotImplementedError
except zmq.error.ZMQError as e: except zmq.error.Again as e:
self.worker_is_alive = False #detect whether `self.worker_is_alive` is True periodically
pass
self.reply_job_socket.close(0)
self.request_master_socket.close(0)
self.reply_master_socket.close(0)
logger.warning("[Worker] Exit Worker {}.".format( logger.warning("[Worker] Exit Worker {}.".format(
self.reply_master_address)) self.reply_master_address))
self.ctx.destroy()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册