提交 84c94cc4 编写于 作者: B Bo Zhou 提交者: Hongsheng Zeng

Cluster#2 (#124)

* fix a number of bug

* update comments

* paddle license

* revert the change on client_not_init_test.py

* fix timeout bug & unnitest

* add unit test for potential bugs

* yapf

* unit test

* logger

* logger#2

* fix job's bug that will cause dead-lock

* fix bug

* add missing files

* solve hanging

* fix comments; exit job completely

* hanging problem

* hanging problem#2

* hanging_program#3

* CMAKE

* remove unused variables
上级 a7670972
......@@ -60,7 +60,7 @@ function run_test_with_gpu() {
mkdir -p ${REPO_ROOT}/build
cd ${REPO_ROOT}/build
cmake ..
cmake .. -DIS_TESTING_GPU=ON
cat <<EOF
========================================
Running unit tests with GPU...
......
......@@ -3,3 +3,4 @@ paddlepaddle-gpu==1.5.1.post97
gym
details
parameterized
timeout_decorator
......@@ -17,26 +17,34 @@ cmake_minimum_required(VERSION 3.0)
enable_testing()
option(WITH_TESTING "Include unit testing" ON)
option(IS_TESTING_IMPORT "Whether is testing import parl" OFF)
option(IS_TESTING_DOCS "Whether is testing compling the docs" OFF)
option(IS_TESTING_IMPORT "testing import parl" OFF)
option(IS_TESTING_DOCS "testing compling the docs" OFF)
option(IS_TESTING_GPU "testing GPU environment" OFF)
set(PADDLE_PYTHON_PATH "" CACHE STRING "Python path to PaddlePaddle Fluid")
function(py_test TARGET_NAME)
function(py3_test TARGET_NAME)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS ARGS ENVS)
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME}_with_python2
COMMAND env PYTHONPATH=.:${py_test_ENVS}
python -u ${py_test_SRCS} ${py_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
cmake_parse_arguments(py3_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
#TODO: add real python2 env.
add_test(NAME ${TARGET_NAME}_with_python3
COMMAND env PYTHONPATH=.:${py_test_ENVS}
python3.6 -u ${py_test_SRCS} ${py_test_ARGS}
COMMAND python3.6 ${py3_test_SRCS} ${py3_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endfunction()
#function(py2_test TARGET_NAME)
# set(options "")
# set(oneValueArgs "")
# set(multiValueArgs SRCS DEPS ARGS ENVS)
# cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
# #TODO: add real python2 env.
# add_test(NAME ${TARGET_NAME}_with_python2
# COMMAND python ${py_test_SRCS} ${py_test_ARGS}
# WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
#endfunction()
function(import_test TARGET_NAME)
set(options "")
set(oneValueArgs "")
......@@ -44,8 +52,7 @@ function(import_test TARGET_NAME)
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME}_with_empty_env
COMMAND env PYTHONPATH=.:${py_test_ENVS}
/root/miniconda3/envs/empty_env/bin/python -u ${py_test_SRCS} ${py_test_ARGS}
COMMAND /root/miniconda3/envs/empty_env/bin/python -u ${py_test_SRCS} ${py_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endfunction()
......@@ -66,7 +73,13 @@ if (WITH_TESTING)
file(GLOB_RECURSE TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_test.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH})
if (${src} MATCHES ".*remote.*")
if (NOT IS_TESTING_GPU)
py3_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH})
endif()
else()
py3_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH})
endif()
endforeach()
endif()
endif()
......@@ -83,7 +83,7 @@ class Model(ModelBase):
Args:
target_model (`parl.Model`): an instance of ``Model`` that has the same neural network architecture as the current model.
decay (float): the rate of decline in copying parameters. 0 if no parameters decay when synchronizing the parameters.
share_vars_parallel_executor (fluid.ParallelExecutor): if not None, will use fluid.ParallelExecutor
share_vars_parallel_executor (fluid.ParallelExecutor): Optional. If not None, will use fluid.ParallelExecutor
to run program instead of fluid.Executor
"""
self.sync_weights_to(
......@@ -97,15 +97,15 @@ class Model(ModelBase):
share_vars_parallel_executor=None):
"""Synchronize parameters of current model to another model.
To speed up the synchronizing process, will create a program implicitly to finish the process. And will
also cache the program to avoid creating program repeatedly.
To speed up the synchronizing process, it will create a program implicitly to finish the process. It
also stores a program as the cache to avoid creating program repeatedly.
target_model_weights = decay * target_model_weights + (1 - decay) * current_model_weights
Args:
target_model (`parl.Model`): an instance of ``Model`` that has the same neural network architecture as the current model.
decay (float): the rate of decline in copying parameters. 0 if no parameters decay when synchronizing the parameters.
share_vars_parallel_executor (fluid.ParallelExecutor): if not None, will use ``fluid.ParallelExecutor``
share_vars_parallel_executor (fluid.ParallelExecutor): Optional. If not None, will use ``fluid.ParallelExecutor``
to run program instead of ``fluid.Executor``.
Example:
......
......@@ -101,7 +101,7 @@ class Client(object):
"address {} is correct.".format(master_address))
def _reply_heartbeat(self):
"""Reply heartbeat signals to the Master node."""
"""Reply heartbeat signals to the specific node."""
socket = self.ctx.socket(zmq.REP)
socket.linger = 0
......@@ -124,6 +124,57 @@ class Client(object):
socket.close(0)
logger.warning("Client exit replying heartbeat for master.")
def _check_and_monitor_job(self, job_heartbeat_address,
ping_heartbeat_address):
""" Sometimes the client may receive a job that is dead, thus
we have to check if this job is still alive before sending it to the actor.
"""
# job_heartbeat_socket: sends heartbeat signal to job
job_heartbeat_socket = self.ctx.socket(zmq.REQ)
job_heartbeat_socket.linger = 0
job_heartbeat_socket.setsockopt(zmq.RCVTIMEO, int(0.9 * 1000))
job_heartbeat_socket.connect("tcp://" + ping_heartbeat_address)
try:
job_heartbeat_socket.send_multipart(
[remote_constants.HEARTBEAT_TAG])
job_heartbeat_socket.recv_multipart()
except zmq.error.Again:
job_heartbeat_socket.close(0)
logger.error(
"[Client] connects to a finished job, will try again, ping_heartbeat_address:{}"
.format(ping_heartbeat_address))
return False
job_heartbeat_socket.disconnect("tcp://" + ping_heartbeat_address)
job_heartbeat_socket.connect("tcp://" + job_heartbeat_address)
job_heartbeat_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
# a thread for sending heartbeat signals to job
thread = threading.Thread(
target=self._create_job_monitor, args=(job_heartbeat_socket, ))
thread.setDaemon(True)
thread.start()
return True
def _create_job_monitor(self, job_heartbeat_socket):
"""Send heartbeat signals to check target's status"""
job_is_alive = True
while job_is_alive and self.client_is_alive:
try:
job_heartbeat_socket.send_multipart(
[remote_constants.HEARTBEAT_TAG])
_ = job_heartbeat_socket.recv_multipart()
time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)
except zmq.error.Again as e:
job_is_alive = False
except zmq.error.ZMQError as e:
break
job_heartbeat_socket.close(0)
def submit_job(self):
"""Send a job to the Master node.
......@@ -132,11 +183,12 @@ class Client(object):
a vacant job from its job pool to the remote object.
Returns:
IP address of the job.
job_address(str): IP address of the job. None if there is no available CPU in the cluster.
"""
if self.master_is_alive:
# A lock to prevent multiple actor submit job at the same time.
while True:
# A lock to prevent multiple actors from submitting job at the same time.
self.lock.acquire()
self.submit_job_socket.send_multipart([
remote_constants.CLIENT_SUBMIT_TAG,
......@@ -149,18 +201,26 @@ class Client(object):
if tag == remote_constants.NORMAL_TAG:
job_address = to_str(message[1])
job_heartbeat_address = to_str(message[2])
ping_heartbeat_address = to_str(message[3])
# no vacant CPU resources, can not submit a new job
check_result = self._check_and_monitor_job(
job_heartbeat_address, ping_heartbeat_address)
if check_result:
return job_address
# no vacant CPU resources, cannot submit a new job
elif tag == remote_constants.CPU_TAG:
job_address = None
# wait 1 second to avoid requesting in a high frequency.
time.sleep(1)
return job_address
else:
raise NotImplementedError
else:
raise Exception("Client can not submit job to the master, "
"please check if master is connected.")
return job_address
return None
GLOBAL_CLIENT = None
......@@ -203,5 +263,10 @@ def get_global_client():
def disconnect():
"""Disconnect the global client from the master node."""
global GLOBAL_CLIENT
if GLOBAL_CLIENT is not None:
GLOBAL_CLIENT.client_is_alive = False
GLOBAL_CLIENT = None
else:
logger.info(
"No client to be released. Please make sure that you have call `parl.connect`"
)
......@@ -29,6 +29,7 @@ from parl.utils.communication import loads_argument, loads_return,\
dumps_argument, dumps_return
from parl.remote import remote_constants
from parl.utils.exceptions import SerializeError, DeserializeError
from parl.remote.message import InitializedJob
class Job(object):
......@@ -37,74 +38,151 @@ class Job(object):
After establishing connection with the remote object, the job will
create a remote class instance locally and enter an infinite loop,
waiting for commands from the remote object.
"""
def __init__(self, worker_address):
"""
Args:
worker_address(str): worker_address for sending job information(e.g, pid)
"""
self.job_is_alive = True
self.worker_address = worker_address
self.lock = threading.Lock()
self._create_sockets()
def _create_sockets(self):
"""Create two sockets for each job.
"""Create three sockets for each job.
(1) reply_socket: receives the command(i.e, the function name and
args) from the actual class instance, and returns the result of
(1) reply_socket(main socket): receives the command(i.e, the function name and args)
from the actual class instance, completes the computation, and returns the result of
the function.
(2) job_socket: sends job_address and heartbeat_address to worker.
(2) job_socket(functional socket): sends job_address and heartbeat_address to worker.
(3) kill_job_socket: sends a command to the corresponding worker to kill the job.
"""
self.ctx = zmq.Context()
# reply_socket: receives class, parameters and call function from
# @remote.class and send computed results to the @remote.class.
# create the reply_socket
self.reply_socket = self.ctx.socket(zmq.REP)
self.reply_socket.linger = 0
job_port = self.reply_socket.bind_to_random_port(addr="tcp://*")
self.reply_socket.linger = 0
self.job_ip = get_ip_address()
self.job_address = "{}:{}".format(self.job_ip, job_port)
reply_thread = threading.Thread(
target=self._reply_heartbeat,
args=("worker {}".format(self.worker_address), ))
reply_thread.setDaemon(True)
reply_thread.start()
def _reply_heartbeat(self, target):
"""reply heartbeat signals to the target"""
socket = self.ctx.socket(zmq.REP)
socket.setsockopt(zmq.RCVTIMEO,
remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
socket.linger = 0
heartbeat_worker_port = socket.bind_to_random_port(addr="tcp://*")
heartbeat_worker_address = "{}:{}".format(self.job_ip,
heartbeat_worker_port)
# job_socket: sends job_address and heartbeat_address to worker
job_socket = self.ctx.socket(zmq.REQ)
job_socket.connect("tcp://{}".format(self.worker_address))
job_socket.send_multipart([
remote_constants.NORMAL_TAG,
to_byte(self.job_address),
to_byte(heartbeat_worker_address),
to_byte(str(os.getpid()))
])
_ = job_socket.recv_multipart()
# create the job_socket
self.job_socket = self.ctx.socket(zmq.REQ)
self.job_socket.connect("tcp://{}".format(self.worker_address))
# a thread that reply ping signals from the client
ping_heartbeat_socket, ping_heartbeat_address = self._create_heartbeat_server(
timeout=False)
ping_thread = threading.Thread(
target=self._reply_ping, args=(ping_heartbeat_socket, ))
ping_thread.setDaemon(True)
ping_thread.start()
self.ping_heartbeat_address = ping_heartbeat_address
# a thread that reply heartbeat signals from the worker
worker_heartbeat_socket, worker_heartbeat_address = self._create_heartbeat_server(
)
worker_thread = threading.Thread(
target=self._reply_worker_heartbeat,
args=(worker_heartbeat_socket, ))
worker_thread.setDaemon(True)
worker_thread.start()
# a thread that reply heartbeat signals from the client
client_heartbeat_socket, client_heartbeat_address = self._create_heartbeat_server(
)
self.client_thread = threading.Thread(
target=self._reply_client_heartbeat,
args=(client_heartbeat_socket, ))
self.client_thread.setDaemon(True)
# sends job information to the worker
initialized_job = InitializedJob(
self.job_address, worker_heartbeat_address,
client_heartbeat_address, self.ping_heartbeat_address, None,
os.getpid())
self.job_socket.send_multipart(
[remote_constants.NORMAL_TAG,
cloudpickle.dumps(initialized_job)])
message = self.job_socket.recv_multipart()
assert message[0] == remote_constants.NORMAL_TAG
# create the kill_job_socket
kill_job_address = to_str(message[1])
self.kill_job_socket = self.ctx.socket(zmq.REQ)
self.kill_job_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
self.kill_job_socket.connect("tcp://{}".format(kill_job_address))
def _reply_ping(self, socket):
"""Create a socket server that reply the ping signal from client.
This signal is used to make sure that the job is still alive.
"""
while self.job_is_alive:
message = socket.recv_multipart()
socket.send_multipart([remote_constants.HEARTBEAT_TAG])
socket.close(0)
def _create_heartbeat_server(self, timeout=True):
"""Create a socket server that will raises timeout exception.
"""
heartbeat_socket = self.ctx.socket(zmq.REP)
if timeout:
heartbeat_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
heartbeat_socket.linger = 0
heartbeat_port = heartbeat_socket.bind_to_random_port(addr="tcp://*")
heartbeat_address = "{}:{}".format(self.job_ip, heartbeat_port)
return heartbeat_socket, heartbeat_address
def _reply_client_heartbeat(self, socket):
"""Create a socket that replies heartbeat signals from the client.
If the job losts connection with the client, it will exit too.
"""
self.client_is_alive = True
while self.client_is_alive:
try:
message = socket.recv_multipart()
socket.send_multipart([remote_constants.HEARTBEAT_TAG])
except zmq.error.Again as e:
logger.warning(
"[Job] Cannot connect to the client. This job will exit and inform the worker."
)
self.client_is_alive = False
socket.close(0)
with self.lock:
self.kill_job_socket.send_multipart(
[remote_constants.KILLJOB_TAG,
to_byte(self.job_address)])
_ = self.kill_job_socket.recv_multipart()
logger.warning("[Job]lost connection with the client, will exit")
os._exit(1)
def _reply_worker_heartbeat(self, socket):
"""create a socket that replies heartbeat signals from the worker.
If the worker has exited, the job will exit automatically.
"""
# a flag to decide when to exit heartbeat loop
self.worker_is_alive = True
# a flag to decide when to exit heartbeat loop
while self.worker_is_alive and self.job_is_alive:
try:
message = socket.recv_multipart()
socket.send_multipart([remote_constants.HEARTBEAT_TAG])
except zmq.error.Again as e:
logger.warning("[Job] Cannot connect to {}. ".format(target) +
"Job will quit.")
logger.warning("[Job] Cannot connect to the worker{}. ".format(
self.worker_address) + "Job will quit.")
self.worker_is_alive = False
self.job_is_alive = False
socket.close(0)
os._exit(1)
def wait_for_files(self):
"""Wait for python files from remote object.
......@@ -132,7 +210,8 @@ class Job(object):
self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
return envdir
else:
logger.warning(message)
logger.error("NotImplementedError:{}, received tag:{}".format(
self.job_address, ))
raise NotImplementedError
def wait_for_connection(self):
......@@ -146,20 +225,62 @@ class Job(object):
A local instance of the remote class object.
"""
while True:
message = self.reply_socket.recv_multipart()
tag = message[0]
obj = None
if tag == remote_constants.INIT_OBJECT_TAG:
cls = cloudpickle.loads(message[1])
args, kwargs = cloudpickle.loads(message[2])
try:
obj = cls(*args, **kwargs)
except Exception as e:
traceback_str = str(traceback.format_exc())
error_str = str(e)
logger.error("traceback:\n{}".format(traceback_str))
self.reply_socket.send_multipart([
remote_constants.EXCEPTION_TAG,
to_byte(error_str + "\ntraceback:\n" + traceback_str)
])
self.client_is_alive = False
return None
self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
return obj
else:
logger.error("Message from job {}".format(message))
self.reply_socket.send_multipart([
remote_constants.EXCEPTION_TAG,
b"[job]Unkonwn tag when tried to receive the class definition"
])
raise NotImplementedError
return obj
def run(self):
"""An infinite loop waiting for a new task.
"""
# receive source code from the actor and append them to the environment variables.
envdir = self.wait_for_files()
sys.path.append(envdir)
self.client_thread.start()
try:
obj = self.wait_for_connection()
assert obj is not None
self.single_task(obj)
except Exception as e:
logger.error(
"Error occurs when running a single task. We will reset this job. Reason:{}"
.format(e))
traceback_str = str(traceback.format_exc())
logger.error("traceback:\n{}".format(traceback_str))
with self.lock:
self.kill_job_socket.send_multipart(
[remote_constants.KILLJOB_TAG,
to_byte(self.job_address)])
_ = self.kill_job_socket.recv_multipart()
def single_task(self, obj):
"""An infinite loop waiting for commands from the remote object.
Each job will receive two kinds of message from the remote object:
......@@ -171,18 +292,12 @@ class Job(object):
related computation resources.
"""
# receive files
envdir = self.wait_for_files()
sys.path.append(envdir)
obj = self.wait_for_connection()
while self.job_is_alive:
while self.job_is_alive and self.client_is_alive:
message = self.reply_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.CALL_TAG:
assert obj is not None
try:
function_name = to_str(message[1])
data = message[2]
......@@ -194,9 +309,11 @@ class Job(object):
[remote_constants.NORMAL_TAG, ret])
except Exception as e:
# reset the job
self.client_is_alive = False
error_str = str(e)
logger.error(error_str)
self.job_is_alive = False
if type(e) == AttributeError:
self.reply_socket.send_multipart([
......@@ -217,6 +334,7 @@ class Job(object):
remote_constants.DESERIALIZE_EXCEPTION_TAG,
to_byte(error_str)
])
raise DeserializeError
else:
traceback_str = str(traceback.format_exc())
......@@ -226,15 +344,19 @@ class Job(object):
to_byte(error_str + "\ntraceback:\n" +
traceback_str)
])
break
# receive DELETE_TAG from actor, and stop replying worker heartbeat
elif tag == remote_constants.KILLJOB_TAG:
self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
self.job_is_alive = False
logger.warning("An actor exits and will quit job {}.".format(
self.client_is_alive = False
logger.warning(
"An actor exits and this job {} will exit.".format(
self.job_address))
break
else:
logger.error("Job message: {}".format(message))
logger.error(
"The job receives an unknown message: {}".format(message))
raise NotImplementedError
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
class JobCenter(object):
"""The job center deals with everythin related to jobs.
Attributes:
job_pool (set): A set to store the job address of vacant cpu.
worker_dict (dict): A dict to store connected workers.
"""
def __init__(self):
self.job_pool = dict()
self.worker_dict = {}
self.lock = threading.Lock()
@property
def cpu_num(self):
""""Return vacant cpu number."""
return len(self.job_pool)
@property
def worker_num(self):
"""Return connected worker number."""
return len(self.worker_dict)
def add_worker(self, worker):
"""A new worker connects.
Args:
worker (InitializedWorker): New worker with initialized jobs.
"""
self.lock.acquire()
self.worker_dict[worker.worker_address] = worker
for job in worker.initialized_jobs:
self.job_pool[job.job_address] = job
self.lock.release()
def drop_worker(self, worker_address):
"""Remove jobs from job_pool when a worker dies.
Args:
worker (start): Old worker to be removed from the cluster.
"""
self.lock.acquire()
worker = self.worker_dict[worker_address]
for job in worker.initialized_jobs:
if job.job_address in self.job_pool:
self.job_pool.pop(job.job_address)
self.worker_dict.pop(worker_address)
self.lock.release()
def request_job(self):
"""Return a job_address when the client submits a job.
If there is no vacant CPU in the cluster, this will return None.
Return:
An ``InitializedJob`` that has information about available job address.
"""
self.lock.acquire()
job = None
if len(self.job_pool):
job_address, job = self.job_pool.popitem()
self.lock.release()
return job
def reset_job(self, job):
"""Reset a job and add the job_address to the job_pool.
Args:
job(``InitializedJob``): The job information of the restarted job.
"""
self.lock.acquire()
self.job_pool[job.job_address] = job
self.lock.release()
def update_job(self, killed_job_address, new_job, worker_address):
"""When worker kill an old job, it will start a new job.
Args:
killed_job_address (str): The job address of the killed job.
new_job(``InitializedJob``): Information of the new job.
worker_address (str): The worker which kills an old job.
"""
self.lock.acquire()
self.job_pool[new_job.job_address] = new_job
if killed_job_address in self.job_pool:
self.job_pool.pop(killed_job_address)
to_del_idx = None
for i, job in enumerate(
self.worker_dict[worker_address].initialized_jobs):
if job.job_address == killed_job_address:
to_del_idx = i
break
del self.worker_dict[worker_address].initialized_jobs[to_del_idx]
self.worker_dict[worker_address].initialized_jobs.append(new_job)
self.lock.release()
......@@ -18,9 +18,11 @@ import threading
import time
import zmq
from collections import defaultdict
from parl.utils import to_str, to_byte, logger
from parl.remote import remote_constants
from parl.remote.job_center import JobCenter
import cloudpickle
import time
class Master(object):
......@@ -40,27 +42,19 @@ class Master(object):
master node.
Attributes:
worker_pool (dict): A dict to store connected workers.
job_pool (list): A list to store the job address of vacant cpu, when
this number is 0, the master will refuse to create
new remote object.
client_job_dict (dict): A dict of list to record the job submitted by
each client.
job_worker_dict (dict): A dict to record the job and related worker.
job_center (JobCenter): A thread-safe data structure that stores the job address of vacant cpus.
client_socket (zmq.Context.socket): A socket that receives submitted
job from the client, and later sends
job_address back to the client.
worker_socket (zmq.Context.socket): A socket that receives job
addresses from the worker node.
cpu_num(int): the number of available CPUs in the cluster.
cpu_num(int): The number of available CPUs in the cluster.
worker_num(int): The number of workers connected to this cluster.
Args:
port: the ip port that the master node binds to.
port: The ip port that the master node binds to.
"""
def __init__(self, port):
logger.set_dir(os.path.expanduser('~/.parl_data/master/'))
self.lock = threading.Lock()
self.ctx = zmq.Context()
self.client_socket = self.ctx.socket(zmq.REP)
......@@ -68,13 +62,7 @@ class Master(object):
self.client_socket.linger = 0
self.port = port
self.worker_pool = {}
self.worker_locks = {}
self.job_pool = []
self.client_job_dict = defaultdict(list)
self.worker_job_dict = defaultdict(list)
self.job_worker_dict = {}
self.job_center = JobCenter()
self.master_is_alive = True
......@@ -96,13 +84,7 @@ class Master(object):
_ = worker_heartbeat_socket.recv_multipart()
time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)
except zmq.error.Again as e:
for job in self.worker_job_dict[worker_address]:
if job in self.job_pool:
self.job_pool.remove(job)
self.job_worker_dict.pop(job)
self.worker_job_dict.pop(worker_address)
self.worker_pool.pop(worker_address)
self.worker_locks.pop(worker_address)
self.job_center.drop_worker(worker_address)
logger.warning("\n[Master] Cannot connect to the worker " +
"{}. ".format(worker_address) +
"Worker_pool will drop this worker.")
......@@ -115,7 +97,7 @@ class Master(object):
logger.warning("Exit worker monitor from master.")
def _create_client_monitor(self, client_heartbeat_address):
"""when a new client connects to the master, a socket is created to
"""When a new client connects to the master, a socket is created to
send heartbeat signals to the client.
"""
......@@ -125,65 +107,43 @@ class Master(object):
zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
client_heartbeat_socket.connect("tcp://" + client_heartbeat_address)
self.client_is_alive = True
while self.client_is_alive and self.master_is_alive:
client_is_alive = True
while client_is_alive and self.master_is_alive:
try:
client_heartbeat_socket.send_multipart(
[remote_constants.HEARTBEAT_TAG])
_ = client_heartbeat_socket.recv_multipart()
except zmq.error.Again as e:
self.client_is_alive = False
client_is_alive = False
logger.warning("[Master] cannot connect to the client " +
"{}. ".format(client_heartbeat_address) +
"Please check if it is still alive.")
self._kill_client_jobs(client_heartbeat_address)
time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)
logger.warning("Master exits client monitor for {}.\n".format(
client_heartbeat_address))
logger.info(
"Master connects to {} workers and have {} vacant CPUs.\n".format(
len(self.worker_pool), len(self.job_pool)))
self.worker_num, self.cpu_num))
client_heartbeat_socket.close(0)
def _kill_client_jobs(self, client_address):
"""set timeout in case the worker and client quit at the same time.
"""
jobs = self.client_job_dict[client_address]
for job_address in jobs:
if job_address in self.job_worker_dict:
worker_address = self.job_worker_dict[job_address]
# ignore this worker if it has been deleted
if worker_address not in self.worker_pool:
continue
worker_socket = self.worker_pool[worker_address].worker_socket
lock = self.worker_locks[worker_address]
lock.acquire()
worker_socket.send_multipart(
[remote_constants.KILLJOB_TAG,
to_byte(job_address)])
try:
_ = worker_socket.recv_multipart()
except zmq.error.Again as e:
logger.warning("Error in recv kill_client_job")
lock.release()
self.job_worker_dict.pop(job_address)
self.client_job_dict.pop(client_address)
def _print_workers(self):
"""Display `worker_pool` infomation."""
logger.info(
"Master connects to {} workers and have {} vacant CPUs.\n".format(
len(self.worker_pool), len(self.job_pool)))
self.worker_num, self.cpu_num))
@property
def cpu_num(self):
return len(self.job_pool)
return self.job_center.cpu_num
@property
def worker_num(self):
return self.job_center.worker_num
def _receive_message(self):
"""master node will receive four types of message: (1) worker
"""Master node will receive various types of message: (1) worker
connection; (2) worker update; (3) client connection; (4) job
submittion.
submittion; (5) reset job.
"""
message = self.client_socket.recv_multipart()
tag = message[0]
......@@ -193,35 +153,18 @@ class Master(object):
self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
elif tag == remote_constants.WORKER_INITIALIZED_TAG:
worker = pickle.loads(message[1])
worker_heartbeat_address = to_str(message[2])
# maintain job & worker relations
for job_address in worker.job_pool:
self.job_worker_dict[job_address] = worker.address
self.worker_job_dict[worker.address] = worker.job_pool
self.job_pool.extend(worker.job_pool)
# a new socket for submitting job to the worker
worker_socket = self.ctx.socket(zmq.REQ)
worker_socket.linger = 0
worker_socket.setsockopt(zmq.RCVTIMEO, 10000)
worker_socket.connect("tcp://{}".format(worker.address))
worker.worker_socket = worker_socket
self.worker_pool[worker.address] = worker
self.worker_locks[worker.address] = threading.Lock()
logger.info(
"A new worker {} is added, ".format(worker.address) +
"the cluster has {} CPUs.\n".format(len(self.job_pool)))
initialized_worker = cloudpickle.loads(message[1])
self.job_center.add_worker(initialized_worker)
logger.info("A new worker {} is added, ".format(initialized_worker.
worker_address) +
"the cluster has {} CPUs.\n".format(self.cpu_num))
# a thread for sending heartbeat signals to `worker.address`
thread = threading.Thread(
target=self._create_worker_monitor,
args=(
worker_heartbeat_address,
worker.address,
))
args=(initialized_worker.master_heartbeat_address,
initialized_worker.worker_address))
thread.start()
self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
......@@ -240,43 +183,30 @@ class Master(object):
# a client submits a job to the master
elif tag == remote_constants.CLIENT_SUBMIT_TAG:
client_address = to_str(message[1])
done_flag = False
# check available CPU resources
if len(self.job_pool):
if self.cpu_num:
logger.info("Submitting job...")
job_address = self.job_pool.pop(0)
worker_address = self.job_worker_dict[job_address]
self.worker_job_dict[worker_address].remove(job_address)
self.client_socket.send_multipart(
[remote_constants.NORMAL_TAG,
to_byte(job_address)])
self.client_job_dict[client_address].append(job_address)
job = self.job_center.request_job()
self.client_socket.send_multipart([
remote_constants.NORMAL_TAG,
to_byte(job.job_address),
to_byte(job.client_heartbeat_address),
to_byte(job.ping_heartbeat_address)
])
self._print_workers()
else:
self.client_socket.send_multipart([remote_constants.CPU_TAG])
# a worker updates
elif tag == remote_constants.NEW_JOB_TAG:
worker_address = to_str(message[1])
new_job_address = to_str(message[2])
killed_job_address = to_str(message[3])
initialized_job = cloudpickle.loads(message[1])
last_job_address = to_str(message[2])
self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
logger.info("A worker updated.")
if killed_job_address in self.job_worker_dict:
self.job_worker_dict.pop(killed_job_address)
if killed_job_address in self.worker_job_dict[worker_address]:
self.worker_job_dict[worker_address].remove(killed_job_address)
if killed_job_address in self.job_pool:
self.job_pool.remove(killed_job_address)
# add new job_address to job_pool
self.job_pool.append(new_job_address)
self.job_worker_dict[new_job_address] = worker_address
self.worker_job_dict[worker_address].append(new_job_address)
self.job_center.update_job(last_job_address, initialized_job,
initialized_job.worker_address)
logger.info("A worker updated. cpu_num:{}".format(self.cpu_num))
self._print_workers()
......@@ -288,6 +218,8 @@ class Master(object):
raise NotImplementedError()
def exit(self):
""" Close the master.
"""
self.master_is_alive = False
def run(self):
......@@ -313,10 +245,5 @@ class Master(object):
except zmq.error.Again as e:
#detect whether `self.master_is_alive` is True periodically
pass
for worker_address, worker in self.worker_pool.items():
lock = self.worker_locks[worker_address]
lock.acquire()
worker.worker_socket.close(0)
lock.release()
logger.warning("[Master] Exit master.")
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class InitializedJob(object):
def __init__(self, job_address, worker_heartbeat_address,
client_heartbeat_address, ping_heartbeat_address,
worker_address, pid):
"""
Args:
job_address(str): Job address to which the new task connect.
worker_heartbeat_address(str): Optional. The address to which the worker sends heartbeat signals.
client_heartbeat_address(str): Address to which the client sends heartbeat signals.
ping_heartbeat_address(str): the server address to which the client sends ping signals.
The signal is used to check if the job is alive.
worker_address(str): Worker's server address that receive command from the master.
pid(int): Optional. Process id of the job.
is_alive(True): Optional. This flag is used in worker to make sure that only alive jobs can be added into the worker_status.
"""
self.job_address = job_address
self.worker_heartbeat_address = worker_heartbeat_address
self.client_heartbeat_address = client_heartbeat_address
self.ping_heartbeat_address = ping_heartbeat_address
self.worker_address = worker_address
self.pid = pid
self.is_alive = True
class InitializedWorker(object):
def __init__(self, worker_address, master_heartbeat_address,
initialized_jobs, cpu_num):
"""
Args:
worker_address(str): Worker server address that receives commands from the master.
master_heartbeat_address(str): Address to which the worker send heartbeat signals to.
initialized_jobs(list): A list of ``InitializedJob`` containing the information for initialized jobs.
cpu_num(int): The number of CPUs used in this worker.
"""
self.worker_address = worker_address
self.master_heartbeat_address = master_heartbeat_address
self.initialized_jobs = initialized_jobs
self.cpu_num = cpu_num
......@@ -98,23 +98,28 @@ def remote_class(cls):
self.job_socket.linger = 0
self.job_socket.connect("tcp://{}".format(job_address))
self.job_address = job_address
self.job_shutdown = False
self.send_file(self.job_socket)
try:
self.job_socket.send_multipart([
remote_constants.INIT_OBJECT_TAG,
cloudpickle.dumps(cls),
cloudpickle.dumps([args, kwargs])
])
_ = self.job_socket.recv_multipart()
except zmq.error.Again as e:
logger.error("Job socket failed.")
message = self.job_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.EXCEPTION_TAG:
traceback_str = to_str(message[1])
self.job_shutdown = True
raise RemoteError('__init__', traceback_str)
def __del__(self):
"""Delete the remote class object and release remote resources."""
if not self.job_shutdown:
try:
self.job_socket.send_multipart([remote_constants.KILLJOB_TAG])
self.job_socket.send_multipart(
[remote_constants.KILLJOB_TAG])
_ = self.job_socket.recv_multipart()
self.job_socket.close(0)
except AttributeError:
......@@ -137,7 +142,7 @@ def remote_class(cls):
if job_address is not None:
return job_address
if cnt % 30 == 0:
logger.warning("No vacant cpu resources at present, "
logger.warning("No vacant cpu resources at the moment, "
"will try {} times later.".format(cnt))
cnt -= 1
return None
......@@ -146,6 +151,9 @@ def remote_class(cls):
"""Call the function of the unwrapped class."""
def wrapper(*args, **kwargs):
if self.job_shutdown:
raise RemoteError(
attr, "This actor losts connection with the job.")
self.internal_lock.acquire()
data = dumps_argument(*args, **kwargs)
......@@ -161,21 +169,26 @@ def remote_class(cls):
elif tag == remote_constants.EXCEPTION_TAG:
error_str = to_str(message[1])
self.job_shutdown = True
raise RemoteError(attr, error_str)
elif tag == remote_constants.ATTRIBUTE_EXCEPTION_TAG:
error_str = to_str(message[1])
self.job_shutdown = True
raise RemoteAttributeError(attr, error_str)
elif tag == remote_constants.SERIALIZE_EXCEPTION_TAG:
error_str = to_str(message[1])
self.job_shutdown = True
raise RemoteSerializeError(attr, error_str)
elif tag == remote_constants.DESERIALIZE_EXCEPTION_TAG:
error_str = to_str(message[1])
self.job_shutdown = True
raise RemoteDeserializeError(attr, error_str)
else:
self.job_shutdown = True
raise NotImplementedError()
self.internal_lock.release()
......
......@@ -83,11 +83,11 @@ def start_master(port, cpu_num):
cpu_num = str(cpu_num) if cpu_num else ''
start_file = __file__.replace('scripts.pyc', 'start.py')
start_file = start_file.replace('scripts.py', 'start.py')
command = ["python", start_file, "--name", "master", "--port", port]
command = [sys.executable, start_file, "--name", "master", "--port", port]
p = subprocess.Popen(command)
command = [
"python", start_file, "--name", "worker", "--address",
sys.executable, start_file, "--name", "worker", "--address",
"localhost:" + str(port), "--cpu_num",
str(cpu_num)
]
......@@ -112,8 +112,8 @@ def start_worker(address, cpu_num):
address) + "is correct.")
cpu_num = str(cpu_num) if cpu_num else ''
command = [
"python", "{}/start.py".format(__file__[:-11]), "--name", "worker",
"--address", address, "--cpu_num",
sys.executable, "{}/start.py".format(__file__[:-11]), "--name",
"worker", "--address", address, "--cpu_num",
str(cpu_num)
]
p = subprocess.Popen(command)
......@@ -123,8 +123,8 @@ def start_worker(address, cpu_num):
def stop():
command = ("pkill -f remote/start.py")
subprocess.call([command], shell=True)
command = ("pkill -f job.py")
p = subprocess.call([command], shell=True)
command = ("pkill -f remote/job.py")
subprocess.call([command], shell=True)
cli.add_command(start_worker)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
from parl.utils import logger
import signal
import os
class WorkerStatus(object):
"""Maintain worker's information in a worker node.
Attributes:
cpu_num(int): The number of CPUs to be used in this worker.
jobs(set): A set that records job addresses provided to the master.
worker_address(str): Address of the worker.
"""
def __init__(self, worker_address, initialized_jobs, cpu_num):
self.worker_address = worker_address
self.jobs = dict()
for job in initialized_jobs:
self.jobs[job.job_address] = job
self._lock = threading.Lock()
self.cpu_num = cpu_num
def remove_job(self, killed_job):
"""Rmove a job from internal job pool.
Args:
killed_job(str): Job address to be removed.
Returns: True if removing the job succeeds.
"""
ret = False
self._lock.acquire()
if killed_job in self.jobs:
pid = self.jobs[killed_job].pid
self.jobs.pop(killed_job)
ret = True
try:
os.kill(pid, signal.SIGTERM)
except OSError:
logger.warning("job:{} has been killed before".format(pid))
logger.info("[Worker] kills a job:{}".format(killed_job))
self._lock.release()
return ret
def clear(self):
"""Remove all the jobs"""
self._lock.acquire()
for job in self.jobs.values():
try:
os.kill(job.pid, signal.SIGTERM)
except OSError:
logger.warning("job:{} has been killed before".format(job.pid))
logger.info("[Worker] kills a job:{}".format(job.pid))
self.jobs = dict()
self._lock.release()
def add_job(self, new_job):
"""Add a new job to interal job pool.
Args:
new_job(InitializedJob): Initialized job to be added to the self.jobs.
"""
self._lock.acquire()
self.jobs[new_job.job_address] = new_job
assert len(self.jobs) <= self.cpu_num
self._lock.release()
......@@ -20,6 +20,9 @@ from parl.remote.worker import Worker
import time
import threading
from parl.remote.client import disconnect
from parl.remote import exceptions
import timeout_decorator
import subprocess
@parl.remote_class
......@@ -55,25 +58,71 @@ class Actor(object):
x = 1 / 0
class TestExit(unittest.TestCase):
def test_delete_worker(self):
# start the master
class TestCluster(unittest.TestCase):
def tearDown(self):
disconnect()
#time.sleep(20)
#command = ("pkill -f remote/job.py")
#subprocess.call([command], shell=True)
def test_actor_exception(self):
master = Master(port=1235)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:1235', 4)
worker1 = Worker('localhost:1235', 1)
self.assertEqual(1, master.cpu_num)
parl.connect('localhost:1235')
for i in range(4):
with self.assertRaises(exceptions.RemoteError):
actor = Actor(abcd='a bug')
actor2 = Actor()
self.assertEqual(actor2.add_one(1), 2)
self.assertEqual(0, master.cpu_num)
master.exit()
worker1.exit()
@timeout_decorator.timeout(seconds=300)
def test_actor_exception(self):
master = Master(port=1236)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:1236', 1)
self.assertEqual(1, master.cpu_num)
parl.connect('localhost:1236')
actor = Actor()
try:
actor.will_raise_exception_func()
except:
pass
actor2 = Actor()
time.sleep(30)
self.assertEqual(actor2.add_one(1), 2)
self.assertEqual(0, master.cpu_num)
del actor
del actor2
worker1.exit()
master.exit()
def test_reset_actor(self):
# start the master
master = Master(port=1237)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:1237', 4)
parl.connect('localhost:1237')
for i in range(10):
actor = Actor()
ret = actor.add_one(1)
self.assertEqual(ret, 2)
del actor
time.sleep(20)
self.assertEqual(master.cpu_num, 4)
worker1.exit()
time.sleep(30)
disconnect()
time.sleep(30)
master.exit()
def test_add_worker(self):
......@@ -91,6 +140,7 @@ class TestExit(unittest.TestCase):
self.assertEqual(master.cpu_num, 4)
master.exit()
worker1.exit()
if __name__ == '__main__':
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from parl.remote.job_center import JobCenter
from parl.remote.message import InitializedWorker, InitializedJob
class InitializedWorker(object):
def __init__(self,
worker_address,
master_heartbeat_address='localhost:8010',
initialized_jobs=[],
cpu_num=4):
self.worker_address = worker_address
self.master_heartbeat_address = master_heartbeat_address
self.initialized_jobs = initialized_jobs
self.cpu_num = cpu_num
class ImportTest(unittest.TestCase):
def setUp(self):
jobs = []
for i in range(5):
job = InitializedJob(
job_address='172.18.182.39:{}'.format(1234 + i),
worker_heartbeat_address='172.18.182.39:48724',
client_heartbeat_address='172.18.182.39:48725',
ping_heartbeat_address='172.18.182.39:48726',
worker_address='172.18.182.39:478727',
pid=1234)
jobs.append(job)
self.worker1 = InitializedWorker(
worker_address='172.18.182.39:8001', initialized_jobs=jobs)
jobs = []
for i in range(5):
job = InitializedJob(
job_address='172.18.182.39:{}'.format(2234 + i),
worker_heartbeat_address='172.18.182.39:48724',
client_heartbeat_address='172.18.182.39:48725',
ping_heartbeat_address='172.18.182.39:48726',
worker_address='172.18.182.39:478727',
pid=1234)
jobs.append(job)
self.worker2 = InitializedWorker(
worker_address='172.18.182.39:8002', initialized_jobs=jobs)
def test_add_worker(self):
job_center = JobCenter()
job_center.add_worker(self.worker1)
job_center.add_worker(self.worker2)
self.assertEqual(len(job_center.job_pool), 10)
self.assertEqual(job_center.worker_dict[self.worker1.worker_address],
self.worker1)
def test_drop_worker(self):
job_center = JobCenter()
job_center.add_worker(self.worker1)
job_center.add_worker(self.worker2)
job_center.drop_worker(self.worker2.worker_address)
self.assertEqual(
set(job_center.job_pool.values()),
set(self.worker1.initialized_jobs))
self.assertEqual(len(job_center.worker_dict), 1)
def test_request_job(self):
job_center = JobCenter()
job_address1 = job_center.request_job()
self.assertTrue(job_address1 is None)
job_center.add_worker(self.worker1)
job_address2 = job_center.request_job()
self.assertTrue(job_address2 in self.worker1.initialized_jobs)
self.assertEqual(len(job_center.job_pool), 4)
def test_reset_job(self):
job_center = JobCenter()
job_center.add_worker(self.worker1)
job_address = job_center.request_job()
self.assertTrue(job_address in self.worker1.initialized_jobs)
self.assertEqual(len(job_center.job_pool), 4)
job_center.reset_job(job_address)
self.assertEqual(len(job_center.job_pool), 5)
def test_update_job(self):
job_center = JobCenter()
job_center.add_worker(self.worker1)
job_center.add_worker(self.worker2)
job = InitializedJob(
job_address='172.18.182.39:{}'.format(9245),
worker_heartbeat_address='172.18.182.39:48724',
client_heartbeat_address='172.18.182.39:48725',
ping_heartbeat_address='172.18.182.39:48726',
worker_address='172.18.182.39:478727',
pid=1234)
job_center.update_job('172.18.182.39:2234', job, '172.18.182.39:8002')
current_job_address = set([
job.job_address for job in job_center.
worker_dict['172.18.182.39:8002'].initialized_jobs
])
self.assertEqual(
current_job_address,
set([
'172.18.182.39:9245', '172.18.182.39:2235',
'172.18.182.39:2236', '172.18.182.39:2237',
'172.18.182.39:2238'
]))
job_pool_address = set(job_center.job_pool.keys())
self.assertEqual(
job_pool_address,
set([
'172.18.182.39:9245', '172.18.182.39:2235',
'172.18.182.39:2236', '172.18.182.39:2237',
'172.18.182.39:2238', '172.18.182.39:1234',
'172.18.182.39:1235', '172.18.182.39:1236',
'172.18.182.39:1237', '172.18.182.39:1238'
]))
job_center.drop_worker(self.worker2.worker_address)
self.assertEqual(5, len(self.worker1.initialized_jobs))
def test_cpu_num(self):
job_center = JobCenter()
job_center.add_worker(self.worker1)
self.assertEqual(job_center.cpu_num, 5)
job_center.add_worker(self.worker2)
self.assertEqual(job_center.cpu_num, 10)
job_center.request_job()
self.assertEqual(job_center.cpu_num, 9)
def test_worker_num(self):
job_center = JobCenter()
job_center.add_worker(self.worker1)
self.assertEqual(job_center.worker_num, 1)
job_center.add_worker(self.worker2)
self.assertEqual(job_center.worker_num, 2)
job_center.drop_worker(self.worker1.worker_address)
self.assertEqual(job_center.worker_num, 1)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.client import disconnect
from parl.utils import logger
import subprocess
import time
import threading
import timeout_decorator
import subprocess
import sys
@parl.remote_class
class Actor(object):
def __init__(self, arg1=None, arg2=None):
self.arg1 = arg1
self.arg2 = arg2
def get_arg1(self):
return self.arg1
def get_arg2(self):
return self.arg2
def set_arg1(self, value):
self.arg1 = value
def set_arg2(self, value):
self.arg2 = value
def get_unable_serialize_object(self):
return UnableSerializeObject()
def add_one(self, value):
value += 1
return value
def add(self, x, y):
time.sleep(3)
return x + y
def will_raise_exception_func(self):
x = 1 / 0
class TestJob(unittest.TestCase):
def tearDown(self):
disconnect()
def test_job_exit_exceptionally(self):
master = Master(port=1334)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:1334', 4)
time.sleep(10)
self.assertEqual(worker1.job_buffer.full(), True)
time.sleep(1)
self.assertEqual(master.cpu_num, 4)
print("We are going to kill all the jobs.")
command = ("pkill -f remote/job.py")
subprocess.call([command], shell=True)
parl.connect('localhost:1334')
actor = Actor()
self.assertEqual(actor.add_one(1), 2)
time.sleep(20)
master.exit()
worker1.exit()
@timeout_decorator.timeout(seconds=300)
def test_acor_exit_exceptionally(self):
master = Master(port=1335)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:1335', 1)
file_path = __file__.replace('reset_job_test', 'simulate_client')
command = [sys.executable, file_path]
proc = subprocess.Popen(command)
time.sleep(10)
self.assertEqual(master.cpu_num, 0)
proc.kill()
parl.connect('localhost:1335')
actor = Actor()
master.exit()
worker1.exit()
disconnect()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import parl
@parl.remote_class
class Actor(object):
def add_one(self, value):
value += 1
return value
def train():
parl.connect('localhost:1335')
actor = Actor()
actor.add_one(1)
time.sleep(100000)
if __name__ == '__main__':
train()
......@@ -25,22 +25,14 @@ import zmq
from parl.utils import get_ip_address, to_byte, to_str, logger
from parl.remote import remote_constants
from parl.remote.message import InitializedWorker
from parl.remote.status import WorkerStatus
from six.moves import queue
if sys.version_info.major == 3:
warnings.simplefilter("ignore", ResourceWarning)
class WorkerInfo(object):
"""A WorkerInfo object records the computation resources of a worker.
"""
def __init__(self, address, cpu_num, job_pool):
self.address = address
self.cpu_num = cpu_num
self.job_pool = job_pool
self.worker_socket = None
class Worker(object):
"""Worker provides the cpu computation resources for the cluster.
......@@ -58,7 +50,6 @@ class Worker(object):
xparl connect --address localhost:1234 --cpu_num 8
Attributes:
job_pid (dict): A dict of subprocess id and its address.
master_address (str): Master's ip address.
request_master_socket (zmq.Context.socket): A socket which sends job
address to the master node.
......@@ -67,22 +58,37 @@ class Worker(object):
node.
reply_job_socket (zmq.Context.socket): A socket which receives
job_address from the job.
kill_job_socket (zmq.Context.socket): A socket that receives commands to kill the job from jobs.
Args:
master_address (str): IP address of the master node.
cpu_num (int): Number of cpu to be used on the worker.
job_buffer (str): A buffer that stores initialized jobs for providing new jobs in a short time.
"""
def __init__(self, master_address, cpu_num=None):
self.lock = threading.Lock()
self.heartbeat_socket_initialized = threading.Event()
self.ctx = zmq.Context.instance()
self.job_pid = {}
self.master_address = master_address
self.master_is_alive = True
self.worker_is_alive = True
self.worker_status = None # initialized at `self._create_jobs`
self.lock = threading.Lock()
self._set_cpu_num(cpu_num)
self.job_buffer = queue.Queue(maxsize=self.cpu_num)
self._create_sockets()
self._create_worker()
# create a thread that waits commands from the job to kill the job.
self.kill_job_thread = threading.Thread(target=self._reply_kill_job)
self.kill_job_thread.start()
self._create_jobs()
# create a thread that initializes jobs and adds them into the job_buffer
job_thread = threading.Thread(target=self._fill_job_buffer)
job_thread.setDaemon(True)
job_thread.start()
def _set_cpu_num(self, cpu_num=None):
"""set useable cpu number for worker"""
......@@ -95,14 +101,15 @@ class Worker(object):
self.cpu_num = multiprocessing.cpu_count()
def _create_sockets(self):
""" Each worker has three sockets at start:
""" Each worker has four sockets at start:
(1) request_master_socket: sends job address to master node.
(2) reply_master_socket: accepts submitted job from master node.
(3) reply_job_socket: receives job_address from subprocess.
(4) kill_job_socket : receives commands to kill the job from jobs.
When a job is start, a new heartbeat socket is created to receive
heartbeat signal from the job.
When a job starts, a new heartbeat socket is created to receive
heartbeat signals from the job.
"""
......@@ -131,8 +138,14 @@ class Worker(object):
reply_job_port = self.reply_job_socket.bind_to_random_port("tcp://*")
self.reply_job_address = "{}:{}".format(self.worker_ip, reply_job_port)
def _create_worker(self):
"""create a WorkerInfo instance and send it to the master."""
# kill_job_socket
self.kill_job_socket = self.ctx.socket(zmq.REP)
self.kill_job_socket.linger = 0
kill_job_port = self.kill_job_socket.bind_to_random_port("tcp://*")
self.kill_job_address = "{}:{}".format(self.worker_ip, kill_job_port)
def _create_jobs(self):
"""Create jobs and send a instance of ``InitializedWorker`` that contains the worker information to the master."""
try:
self.request_master_socket.send_multipart(
[remote_constants.WORKER_CONNECT_TAG])
......@@ -143,24 +156,40 @@ class Worker(object):
self.master_is_alive = False
return
self._init_jobs(job_num=self.cpu_num)
initialized_jobs = self._init_jobs(job_num=self.cpu_num)
self.request_master_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
self.worker = WorkerInfo(self.reply_master_address, self.cpu_num,
list(self.job_pid.keys()))
reply_thread = threading.Thread(
self.reply_master_hearbeat_thread = threading.Thread(
target=self._reply_heartbeat,
args=("master {}".format(self.master_address), ))
reply_thread.start()
self.reply_master_hearbeat_thread.start()
self.heartbeat_socket_initialized.wait()
initialized_worker = InitializedWorker(self.reply_master_address,
self.master_heartbeat_address,
initialized_jobs, self.cpu_num)
self.request_master_socket.send_multipart([
remote_constants.WORKER_INITIALIZED_TAG,
cloudpickle.dumps(self.worker),
to_byte(self.heartbeat_master_address)
cloudpickle.dumps(initialized_worker)
])
_ = self.request_master_socket.recv_multipart()
self.worker_status = WorkerStatus(self.reply_master_address,
initialized_jobs, self.cpu_num)
def _fill_job_buffer(self):
"""An endless loop that adds initialized job into the job buffer"""
while self.worker_is_alive:
initialized_jobs = self._init_jobs(job_num=self.cpu_num)
for job in initialized_jobs:
self.job_buffer.put(job)
# release jobs if the worker is not alive
for job in initialized_jobs:
try:
os.kill(job.pid, signal.SIGTERM)
except OSError:
pass
def _init_jobs(self, job_num):
"""Create jobs.
......@@ -175,76 +204,73 @@ class Worker(object):
self.reply_job_address
]
# avoid that many jobs are killed and restarted at the same time.
self.lock.acquire()
# Redirect the output to DEVNULL
FNULL = open(os.devnull, 'w')
for _ in range(job_num):
pid = subprocess.Popen(
command, stdout=FNULL, stderr=subprocess.STDOUT)
subprocess.Popen(command, stdout=FNULL, stderr=subprocess.STDOUT)
FNULL.close()
new_job_address = []
new_jobs = []
for _ in range(job_num):
job_message = self.reply_job_socket.recv_multipart()
self.reply_job_socket.send_multipart([remote_constants.NORMAL_TAG])
job_address = to_str(job_message[1])
new_job_address.append(job_address)
heartbeat_job_address = to_str(job_message[2])
pid = to_str(job_message[3])
self.job_pid[job_address] = int(pid)
self.reply_job_socket.send_multipart(
[remote_constants.NORMAL_TAG,
to_byte(self.kill_job_address)])
initialized_job = cloudpickle.loads(job_message[1])
initialized_job.worker_address = self.reply_master_address
new_jobs.append(initialized_job)
# a thread for sending heartbeat signals to job
thread = threading.Thread(
target=self._create_job_monitor,
args=(
job_address,
heartbeat_job_address,
))
target=self._create_job_monitor, args=(initialized_job, ))
thread.start()
assert len(new_job_address) > 0, "init jobs failed"
if len(new_job_address) > 1:
return new_job_address
else:
return new_job_address[0]
def _kill_job(self, job_address):
"""kill problematic job process and update worker information"""
if job_address in self.job_pid:
self.lock.acquire()
pid = self.job_pid[job_address]
try:
os.kill(pid, signal.SIGTERM)
except OSError:
logger.warn("job:{} has been killed before".format(pid))
self.job_pid.pop(job_address)
logger.warning("Worker kills job process {},".format(job_address))
self.lock.release()
assert len(new_jobs) > 0, "init jobs failed"
return new_jobs
# When a old job is killed, the worker will create a new job.
if self.master_is_alive:
new_job_address = self._init_jobs(job_num=1)
def _kill_job(self, job_address):
"""Kill a job process and update worker information"""
success = self.worker_status.remove_job(job_address)
if success:
while True:
initialized_job = self.job_buffer.get()
if initialized_job.is_alive:
self.worker_status.add_job(initialized_job)
if not initialized_job.is_alive: # make sure that the job is still alive.
self.worker_status.remove_job(
initialized_job.job_address)
continue
else:
logger.warning(
"[Worker] a dead job found. The job buffer will not accept this one."
)
if initialized_job.is_alive:
break
self.lock.acquire()
self.request_master_socket.send_multipart([
remote_constants.NEW_JOB_TAG,
to_byte(self.reply_master_address),
to_byte(new_job_address),
cloudpickle.dumps(initialized_job),
to_byte(job_address)
])
_ = self.request_master_socket.recv_multipart()
self.lock.release()
def _create_job_monitor(self, job_address, heartbeat_job_address):
"""Sending heartbeat signals to check target's status"""
def _create_job_monitor(self, job):
"""Send heartbeat signals to check target's status"""
# job_heartbeat_socket: sends heartbeat signal to job
job_heartbeat_socket = self.ctx.socket(zmq.REQ)
job_heartbeat_socket.linger = 0
job_heartbeat_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
job_heartbeat_socket.connect("tcp://" + heartbeat_job_address)
job_heartbeat_socket.connect("tcp://" + job.worker_heartbeat_address)
job_is_alive = True
while job_is_alive and self.master_is_alive and self.worker_is_alive:
job.is_alive = True
while job.is_alive and self.master_is_alive and self.worker_is_alive:
try:
job_heartbeat_socket.send_multipart(
[remote_constants.HEARTBEAT_TAG])
......@@ -252,17 +278,35 @@ class Worker(object):
time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)
except zmq.error.Again as e:
job_is_alive = False
if job_address in self.job_pid:
logger.warning("[Worker] No heartbeat reply from the job, "
"will kill {}.".format(job_address))
self._kill_job(job_address)
job.is_alive = False
logger.warning(
"[Worker] lost connection with the job:{}".format(
job.job_address))
if self.master_is_alive and self.worker_is_alive:
self._kill_job(job.job_address)
except zmq.error.ZMQError as e:
break
job_heartbeat_socket.close(0)
def _reply_kill_job(self):
"""Worker starts a thread to wait jobs' commands to kill the job"""
self.kill_job_socket.linger = 0
self.kill_job_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
while self.worker_is_alive and self.master_is_alive:
try:
message = self.kill_job_socket.recv_multipart()
assert message[0] == remote_constants.KILLJOB_TAG
to_kill_job_address = to_str(message[1])
self._kill_job(to_kill_job_address)
self.kill_job_socket.send_multipart(
[remote_constants.NORMAL_TAG])
except zmq.error.Again as e:
#detect whether `self.worker_is_alive` is True periodically
pass
def _reply_heartbeat(self, target):
"""Worker will kill its jobs when it lost connection with the master.
"""
......@@ -273,7 +317,7 @@ class Worker(object):
remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
heartbeat_master_port =\
socket.bind_to_random_port("tcp://*")
self.heartbeat_master_address = "{}:{}".format(self.worker_ip,
self.master_heartbeat_address = "{}:{}".format(self.worker_ip,
heartbeat_master_port)
self.heartbeat_socket_initialized.set()
logger.info("[Worker] Connect to the master node successfully. "
......@@ -284,14 +328,13 @@ class Worker(object):
socket.send_multipart([remote_constants.HEARTBEAT_TAG])
except zmq.error.Again as e:
self.master_is_alive = False
for job_address in list(self.job_pid.keys()):
self._kill_job(job_address)
except zmq.error.ContextTerminated as e:
break
socket.close(0)
logger.warning(
"[Worker] lost connection with the master, will exit replying heartbeat for master."
)
self.worker_status.clear()
# exit the worker
self.worker_is_alive = False
......@@ -300,40 +343,6 @@ class Worker(object):
self.worker_is_alive = False
def run(self):
"""An infinite loop waiting for killing job commands from
the mater node.
After creating `cpu_num` jobs and sending job addresses to the master
node, a worker will keep waiting for killing job commands from master
node to release computation resources occupied by a dead client. Then
the worker will kill the jobs related to the dead client and create
new jobs and update job addresses to the master node.
"""Keep running until it lost connection with the master.
"""
self.reply_master_socket.linger = 0
self.reply_master_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
while self.master_is_alive and self.worker_is_alive:
try:
message = self.reply_master_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.KILLJOB_TAG:
job_address = to_str(message[1])
self.reply_master_socket.send_multipart(
[remote_constants.NORMAL_TAG])
self._kill_job(job_address)
else:
raise NotImplementedError
except zmq.error.Again as e:
#detect whether `self.worker_is_alive` is True periodically
pass
self.reply_job_socket.close(0)
self.request_master_socket.close(0)
self.reply_master_socket.close(0)
logger.warning("[Worker] Exit Worker {}.".format(
self.reply_master_address))
self.ctx.destroy()
self.reply_master_hearbeat_thread.join()
......@@ -18,8 +18,10 @@ import os
import os.path
import sys
from termcolor import colored
import shutil
from datetime import datetime
__all__ = ['set_dir', 'get_dir', 'set_level']
__all__ = ['set_dir', 'get_dir', 'set_level', 'auto_set_dir']
# globals: logger file and directory:
LOG_DIR = None
......@@ -37,6 +39,10 @@ def _makedirs(dirname):
raise e
def _get_time_str():
return datetime.now().strftime('%m%d-%H%M%S')
class _Formatter(logging.Formatter):
def format(self, record):
msg = '%(message)s'
......@@ -82,21 +88,6 @@ def _getlogger():
return logger
def create_file_after_first_call(func_name):
def call(*args, **kwargs):
global _logger
if LOG_DIR is None and hasattr(mod, '__file__'):
basename = os.path.basename(mod.__file__)
auto_dirname = os.path.join('log_dir',
basename[:basename.rfind('.')])
set_dir(auto_dirname)
func = getattr(_logger, func_name)
func(*args, **kwargs)
return call
_logger = _getlogger()
_LOGGING_METHOD = [
'info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug',
......@@ -105,7 +96,7 @@ _LOGGING_METHOD = [
# export logger functions
for func in _LOGGING_METHOD:
locals()[func] = create_file_after_first_call(func)
locals()[func] = getattr(_logger, func)
__all__.append(func)
# export Level information
......@@ -151,6 +142,70 @@ def set_dir(dirname):
_set_file(os.path.join(dirname, 'log.log'))
def auto_set_dir(action=None):
"""Set the global logging directory automatically. The default path is "./train_log/{scriptname}". "scriptname" is the name of the main python file currently running"
Note: This function references `https://github.com/tensorpack/tensorpack/blob/master/tensorpack/utils/logger.py#L93`
Args:
dir_name(str): log directory
action(str): an action of ["k","d","q"] to be performed
when the directory exists. Will ask user by default.
"d": delete the directory. Note that the deletion may fail when
the directory is used by tensorboard.
"k": keep the directory. This is useful when you resume from a
previous training and want the directory to look as if the
training was not interrupted.
Note that this option does not load old models or any other
old states for you. It simply does nothing.
Returns:
dirname(str): log directory used in the global logging directory.
"""
mod = sys.modules['__main__']
basename = os.path.basename(mod.__file__)
dirname = os.path.join('train_log', basename[:basename.rfind('.')])
dirname = os.path.normpath(dirname)
global LOG_DIR, _FILE_HANDLER
if _FILE_HANDLER:
# unload and close the old file handler, so that we may safely delete the logger directory
_logger.removeHandler(_FILE_HANDLER)
del _FILE_HANDLER
def dir_nonempty(dirname):
# If directory exists and nonempty (ignore hidden files), prompt for action
return os.path.isdir(dirname) and len(
[x for x in os.listdir(dirname) if x[0] != '.'])
if dir_nonempty(dirname):
if not action:
_logger.warning("""\
Log directory {} exists! Use 'd' to delete it. """.format(dirname))
_logger.warning("""\
If you're resuming from a previous run, you can choose to keep it.
Press any other key to exit. """)
while not action:
action = input("Select Action: k (keep) / d (delete) / q (quit):"
).lower().strip()
act = action
if act == 'd':
shutil.rmtree(dirname, ignore_errors=True)
if dir_nonempty(dirname):
shutil.rmtree(dirname, ignore_errors=False)
elif act == 'n':
dirname = dirname + _get_time_str()
info("Use a new log directory {}".format(dirname)) # noqa: F821
elif act == 'k':
pass
else:
raise OSError("Directory {} exits!".format(dirname))
LOG_DIR = dirname
_makedirs(dirname)
_set_file(os.path.join(dirname, 'log.log'))
return dirname
def get_dir():
return LOG_DIR
......
......@@ -25,6 +25,12 @@ def create_file_after_first_call(func_name):
def call(*args, **kwargs):
global _writer
if _writer is None:
logdir = logger.get_dir()
if logdir is None:
logdir = logger.auto_set_dir(action='d')
logger.warning(
"[tensorboard] logdir is None, will save tensorboard files to {}"
.format(logdir))
_writer = SummaryWriter(logdir=logger.get_dir())
func = getattr(_writer, func_name)
func(*args, **kwargs)
......
......@@ -40,6 +40,11 @@ class TestLogger(unittest.TestCase):
for t in th_list:
t.join()
def test_auto_set_dir(self):
logger.auto_set_dir(action='d')
logger.auto_set_dir(action='n')
logger.auto_set_dir(action='k')
if __name__ == '__main__':
unittest.main()
......@@ -14,6 +14,8 @@
import unittest
from parl.utils import tensorboard
import numpy as np
from parl.utils import logger
import os
class TestUtils(unittest.TestCase):
......@@ -24,6 +26,7 @@ class TestUtils(unittest.TestCase):
x = range(100)
for i in x:
tensorboard.add_scalar('y=2x', i * 2, i)
self.assertTrue(os.path.exists('./train_log/tensorboard_test'))
def test_add_histogram(self):
for i in range(10):
......@@ -32,4 +35,5 @@ class TestUtils(unittest.TestCase):
if __name__ == '__main__':
logger.auto_set_dir(action='d')
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册