提交 7318f63e 编写于 作者: F fuyw 提交者: Bo Zhou

Add `files` argument when creating a client. (#143)

* Add `files` argument when creating a client.

* Add `distributed_files` to parl.connect

* yapf

* Add try exception for `cls=cloudpickle.loads()`

* yapf

* unittest cluster_status_test.py `address is used`.

* Solve port address duplicate problem.

* combine two try except

* check monitor for xparl status

* yapf

* fix bugs

* fix bugs

* More sleep in cluster_test.py.

* yapf

* add debug mode

* add debug model

* Fix reset_job_test.py timeout problem.

* remove time decorator

* Add timeout_decorator
上级 44d06807
...@@ -41,12 +41,16 @@ class Client(object): ...@@ -41,12 +41,16 @@ class Client(object):
""" """
def __init__(self, master_address, process_id): def __init__(self, master_address, process_id, distributed_files=[]):
""" """
Args: Args:
master_addr (str): ip address of the master node. master_addr (str): ip address of the master node.
process_id (str): id of the process that created the Client. process_id (str): id of the process that created the Client.
Should use os.getpid() to get the process id. Should use os.getpid() to get the process id.
distributed_files (list): A list of files to be distributed at all
remote instances(e,g. the configuration
file for initialization) .
""" """
self.master_address = master_address self.master_address = master_address
self.process_id = process_id self.process_id = process_id
...@@ -61,7 +65,7 @@ class Client(object): ...@@ -61,7 +65,7 @@ class Client(object):
self.actor_num = 0 self.actor_num = 0
self._create_sockets(master_address) self._create_sockets(master_address)
self.pyfiles = self.read_local_files() self.pyfiles = self.read_local_files(distributed_files)
def get_executable_path(self): def get_executable_path(self):
"""Return current executable path.""" """Return current executable path."""
...@@ -73,20 +77,28 @@ class Client(object): ...@@ -73,20 +77,28 @@ class Client(object):
executable_path = executable_path[:executable_path.rfind('/')] executable_path = executable_path[:executable_path.rfind('/')]
return executable_path return executable_path
def read_local_files(self): def read_local_files(self, distributed_files=[]):
"""Read local python code and store them in a dictionary, which will """Read local python code and store them in a dictionary, which will
then be sent to the job. then be sent to the job.
Args:
distributed_files (list): A list of files to be distributed at all
remote instances(e,g. the configuration
file for initialization) .
Returns: Returns:
A cloudpickled dictionary containing the python code in current A cloudpickled dictionary containing the python code in current
working directory. working directory.
""" """
pyfiles = dict() pyfiles = dict()
for file in os.listdir('./'):
if file.endswith('.py'): code_files = filter(lambda x: x.endswith('.py'), os.listdir('./'))
with open(file, 'rb') as code_file: to_distributed_files = list(code_files) + distributed_files
code = code_file.read()
pyfiles[file] = code for file in to_distributed_files:
with open(file, 'rb') as code_file:
code = code_file.read()
pyfiles[file] = code
return cloudpickle.dumps(pyfiles) return cloudpickle.dumps(pyfiles)
def _create_sockets(self, master_address): def _create_sockets(self, master_address):
...@@ -274,7 +286,7 @@ class Client(object): ...@@ -274,7 +286,7 @@ class Client(object):
GLOBAL_CLIENT = None GLOBAL_CLIENT = None
def connect(master_address): def connect(master_address, distributed_files=[]):
"""Create a global client which connects to the master node. """Create a global client which connects to the master node.
.. code-block:: python .. code-block:: python
...@@ -283,6 +295,9 @@ def connect(master_address): ...@@ -283,6 +295,9 @@ def connect(master_address):
Args: Args:
master_address (str): The address of the Master node to connect to. master_address (str): The address of the Master node to connect to.
distributed_files (list): A list of files to be distributed at all
remote instances(e,g. the configuration
file for initialization) .
Raises: Raises:
Exception: An exception is raised if the master node is not started. Exception: An exception is raised if the master node is not started.
...@@ -293,10 +308,12 @@ def connect(master_address): ...@@ -293,10 +308,12 @@ def connect(master_address):
global GLOBAL_CLIENT global GLOBAL_CLIENT
cur_process_id = os.getpid() cur_process_id = os.getpid()
if GLOBAL_CLIENT is None: if GLOBAL_CLIENT is None:
GLOBAL_CLIENT = Client(master_address, cur_process_id) GLOBAL_CLIENT = Client(master_address, cur_process_id,
distributed_files)
else: else:
if GLOBAL_CLIENT.process_id != cur_process_id: if GLOBAL_CLIENT.process_id != cur_process_id:
GLOBAL_CLIENT = Client(master_address, cur_process_id) GLOBAL_CLIENT = Client(master_address, cur_process_id,
distributed_files)
def get_global_client(): def get_global_client():
......
...@@ -263,14 +263,14 @@ class Job(object): ...@@ -263,14 +263,14 @@ class Job(object):
message = self.reply_socket.recv_multipart() message = self.reply_socket.recv_multipart()
tag = message[0] tag = message[0]
obj = None obj = None
if tag == remote_constants.INIT_OBJECT_TAG:
cls = cloudpickle.loads(message[1])
args, kwargs = cloudpickle.loads(message[2])
max_memory = to_str(message[3])
if max_memory != 'None':
self.max_memory = float(max_memory)
if tag == remote_constants.INIT_OBJECT_TAG:
try: try:
cls = cloudpickle.loads(message[1])
args, kwargs = cloudpickle.loads(message[2])
max_memory = to_str(message[3])
if max_memory != 'None':
self.max_memory = float(max_memory)
obj = cls(*args, **kwargs) obj = cls(*args, **kwargs)
except Exception as e: except Exception as e:
traceback_str = str(traceback.format_exc()) traceback_str = str(traceback.format_exc())
...@@ -282,7 +282,6 @@ class Job(object): ...@@ -282,7 +282,6 @@ class Job(object):
]) ])
self.client_is_alive = False self.client_is_alive = False
return None return None
self.reply_socket.send_multipart([remote_constants.NORMAL_TAG]) self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
else: else:
logger.error("Message from job {}".format(message)) logger.error("Message from job {}".format(message))
......
...@@ -21,6 +21,7 @@ import re ...@@ -21,6 +21,7 @@ import re
import socket import socket
import subprocess import subprocess
import sys import sys
import time
import threading import threading
import warnings import warnings
import zmq import zmq
...@@ -83,6 +84,10 @@ def cli(): ...@@ -83,6 +84,10 @@ def cli():
@click.command("start", short_help="Start a master node.") @click.command("start", short_help="Start a master node.")
@click.option("--port", help="The port to bind to.", type=str, required=True) @click.option("--port", help="The port to bind to.", type=str, required=True)
@click.option(
"--debug",
help="Start parl in debug mode to show all logs.",
default=False)
@click.option( @click.option(
"--cpu_num", "--cpu_num",
type=int, type=int,
...@@ -90,7 +95,10 @@ def cli(): ...@@ -90,7 +95,10 @@ def cli():
"cpus of this machine.") "cpus of this machine.")
@click.option( @click.option(
"--monitor_port", help="The port to start a cluster monitor.", type=str) "--monitor_port", help="The port to start a cluster monitor.", type=str)
def start_master(port, cpu_num, monitor_port): def start_master(port, cpu_num, monitor_port, debug):
if debug:
os.environ['DEBUG'] = 'True'
if not is_port_available(port): if not is_port_available(port):
raise Exception( raise Exception(
"The master address localhost:{} is already in use.".format(port)) "The master address localhost:{} is already in use.".format(port))
...@@ -103,52 +111,80 @@ def start_master(port, cpu_num, monitor_port): ...@@ -103,52 +111,80 @@ def start_master(port, cpu_num, monitor_port):
cpu_num = cpu_num if cpu_num else multiprocessing.cpu_count() cpu_num = cpu_num if cpu_num else multiprocessing.cpu_count()
start_file = __file__.replace('scripts.pyc', 'start.py') start_file = __file__.replace('scripts.pyc', 'start.py')
start_file = start_file.replace('scripts.py', 'start.py') start_file = start_file.replace('scripts.py', 'start.py')
command = [sys.executable, start_file, "--name", "master", "--port", port] monitor_port = monitor_port if monitor_port else get_free_tcp_port()
p = subprocess.Popen(command) master_command = [
command = [ sys.executable, start_file, "--name", "master", "--port", port
]
worker_command = [
sys.executable, start_file, "--name", "worker", "--address", sys.executable, start_file, "--name", "worker", "--address",
"localhost:" + str(port), "--cpu_num", "localhost:" + str(port), "--cpu_num",
str(cpu_num) str(cpu_num)
] ]
# Redirect the output to DEVNULL to solve the warning log. monitor_command = [
FNULL = open(os.devnull, 'w')
p = subprocess.Popen(command, stdout=FNULL, stderr=subprocess.STDOUT)
monitor_port = monitor_port if monitor_port else get_free_tcp_port()
command = [
sys.executable, '{}/monitor.py'.format(__file__[:__file__.rfind('/')]), sys.executable, '{}/monitor.py'.format(__file__[:__file__.rfind('/')]),
"--monitor_port", "--monitor_port",
str(monitor_port), "--address", "localhost:" + str(port) str(monitor_port), "--address", "localhost:" + str(port)
] ]
p = subprocess.Popen(command, stdout=FNULL, stderr=subprocess.STDOUT)
FNULL = open(os.devnull, 'w')
# Redirect the output to DEVNULL to solve the warning log.
_ = subprocess.Popen(
master_command, stdout=FNULL, stderr=subprocess.STDOUT)
_ = subprocess.Popen(
worker_command, stdout=FNULL, stderr=subprocess.STDOUT)
_ = subprocess.Popen(
monitor_command, stdout=FNULL, stderr=subprocess.STDOUT)
FNULL.close() FNULL.close()
master_ip = get_ip_address() monitor_info = """
cluster_info = """
# The Parl cluster is started at localhost:{}. # The Parl cluster is started at localhost:{}.
# A local worker with {} CPUs is connected to the cluster. # A local worker with {} CPUs is connected to the cluster.
# Starting the cluster monitor...""".format(
port,
cpu_num,
)
click.echo(monitor_info)
# check if monitor is started
cmd = r'ps -ef | grep remote/monitor.py\ --monitor_port\ {}\ --address\ localhost:{}'.format(
monitor_port, port)
monitor_is_started = False
for i in range(3):
check_monitor_is_started = os.popen(cmd).read().strip().split('\n')
if len(check_monitor_is_started) == 2:
monitor_is_started = True
break
time.sleep(3)
master_ip = get_ip_address()
if monitor_is_started:
start_info = """
## If you want to check cluster status, please view: ## If you want to check cluster status, please view:
http://{}:{} http://{}:{}
or call: or call:
xparl status xparl status""".format(master_ip, monitor_port)
else:
start_info = "# Fail to start the cluster monitor."
monitor_info = """
{}
## If you want to add more CPU resources, please call: ## If you want to add more CPU resources, please call:
xparl connect --address {}:{} xparl connect --address {}:{}
## If you want to shutdown the cluster, please call: ## If you want to shutdown the cluster, please call:
xparl stop
""".format(port, cpu_num, master_ip, monitor_port, master_ip, port)
click.echo(cluster_info) xparl stop
""".format(start_info, master_ip, port)
click.echo(monitor_info)
@click.command("connect", short_help="Start a worker node.") @click.command("connect", short_help="Start a worker node.")
...@@ -185,29 +221,50 @@ def stop(): ...@@ -185,29 +221,50 @@ def stop():
@click.command("status") @click.command("status")
def status(): def status():
cmd = r'ps -ef | grep remote/monitor.py\ --monitor_port' cmd = r'ps -ef | grep remote/start.py\ --name\ worker\ --address'
content = os.popen(cmd).read() content = os.popen(cmd).read().strip()
pattern = re.compile('--monitor_port (.*?)\n', re.S) pattern = re.compile('--address (.*?) --cpu')
monitors = pattern.findall(content) clusters = set(pattern.findall(content))
if len(monitors) == 0: if len(clusters) == 0:
click.echo('No active cluster is found.') click.echo('No active cluster is found.')
else: else:
ctx = zmq.Context() ctx = zmq.Context()
status = [] status = []
for monitor in monitors: for cluster in clusters:
monitor_port, _, master_address = monitor.split(' ') cmd = r'ps -ef | grep address\ {}'.format(cluster)
monitor_address = "{}:{}".format(get_ip_address(), monitor_port) content = os.popen(cmd).read()
socket = ctx.socket(zmq.REQ) pattern = re.compile('--monitor_port (.*?)\n', re.S)
socket.connect('tcp://{}'.format(master_address)) monitors = pattern.findall(content)
socket.send_multipart([STATUS_TAG])
cluster_info = to_str(socket.recv_multipart()[1]) if len(monitors):
msg = """ monitor_port, _, master_address = monitors[0].split(' ')
monitor_address = "{}:{}".format(get_ip_address(),
monitor_port)
socket = ctx.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 10000)
socket.connect('tcp://{}'.format(master_address))
try:
socket.send_multipart([STATUS_TAG])
monitor_info = to_str(socket.recv_multipart()[1])
except zmq.error.Again as e:
click.echo(
'Can not connect to cluster {}, please try later.'.
format(master_address))
socket.close(0)
continue
msg = """
# Cluster {} {} # Cluster {} {}
# If you want to check cluster status, please view: http://{} # If you want to check cluster status, please view: http://{}
""".format(master_address, cluster_info, monitor_address) """.format(master_address, monitor_info, monitor_address)
status.append(msg) status.append(msg)
socket.close(0) socket.close(0)
else:
msg = """
# Cluster {} fails to start the cluster monitor.
""".format(cluster)
status.append(msg)
for monitor_status in status: for monitor_status in status:
click.echo(monitor_status) click.echo(monitor_status)
......
...@@ -66,26 +66,35 @@ class TestCluster(unittest.TestCase): ...@@ -66,26 +66,35 @@ class TestCluster(unittest.TestCase):
master = Master(port=1235) master = Master(port=1235)
th = threading.Thread(target=master.run) th = threading.Thread(target=master.run)
th.start() th.start()
time.sleep(1) time.sleep(3)
worker1 = Worker('localhost:1235', 1) worker1 = Worker('localhost:1235', 1)
for _ in range(3):
if master.cpu_num == 1:
break
time.sleep(10)
self.assertEqual(1, master.cpu_num) self.assertEqual(1, master.cpu_num)
parl.connect('localhost:1235') parl.connect('localhost:1235')
with self.assertRaises(exceptions.RemoteError): with self.assertRaises(exceptions.RemoteError):
actor = Actor(abcd='a bug') actor = Actor(abcd='a bug')
actor2 = Actor() actor2 = Actor()
for _ in range(3):
if master.cpu_num == 0:
break
time.sleep(10)
self.assertEqual(actor2.add_one(1), 2) self.assertEqual(actor2.add_one(1), 2)
self.assertEqual(0, master.cpu_num) self.assertEqual(0, master.cpu_num)
master.exit() master.exit()
worker1.exit() worker1.exit()
@timeout_decorator.timeout(seconds=300) @timeout_decorator.timeout(seconds=500)
def test_actor_exception(self): def test_actor_exception(self):
master = Master(port=1236) master = Master(port=1236)
th = threading.Thread(target=master.run) th = threading.Thread(target=master.run)
th.start() th.start()
time.sleep(1) time.sleep(3)
worker1 = Worker('localhost:1236', 1) worker1 = Worker('localhost:1236', 1)
self.assertEqual(1, master.cpu_num) self.assertEqual(1, master.cpu_num)
parl.connect('localhost:1236') parl.connect('localhost:1236')
...@@ -95,7 +104,10 @@ class TestCluster(unittest.TestCase): ...@@ -95,7 +104,10 @@ class TestCluster(unittest.TestCase):
except: except:
pass pass
actor2 = Actor() actor2 = Actor()
time.sleep(30) for _ in range(5):
if master.cpu_num == 0:
break
time.sleep(10)
self.assertEqual(actor2.add_one(1), 2) self.assertEqual(actor2.add_one(1), 2)
self.assertEqual(0, master.cpu_num) self.assertEqual(0, master.cpu_num)
del actor del actor
...@@ -108,16 +120,21 @@ class TestCluster(unittest.TestCase): ...@@ -108,16 +120,21 @@ class TestCluster(unittest.TestCase):
master = Master(port=1237) master = Master(port=1237)
th = threading.Thread(target=master.run) th = threading.Thread(target=master.run)
th.start() th.start()
time.sleep(1) time.sleep(3)
worker1 = Worker('localhost:1237', 4) worker1 = Worker('localhost:1237', 4)
parl.connect('localhost:1237') parl.connect('localhost:1237')
for i in range(10): for _ in range(10):
actor = Actor() actor = Actor()
ret = actor.add_one(1) ret = actor.add_one(1)
self.assertEqual(ret, 2) self.assertEqual(ret, 2)
del actor del actor
time.sleep(20)
for _ in range(10):
if master.cpu_num == 4:
break
time.sleep(10)
self.assertEqual(master.cpu_num, 4) self.assertEqual(master.cpu_num, 4)
worker1.exit() worker1.exit()
master.exit() master.exit()
...@@ -127,13 +144,27 @@ class TestCluster(unittest.TestCase): ...@@ -127,13 +144,27 @@ class TestCluster(unittest.TestCase):
th = threading.Thread(target=master.run) th = threading.Thread(target=master.run)
th.start() th.start()
time.sleep(1) time.sleep(1)
worker1 = Worker('localhost:1234', 4) worker1 = Worker('localhost:1234', 4)
for _ in range(3):
if master.cpu_num == 4:
break
time.sleep(10)
self.assertEqual(master.cpu_num, 4) self.assertEqual(master.cpu_num, 4)
worker2 = Worker('localhost:1234', 4) worker2 = Worker('localhost:1234', 4)
for _ in range(3):
if master.cpu_num == 8:
break
time.sleep(10)
self.assertEqual(master.cpu_num, 8) self.assertEqual(master.cpu_num, 8)
worker2.exit() worker2.exit()
time.sleep(50)
for _ in range(10):
if master.cpu_num == 4:
break
time.sleep(10)
self.assertEqual(master.cpu_num, 4) self.assertEqual(master.cpu_num, 4)
master.exit() master.exit()
......
...@@ -20,9 +20,9 @@ from parl.utils import logger ...@@ -20,9 +20,9 @@ from parl.utils import logger
import subprocess import subprocess
import time import time
import threading import threading
import timeout_decorator
import subprocess import subprocess
import sys import sys
import timeout_decorator
@parl.remote_class @parl.remote_class
...@@ -62,22 +62,27 @@ class TestJob(unittest.TestCase): ...@@ -62,22 +62,27 @@ class TestJob(unittest.TestCase):
def tearDown(self): def tearDown(self):
disconnect() disconnect()
@timeout_decorator.timeout(seconds=300) @timeout_decorator.timeout(seconds=600)
def test_acor_exit_exceptionally(self): def test_acor_exit_exceptionally(self):
master = Master(port=1335) port = 1337
master = Master(port)
th = threading.Thread(target=master.run) th = threading.Thread(target=master.run)
th.start() th.start()
time.sleep(1) time.sleep(1)
worker1 = Worker('localhost:1335', 1) worker1 = Worker('localhost:{}'.format(port), 1)
file_path = __file__.replace('reset_job_test', 'simulate_client') file_path = __file__.replace('reset_job_test', 'simulate_client')
command = [sys.executable, file_path] command = [sys.executable, file_path]
proc = subprocess.Popen(command) proc = subprocess.Popen(command)
time.sleep(20) for _ in range(6):
if master.cpu_num == 0:
break
else:
time.sleep(10)
self.assertEqual(master.cpu_num, 0) self.assertEqual(master.cpu_num, 0)
proc.kill() proc.kill()
parl.connect('localhost:1335') parl.connect('localhost:{}'.format(port))
actor = Actor() actor = Actor()
master.exit() master.exit()
worker1.exit() worker1.exit()
......
...@@ -23,7 +23,7 @@ class Actor(object): ...@@ -23,7 +23,7 @@ class Actor(object):
def train(): def train():
parl.connect('localhost:1335') parl.connect('localhost:1337')
actor = Actor() actor = Actor()
actor.add_one(1) actor.add_one(1)
time.sleep(100000) time.sleep(100000)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.client import disconnect
import time
import threading
import sys
import numpy as np
import json
@parl.remote_class
class Actor(object):
def __init__(self, random_array, config_file):
self.random_array = random_array
self.config_file = config_file
def random_sum(self):
return np.load(self.random_array).sum()
def read_config(self):
with open(self.config_file, 'r') as f:
config_file = json.load(f)
return config_file['test']
class TestConfigfile(unittest.TestCase):
def tearDown(self):
disconnect()
def test_sync_config_file(self):
master = Master(port=1335)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker = Worker('localhost:1335', 1)
random_file = 'random.npy'
random_array = np.random.randn(3, 5)
np.save(random_file, random_array)
random_sum = random_array.sum()
with open('config.json', 'w') as f:
config_file = {'test': 1000}
json.dump(config_file, f)
parl.connect('localhost:1335', ['random.npy', 'config.json'])
actor = Actor('random.npy', 'config.json')
time.sleep(5)
remote_sum = actor.random_sum()
self.assertEqual(remote_sum, random_sum)
time.sleep(10)
remote_config = actor.read_config()
self.assertEqual(config_file['test'], remote_config)
del actor
worker.exit()
master.exit()
if __name__ == '__main__':
unittest.main()
...@@ -71,7 +71,6 @@ class Worker(object): ...@@ -71,7 +71,6 @@ class Worker(object):
self.worker_is_alive = True self.worker_is_alive = True
self.worker_status = None # initialized at `self._create_jobs` self.worker_status = None # initialized at `self._create_jobs`
self.lock = threading.Lock() self.lock = threading.Lock()
self._set_cpu_num(cpu_num) self._set_cpu_num(cpu_num)
self.job_buffer = queue.Queue(maxsize=self.cpu_num) self.job_buffer = queue.Queue(maxsize=self.cpu_num)
self._create_sockets() self._create_sockets()
......
...@@ -81,6 +81,13 @@ def _getlogger(): ...@@ -81,6 +81,13 @@ def _getlogger():
logger = logging.getLogger('PARL') logger = logging.getLogger('PARL')
logger.propagate = False logger.propagate = False
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
if 'DEBUG' in os.environ:
handler = logging.FileHandler('parl_debug.log')
handler.setFormatter(_Formatter(datefmt='%m-%d %H:%M:%S'))
logger.addHandler(handler)
return logger
if 'XPARL' not in os.environ: if 'XPARL' not in os.environ:
handler = logging.StreamHandler(sys.stdout) handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(_Formatter(datefmt='%m-%d %H:%M:%S')) handler.setFormatter(_Formatter(datefmt='%m-%d %H:%M:%S'))
......
...@@ -72,8 +72,8 @@ setup( ...@@ -72,8 +72,8 @@ setup(
"cloudpickle==1.2.1", "cloudpickle==1.2.1",
"tensorboardX==1.8", "tensorboardX==1.8",
"tb-nightly==1.15.0a20190801", "tb-nightly==1.15.0a20190801",
"flask==1.0.4",
"click", "click",
"flask",
"psutil", "psutil",
], ],
classifiers=[ classifiers=[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册