提交 64aebb6d 编写于 作者: H Hongsheng Zeng 提交者: Bo Zhou

make job run task in a separate process (#170)

* make job run task in a separate process

* fix typo

* add more debug info in xparl client

* refine control flow of different processes in xparl job

* refine control flow of different processes in xparl job

* remove tsinghua source

* remove tsinghua source

* remove unnecessary logic

* fix typo

* refine comments and some logic

* fix bug, `decay=0` means totally synchronize weights of source model to target model
上级 ee36f15b
......@@ -173,7 +173,7 @@ function main() {
run_test_with_gpu
#
/root/miniconda3/envs/empty_env/bin/pip install -i https://pypi.tuna.tsinghua.edu.cn/simple .
/root/miniconda3/envs/empty_env/bin/pip install .
run_import_test
run_docs_test
;;
......
......@@ -4,5 +4,5 @@ source ~/.bashrc
export PATH="/root/miniconda3/bin:$PATH"
source deactivate
source activate docs
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple /work/
pip install /work/
make html
......@@ -40,7 +40,7 @@ For final submission, we test our model in 500 CPUs, running 10 episodes per CPU
2. Download the model file from online stroage service, [Baidu Pan](https://pan.baidu.com/s/1NN1auY2eDblGzUiqR8Bfqw) or [Google Drive](https://drive.google.com/open?id=1DQHrwtXzgFbl9dE7jGOe9ZbY0G9-qfq3)
3. Unpack the file by using:
`tar zxvf saved_model.tar.gz`
4. Launch test scription:
4. Launch the test script:
`python test.py`
## Part2: Curriculum learning
......
......@@ -59,7 +59,7 @@ class OpenSimAgent(parl.Agent):
# Attention: In the beginning, sync target model totally.
self.alg.sync_target(
model_id=i,
decay=1.0,
decay=0,
share_vars_parallel_executor=self.learn_pe[i])
# Do cache, will create ParallelExecutor of sync params in advance
# If not, there are some issues when ensemble_num > 1
......
......@@ -14,5 +14,5 @@
2. Download the model file from online stroage service: [Baidu Pan](https://pan.baidu.com/s/12LIPspckCT8-Q5U1QX69Fg) (password: `b5ck`) or [Google Drive](https://drive.google.com/file/d/1jJtOcOVJ6auz3s-TyWgUJvofPXI94yxy/view?usp=sharing)
3. Unpack the file:
`tar zxvf saved_models.tar.gz`
4. Launch test scription:
4. Launch the test script:
`python test.py`
......@@ -91,20 +91,27 @@ class Client(object):
working directory.
"""
pyfiles = dict()
pyfiles['python_files'] = {}
pyfiles['other_files'] = {}
code_files = filter(lambda x: x.endswith('.py'), os.listdir('./'))
to_distributed_files = list(code_files) + distributed_files
for file in to_distributed_files:
try:
try:
for file in code_files:
assert os.path.exists(file)
with open(file, 'rb') as code_file:
code = code_file.read()
pyfiles[file] = code
except AssertionError as e:
raise Exception(
'Failed to create the client, the file {} does not exist.'.
format(file))
pyfiles['python_files'][file] = code
for file in distributed_files:
assert os.path.exists(file)
with open(file, 'rb') as f:
content = f.read()
pyfiles['other_files'][file] = content
except AssertionError as e:
raise Exception(
'Failed to create the client, the file {} does not exist.'.
format(file))
return cloudpickle.dumps(pyfiles)
def _create_sockets(self, master_address):
......@@ -173,7 +180,7 @@ class Client(object):
logger.warning("Client exit replying heartbeat for master.")
def _check_and_monitor_job(self, job_heartbeat_address,
ping_heartbeat_address):
ping_heartbeat_address, max_memory):
""" Sometimes the client may receive a job that is dead, thus
we have to check if this job is still alive before sending it to the actor.
"""
......@@ -184,7 +191,8 @@ class Client(object):
job_heartbeat_socket.connect("tcp://" + ping_heartbeat_address)
try:
job_heartbeat_socket.send_multipart(
[remote_constants.HEARTBEAT_TAG])
[remote_constants.HEARTBEAT_TAG,
to_byte(str(max_memory))])
job_heartbeat_socket.recv_multipart()
except zmq.error.Again:
job_heartbeat_socket.close(0)
......@@ -231,6 +239,9 @@ class Client(object):
job_is_alive = False
self.lock.acquire()
self.actor_num -= 1
logger.error(
'[xparl] lost connection with a job, current actor num: {}'
.format(self.actor_num))
self.lock.release()
except zmq.error.ZMQError as e:
......@@ -238,13 +249,18 @@ class Client(object):
job_heartbeat_socket.close(0)
def submit_job(self):
def submit_job(self, max_memory):
"""Send a job to the Master node.
When a `@parl.remote_class` object is created, the global client
sends a job to the master node. Then the master node will allocate
a vacant job from its job pool to the remote object.
Args:
max_memory (float): Maximum memory (MB) can be used by each remote
instance, the unit is in MB and default value is
none(unlimited).
Returns:
job_address(str): IP address of the job. None if there is no available CPU in the cluster.
"""
......@@ -268,7 +284,8 @@ class Client(object):
ping_heartbeat_address = to_str(message[3])
check_result = self._check_and_monitor_job(
job_heartbeat_address, ping_heartbeat_address)
job_heartbeat_address, ping_heartbeat_address,
max_memory)
if check_result:
self.lock.acquire()
self.actor_num += 1
......
......@@ -26,6 +26,7 @@ import threading
import time
import traceback
import zmq
from multiprocessing import Process, Pipe
from parl.utils import to_str, to_byte, get_ip_address, logger
from parl.utils.communication import loads_argument, loads_return,\
dumps_argument, dumps_return
......@@ -38,8 +39,8 @@ class Job(object):
"""Base class for the job.
After establishing connection with the remote object, the job will
create a remote class instance locally and enter an infinite loop,
waiting for commands from the remote object.
create a remote class instance locally and enter an infinite loop
in a separate process, waiting for commands from the remote object.
"""
......@@ -52,36 +53,50 @@ class Job(object):
pid (int): Job process ID.
max_memory (float): Maximum memory (MB) can be used by each remote instance.
"""
self.job_is_alive = True
self.max_memory = None
self.job_address_receiver, job_address_sender = Pipe()
self.worker_address = worker_address
self.job_ip = get_ip_address()
self.pid = os.getpid()
self.max_memory = None
self.lock = threading.Lock()
self.run_job_process = Process(
target=self.run, args=(job_address_sender, ))
self.run_job_process.start()
self._create_sockets()
process = psutil.Process(self.pid)
self.init_memory = float(process.memory_info()[0]) / (1024**2)
self.run_job_process.join()
with self.lock:
self.kill_job_socket.send_multipart(
[remote_constants.KILLJOB_TAG,
to_byte(self.job_address)])
try:
_ = self.kill_job_socket.recv_multipart()
except zmq.error.Again as e:
pass
os._exit(1)
def _create_sockets(self):
"""Create three sockets for each job.
"""Create five sockets for each job in main process.
(1) reply_socket(main socket): receives the command(i.e, the function name and args)
from the actual class instance, completes the computation, and returns the result of
the function.
(2) job_socket(functional socket): sends job_address and heartbeat_address to worker.
(3) kill_job_socket: sends a command to the corresponding worker to kill the job.
(1) job_socket(functional socket): sends job_address and heartbeat_address to worker.
(2) ping_heartbeat_socket: replies ping message of client.
(3) worker_heartbeat_socket: replies heartbeat message of worker.
(4) client_heartbeat_socket: replies heartbeat message of client.
(5) kill_job_socket: sends a command to the corresponding worker to kill the job.
"""
# wait for another process to create reply socket
self.job_address = self.job_address_receiver.recv()
self.ctx = zmq.Context()
# create the reply_socket
self.reply_socket = self.ctx.socket(zmq.REP)
job_port = self.reply_socket.bind_to_random_port(addr="tcp://*")
self.reply_socket.linger = 0
self.job_ip = get_ip_address()
self.job_address = "{}:{}".format(self.job_ip, job_port)
# create the job_socket
self.job_socket = self.ctx.socket(zmq.REQ)
self.job_socket.connect("tcp://{}".format(self.worker_address))
......@@ -93,7 +108,6 @@ class Job(object):
target=self._reply_ping, args=(ping_heartbeat_socket, ))
ping_thread.setDaemon(True)
ping_thread.start()
self.ping_heartbeat_address = ping_heartbeat_address
# a thread that reply heartbeat signals from the worker
worker_heartbeat_socket, worker_heartbeat_address = self._create_heartbeat_server(
......@@ -114,8 +128,7 @@ class Job(object):
# sends job information to the worker
initialized_job = InitializedJob(
self.job_address, worker_heartbeat_address,
client_heartbeat_address, self.ping_heartbeat_address, None,
self.pid)
client_heartbeat_address, ping_heartbeat_address, None, self.pid)
self.job_socket.send_multipart(
[remote_constants.NORMAL_TAG,
cloudpickle.dumps(initialized_job)])
......@@ -145,9 +158,12 @@ class Job(object):
"""Create a socket server that reply the ping signal from client.
This signal is used to make sure that the job is still alive.
"""
while self.job_is_alive:
message = socket.recv_multipart()
socket.send_multipart([remote_constants.HEARTBEAT_TAG])
message = socket.recv_multipart()
max_memory = to_str(message[1])
if max_memory != 'None':
self.max_memory = float(max_memory)
socket.send_multipart([remote_constants.HEARTBEAT_TAG])
self.client_thread.start()
socket.close(0)
def _create_heartbeat_server(self, timeout=True):
......@@ -166,8 +182,7 @@ class Job(object):
"""Create a socket that replies heartbeat signals from the client.
If the job losts connection with the client, it will exit too.
"""
self.client_is_alive = True
while self.client_is_alive and self.job_is_alive:
while True:
try:
message = socket.recv_multipart()
stop_job = self._check_used_memory()
......@@ -187,7 +202,7 @@ class Job(object):
logger.warning(
"[Job] Cannot connect to the client. This job will exit and inform the worker."
)
self.client_is_alive = False
break
socket.close(0)
with self.lock:
self.kill_job_socket.send_multipart(
......@@ -204,73 +219,77 @@ class Job(object):
"""create a socket that replies heartbeat signals from the worker.
If the worker has exited, the job will exit automatically.
"""
self.worker_is_alive = True
# a flag to decide when to exit heartbeat loop
while self.worker_is_alive and self.job_is_alive:
while True:
try:
message = socket.recv_multipart()
socket.send_multipart([remote_constants.HEARTBEAT_TAG])
except zmq.error.Again as e:
logger.warning("[Job] Cannot connect to the worker{}. ".format(
self.worker_address) + "Job will quit.")
self.worker_is_alive = False
self.job_is_alive = False
break
socket.close(0)
os._exit(1)
def wait_for_files(self):
def wait_for_files(self, reply_socket, job_address):
"""Wait for python files from remote object.
When a remote object receives the allocated job address, it will send
the python files to the job. Later, the job will save these files to a
temporary directory and add the temporary diretory to Python's working
directory.
Args:
reply_socket (sockert): main socket to accept commands of remote object.
job_address (String): address of reply_socket.
Returns:
A temporary directory containing the python files.
"""
while True:
message = self.reply_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.SEND_FILE_TAG:
pyfiles = pickle.loads(message[1])
envdir = tempfile.mkdtemp()
for file in pyfiles:
code = pyfiles[file]
# create directory (i.e. ./rom_files/)
if '/' in file:
try:
os.makedirs(
os.path.join(envdir,
*file.rsplit('/')[:-1]))
except OSError as e:
pass
file = os.path.join(envdir, file)
with open(file, 'wb') as code_file:
code_file.write(code)
self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
return envdir
else:
logger.error("NotImplementedError:{}, received tag:{}".format(
self.job_address, ))
raise NotImplementedError
message = reply_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.SEND_FILE_TAG:
pyfiles = pickle.loads(message[1])
# save python files to temporary directory
envdir = tempfile.mkdtemp()
for file, code in pyfiles['python_files'].items():
file = os.path.join(envdir, file)
with open(file, 'wb') as code_file:
code_file.write(code)
# save other files to current directory
for file, content in pyfiles['other_files'].items():
# create directory (i.e. ./rom_files/)
if '/' in file:
try:
os.makedirs(os.path.join(*file.rsplit('/')[:-1]))
except OSError as e:
pass
with open(file, 'wb') as f:
f.write(content)
logger.info('[job] reply')
reply_socket.send_multipart([remote_constants.NORMAL_TAG])
return envdir
else:
logger.error("NotImplementedError:{}, received tag:{}".format(
job_address, ))
raise NotImplementedError
def wait_for_connection(self):
def wait_for_connection(self, reply_socket):
"""Wait for connection from the remote object.
The remote object will send its class information and initialization
arguments to the job, these parameters are then used to create a
local instance in the job process.
Args:
reply_socket (sockert): main socket to accept commands of remote object.
Returns:
A local instance of the remote class object.
"""
message = self.reply_socket.recv_multipart()
message = reply_socket.recv_multipart()
tag = message[0]
obj = None
......@@ -278,24 +297,20 @@ class Job(object):
try:
cls = cloudpickle.loads(message[1])
args, kwargs = cloudpickle.loads(message[2])
max_memory = to_str(message[3])
if max_memory != 'None':
self.max_memory = float(max_memory)
obj = cls(*args, **kwargs)
except Exception as e:
traceback_str = str(traceback.format_exc())
error_str = str(e)
logger.error("traceback:\n{}".format(traceback_str))
self.reply_socket.send_multipart([
reply_socket.send_multipart([
remote_constants.EXCEPTION_TAG,
to_byte(error_str + "\ntraceback:\n" + traceback_str)
])
self.client_is_alive = False
return None
self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
reply_socket.send_multipart([remote_constants.NORMAL_TAG])
else:
logger.error("Message from job {}".format(message))
self.reply_socket.send_multipart([
reply_socket.send_multipart([
remote_constants.EXCEPTION_TAG,
b"[job]Unkonwn tag when tried to receive the class definition"
])
......@@ -303,36 +318,39 @@ class Job(object):
return obj
def run(self):
def run(self, job_address_sender):
"""An infinite loop waiting for a new task.
Args:
job_address_sender(sending end of multiprocessing.Pipe): send job address of reply_socket to main process.
"""
# receive source code from the actor and append them to the environment variables.
envdir = self.wait_for_files()
sys.path.append(envdir)
self.client_is_alive = True
self.client_thread.start()
ctx = zmq.Context()
# create the reply_socket
reply_socket = ctx.socket(zmq.REP)
job_port = reply_socket.bind_to_random_port(addr="tcp://*")
reply_socket.linger = 0
job_ip = get_ip_address()
job_address = "{}:{}".format(job_ip, job_port)
job_address_sender.send(job_address)
try:
obj = self.wait_for_connection()
# receive source code from the actor and append them to the environment variables.
envdir = self.wait_for_files(reply_socket, job_address)
sys.path.append(envdir)
obj = self.wait_for_connection(reply_socket)
assert obj is not None
self.single_task(obj)
self.single_task(obj, reply_socket, job_address)
except Exception as e:
logger.error(
"Error occurs when running a single task. We will reset this job. Reason:{}"
.format(e))
traceback_str = str(traceback.format_exc())
logger.error("traceback:\n{}".format(traceback_str))
with self.lock:
self.kill_job_socket.send_multipart(
[remote_constants.KILLJOB_TAG,
to_byte(self.job_address)])
try:
_ = self.kill_job_socket.recv_multipart()
except zmq.error.Again as e:
pass
os._exit(1)
def single_task(self, obj):
def single_task(self, obj, reply_socket, job_address):
"""An infinite loop waiting for commands from the remote object.
Each job will receive two kinds of message from the remote object:
......@@ -342,10 +360,14 @@ class Job(object):
remote object.
2. When the remote object is deleted, the job will quit and release
related computation resources.
Args:
reply_socket (sockert): main socket to accept commands of remote object.
job_address (String): address of reply_socket.
"""
while self.job_is_alive and self.client_is_alive:
message = self.reply_socket.recv_multipart()
while True:
message = reply_socket.recv_multipart()
tag = message[0]
......@@ -357,32 +379,31 @@ class Job(object):
ret = getattr(obj, function_name)(*args, **kwargs)
ret = dumps_return(ret)
self.reply_socket.send_multipart(
reply_socket.send_multipart(
[remote_constants.NORMAL_TAG, ret])
except Exception as e:
# reset the job
self.client_is_alive = False
error_str = str(e)
logger.error(error_str)
if type(e) == AttributeError:
self.reply_socket.send_multipart([
reply_socket.send_multipart([
remote_constants.ATTRIBUTE_EXCEPTION_TAG,
to_byte(error_str)
])
raise AttributeError
elif type(e) == SerializeError:
self.reply_socket.send_multipart([
reply_socket.send_multipart([
remote_constants.SERIALIZE_EXCEPTION_TAG,
to_byte(error_str)
])
raise SerializeError
elif type(e) == DeserializeError:
self.reply_socket.send_multipart([
reply_socket.send_multipart([
remote_constants.DESERIALIZE_EXCEPTION_TAG,
to_byte(error_str)
])
......@@ -391,7 +412,7 @@ class Job(object):
else:
traceback_str = str(traceback.format_exc())
logger.error("traceback:\n{}".format(traceback_str))
self.reply_socket.send_multipart([
reply_socket.send_multipart([
remote_constants.EXCEPTION_TAG,
to_byte(error_str + "\ntraceback:\n" +
traceback_str)
......@@ -400,11 +421,9 @@ class Job(object):
# receive DELETE_TAG from actor, and stop replying worker heartbeat
elif tag == remote_constants.KILLJOB_TAG:
self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
self.client_is_alive = False
logger.warning(
"An actor exits and this job {} will exit.".format(
self.job_address))
reply_socket.send_multipart([remote_constants.NORMAL_TAG])
logger.warning("An actor exits and this job {} will exit.".
format(job_address))
break
else:
logger.error(
......@@ -418,4 +437,3 @@ if __name__ == "__main__":
"--worker_address", required=True, type=str, help="worker_address")
args = parser.parse_args()
job = Job(args.worker_address)
job.run()
......@@ -92,7 +92,8 @@ def remote_class(*args, **kwargs):
# GLOBAL_CLIENT will set `master_is_alive` to False when hearbeat
# finds the master is dead.
if self.GLOBAL_CLIENT.master_is_alive:
job_address = self.request_cpu_resource(self.GLOBAL_CLIENT)
job_address = self.request_cpu_resource(
self.GLOBAL_CLIENT, max_memory)
else:
raise Exception("Can not submit job to the master. "
"Please check if master is still alive.")
......@@ -117,7 +118,6 @@ def remote_class(*args, **kwargs):
remote_constants.INIT_OBJECT_TAG,
cloudpickle.dumps(cls),
cloudpickle.dumps([args, kwargs]),
to_byte(str(max_memory))
])
message = self.job_socket.recv_multipart()
tag = message[0]
......@@ -149,11 +149,11 @@ def remote_class(*args, **kwargs):
except zmq.error.Again as e:
logger.error("Send python files failed.")
def request_cpu_resource(self, global_client):
def request_cpu_resource(self, global_client, max_memory):
"""Try to request cpu resource for 1 second/time for 300 times."""
cnt = 300
while cnt > 0:
job_address = global_client.submit_job()
job_address = global_client.submit_job(max_memory)
if job_address is not None:
return job_address
if cnt % 30 == 0:
......
......@@ -86,8 +86,8 @@ def cli():
@click.option("--port", help="The port to bind to.", type=str, required=True)
@click.option(
"--debug",
help="Start parl in debug mode to show all logs.",
default=False)
help="Start parl in the debugging mode to print all running log.",
is_flag=True)
@click.option(
"--cpu_num",
type=int,
......
......@@ -56,10 +56,11 @@ class Worker(object):
reply_job_socket (zmq.Context.socket): A socket which receives
job_address from the job.
kill_job_socket (zmq.Context.socket): A socket that receives commands to kill the job from jobs.
job_buffer (str): A buffer that stores initialized jobs for providing new jobs in a short time.
Args:
master_address (str): IP address of the master node.
cpu_num (int): Number of cpu to be used on the worker.
job_buffer (str): A buffer that stores initialized jobs for providing new jobs in a short time.
"""
def __init__(self, master_address, cpu_num=None):
......@@ -170,9 +171,13 @@ class Worker(object):
"""An endless loop that adds initialized job into the job buffer"""
while self.worker_is_alive:
if self.job_buffer.full() is False:
initialized_jobs = self._init_jobs(job_num=self.cpu_num)
for job in initialized_jobs:
self.job_buffer.put(job)
job_num = self.cpu_num - self.job_buffer.qsize()
if job_num > 0:
initialized_jobs = self._init_jobs(job_num=job_num)
for job in initialized_jobs:
self.job_buffer.put(job)
time.sleep(0.02)
# release jobs if the worker is not alive
for job in initialized_jobs:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册