提交 bc2c3ad3 编写于 作者: B Bo Zhou 提交者: Hongsheng Zeng

fix bug: removing the job from the cluster twice; add test scripts (#121)

* fix bug: removing the job from the cluster twice; add test scripts

* fix attribute error

* fix comment

* rename

* fix the bug in unit test

* remove queue
上级 c1646351
......@@ -47,11 +47,12 @@ class Master(object):
client_job_dict (dict): A dict of list to record the job submitted by
each client.
job_worker_dict (dict): A dict to record the job and related worker.
client_socket (zmq.Context.socket): A socket which receives submitted
client_socket (zmq.Context.socket): A socket that receives submitted
job from the client, and later sends
job_address back to the client.
worker_socket (zmq.Context.socket): A socket which receives job
worker_socket (zmq.Context.socket): A socket that receives job
addresses from the worker node.
cpu_num(int): the number of available CPUs in the cluster.
Args:
port: the ip port that the master node binds to.
......@@ -101,6 +102,7 @@ class Master(object):
self.job_worker_dict.pop(job)
self.worker_job_dict.pop(worker_address)
self.worker_pool.pop(worker_address)
self.worker_locks.pop(worker_address)
logger.warning("\n[Master] Cannot connect to the worker " +
"{}. ".format(worker_address) +
"Worker_pool will drop this worker.")
......@@ -151,8 +153,12 @@ class Master(object):
for job_address in jobs:
if job_address in self.job_worker_dict:
worker_address = self.job_worker_dict[job_address]
# ignore this worker if it has been deleted
if worker_address not in self.worker_pool:
continue
worker_socket = self.worker_pool[worker_address].worker_socket
self.worker_locks[worker_address].acquire()
lock = self.worker_locks[worker_address]
lock.acquire()
worker_socket.send_multipart(
[remote_constants.KILLJOB_TAG,
to_byte(job_address)])
......@@ -160,7 +166,7 @@ class Master(object):
_ = worker_socket.recv_multipart()
except zmq.error.Again as e:
logger.warning("Error in recv kill_client_job")
self.worker_locks[worker_address].release()
lock.release()
self.job_worker_dict.pop(job_address)
self.client_job_dict.pop(client_address)
......@@ -170,6 +176,10 @@ class Master(object):
"Master connects to {} workers and have {} vacant CPUs.\n".format(
len(self.worker_pool), len(self.job_pool)))
@property
def cpu_num(self):
return len(self.job_pool)
def _receive_message(self):
"""master node will receive four types of message: (1) worker
connection; (2) worker update; (3) client connection; (4) job
......@@ -212,7 +222,6 @@ class Master(object):
worker_heartbeat_address,
worker.address,
))
thread.setDaemon(True)
thread.start()
self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
......@@ -226,7 +235,6 @@ class Master(object):
thread = threading.Thread(
target=self._create_client_monitor,
args=(client_heartbeat_address, ))
thread.setDaemon(True)
thread.start()
self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
......@@ -281,7 +289,6 @@ class Master(object):
def exit(self):
self.master_is_alive = False
self.ctx.destroy()
def run(self):
"""An infinite loop waiting for messages from the workers and
......@@ -295,10 +302,21 @@ class Master(object):
3. A new client connects to the master node.
4. A connected client submits a job after a remote object is created.
"""
self.client_socket.linger = 0
self.client_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
while self.master_is_alive:
try:
self._receive_message()
except zmq.error.ContextTerminated as e:
pass
except zmq.error.Again as e:
#detect whether `self.master_is_alive` is True periodically
pass
for worker_address, worker in self.worker_pool.items():
lock = self.worker_locks[worker_address]
lock.acquire()
worker.worker_socket.close(0)
lock.release()
logger.warning("[Master] Exit master.")
......@@ -113,9 +113,12 @@ def remote_class(cls):
def __del__(self):
"""Delete the remote class object and release remote resources."""
self.job_socket.send_multipart([remote_constants.KILLJOB_TAG])
_ = self.job_socket.recv_multipart()
self.job_socket.close(0)
try:
self.job_socket.send_multipart([remote_constants.KILLJOB_TAG])
_ = self.job_socket.recv_multipart()
self.job_socket.close(0)
except AttributeError:
pass
def send_file(self, socket):
try:
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
@parl.remote_class
class Actor(object):
pass
class TestClient(unittest.TestCase):
def test_not_init(self):
"""client is expected to raise an error and say that the master has not been started"""
def create_actor():
actor = Actor()
self.assertRaises(AssertionError, create_actor)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
from parl.remote.master import Master
from parl.remote.worker import Worker
import time
import threading
from parl.remote.client import disconnect
@parl.remote_class
class Actor(object):
def __init__(self, arg1=None, arg2=None):
self.arg1 = arg1
self.arg2 = arg2
def get_arg1(self):
return self.arg1
def get_arg2(self):
return self.arg2
def set_arg1(self, value):
self.arg1 = value
def set_arg2(self, value):
self.arg2 = value
def get_unable_serialize_object(self):
return UnableSerializeObject()
def add_one(self, value):
value += 1
return value
def add(self, x, y):
time.sleep(3)
return x + y
def will_raise_exception_func(self):
x = 1 / 0
class TestExit(unittest.TestCase):
def test_delete_worker(self):
# start the master
master = Master(port=1235)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:1235', 4)
parl.connect('localhost:1235')
for i in range(4):
actor = Actor()
ret = actor.add_one(1)
self.assertEqual(ret, 2)
worker1.exit()
time.sleep(30)
disconnect()
time.sleep(30)
master.exit()
def test_add_worker(self):
master = Master(port=1234)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:1234', 4)
self.assertEqual(master.cpu_num, 4)
worker2 = Worker('localhost:1234', 4)
self.assertEqual(master.cpu_num, 8)
worker2.exit()
time.sleep(30)
self.assertEqual(master.cpu_num, 4)
master.exit()
if __name__ == '__main__':
unittest.main()
......@@ -152,7 +152,6 @@ class Worker(object):
reply_thread = threading.Thread(
target=self._reply_heartbeat,
args=("master {}".format(self.master_address), ))
reply_thread.setDaemon(True)
reply_thread.start()
self.heartbeat_socket_initialized.wait()
......@@ -172,7 +171,8 @@ class Worker(object):
job_file = __file__.replace('worker.pyc', 'job.py')
job_file = job_file.replace('worker.py', 'job.py')
command = [
"python", job_file, "--worker_address", self.reply_job_address
sys.executable, job_file, "--worker_address",
self.reply_job_address
]
# Redirect the output to DEVNULL
......@@ -199,7 +199,6 @@ class Worker(object):
job_address,
heartbeat_job_address,
))
thread.setDaemon(True)
thread.start()
assert len(new_job_address) > 0, "init jobs failed"
if len(new_job_address) > 1:
......@@ -245,7 +244,7 @@ class Worker(object):
job_heartbeat_socket.connect("tcp://" + heartbeat_job_address)
job_is_alive = True
while job_is_alive and self.master_is_alive:
while job_is_alive and self.master_is_alive and self.worker_is_alive:
try:
job_heartbeat_socket.send_multipart(
[remote_constants.HEARTBEAT_TAG])
......@@ -279,7 +278,7 @@ class Worker(object):
self.heartbeat_socket_initialized.set()
logger.info("[Worker] Connect to the master node successfully. "
"({} CPUs)".format(self.cpu_num))
while self.master_is_alive:
while self.master_is_alive and self.worker_is_alive:
try:
message = socket.recv_multipart()
socket.send_multipart([remote_constants.HEARTBEAT_TAG])
......@@ -290,14 +289,15 @@ class Worker(object):
except zmq.error.ContextTerminated as e:
break
socket.close(0)
logger.warning("Worker exit replying heartbeat for master.")
if self.worker_is_alive:
self.exit()
logger.warning(
"[Worker] lost connection with the master, will exit replying heartbeat for master."
)
# exit the worker
self.worker_is_alive = False
def exit(self):
"""Exit all zmq sockets related to the worker."""
"""close the worker"""
self.worker_is_alive = False
self.ctx.destroy()
def run(self):
"""An infinite loop waiting for killing job commands from
......@@ -310,6 +310,10 @@ class Worker(object):
new jobs and update job addresses to the master node.
"""
self.reply_master_socket.linger = 0
self.reply_master_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
while self.master_is_alive and self.worker_is_alive:
try:
message = self.reply_master_socket.recv_multipart()
......@@ -323,8 +327,13 @@ class Worker(object):
else:
raise NotImplementedError
except zmq.error.ZMQError as e:
self.worker_is_alive = False
except zmq.error.Again as e:
#detect whether `self.worker_is_alive` is True periodically
pass
self.reply_job_socket.close(0)
self.request_master_socket.close(0)
self.reply_master_socket.close(0)
logger.warning("[Worker] Exit Worker {}.".format(
self.reply_master_address))
self.ctx.destroy()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册