提交 84c94cc4 编写于 作者: B Bo Zhou 提交者: Hongsheng Zeng

Cluster#2 (#124)

* fix a number of bug

* update comments

* paddle license

* revert the change on client_not_init_test.py

* fix timeout bug & unnitest

* add unit test for potential bugs

* yapf

* unit test

* logger

* logger#2

* fix job's bug that will cause dead-lock

* fix bug

* add missing files

* solve hanging

* fix comments; exit job completely

* hanging problem

* hanging problem#2

* hanging_program#3

* CMAKE

* remove unused variables
上级 a7670972
...@@ -60,7 +60,7 @@ function run_test_with_gpu() { ...@@ -60,7 +60,7 @@ function run_test_with_gpu() {
mkdir -p ${REPO_ROOT}/build mkdir -p ${REPO_ROOT}/build
cd ${REPO_ROOT}/build cd ${REPO_ROOT}/build
cmake .. cmake .. -DIS_TESTING_GPU=ON
cat <<EOF cat <<EOF
======================================== ========================================
Running unit tests with GPU... Running unit tests with GPU...
......
...@@ -3,3 +3,4 @@ paddlepaddle-gpu==1.5.1.post97 ...@@ -3,3 +3,4 @@ paddlepaddle-gpu==1.5.1.post97
gym gym
details details
parameterized parameterized
timeout_decorator
...@@ -17,26 +17,34 @@ cmake_minimum_required(VERSION 3.0) ...@@ -17,26 +17,34 @@ cmake_minimum_required(VERSION 3.0)
enable_testing() enable_testing()
option(WITH_TESTING "Include unit testing" ON) option(WITH_TESTING "Include unit testing" ON)
option(IS_TESTING_IMPORT "Whether is testing import parl" OFF) option(IS_TESTING_IMPORT "testing import parl" OFF)
option(IS_TESTING_DOCS "Whether is testing compling the docs" OFF) option(IS_TESTING_DOCS "testing compling the docs" OFF)
option(IS_TESTING_GPU "testing GPU environment" OFF)
set(PADDLE_PYTHON_PATH "" CACHE STRING "Python path to PaddlePaddle Fluid") set(PADDLE_PYTHON_PATH "" CACHE STRING "Python path to PaddlePaddle Fluid")
function(py_test TARGET_NAME) function(py3_test TARGET_NAME)
set(options "") set(options "")
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS ARGS ENVS) set(multiValueArgs SRCS DEPS ARGS ENVS)
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(py3_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME}_with_python2 #TODO: add real python2 env.
COMMAND env PYTHONPATH=.:${py_test_ENVS}
python -u ${py_test_SRCS} ${py_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_test(NAME ${TARGET_NAME}_with_python3 add_test(NAME ${TARGET_NAME}_with_python3
COMMAND env PYTHONPATH=.:${py_test_ENVS} COMMAND python3.6 ${py3_test_SRCS} ${py3_test_ARGS}
python3.6 -u ${py_test_SRCS} ${py_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endfunction() endfunction()
#function(py2_test TARGET_NAME)
# set(options "")
# set(oneValueArgs "")
# set(multiValueArgs SRCS DEPS ARGS ENVS)
# cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
# #TODO: add real python2 env.
# add_test(NAME ${TARGET_NAME}_with_python2
# COMMAND python ${py_test_SRCS} ${py_test_ARGS}
# WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
#endfunction()
function(import_test TARGET_NAME) function(import_test TARGET_NAME)
set(options "") set(options "")
set(oneValueArgs "") set(oneValueArgs "")
...@@ -44,8 +52,7 @@ function(import_test TARGET_NAME) ...@@ -44,8 +52,7 @@ function(import_test TARGET_NAME)
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME}_with_empty_env add_test(NAME ${TARGET_NAME}_with_empty_env
COMMAND env PYTHONPATH=.:${py_test_ENVS} COMMAND /root/miniconda3/envs/empty_env/bin/python -u ${py_test_SRCS} ${py_test_ARGS}
/root/miniconda3/envs/empty_env/bin/python -u ${py_test_SRCS} ${py_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endfunction() endfunction()
...@@ -66,7 +73,13 @@ if (WITH_TESTING) ...@@ -66,7 +73,13 @@ if (WITH_TESTING)
file(GLOB_RECURSE TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_test.py") file(GLOB_RECURSE TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_test.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(src ${TEST_OPS}) foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH}) if (${src} MATCHES ".*remote.*")
if (NOT IS_TESTING_GPU)
py3_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH})
endif()
else()
py3_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH})
endif()
endforeach() endforeach()
endif() endif()
endif() endif()
...@@ -83,7 +83,7 @@ class Model(ModelBase): ...@@ -83,7 +83,7 @@ class Model(ModelBase):
Args: Args:
target_model (`parl.Model`): an instance of ``Model`` that has the same neural network architecture as the current model. target_model (`parl.Model`): an instance of ``Model`` that has the same neural network architecture as the current model.
decay (float): the rate of decline in copying parameters. 0 if no parameters decay when synchronizing the parameters. decay (float): the rate of decline in copying parameters. 0 if no parameters decay when synchronizing the parameters.
share_vars_parallel_executor (fluid.ParallelExecutor): if not None, will use fluid.ParallelExecutor share_vars_parallel_executor (fluid.ParallelExecutor): Optional. If not None, will use fluid.ParallelExecutor
to run program instead of fluid.Executor to run program instead of fluid.Executor
""" """
self.sync_weights_to( self.sync_weights_to(
...@@ -97,15 +97,15 @@ class Model(ModelBase): ...@@ -97,15 +97,15 @@ class Model(ModelBase):
share_vars_parallel_executor=None): share_vars_parallel_executor=None):
"""Synchronize parameters of current model to another model. """Synchronize parameters of current model to another model.
To speed up the synchronizing process, will create a program implicitly to finish the process. And will To speed up the synchronizing process, it will create a program implicitly to finish the process. It
also cache the program to avoid creating program repeatedly. also stores a program as the cache to avoid creating program repeatedly.
target_model_weights = decay * target_model_weights + (1 - decay) * current_model_weights target_model_weights = decay * target_model_weights + (1 - decay) * current_model_weights
Args: Args:
target_model (`parl.Model`): an instance of ``Model`` that has the same neural network architecture as the current model. target_model (`parl.Model`): an instance of ``Model`` that has the same neural network architecture as the current model.
decay (float): the rate of decline in copying parameters. 0 if no parameters decay when synchronizing the parameters. decay (float): the rate of decline in copying parameters. 0 if no parameters decay when synchronizing the parameters.
share_vars_parallel_executor (fluid.ParallelExecutor): if not None, will use ``fluid.ParallelExecutor`` share_vars_parallel_executor (fluid.ParallelExecutor): Optional. If not None, will use ``fluid.ParallelExecutor``
to run program instead of ``fluid.Executor``. to run program instead of ``fluid.Executor``.
Example: Example:
......
...@@ -101,7 +101,7 @@ class Client(object): ...@@ -101,7 +101,7 @@ class Client(object):
"address {} is correct.".format(master_address)) "address {} is correct.".format(master_address))
def _reply_heartbeat(self): def _reply_heartbeat(self):
"""Reply heartbeat signals to the Master node.""" """Reply heartbeat signals to the specific node."""
socket = self.ctx.socket(zmq.REP) socket = self.ctx.socket(zmq.REP)
socket.linger = 0 socket.linger = 0
...@@ -124,6 +124,57 @@ class Client(object): ...@@ -124,6 +124,57 @@ class Client(object):
socket.close(0) socket.close(0)
logger.warning("Client exit replying heartbeat for master.") logger.warning("Client exit replying heartbeat for master.")
def _check_and_monitor_job(self, job_heartbeat_address,
ping_heartbeat_address):
""" Sometimes the client may receive a job that is dead, thus
we have to check if this job is still alive before sending it to the actor.
"""
# job_heartbeat_socket: sends heartbeat signal to job
job_heartbeat_socket = self.ctx.socket(zmq.REQ)
job_heartbeat_socket.linger = 0
job_heartbeat_socket.setsockopt(zmq.RCVTIMEO, int(0.9 * 1000))
job_heartbeat_socket.connect("tcp://" + ping_heartbeat_address)
try:
job_heartbeat_socket.send_multipart(
[remote_constants.HEARTBEAT_TAG])
job_heartbeat_socket.recv_multipart()
except zmq.error.Again:
job_heartbeat_socket.close(0)
logger.error(
"[Client] connects to a finished job, will try again, ping_heartbeat_address:{}"
.format(ping_heartbeat_address))
return False
job_heartbeat_socket.disconnect("tcp://" + ping_heartbeat_address)
job_heartbeat_socket.connect("tcp://" + job_heartbeat_address)
job_heartbeat_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
# a thread for sending heartbeat signals to job
thread = threading.Thread(
target=self._create_job_monitor, args=(job_heartbeat_socket, ))
thread.setDaemon(True)
thread.start()
return True
def _create_job_monitor(self, job_heartbeat_socket):
"""Send heartbeat signals to check target's status"""
job_is_alive = True
while job_is_alive and self.client_is_alive:
try:
job_heartbeat_socket.send_multipart(
[remote_constants.HEARTBEAT_TAG])
_ = job_heartbeat_socket.recv_multipart()
time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)
except zmq.error.Again as e:
job_is_alive = False
except zmq.error.ZMQError as e:
break
job_heartbeat_socket.close(0)
def submit_job(self): def submit_job(self):
"""Send a job to the Master node. """Send a job to the Master node.
...@@ -132,35 +183,44 @@ class Client(object): ...@@ -132,35 +183,44 @@ class Client(object):
a vacant job from its job pool to the remote object. a vacant job from its job pool to the remote object.
Returns: Returns:
IP address of the job. job_address(str): IP address of the job. None if there is no available CPU in the cluster.
""" """
if self.master_is_alive: if self.master_is_alive:
# A lock to prevent multiple actor submit job at the same time. while True:
self.lock.acquire() # A lock to prevent multiple actors from submitting job at the same time.
self.submit_job_socket.send_multipart([ self.lock.acquire()
remote_constants.CLIENT_SUBMIT_TAG, self.submit_job_socket.send_multipart([
to_byte(self.heartbeat_master_address) remote_constants.CLIENT_SUBMIT_TAG,
]) to_byte(self.heartbeat_master_address)
message = self.submit_job_socket.recv_multipart() ])
self.lock.release() message = self.submit_job_socket.recv_multipart()
self.lock.release()
tag = message[0]
tag = message[0]
if tag == remote_constants.NORMAL_TAG:
job_address = to_str(message[1]) if tag == remote_constants.NORMAL_TAG:
job_address = to_str(message[1])
# no vacant CPU resources, can not submit a new job job_heartbeat_address = to_str(message[2])
elif tag == remote_constants.CPU_TAG: ping_heartbeat_address = to_str(message[3])
job_address = None
# wait 1 second to avoid requesting in a high frequency. check_result = self._check_and_monitor_job(
time.sleep(1) job_heartbeat_address, ping_heartbeat_address)
else: if check_result:
raise NotImplementedError return job_address
# no vacant CPU resources, cannot submit a new job
elif tag == remote_constants.CPU_TAG:
job_address = None
# wait 1 second to avoid requesting in a high frequency.
time.sleep(1)
return job_address
else:
raise NotImplementedError
else: else:
raise Exception("Client can not submit job to the master, " raise Exception("Client can not submit job to the master, "
"please check if master is connected.") "please check if master is connected.")
return job_address return None
GLOBAL_CLIENT = None GLOBAL_CLIENT = None
...@@ -203,5 +263,10 @@ def get_global_client(): ...@@ -203,5 +263,10 @@ def get_global_client():
def disconnect(): def disconnect():
"""Disconnect the global client from the master node.""" """Disconnect the global client from the master node."""
global GLOBAL_CLIENT global GLOBAL_CLIENT
GLOBAL_CLIENT.client_is_alive = False if GLOBAL_CLIENT is not None:
GLOBAL_CLIENT = None GLOBAL_CLIENT.client_is_alive = False
GLOBAL_CLIENT = None
else:
logger.info(
"No client to be released. Please make sure that you have call `parl.connect`"
)
...@@ -29,6 +29,7 @@ from parl.utils.communication import loads_argument, loads_return,\ ...@@ -29,6 +29,7 @@ from parl.utils.communication import loads_argument, loads_return,\
dumps_argument, dumps_return dumps_argument, dumps_return
from parl.remote import remote_constants from parl.remote import remote_constants
from parl.utils.exceptions import SerializeError, DeserializeError from parl.utils.exceptions import SerializeError, DeserializeError
from parl.remote.message import InitializedJob
class Job(object): class Job(object):
...@@ -37,74 +38,151 @@ class Job(object): ...@@ -37,74 +38,151 @@ class Job(object):
After establishing connection with the remote object, the job will After establishing connection with the remote object, the job will
create a remote class instance locally and enter an infinite loop, create a remote class instance locally and enter an infinite loop,
waiting for commands from the remote object. waiting for commands from the remote object.
""" """
def __init__(self, worker_address): def __init__(self, worker_address):
"""
Args:
worker_address(str): worker_address for sending job information(e.g, pid)
"""
self.job_is_alive = True self.job_is_alive = True
self.worker_address = worker_address self.worker_address = worker_address
self.lock = threading.Lock()
self._create_sockets() self._create_sockets()
def _create_sockets(self): def _create_sockets(self):
"""Create two sockets for each job. """Create three sockets for each job.
(1) reply_socket: receives the command(i.e, the function name and (1) reply_socket(main socket): receives the command(i.e, the function name and args)
args) from the actual class instance, and returns the result of from the actual class instance, completes the computation, and returns the result of
the function. the function.
(2) job_socket: sends job_address and heartbeat_address to worker. (2) job_socket(functional socket): sends job_address and heartbeat_address to worker.
(3) kill_job_socket: sends a command to the corresponding worker to kill the job.
""" """
self.ctx = zmq.Context() self.ctx = zmq.Context()
# reply_socket: receives class, parameters and call function from # create the reply_socket
# @remote.class and send computed results to the @remote.class.
self.reply_socket = self.ctx.socket(zmq.REP) self.reply_socket = self.ctx.socket(zmq.REP)
self.reply_socket.linger = 0
job_port = self.reply_socket.bind_to_random_port(addr="tcp://*") job_port = self.reply_socket.bind_to_random_port(addr="tcp://*")
self.reply_socket.linger = 0
self.job_ip = get_ip_address() self.job_ip = get_ip_address()
self.job_address = "{}:{}".format(self.job_ip, job_port) self.job_address = "{}:{}".format(self.job_ip, job_port)
reply_thread = threading.Thread( # create the job_socket
target=self._reply_heartbeat, self.job_socket = self.ctx.socket(zmq.REQ)
args=("worker {}".format(self.worker_address), )) self.job_socket.connect("tcp://{}".format(self.worker_address))
reply_thread.setDaemon(True)
reply_thread.start() # a thread that reply ping signals from the client
ping_heartbeat_socket, ping_heartbeat_address = self._create_heartbeat_server(
def _reply_heartbeat(self, target): timeout=False)
"""reply heartbeat signals to the target""" ping_thread = threading.Thread(
target=self._reply_ping, args=(ping_heartbeat_socket, ))
socket = self.ctx.socket(zmq.REP) ping_thread.setDaemon(True)
socket.setsockopt(zmq.RCVTIMEO, ping_thread.start()
remote_constants.HEARTBEAT_RCVTIMEO_S * 1000) self.ping_heartbeat_address = ping_heartbeat_address
socket.linger = 0
heartbeat_worker_port = socket.bind_to_random_port(addr="tcp://*") # a thread that reply heartbeat signals from the worker
heartbeat_worker_address = "{}:{}".format(self.job_ip, worker_heartbeat_socket, worker_heartbeat_address = self._create_heartbeat_server(
heartbeat_worker_port) )
worker_thread = threading.Thread(
# job_socket: sends job_address and heartbeat_address to worker target=self._reply_worker_heartbeat,
job_socket = self.ctx.socket(zmq.REQ) args=(worker_heartbeat_socket, ))
job_socket.connect("tcp://{}".format(self.worker_address)) worker_thread.setDaemon(True)
job_socket.send_multipart([ worker_thread.start()
remote_constants.NORMAL_TAG,
to_byte(self.job_address), # a thread that reply heartbeat signals from the client
to_byte(heartbeat_worker_address), client_heartbeat_socket, client_heartbeat_address = self._create_heartbeat_server(
to_byte(str(os.getpid())) )
]) self.client_thread = threading.Thread(
_ = job_socket.recv_multipart() target=self._reply_client_heartbeat,
args=(client_heartbeat_socket, ))
self.client_thread.setDaemon(True)
# sends job information to the worker
initialized_job = InitializedJob(
self.job_address, worker_heartbeat_address,
client_heartbeat_address, self.ping_heartbeat_address, None,
os.getpid())
self.job_socket.send_multipart(
[remote_constants.NORMAL_TAG,
cloudpickle.dumps(initialized_job)])
message = self.job_socket.recv_multipart()
assert message[0] == remote_constants.NORMAL_TAG
# create the kill_job_socket
kill_job_address = to_str(message[1])
self.kill_job_socket = self.ctx.socket(zmq.REQ)
self.kill_job_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
self.kill_job_socket.connect("tcp://{}".format(kill_job_address))
def _reply_ping(self, socket):
"""Create a socket server that reply the ping signal from client.
This signal is used to make sure that the job is still alive.
"""
while self.job_is_alive:
message = socket.recv_multipart()
socket.send_multipart([remote_constants.HEARTBEAT_TAG])
socket.close(0)
def _create_heartbeat_server(self, timeout=True):
"""Create a socket server that will raises timeout exception.
"""
heartbeat_socket = self.ctx.socket(zmq.REP)
if timeout:
heartbeat_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
heartbeat_socket.linger = 0
heartbeat_port = heartbeat_socket.bind_to_random_port(addr="tcp://*")
heartbeat_address = "{}:{}".format(self.job_ip, heartbeat_port)
return heartbeat_socket, heartbeat_address
def _reply_client_heartbeat(self, socket):
"""Create a socket that replies heartbeat signals from the client.
If the job losts connection with the client, it will exit too.
"""
self.client_is_alive = True
while self.client_is_alive:
try:
message = socket.recv_multipart()
socket.send_multipart([remote_constants.HEARTBEAT_TAG])
except zmq.error.Again as e:
logger.warning(
"[Job] Cannot connect to the client. This job will exit and inform the worker."
)
self.client_is_alive = False
socket.close(0)
with self.lock:
self.kill_job_socket.send_multipart(
[remote_constants.KILLJOB_TAG,
to_byte(self.job_address)])
_ = self.kill_job_socket.recv_multipart()
logger.warning("[Job]lost connection with the client, will exit")
os._exit(1)
def _reply_worker_heartbeat(self, socket):
"""create a socket that replies heartbeat signals from the worker.
If the worker has exited, the job will exit automatically.
"""
# a flag to decide when to exit heartbeat loop
self.worker_is_alive = True self.worker_is_alive = True
# a flag to decide when to exit heartbeat loop
while self.worker_is_alive and self.job_is_alive: while self.worker_is_alive and self.job_is_alive:
try: try:
message = socket.recv_multipart() message = socket.recv_multipart()
socket.send_multipart([remote_constants.HEARTBEAT_TAG]) socket.send_multipart([remote_constants.HEARTBEAT_TAG])
except zmq.error.Again as e: except zmq.error.Again as e:
logger.warning("[Job] Cannot connect to {}. ".format(target) + logger.warning("[Job] Cannot connect to the worker{}. ".format(
"Job will quit.") self.worker_address) + "Job will quit.")
self.worker_is_alive = False self.worker_is_alive = False
self.job_is_alive = False self.job_is_alive = False
socket.close(0)
os._exit(1)
def wait_for_files(self): def wait_for_files(self):
"""Wait for python files from remote object. """Wait for python files from remote object.
...@@ -132,7 +210,8 @@ class Job(object): ...@@ -132,7 +210,8 @@ class Job(object):
self.reply_socket.send_multipart([remote_constants.NORMAL_TAG]) self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
return envdir return envdir
else: else:
logger.warning(message) logger.error("NotImplementedError:{}, received tag:{}".format(
self.job_address, ))
raise NotImplementedError raise NotImplementedError
def wait_for_connection(self): def wait_for_connection(self):
...@@ -146,20 +225,62 @@ class Job(object): ...@@ -146,20 +225,62 @@ class Job(object):
A local instance of the remote class object. A local instance of the remote class object.
""" """
while True: message = self.reply_socket.recv_multipart()
message = self.reply_socket.recv_multipart() tag = message[0]
tag = message[0] obj = None
if tag == remote_constants.INIT_OBJECT_TAG: if tag == remote_constants.INIT_OBJECT_TAG:
cls = cloudpickle.loads(message[1]) cls = cloudpickle.loads(message[1])
args, kwargs = cloudpickle.loads(message[2]) args, kwargs = cloudpickle.loads(message[2])
try:
obj = cls(*args, **kwargs) obj = cls(*args, **kwargs)
self.reply_socket.send_multipart([remote_constants.NORMAL_TAG]) except Exception as e:
return obj traceback_str = str(traceback.format_exc())
else: error_str = str(e)
logger.error("Message from job {}".format(message)) logger.error("traceback:\n{}".format(traceback_str))
raise NotImplementedError self.reply_socket.send_multipart([
remote_constants.EXCEPTION_TAG,
to_byte(error_str + "\ntraceback:\n" + traceback_str)
])
self.client_is_alive = False
return None
self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
else:
logger.error("Message from job {}".format(message))
self.reply_socket.send_multipart([
remote_constants.EXCEPTION_TAG,
b"[job]Unkonwn tag when tried to receive the class definition"
])
raise NotImplementedError
return obj
def run(self): def run(self):
"""An infinite loop waiting for a new task.
"""
# receive source code from the actor and append them to the environment variables.
envdir = self.wait_for_files()
sys.path.append(envdir)
self.client_thread.start()
try:
obj = self.wait_for_connection()
assert obj is not None
self.single_task(obj)
except Exception as e:
logger.error(
"Error occurs when running a single task. We will reset this job. Reason:{}"
.format(e))
traceback_str = str(traceback.format_exc())
logger.error("traceback:\n{}".format(traceback_str))
with self.lock:
self.kill_job_socket.send_multipart(
[remote_constants.KILLJOB_TAG,
to_byte(self.job_address)])
_ = self.kill_job_socket.recv_multipart()
def single_task(self, obj):
"""An infinite loop waiting for commands from the remote object. """An infinite loop waiting for commands from the remote object.
Each job will receive two kinds of message from the remote object: Each job will receive two kinds of message from the remote object:
...@@ -171,18 +292,12 @@ class Job(object): ...@@ -171,18 +292,12 @@ class Job(object):
related computation resources. related computation resources.
""" """
# receive files while self.job_is_alive and self.client_is_alive:
envdir = self.wait_for_files()
sys.path.append(envdir)
obj = self.wait_for_connection()
while self.job_is_alive:
message = self.reply_socket.recv_multipart() message = self.reply_socket.recv_multipart()
tag = message[0] tag = message[0]
if tag == remote_constants.CALL_TAG: if tag == remote_constants.CALL_TAG:
assert obj is not None
try: try:
function_name = to_str(message[1]) function_name = to_str(message[1])
data = message[2] data = message[2]
...@@ -194,9 +309,11 @@ class Job(object): ...@@ -194,9 +309,11 @@ class Job(object):
[remote_constants.NORMAL_TAG, ret]) [remote_constants.NORMAL_TAG, ret])
except Exception as e: except Exception as e:
# reset the job
self.client_is_alive = False
error_str = str(e) error_str = str(e)
logger.error(error_str) logger.error(error_str)
self.job_is_alive = False
if type(e) == AttributeError: if type(e) == AttributeError:
self.reply_socket.send_multipart([ self.reply_socket.send_multipart([
...@@ -217,6 +334,7 @@ class Job(object): ...@@ -217,6 +334,7 @@ class Job(object):
remote_constants.DESERIALIZE_EXCEPTION_TAG, remote_constants.DESERIALIZE_EXCEPTION_TAG,
to_byte(error_str) to_byte(error_str)
]) ])
raise DeserializeError
else: else:
traceback_str = str(traceback.format_exc()) traceback_str = str(traceback.format_exc())
...@@ -226,15 +344,19 @@ class Job(object): ...@@ -226,15 +344,19 @@ class Job(object):
to_byte(error_str + "\ntraceback:\n" + to_byte(error_str + "\ntraceback:\n" +
traceback_str) traceback_str)
]) ])
break
# receive DELETE_TAG from actor, and stop replying worker heartbeat # receive DELETE_TAG from actor, and stop replying worker heartbeat
elif tag == remote_constants.KILLJOB_TAG: elif tag == remote_constants.KILLJOB_TAG:
self.reply_socket.send_multipart([remote_constants.NORMAL_TAG]) self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
self.job_is_alive = False self.client_is_alive = False
logger.warning("An actor exits and will quit job {}.".format( logger.warning(
self.job_address)) "An actor exits and this job {} will exit.".format(
self.job_address))
break
else: else:
logger.error("Job message: {}".format(message)) logger.error(
"The job receives an unknown message: {}".format(message))
raise NotImplementedError raise NotImplementedError
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
class JobCenter(object):
"""The job center deals with everythin related to jobs.
Attributes:
job_pool (set): A set to store the job address of vacant cpu.
worker_dict (dict): A dict to store connected workers.
"""
def __init__(self):
self.job_pool = dict()
self.worker_dict = {}
self.lock = threading.Lock()
@property
def cpu_num(self):
""""Return vacant cpu number."""
return len(self.job_pool)
@property
def worker_num(self):
"""Return connected worker number."""
return len(self.worker_dict)
def add_worker(self, worker):
"""A new worker connects.
Args:
worker (InitializedWorker): New worker with initialized jobs.
"""
self.lock.acquire()
self.worker_dict[worker.worker_address] = worker
for job in worker.initialized_jobs:
self.job_pool[job.job_address] = job
self.lock.release()
def drop_worker(self, worker_address):
"""Remove jobs from job_pool when a worker dies.
Args:
worker (start): Old worker to be removed from the cluster.
"""
self.lock.acquire()
worker = self.worker_dict[worker_address]
for job in worker.initialized_jobs:
if job.job_address in self.job_pool:
self.job_pool.pop(job.job_address)
self.worker_dict.pop(worker_address)
self.lock.release()
def request_job(self):
"""Return a job_address when the client submits a job.
If there is no vacant CPU in the cluster, this will return None.
Return:
An ``InitializedJob`` that has information about available job address.
"""
self.lock.acquire()
job = None
if len(self.job_pool):
job_address, job = self.job_pool.popitem()
self.lock.release()
return job
def reset_job(self, job):
"""Reset a job and add the job_address to the job_pool.
Args:
job(``InitializedJob``): The job information of the restarted job.
"""
self.lock.acquire()
self.job_pool[job.job_address] = job
self.lock.release()
def update_job(self, killed_job_address, new_job, worker_address):
"""When worker kill an old job, it will start a new job.
Args:
killed_job_address (str): The job address of the killed job.
new_job(``InitializedJob``): Information of the new job.
worker_address (str): The worker which kills an old job.
"""
self.lock.acquire()
self.job_pool[new_job.job_address] = new_job
if killed_job_address in self.job_pool:
self.job_pool.pop(killed_job_address)
to_del_idx = None
for i, job in enumerate(
self.worker_dict[worker_address].initialized_jobs):
if job.job_address == killed_job_address:
to_del_idx = i
break
del self.worker_dict[worker_address].initialized_jobs[to_del_idx]
self.worker_dict[worker_address].initialized_jobs.append(new_job)
self.lock.release()
...@@ -18,9 +18,11 @@ import threading ...@@ -18,9 +18,11 @@ import threading
import time import time
import zmq import zmq
from collections import defaultdict
from parl.utils import to_str, to_byte, logger from parl.utils import to_str, to_byte, logger
from parl.remote import remote_constants from parl.remote import remote_constants
from parl.remote.job_center import JobCenter
import cloudpickle
import time
class Master(object): class Master(object):
...@@ -40,27 +42,19 @@ class Master(object): ...@@ -40,27 +42,19 @@ class Master(object):
master node. master node.
Attributes: Attributes:
worker_pool (dict): A dict to store connected workers. job_center (JobCenter): A thread-safe data structure that stores the job address of vacant cpus.
job_pool (list): A list to store the job address of vacant cpu, when
this number is 0, the master will refuse to create
new remote object.
client_job_dict (dict): A dict of list to record the job submitted by
each client.
job_worker_dict (dict): A dict to record the job and related worker.
client_socket (zmq.Context.socket): A socket that receives submitted client_socket (zmq.Context.socket): A socket that receives submitted
job from the client, and later sends job from the client, and later sends
job_address back to the client. job_address back to the client.
worker_socket (zmq.Context.socket): A socket that receives job cpu_num(int): The number of available CPUs in the cluster.
addresses from the worker node. worker_num(int): The number of workers connected to this cluster.
cpu_num(int): the number of available CPUs in the cluster.
Args: Args:
port: the ip port that the master node binds to. port: The ip port that the master node binds to.
""" """
def __init__(self, port): def __init__(self, port):
logger.set_dir(os.path.expanduser('~/.parl_data/master/')) logger.set_dir(os.path.expanduser('~/.parl_data/master/'))
self.lock = threading.Lock()
self.ctx = zmq.Context() self.ctx = zmq.Context()
self.client_socket = self.ctx.socket(zmq.REP) self.client_socket = self.ctx.socket(zmq.REP)
...@@ -68,13 +62,7 @@ class Master(object): ...@@ -68,13 +62,7 @@ class Master(object):
self.client_socket.linger = 0 self.client_socket.linger = 0
self.port = port self.port = port
self.worker_pool = {} self.job_center = JobCenter()
self.worker_locks = {}
self.job_pool = []
self.client_job_dict = defaultdict(list)
self.worker_job_dict = defaultdict(list)
self.job_worker_dict = {}
self.master_is_alive = True self.master_is_alive = True
...@@ -96,13 +84,7 @@ class Master(object): ...@@ -96,13 +84,7 @@ class Master(object):
_ = worker_heartbeat_socket.recv_multipart() _ = worker_heartbeat_socket.recv_multipart()
time.sleep(remote_constants.HEARTBEAT_INTERVAL_S) time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)
except zmq.error.Again as e: except zmq.error.Again as e:
for job in self.worker_job_dict[worker_address]: self.job_center.drop_worker(worker_address)
if job in self.job_pool:
self.job_pool.remove(job)
self.job_worker_dict.pop(job)
self.worker_job_dict.pop(worker_address)
self.worker_pool.pop(worker_address)
self.worker_locks.pop(worker_address)
logger.warning("\n[Master] Cannot connect to the worker " + logger.warning("\n[Master] Cannot connect to the worker " +
"{}. ".format(worker_address) + "{}. ".format(worker_address) +
"Worker_pool will drop this worker.") "Worker_pool will drop this worker.")
...@@ -115,7 +97,7 @@ class Master(object): ...@@ -115,7 +97,7 @@ class Master(object):
logger.warning("Exit worker monitor from master.") logger.warning("Exit worker monitor from master.")
def _create_client_monitor(self, client_heartbeat_address): def _create_client_monitor(self, client_heartbeat_address):
"""when a new client connects to the master, a socket is created to """When a new client connects to the master, a socket is created to
send heartbeat signals to the client. send heartbeat signals to the client.
""" """
...@@ -125,65 +107,43 @@ class Master(object): ...@@ -125,65 +107,43 @@ class Master(object):
zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000) zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
client_heartbeat_socket.connect("tcp://" + client_heartbeat_address) client_heartbeat_socket.connect("tcp://" + client_heartbeat_address)
self.client_is_alive = True client_is_alive = True
while self.client_is_alive and self.master_is_alive: while client_is_alive and self.master_is_alive:
try: try:
client_heartbeat_socket.send_multipart( client_heartbeat_socket.send_multipart(
[remote_constants.HEARTBEAT_TAG]) [remote_constants.HEARTBEAT_TAG])
_ = client_heartbeat_socket.recv_multipart() _ = client_heartbeat_socket.recv_multipart()
except zmq.error.Again as e: except zmq.error.Again as e:
self.client_is_alive = False client_is_alive = False
logger.warning("[Master] cannot connect to the client " + logger.warning("[Master] cannot connect to the client " +
"{}. ".format(client_heartbeat_address) + "{}. ".format(client_heartbeat_address) +
"Please check if it is still alive.") "Please check if it is still alive.")
self._kill_client_jobs(client_heartbeat_address)
time.sleep(remote_constants.HEARTBEAT_INTERVAL_S) time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)
logger.warning("Master exits client monitor for {}.\n".format( logger.warning("Master exits client monitor for {}.\n".format(
client_heartbeat_address)) client_heartbeat_address))
logger.info( logger.info(
"Master connects to {} workers and have {} vacant CPUs.\n".format( "Master connects to {} workers and have {} vacant CPUs.\n".format(
len(self.worker_pool), len(self.job_pool))) self.worker_num, self.cpu_num))
client_heartbeat_socket.close(0) client_heartbeat_socket.close(0)
def _kill_client_jobs(self, client_address):
"""set timeout in case the worker and client quit at the same time.
"""
jobs = self.client_job_dict[client_address]
for job_address in jobs:
if job_address in self.job_worker_dict:
worker_address = self.job_worker_dict[job_address]
# ignore this worker if it has been deleted
if worker_address not in self.worker_pool:
continue
worker_socket = self.worker_pool[worker_address].worker_socket
lock = self.worker_locks[worker_address]
lock.acquire()
worker_socket.send_multipart(
[remote_constants.KILLJOB_TAG,
to_byte(job_address)])
try:
_ = worker_socket.recv_multipart()
except zmq.error.Again as e:
logger.warning("Error in recv kill_client_job")
lock.release()
self.job_worker_dict.pop(job_address)
self.client_job_dict.pop(client_address)
def _print_workers(self): def _print_workers(self):
"""Display `worker_pool` infomation.""" """Display `worker_pool` infomation."""
logger.info( logger.info(
"Master connects to {} workers and have {} vacant CPUs.\n".format( "Master connects to {} workers and have {} vacant CPUs.\n".format(
len(self.worker_pool), len(self.job_pool))) self.worker_num, self.cpu_num))
@property @property
def cpu_num(self): def cpu_num(self):
return len(self.job_pool) return self.job_center.cpu_num
@property
def worker_num(self):
return self.job_center.worker_num
def _receive_message(self): def _receive_message(self):
"""master node will receive four types of message: (1) worker """Master node will receive various types of message: (1) worker
connection; (2) worker update; (3) client connection; (4) job connection; (2) worker update; (3) client connection; (4) job
submittion. submittion; (5) reset job.
""" """
message = self.client_socket.recv_multipart() message = self.client_socket.recv_multipart()
tag = message[0] tag = message[0]
...@@ -193,35 +153,18 @@ class Master(object): ...@@ -193,35 +153,18 @@ class Master(object):
self.client_socket.send_multipart([remote_constants.NORMAL_TAG]) self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
elif tag == remote_constants.WORKER_INITIALIZED_TAG: elif tag == remote_constants.WORKER_INITIALIZED_TAG:
worker = pickle.loads(message[1])
worker_heartbeat_address = to_str(message[2])
# maintain job & worker relations
for job_address in worker.job_pool:
self.job_worker_dict[job_address] = worker.address
self.worker_job_dict[worker.address] = worker.job_pool
self.job_pool.extend(worker.job_pool)
# a new socket for submitting job to the worker
worker_socket = self.ctx.socket(zmq.REQ)
worker_socket.linger = 0
worker_socket.setsockopt(zmq.RCVTIMEO, 10000)
worker_socket.connect("tcp://{}".format(worker.address))
worker.worker_socket = worker_socket
self.worker_pool[worker.address] = worker
self.worker_locks[worker.address] = threading.Lock()
logger.info( initialized_worker = cloudpickle.loads(message[1])
"A new worker {} is added, ".format(worker.address) + self.job_center.add_worker(initialized_worker)
"the cluster has {} CPUs.\n".format(len(self.job_pool))) logger.info("A new worker {} is added, ".format(initialized_worker.
worker_address) +
"the cluster has {} CPUs.\n".format(self.cpu_num))
# a thread for sending heartbeat signals to `worker.address` # a thread for sending heartbeat signals to `worker.address`
thread = threading.Thread( thread = threading.Thread(
target=self._create_worker_monitor, target=self._create_worker_monitor,
args=( args=(initialized_worker.master_heartbeat_address,
worker_heartbeat_address, initialized_worker.worker_address))
worker.address,
))
thread.start() thread.start()
self.client_socket.send_multipart([remote_constants.NORMAL_TAG]) self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
...@@ -240,43 +183,30 @@ class Master(object): ...@@ -240,43 +183,30 @@ class Master(object):
# a client submits a job to the master # a client submits a job to the master
elif tag == remote_constants.CLIENT_SUBMIT_TAG: elif tag == remote_constants.CLIENT_SUBMIT_TAG:
client_address = to_str(message[1])
done_flag = False
# check available CPU resources # check available CPU resources
if len(self.job_pool): if self.cpu_num:
logger.info("Submitting job...") logger.info("Submitting job...")
job_address = self.job_pool.pop(0) job = self.job_center.request_job()
worker_address = self.job_worker_dict[job_address] self.client_socket.send_multipart([
self.worker_job_dict[worker_address].remove(job_address) remote_constants.NORMAL_TAG,
self.client_socket.send_multipart( to_byte(job.job_address),
[remote_constants.NORMAL_TAG, to_byte(job.client_heartbeat_address),
to_byte(job_address)]) to_byte(job.ping_heartbeat_address)
self.client_job_dict[client_address].append(job_address) ])
self._print_workers() self._print_workers()
else: else:
self.client_socket.send_multipart([remote_constants.CPU_TAG]) self.client_socket.send_multipart([remote_constants.CPU_TAG])
# a worker updates # a worker updates
elif tag == remote_constants.NEW_JOB_TAG: elif tag == remote_constants.NEW_JOB_TAG:
worker_address = to_str(message[1]) initialized_job = cloudpickle.loads(message[1])
new_job_address = to_str(message[2]) last_job_address = to_str(message[2])
killed_job_address = to_str(message[3])
self.client_socket.send_multipart([remote_constants.NORMAL_TAG]) self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
logger.info("A worker updated.") self.job_center.update_job(last_job_address, initialized_job,
initialized_job.worker_address)
if killed_job_address in self.job_worker_dict: logger.info("A worker updated. cpu_num:{}".format(self.cpu_num))
self.job_worker_dict.pop(killed_job_address)
if killed_job_address in self.worker_job_dict[worker_address]:
self.worker_job_dict[worker_address].remove(killed_job_address)
if killed_job_address in self.job_pool:
self.job_pool.remove(killed_job_address)
# add new job_address to job_pool
self.job_pool.append(new_job_address)
self.job_worker_dict[new_job_address] = worker_address
self.worker_job_dict[worker_address].append(new_job_address)
self._print_workers() self._print_workers()
...@@ -288,6 +218,8 @@ class Master(object): ...@@ -288,6 +218,8 @@ class Master(object):
raise NotImplementedError() raise NotImplementedError()
def exit(self): def exit(self):
""" Close the master.
"""
self.master_is_alive = False self.master_is_alive = False
def run(self): def run(self):
...@@ -313,10 +245,5 @@ class Master(object): ...@@ -313,10 +245,5 @@ class Master(object):
except zmq.error.Again as e: except zmq.error.Again as e:
#detect whether `self.master_is_alive` is True periodically #detect whether `self.master_is_alive` is True periodically
pass pass
for worker_address, worker in self.worker_pool.items():
lock = self.worker_locks[worker_address]
lock.acquire()
worker.worker_socket.close(0)
lock.release()
logger.warning("[Master] Exit master.") logger.warning("[Master] Exit master.")
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class InitializedJob(object):
def __init__(self, job_address, worker_heartbeat_address,
client_heartbeat_address, ping_heartbeat_address,
worker_address, pid):
"""
Args:
job_address(str): Job address to which the new task connect.
worker_heartbeat_address(str): Optional. The address to which the worker sends heartbeat signals.
client_heartbeat_address(str): Address to which the client sends heartbeat signals.
ping_heartbeat_address(str): the server address to which the client sends ping signals.
The signal is used to check if the job is alive.
worker_address(str): Worker's server address that receive command from the master.
pid(int): Optional. Process id of the job.
is_alive(True): Optional. This flag is used in worker to make sure that only alive jobs can be added into the worker_status.
"""
self.job_address = job_address
self.worker_heartbeat_address = worker_heartbeat_address
self.client_heartbeat_address = client_heartbeat_address
self.ping_heartbeat_address = ping_heartbeat_address
self.worker_address = worker_address
self.pid = pid
self.is_alive = True
class InitializedWorker(object):
def __init__(self, worker_address, master_heartbeat_address,
initialized_jobs, cpu_num):
"""
Args:
worker_address(str): Worker server address that receives commands from the master.
master_heartbeat_address(str): Address to which the worker send heartbeat signals to.
initialized_jobs(list): A list of ``InitializedJob`` containing the information for initialized jobs.
cpu_num(int): The number of CPUs used in this worker.
"""
self.worker_address = worker_address
self.master_heartbeat_address = master_heartbeat_address
self.initialized_jobs = initialized_jobs
self.cpu_num = cpu_num
...@@ -98,27 +98,32 @@ def remote_class(cls): ...@@ -98,27 +98,32 @@ def remote_class(cls):
self.job_socket.linger = 0 self.job_socket.linger = 0
self.job_socket.connect("tcp://{}".format(job_address)) self.job_socket.connect("tcp://{}".format(job_address))
self.job_address = job_address self.job_address = job_address
self.job_shutdown = False
self.send_file(self.job_socket) self.send_file(self.job_socket)
try: self.job_socket.send_multipart([
self.job_socket.send_multipart([ remote_constants.INIT_OBJECT_TAG,
remote_constants.INIT_OBJECT_TAG, cloudpickle.dumps(cls),
cloudpickle.dumps(cls), cloudpickle.dumps([args, kwargs])
cloudpickle.dumps([args, kwargs]) ])
]) message = self.job_socket.recv_multipart()
_ = self.job_socket.recv_multipart() tag = message[0]
except zmq.error.Again as e: if tag == remote_constants.EXCEPTION_TAG:
logger.error("Job socket failed.") traceback_str = to_str(message[1])
self.job_shutdown = True
raise RemoteError('__init__', traceback_str)
def __del__(self): def __del__(self):
"""Delete the remote class object and release remote resources.""" """Delete the remote class object and release remote resources."""
try: if not self.job_shutdown:
self.job_socket.send_multipart([remote_constants.KILLJOB_TAG]) try:
_ = self.job_socket.recv_multipart() self.job_socket.send_multipart(
self.job_socket.close(0) [remote_constants.KILLJOB_TAG])
except AttributeError: _ = self.job_socket.recv_multipart()
pass self.job_socket.close(0)
except AttributeError:
pass
def send_file(self, socket): def send_file(self, socket):
try: try:
...@@ -137,7 +142,7 @@ def remote_class(cls): ...@@ -137,7 +142,7 @@ def remote_class(cls):
if job_address is not None: if job_address is not None:
return job_address return job_address
if cnt % 30 == 0: if cnt % 30 == 0:
logger.warning("No vacant cpu resources at present, " logger.warning("No vacant cpu resources at the moment, "
"will try {} times later.".format(cnt)) "will try {} times later.".format(cnt))
cnt -= 1 cnt -= 1
return None return None
...@@ -146,6 +151,9 @@ def remote_class(cls): ...@@ -146,6 +151,9 @@ def remote_class(cls):
"""Call the function of the unwrapped class.""" """Call the function of the unwrapped class."""
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if self.job_shutdown:
raise RemoteError(
attr, "This actor losts connection with the job.")
self.internal_lock.acquire() self.internal_lock.acquire()
data = dumps_argument(*args, **kwargs) data = dumps_argument(*args, **kwargs)
...@@ -161,21 +169,26 @@ def remote_class(cls): ...@@ -161,21 +169,26 @@ def remote_class(cls):
elif tag == remote_constants.EXCEPTION_TAG: elif tag == remote_constants.EXCEPTION_TAG:
error_str = to_str(message[1]) error_str = to_str(message[1])
self.job_shutdown = True
raise RemoteError(attr, error_str) raise RemoteError(attr, error_str)
elif tag == remote_constants.ATTRIBUTE_EXCEPTION_TAG: elif tag == remote_constants.ATTRIBUTE_EXCEPTION_TAG:
error_str = to_str(message[1]) error_str = to_str(message[1])
self.job_shutdown = True
raise RemoteAttributeError(attr, error_str) raise RemoteAttributeError(attr, error_str)
elif tag == remote_constants.SERIALIZE_EXCEPTION_TAG: elif tag == remote_constants.SERIALIZE_EXCEPTION_TAG:
error_str = to_str(message[1]) error_str = to_str(message[1])
self.job_shutdown = True
raise RemoteSerializeError(attr, error_str) raise RemoteSerializeError(attr, error_str)
elif tag == remote_constants.DESERIALIZE_EXCEPTION_TAG: elif tag == remote_constants.DESERIALIZE_EXCEPTION_TAG:
error_str = to_str(message[1]) error_str = to_str(message[1])
self.job_shutdown = True
raise RemoteDeserializeError(attr, error_str) raise RemoteDeserializeError(attr, error_str)
else: else:
self.job_shutdown = True
raise NotImplementedError() raise NotImplementedError()
self.internal_lock.release() self.internal_lock.release()
......
...@@ -83,11 +83,11 @@ def start_master(port, cpu_num): ...@@ -83,11 +83,11 @@ def start_master(port, cpu_num):
cpu_num = str(cpu_num) if cpu_num else '' cpu_num = str(cpu_num) if cpu_num else ''
start_file = __file__.replace('scripts.pyc', 'start.py') start_file = __file__.replace('scripts.pyc', 'start.py')
start_file = start_file.replace('scripts.py', 'start.py') start_file = start_file.replace('scripts.py', 'start.py')
command = ["python", start_file, "--name", "master", "--port", port] command = [sys.executable, start_file, "--name", "master", "--port", port]
p = subprocess.Popen(command) p = subprocess.Popen(command)
command = [ command = [
"python", start_file, "--name", "worker", "--address", sys.executable, start_file, "--name", "worker", "--address",
"localhost:" + str(port), "--cpu_num", "localhost:" + str(port), "--cpu_num",
str(cpu_num) str(cpu_num)
] ]
...@@ -112,8 +112,8 @@ def start_worker(address, cpu_num): ...@@ -112,8 +112,8 @@ def start_worker(address, cpu_num):
address) + "is correct.") address) + "is correct.")
cpu_num = str(cpu_num) if cpu_num else '' cpu_num = str(cpu_num) if cpu_num else ''
command = [ command = [
"python", "{}/start.py".format(__file__[:-11]), "--name", "worker", sys.executable, "{}/start.py".format(__file__[:-11]), "--name",
"--address", address, "--cpu_num", "worker", "--address", address, "--cpu_num",
str(cpu_num) str(cpu_num)
] ]
p = subprocess.Popen(command) p = subprocess.Popen(command)
...@@ -123,8 +123,8 @@ def start_worker(address, cpu_num): ...@@ -123,8 +123,8 @@ def start_worker(address, cpu_num):
def stop(): def stop():
command = ("pkill -f remote/start.py") command = ("pkill -f remote/start.py")
subprocess.call([command], shell=True) subprocess.call([command], shell=True)
command = ("pkill -f job.py") command = ("pkill -f remote/job.py")
p = subprocess.call([command], shell=True) subprocess.call([command], shell=True)
cli.add_command(start_worker) cli.add_command(start_worker)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
from parl.utils import logger
import signal
import os
class WorkerStatus(object):
"""Maintain worker's information in a worker node.
Attributes:
cpu_num(int): The number of CPUs to be used in this worker.
jobs(set): A set that records job addresses provided to the master.
worker_address(str): Address of the worker.
"""
def __init__(self, worker_address, initialized_jobs, cpu_num):
self.worker_address = worker_address
self.jobs = dict()
for job in initialized_jobs:
self.jobs[job.job_address] = job
self._lock = threading.Lock()
self.cpu_num = cpu_num
def remove_job(self, killed_job):
"""Rmove a job from internal job pool.
Args:
killed_job(str): Job address to be removed.
Returns: True if removing the job succeeds.
"""
ret = False
self._lock.acquire()
if killed_job in self.jobs:
pid = self.jobs[killed_job].pid
self.jobs.pop(killed_job)
ret = True
try:
os.kill(pid, signal.SIGTERM)
except OSError:
logger.warning("job:{} has been killed before".format(pid))
logger.info("[Worker] kills a job:{}".format(killed_job))
self._lock.release()
return ret
def clear(self):
"""Remove all the jobs"""
self._lock.acquire()
for job in self.jobs.values():
try:
os.kill(job.pid, signal.SIGTERM)
except OSError:
logger.warning("job:{} has been killed before".format(job.pid))
logger.info("[Worker] kills a job:{}".format(job.pid))
self.jobs = dict()
self._lock.release()
def add_job(self, new_job):
"""Add a new job to interal job pool.
Args:
new_job(InitializedJob): Initialized job to be added to the self.jobs.
"""
self._lock.acquire()
self.jobs[new_job.job_address] = new_job
assert len(self.jobs) <= self.cpu_num
self._lock.release()
...@@ -20,6 +20,9 @@ from parl.remote.worker import Worker ...@@ -20,6 +20,9 @@ from parl.remote.worker import Worker
import time import time
import threading import threading
from parl.remote.client import disconnect from parl.remote.client import disconnect
from parl.remote import exceptions
import timeout_decorator
import subprocess
@parl.remote_class @parl.remote_class
...@@ -55,25 +58,71 @@ class Actor(object): ...@@ -55,25 +58,71 @@ class Actor(object):
x = 1 / 0 x = 1 / 0
class TestExit(unittest.TestCase): class TestCluster(unittest.TestCase):
def test_delete_worker(self): def tearDown(self):
# start the master disconnect()
#time.sleep(20)
#command = ("pkill -f remote/job.py")
#subprocess.call([command], shell=True)
def test_actor_exception(self):
master = Master(port=1235) master = Master(port=1235)
th = threading.Thread(target=master.run) th = threading.Thread(target=master.run)
th.start() th.start()
time.sleep(1) time.sleep(1)
worker1 = Worker('localhost:1235', 1)
worker1 = Worker('localhost:1235', 4) self.assertEqual(1, master.cpu_num)
parl.connect('localhost:1235') parl.connect('localhost:1235')
for i in range(4): with self.assertRaises(exceptions.RemoteError):
actor = Actor(abcd='a bug')
actor2 = Actor()
self.assertEqual(actor2.add_one(1), 2)
self.assertEqual(0, master.cpu_num)
master.exit()
worker1.exit()
@timeout_decorator.timeout(seconds=300)
def test_actor_exception(self):
master = Master(port=1236)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:1236', 1)
self.assertEqual(1, master.cpu_num)
parl.connect('localhost:1236')
actor = Actor()
try:
actor.will_raise_exception_func()
except:
pass
actor2 = Actor()
time.sleep(30)
self.assertEqual(actor2.add_one(1), 2)
self.assertEqual(0, master.cpu_num)
del actor
del actor2
worker1.exit()
master.exit()
def test_reset_actor(self):
# start the master
master = Master(port=1237)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:1237', 4)
parl.connect('localhost:1237')
for i in range(10):
actor = Actor() actor = Actor()
ret = actor.add_one(1) ret = actor.add_one(1)
self.assertEqual(ret, 2) self.assertEqual(ret, 2)
del actor
time.sleep(20)
self.assertEqual(master.cpu_num, 4)
worker1.exit() worker1.exit()
time.sleep(30)
disconnect()
time.sleep(30)
master.exit() master.exit()
def test_add_worker(self): def test_add_worker(self):
...@@ -91,6 +140,7 @@ class TestExit(unittest.TestCase): ...@@ -91,6 +140,7 @@ class TestExit(unittest.TestCase):
self.assertEqual(master.cpu_num, 4) self.assertEqual(master.cpu_num, 4)
master.exit() master.exit()
worker1.exit()
if __name__ == '__main__': if __name__ == '__main__':
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from parl.remote.job_center import JobCenter
from parl.remote.message import InitializedWorker, InitializedJob
class InitializedWorker(object):
def __init__(self,
worker_address,
master_heartbeat_address='localhost:8010',
initialized_jobs=[],
cpu_num=4):
self.worker_address = worker_address
self.master_heartbeat_address = master_heartbeat_address
self.initialized_jobs = initialized_jobs
self.cpu_num = cpu_num
class ImportTest(unittest.TestCase):
def setUp(self):
jobs = []
for i in range(5):
job = InitializedJob(
job_address='172.18.182.39:{}'.format(1234 + i),
worker_heartbeat_address='172.18.182.39:48724',
client_heartbeat_address='172.18.182.39:48725',
ping_heartbeat_address='172.18.182.39:48726',
worker_address='172.18.182.39:478727',
pid=1234)
jobs.append(job)
self.worker1 = InitializedWorker(
worker_address='172.18.182.39:8001', initialized_jobs=jobs)
jobs = []
for i in range(5):
job = InitializedJob(
job_address='172.18.182.39:{}'.format(2234 + i),
worker_heartbeat_address='172.18.182.39:48724',
client_heartbeat_address='172.18.182.39:48725',
ping_heartbeat_address='172.18.182.39:48726',
worker_address='172.18.182.39:478727',
pid=1234)
jobs.append(job)
self.worker2 = InitializedWorker(
worker_address='172.18.182.39:8002', initialized_jobs=jobs)
def test_add_worker(self):
job_center = JobCenter()
job_center.add_worker(self.worker1)
job_center.add_worker(self.worker2)
self.assertEqual(len(job_center.job_pool), 10)
self.assertEqual(job_center.worker_dict[self.worker1.worker_address],
self.worker1)
def test_drop_worker(self):
job_center = JobCenter()
job_center.add_worker(self.worker1)
job_center.add_worker(self.worker2)
job_center.drop_worker(self.worker2.worker_address)
self.assertEqual(
set(job_center.job_pool.values()),
set(self.worker1.initialized_jobs))
self.assertEqual(len(job_center.worker_dict), 1)
def test_request_job(self):
job_center = JobCenter()
job_address1 = job_center.request_job()
self.assertTrue(job_address1 is None)
job_center.add_worker(self.worker1)
job_address2 = job_center.request_job()
self.assertTrue(job_address2 in self.worker1.initialized_jobs)
self.assertEqual(len(job_center.job_pool), 4)
def test_reset_job(self):
job_center = JobCenter()
job_center.add_worker(self.worker1)
job_address = job_center.request_job()
self.assertTrue(job_address in self.worker1.initialized_jobs)
self.assertEqual(len(job_center.job_pool), 4)
job_center.reset_job(job_address)
self.assertEqual(len(job_center.job_pool), 5)
def test_update_job(self):
job_center = JobCenter()
job_center.add_worker(self.worker1)
job_center.add_worker(self.worker2)
job = InitializedJob(
job_address='172.18.182.39:{}'.format(9245),
worker_heartbeat_address='172.18.182.39:48724',
client_heartbeat_address='172.18.182.39:48725',
ping_heartbeat_address='172.18.182.39:48726',
worker_address='172.18.182.39:478727',
pid=1234)
job_center.update_job('172.18.182.39:2234', job, '172.18.182.39:8002')
current_job_address = set([
job.job_address for job in job_center.
worker_dict['172.18.182.39:8002'].initialized_jobs
])
self.assertEqual(
current_job_address,
set([
'172.18.182.39:9245', '172.18.182.39:2235',
'172.18.182.39:2236', '172.18.182.39:2237',
'172.18.182.39:2238'
]))
job_pool_address = set(job_center.job_pool.keys())
self.assertEqual(
job_pool_address,
set([
'172.18.182.39:9245', '172.18.182.39:2235',
'172.18.182.39:2236', '172.18.182.39:2237',
'172.18.182.39:2238', '172.18.182.39:1234',
'172.18.182.39:1235', '172.18.182.39:1236',
'172.18.182.39:1237', '172.18.182.39:1238'
]))
job_center.drop_worker(self.worker2.worker_address)
self.assertEqual(5, len(self.worker1.initialized_jobs))
def test_cpu_num(self):
job_center = JobCenter()
job_center.add_worker(self.worker1)
self.assertEqual(job_center.cpu_num, 5)
job_center.add_worker(self.worker2)
self.assertEqual(job_center.cpu_num, 10)
job_center.request_job()
self.assertEqual(job_center.cpu_num, 9)
def test_worker_num(self):
job_center = JobCenter()
job_center.add_worker(self.worker1)
self.assertEqual(job_center.worker_num, 1)
job_center.add_worker(self.worker2)
self.assertEqual(job_center.worker_num, 2)
job_center.drop_worker(self.worker1.worker_address)
self.assertEqual(job_center.worker_num, 1)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.client import disconnect
from parl.utils import logger
import subprocess
import time
import threading
import timeout_decorator
import subprocess
import sys
@parl.remote_class
class Actor(object):
def __init__(self, arg1=None, arg2=None):
self.arg1 = arg1
self.arg2 = arg2
def get_arg1(self):
return self.arg1
def get_arg2(self):
return self.arg2
def set_arg1(self, value):
self.arg1 = value
def set_arg2(self, value):
self.arg2 = value
def get_unable_serialize_object(self):
return UnableSerializeObject()
def add_one(self, value):
value += 1
return value
def add(self, x, y):
time.sleep(3)
return x + y
def will_raise_exception_func(self):
x = 1 / 0
class TestJob(unittest.TestCase):
def tearDown(self):
disconnect()
def test_job_exit_exceptionally(self):
master = Master(port=1334)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:1334', 4)
time.sleep(10)
self.assertEqual(worker1.job_buffer.full(), True)
time.sleep(1)
self.assertEqual(master.cpu_num, 4)
print("We are going to kill all the jobs.")
command = ("pkill -f remote/job.py")
subprocess.call([command], shell=True)
parl.connect('localhost:1334')
actor = Actor()
self.assertEqual(actor.add_one(1), 2)
time.sleep(20)
master.exit()
worker1.exit()
@timeout_decorator.timeout(seconds=300)
def test_acor_exit_exceptionally(self):
master = Master(port=1335)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:1335', 1)
file_path = __file__.replace('reset_job_test', 'simulate_client')
command = [sys.executable, file_path]
proc = subprocess.Popen(command)
time.sleep(10)
self.assertEqual(master.cpu_num, 0)
proc.kill()
parl.connect('localhost:1335')
actor = Actor()
master.exit()
worker1.exit()
disconnect()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import parl
@parl.remote_class
class Actor(object):
def add_one(self, value):
value += 1
return value
def train():
parl.connect('localhost:1335')
actor = Actor()
actor.add_one(1)
time.sleep(100000)
if __name__ == '__main__':
train()
...@@ -25,22 +25,14 @@ import zmq ...@@ -25,22 +25,14 @@ import zmq
from parl.utils import get_ip_address, to_byte, to_str, logger from parl.utils import get_ip_address, to_byte, to_str, logger
from parl.remote import remote_constants from parl.remote import remote_constants
from parl.remote.message import InitializedWorker
from parl.remote.status import WorkerStatus
from six.moves import queue
if sys.version_info.major == 3: if sys.version_info.major == 3:
warnings.simplefilter("ignore", ResourceWarning) warnings.simplefilter("ignore", ResourceWarning)
class WorkerInfo(object):
"""A WorkerInfo object records the computation resources of a worker.
"""
def __init__(self, address, cpu_num, job_pool):
self.address = address
self.cpu_num = cpu_num
self.job_pool = job_pool
self.worker_socket = None
class Worker(object): class Worker(object):
"""Worker provides the cpu computation resources for the cluster. """Worker provides the cpu computation resources for the cluster.
...@@ -58,7 +50,6 @@ class Worker(object): ...@@ -58,7 +50,6 @@ class Worker(object):
xparl connect --address localhost:1234 --cpu_num 8 xparl connect --address localhost:1234 --cpu_num 8
Attributes: Attributes:
job_pid (dict): A dict of subprocess id and its address.
master_address (str): Master's ip address. master_address (str): Master's ip address.
request_master_socket (zmq.Context.socket): A socket which sends job request_master_socket (zmq.Context.socket): A socket which sends job
address to the master node. address to the master node.
...@@ -67,22 +58,37 @@ class Worker(object): ...@@ -67,22 +58,37 @@ class Worker(object):
node. node.
reply_job_socket (zmq.Context.socket): A socket which receives reply_job_socket (zmq.Context.socket): A socket which receives
job_address from the job. job_address from the job.
kill_job_socket (zmq.Context.socket): A socket that receives commands to kill the job from jobs.
Args: Args:
master_address (str): IP address of the master node. master_address (str): IP address of the master node.
cpu_num (int): Number of cpu to be used on the worker. cpu_num (int): Number of cpu to be used on the worker.
job_buffer (str): A buffer that stores initialized jobs for providing new jobs in a short time.
""" """
def __init__(self, master_address, cpu_num=None): def __init__(self, master_address, cpu_num=None):
self.lock = threading.Lock() self.lock = threading.Lock()
self.heartbeat_socket_initialized = threading.Event() self.heartbeat_socket_initialized = threading.Event()
self.ctx = zmq.Context.instance() self.ctx = zmq.Context.instance()
self.job_pid = {}
self.master_address = master_address self.master_address = master_address
self.master_is_alive = True self.master_is_alive = True
self.worker_is_alive = True self.worker_is_alive = True
self.worker_status = None # initialized at `self._create_jobs`
self.lock = threading.Lock()
self._set_cpu_num(cpu_num) self._set_cpu_num(cpu_num)
self.job_buffer = queue.Queue(maxsize=self.cpu_num)
self._create_sockets() self._create_sockets()
self._create_worker()
# create a thread that waits commands from the job to kill the job.
self.kill_job_thread = threading.Thread(target=self._reply_kill_job)
self.kill_job_thread.start()
self._create_jobs()
# create a thread that initializes jobs and adds them into the job_buffer
job_thread = threading.Thread(target=self._fill_job_buffer)
job_thread.setDaemon(True)
job_thread.start()
def _set_cpu_num(self, cpu_num=None): def _set_cpu_num(self, cpu_num=None):
"""set useable cpu number for worker""" """set useable cpu number for worker"""
...@@ -95,14 +101,15 @@ class Worker(object): ...@@ -95,14 +101,15 @@ class Worker(object):
self.cpu_num = multiprocessing.cpu_count() self.cpu_num = multiprocessing.cpu_count()
def _create_sockets(self): def _create_sockets(self):
""" Each worker has three sockets at start: """ Each worker has four sockets at start:
(1) request_master_socket: sends job address to master node. (1) request_master_socket: sends job address to master node.
(2) reply_master_socket: accepts submitted job from master node. (2) reply_master_socket: accepts submitted job from master node.
(3) reply_job_socket: receives job_address from subprocess. (3) reply_job_socket: receives job_address from subprocess.
(4) kill_job_socket : receives commands to kill the job from jobs.
When a job is start, a new heartbeat socket is created to receive When a job starts, a new heartbeat socket is created to receive
heartbeat signal from the job. heartbeat signals from the job.
""" """
...@@ -131,8 +138,14 @@ class Worker(object): ...@@ -131,8 +138,14 @@ class Worker(object):
reply_job_port = self.reply_job_socket.bind_to_random_port("tcp://*") reply_job_port = self.reply_job_socket.bind_to_random_port("tcp://*")
self.reply_job_address = "{}:{}".format(self.worker_ip, reply_job_port) self.reply_job_address = "{}:{}".format(self.worker_ip, reply_job_port)
def _create_worker(self): # kill_job_socket
"""create a WorkerInfo instance and send it to the master.""" self.kill_job_socket = self.ctx.socket(zmq.REP)
self.kill_job_socket.linger = 0
kill_job_port = self.kill_job_socket.bind_to_random_port("tcp://*")
self.kill_job_address = "{}:{}".format(self.worker_ip, kill_job_port)
def _create_jobs(self):
"""Create jobs and send a instance of ``InitializedWorker`` that contains the worker information to the master."""
try: try:
self.request_master_socket.send_multipart( self.request_master_socket.send_multipart(
[remote_constants.WORKER_CONNECT_TAG]) [remote_constants.WORKER_CONNECT_TAG])
...@@ -143,24 +156,40 @@ class Worker(object): ...@@ -143,24 +156,40 @@ class Worker(object):
self.master_is_alive = False self.master_is_alive = False
return return
self._init_jobs(job_num=self.cpu_num) initialized_jobs = self._init_jobs(job_num=self.cpu_num)
self.request_master_socket.setsockopt( self.request_master_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000) zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
self.worker = WorkerInfo(self.reply_master_address, self.cpu_num, self.reply_master_hearbeat_thread = threading.Thread(
list(self.job_pid.keys()))
reply_thread = threading.Thread(
target=self._reply_heartbeat, target=self._reply_heartbeat,
args=("master {}".format(self.master_address), )) args=("master {}".format(self.master_address), ))
reply_thread.start() self.reply_master_hearbeat_thread.start()
self.heartbeat_socket_initialized.wait() self.heartbeat_socket_initialized.wait()
initialized_worker = InitializedWorker(self.reply_master_address,
self.master_heartbeat_address,
initialized_jobs, self.cpu_num)
self.request_master_socket.send_multipart([ self.request_master_socket.send_multipart([
remote_constants.WORKER_INITIALIZED_TAG, remote_constants.WORKER_INITIALIZED_TAG,
cloudpickle.dumps(self.worker), cloudpickle.dumps(initialized_worker)
to_byte(self.heartbeat_master_address)
]) ])
_ = self.request_master_socket.recv_multipart() _ = self.request_master_socket.recv_multipart()
self.worker_status = WorkerStatus(self.reply_master_address,
initialized_jobs, self.cpu_num)
def _fill_job_buffer(self):
"""An endless loop that adds initialized job into the job buffer"""
while self.worker_is_alive:
initialized_jobs = self._init_jobs(job_num=self.cpu_num)
for job in initialized_jobs:
self.job_buffer.put(job)
# release jobs if the worker is not alive
for job in initialized_jobs:
try:
os.kill(job.pid, signal.SIGTERM)
except OSError:
pass
def _init_jobs(self, job_num): def _init_jobs(self, job_num):
"""Create jobs. """Create jobs.
...@@ -175,76 +204,73 @@ class Worker(object): ...@@ -175,76 +204,73 @@ class Worker(object):
self.reply_job_address self.reply_job_address
] ]
# avoid that many jobs are killed and restarted at the same time.
self.lock.acquire()
# Redirect the output to DEVNULL # Redirect the output to DEVNULL
FNULL = open(os.devnull, 'w') FNULL = open(os.devnull, 'w')
for _ in range(job_num): for _ in range(job_num):
pid = subprocess.Popen( subprocess.Popen(command, stdout=FNULL, stderr=subprocess.STDOUT)
command, stdout=FNULL, stderr=subprocess.STDOUT)
FNULL.close() FNULL.close()
new_job_address = [] new_jobs = []
for _ in range(job_num): for _ in range(job_num):
job_message = self.reply_job_socket.recv_multipart() job_message = self.reply_job_socket.recv_multipart()
self.reply_job_socket.send_multipart([remote_constants.NORMAL_TAG]) self.reply_job_socket.send_multipart(
job_address = to_str(job_message[1]) [remote_constants.NORMAL_TAG,
new_job_address.append(job_address) to_byte(self.kill_job_address)])
heartbeat_job_address = to_str(job_message[2]) initialized_job = cloudpickle.loads(job_message[1])
pid = to_str(job_message[3]) initialized_job.worker_address = self.reply_master_address
self.job_pid[job_address] = int(pid) new_jobs.append(initialized_job)
# a thread for sending heartbeat signals to job # a thread for sending heartbeat signals to job
thread = threading.Thread( thread = threading.Thread(
target=self._create_job_monitor, target=self._create_job_monitor, args=(initialized_job, ))
args=(
job_address,
heartbeat_job_address,
))
thread.start() thread.start()
assert len(new_job_address) > 0, "init jobs failed" self.lock.release()
if len(new_job_address) > 1: assert len(new_jobs) > 0, "init jobs failed"
return new_job_address return new_jobs
else:
return new_job_address[0]
def _kill_job(self, job_address): def _kill_job(self, job_address):
"""kill problematic job process and update worker information""" """Kill a job process and update worker information"""
if job_address in self.job_pid: success = self.worker_status.remove_job(job_address)
if success:
while True:
initialized_job = self.job_buffer.get()
if initialized_job.is_alive:
self.worker_status.add_job(initialized_job)
if not initialized_job.is_alive: # make sure that the job is still alive.
self.worker_status.remove_job(
initialized_job.job_address)
continue
else:
logger.warning(
"[Worker] a dead job found. The job buffer will not accept this one."
)
if initialized_job.is_alive:
break
self.lock.acquire() self.lock.acquire()
pid = self.job_pid[job_address] self.request_master_socket.send_multipart([
try: remote_constants.NEW_JOB_TAG,
os.kill(pid, signal.SIGTERM) cloudpickle.dumps(initialized_job),
except OSError: to_byte(job_address)
logger.warn("job:{} has been killed before".format(pid)) ])
self.job_pid.pop(job_address) _ = self.request_master_socket.recv_multipart()
logger.warning("Worker kills job process {},".format(job_address))
self.lock.release() self.lock.release()
# When a old job is killed, the worker will create a new job. def _create_job_monitor(self, job):
if self.master_is_alive: """Send heartbeat signals to check target's status"""
new_job_address = self._init_jobs(job_num=1)
self.lock.acquire()
self.request_master_socket.send_multipart([
remote_constants.NEW_JOB_TAG,
to_byte(self.reply_master_address),
to_byte(new_job_address),
to_byte(job_address)
])
_ = self.request_master_socket.recv_multipart()
self.lock.release()
def _create_job_monitor(self, job_address, heartbeat_job_address):
"""Sending heartbeat signals to check target's status"""
# job_heartbeat_socket: sends heartbeat signal to job # job_heartbeat_socket: sends heartbeat signal to job
job_heartbeat_socket = self.ctx.socket(zmq.REQ) job_heartbeat_socket = self.ctx.socket(zmq.REQ)
job_heartbeat_socket.linger = 0 job_heartbeat_socket.linger = 0
job_heartbeat_socket.setsockopt( job_heartbeat_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000) zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
job_heartbeat_socket.connect("tcp://" + heartbeat_job_address) job_heartbeat_socket.connect("tcp://" + job.worker_heartbeat_address)
job_is_alive = True job.is_alive = True
while job_is_alive and self.master_is_alive and self.worker_is_alive: while job.is_alive and self.master_is_alive and self.worker_is_alive:
try: try:
job_heartbeat_socket.send_multipart( job_heartbeat_socket.send_multipart(
[remote_constants.HEARTBEAT_TAG]) [remote_constants.HEARTBEAT_TAG])
...@@ -252,17 +278,35 @@ class Worker(object): ...@@ -252,17 +278,35 @@ class Worker(object):
time.sleep(remote_constants.HEARTBEAT_INTERVAL_S) time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)
except zmq.error.Again as e: except zmq.error.Again as e:
job_is_alive = False job.is_alive = False
if job_address in self.job_pid: logger.warning(
logger.warning("[Worker] No heartbeat reply from the job, " "[Worker] lost connection with the job:{}".format(
"will kill {}.".format(job_address)) job.job_address))
self._kill_job(job_address) if self.master_is_alive and self.worker_is_alive:
self._kill_job(job.job_address)
except zmq.error.ZMQError as e: except zmq.error.ZMQError as e:
break break
job_heartbeat_socket.close(0) job_heartbeat_socket.close(0)
def _reply_kill_job(self):
"""Worker starts a thread to wait jobs' commands to kill the job"""
self.kill_job_socket.linger = 0
self.kill_job_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
while self.worker_is_alive and self.master_is_alive:
try:
message = self.kill_job_socket.recv_multipart()
assert message[0] == remote_constants.KILLJOB_TAG
to_kill_job_address = to_str(message[1])
self._kill_job(to_kill_job_address)
self.kill_job_socket.send_multipart(
[remote_constants.NORMAL_TAG])
except zmq.error.Again as e:
#detect whether `self.worker_is_alive` is True periodically
pass
def _reply_heartbeat(self, target): def _reply_heartbeat(self, target):
"""Worker will kill its jobs when it lost connection with the master. """Worker will kill its jobs when it lost connection with the master.
""" """
...@@ -273,7 +317,7 @@ class Worker(object): ...@@ -273,7 +317,7 @@ class Worker(object):
remote_constants.HEARTBEAT_RCVTIMEO_S * 1000) remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
heartbeat_master_port =\ heartbeat_master_port =\
socket.bind_to_random_port("tcp://*") socket.bind_to_random_port("tcp://*")
self.heartbeat_master_address = "{}:{}".format(self.worker_ip, self.master_heartbeat_address = "{}:{}".format(self.worker_ip,
heartbeat_master_port) heartbeat_master_port)
self.heartbeat_socket_initialized.set() self.heartbeat_socket_initialized.set()
logger.info("[Worker] Connect to the master node successfully. " logger.info("[Worker] Connect to the master node successfully. "
...@@ -284,14 +328,13 @@ class Worker(object): ...@@ -284,14 +328,13 @@ class Worker(object):
socket.send_multipart([remote_constants.HEARTBEAT_TAG]) socket.send_multipart([remote_constants.HEARTBEAT_TAG])
except zmq.error.Again as e: except zmq.error.Again as e:
self.master_is_alive = False self.master_is_alive = False
for job_address in list(self.job_pid.keys()):
self._kill_job(job_address)
except zmq.error.ContextTerminated as e: except zmq.error.ContextTerminated as e:
break break
socket.close(0) socket.close(0)
logger.warning( logger.warning(
"[Worker] lost connection with the master, will exit replying heartbeat for master." "[Worker] lost connection with the master, will exit replying heartbeat for master."
) )
self.worker_status.clear()
# exit the worker # exit the worker
self.worker_is_alive = False self.worker_is_alive = False
...@@ -300,40 +343,6 @@ class Worker(object): ...@@ -300,40 +343,6 @@ class Worker(object):
self.worker_is_alive = False self.worker_is_alive = False
def run(self): def run(self):
"""An infinite loop waiting for killing job commands from """Keep running until it lost connection with the master.
the mater node.
After creating `cpu_num` jobs and sending job addresses to the master
node, a worker will keep waiting for killing job commands from master
node to release computation resources occupied by a dead client. Then
the worker will kill the jobs related to the dead client and create
new jobs and update job addresses to the master node.
""" """
self.reply_master_hearbeat_thread.join()
self.reply_master_socket.linger = 0
self.reply_master_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
while self.master_is_alive and self.worker_is_alive:
try:
message = self.reply_master_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.KILLJOB_TAG:
job_address = to_str(message[1])
self.reply_master_socket.send_multipart(
[remote_constants.NORMAL_TAG])
self._kill_job(job_address)
else:
raise NotImplementedError
except zmq.error.Again as e:
#detect whether `self.worker_is_alive` is True periodically
pass
self.reply_job_socket.close(0)
self.request_master_socket.close(0)
self.reply_master_socket.close(0)
logger.warning("[Worker] Exit Worker {}.".format(
self.reply_master_address))
self.ctx.destroy()
...@@ -18,8 +18,10 @@ import os ...@@ -18,8 +18,10 @@ import os
import os.path import os.path
import sys import sys
from termcolor import colored from termcolor import colored
import shutil
from datetime import datetime
__all__ = ['set_dir', 'get_dir', 'set_level'] __all__ = ['set_dir', 'get_dir', 'set_level', 'auto_set_dir']
# globals: logger file and directory: # globals: logger file and directory:
LOG_DIR = None LOG_DIR = None
...@@ -37,6 +39,10 @@ def _makedirs(dirname): ...@@ -37,6 +39,10 @@ def _makedirs(dirname):
raise e raise e
def _get_time_str():
return datetime.now().strftime('%m%d-%H%M%S')
class _Formatter(logging.Formatter): class _Formatter(logging.Formatter):
def format(self, record): def format(self, record):
msg = '%(message)s' msg = '%(message)s'
...@@ -82,21 +88,6 @@ def _getlogger(): ...@@ -82,21 +88,6 @@ def _getlogger():
return logger return logger
def create_file_after_first_call(func_name):
def call(*args, **kwargs):
global _logger
if LOG_DIR is None and hasattr(mod, '__file__'):
basename = os.path.basename(mod.__file__)
auto_dirname = os.path.join('log_dir',
basename[:basename.rfind('.')])
set_dir(auto_dirname)
func = getattr(_logger, func_name)
func(*args, **kwargs)
return call
_logger = _getlogger() _logger = _getlogger()
_LOGGING_METHOD = [ _LOGGING_METHOD = [
'info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug', 'info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug',
...@@ -105,7 +96,7 @@ _LOGGING_METHOD = [ ...@@ -105,7 +96,7 @@ _LOGGING_METHOD = [
# export logger functions # export logger functions
for func in _LOGGING_METHOD: for func in _LOGGING_METHOD:
locals()[func] = create_file_after_first_call(func) locals()[func] = getattr(_logger, func)
__all__.append(func) __all__.append(func)
# export Level information # export Level information
...@@ -151,6 +142,70 @@ def set_dir(dirname): ...@@ -151,6 +142,70 @@ def set_dir(dirname):
_set_file(os.path.join(dirname, 'log.log')) _set_file(os.path.join(dirname, 'log.log'))
def auto_set_dir(action=None):
"""Set the global logging directory automatically. The default path is "./train_log/{scriptname}". "scriptname" is the name of the main python file currently running"
Note: This function references `https://github.com/tensorpack/tensorpack/blob/master/tensorpack/utils/logger.py#L93`
Args:
dir_name(str): log directory
action(str): an action of ["k","d","q"] to be performed
when the directory exists. Will ask user by default.
"d": delete the directory. Note that the deletion may fail when
the directory is used by tensorboard.
"k": keep the directory. This is useful when you resume from a
previous training and want the directory to look as if the
training was not interrupted.
Note that this option does not load old models or any other
old states for you. It simply does nothing.
Returns:
dirname(str): log directory used in the global logging directory.
"""
mod = sys.modules['__main__']
basename = os.path.basename(mod.__file__)
dirname = os.path.join('train_log', basename[:basename.rfind('.')])
dirname = os.path.normpath(dirname)
global LOG_DIR, _FILE_HANDLER
if _FILE_HANDLER:
# unload and close the old file handler, so that we may safely delete the logger directory
_logger.removeHandler(_FILE_HANDLER)
del _FILE_HANDLER
def dir_nonempty(dirname):
# If directory exists and nonempty (ignore hidden files), prompt for action
return os.path.isdir(dirname) and len(
[x for x in os.listdir(dirname) if x[0] != '.'])
if dir_nonempty(dirname):
if not action:
_logger.warning("""\
Log directory {} exists! Use 'd' to delete it. """.format(dirname))
_logger.warning("""\
If you're resuming from a previous run, you can choose to keep it.
Press any other key to exit. """)
while not action:
action = input("Select Action: k (keep) / d (delete) / q (quit):"
).lower().strip()
act = action
if act == 'd':
shutil.rmtree(dirname, ignore_errors=True)
if dir_nonempty(dirname):
shutil.rmtree(dirname, ignore_errors=False)
elif act == 'n':
dirname = dirname + _get_time_str()
info("Use a new log directory {}".format(dirname)) # noqa: F821
elif act == 'k':
pass
else:
raise OSError("Directory {} exits!".format(dirname))
LOG_DIR = dirname
_makedirs(dirname)
_set_file(os.path.join(dirname, 'log.log'))
return dirname
def get_dir(): def get_dir():
return LOG_DIR return LOG_DIR
......
...@@ -25,6 +25,12 @@ def create_file_after_first_call(func_name): ...@@ -25,6 +25,12 @@ def create_file_after_first_call(func_name):
def call(*args, **kwargs): def call(*args, **kwargs):
global _writer global _writer
if _writer is None: if _writer is None:
logdir = logger.get_dir()
if logdir is None:
logdir = logger.auto_set_dir(action='d')
logger.warning(
"[tensorboard] logdir is None, will save tensorboard files to {}"
.format(logdir))
_writer = SummaryWriter(logdir=logger.get_dir()) _writer = SummaryWriter(logdir=logger.get_dir())
func = getattr(_writer, func_name) func = getattr(_writer, func_name)
func(*args, **kwargs) func(*args, **kwargs)
......
...@@ -40,6 +40,11 @@ class TestLogger(unittest.TestCase): ...@@ -40,6 +40,11 @@ class TestLogger(unittest.TestCase):
for t in th_list: for t in th_list:
t.join() t.join()
def test_auto_set_dir(self):
logger.auto_set_dir(action='d')
logger.auto_set_dir(action='n')
logger.auto_set_dir(action='k')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
import unittest import unittest
from parl.utils import tensorboard from parl.utils import tensorboard
import numpy as np import numpy as np
from parl.utils import logger
import os
class TestUtils(unittest.TestCase): class TestUtils(unittest.TestCase):
...@@ -24,6 +26,7 @@ class TestUtils(unittest.TestCase): ...@@ -24,6 +26,7 @@ class TestUtils(unittest.TestCase):
x = range(100) x = range(100)
for i in x: for i in x:
tensorboard.add_scalar('y=2x', i * 2, i) tensorboard.add_scalar('y=2x', i * 2, i)
self.assertTrue(os.path.exists('./train_log/tensorboard_test'))
def test_add_histogram(self): def test_add_histogram(self):
for i in range(10): for i in range(10):
...@@ -32,4 +35,5 @@ class TestUtils(unittest.TestCase): ...@@ -32,4 +35,5 @@ class TestUtils(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
logger.auto_set_dir(action='d')
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册