提交 7318f63e 编写于 作者: F fuyw 提交者: Bo Zhou

Add `files` argument when creating a client. (#143)

* Add `files` argument when creating a client.

* Add `distributed_files` to parl.connect

* yapf

* Add try exception for `cls=cloudpickle.loads()`

* yapf

* unittest cluster_status_test.py `address is used`.

* Solve port address duplicate problem.

* combine two try except

* check monitor for xparl status

* yapf

* fix bugs

* fix bugs

* More sleep in cluster_test.py.

* yapf

* add debug mode

* add debug model

* Fix reset_job_test.py timeout problem.

* remove time decorator

* Add timeout_decorator
上级 44d06807
......@@ -41,12 +41,16 @@ class Client(object):
"""
def __init__(self, master_address, process_id):
def __init__(self, master_address, process_id, distributed_files=[]):
"""
Args:
master_addr (str): ip address of the master node.
process_id (str): id of the process that created the Client.
Should use os.getpid() to get the process id.
distributed_files (list): A list of files to be distributed at all
remote instances(e,g. the configuration
file for initialization) .
"""
self.master_address = master_address
self.process_id = process_id
......@@ -61,7 +65,7 @@ class Client(object):
self.actor_num = 0
self._create_sockets(master_address)
self.pyfiles = self.read_local_files()
self.pyfiles = self.read_local_files(distributed_files)
def get_executable_path(self):
"""Return current executable path."""
......@@ -73,20 +77,28 @@ class Client(object):
executable_path = executable_path[:executable_path.rfind('/')]
return executable_path
def read_local_files(self):
def read_local_files(self, distributed_files=[]):
"""Read local python code and store them in a dictionary, which will
then be sent to the job.
Args:
distributed_files (list): A list of files to be distributed at all
remote instances(e,g. the configuration
file for initialization) .
Returns:
A cloudpickled dictionary containing the python code in current
working directory.
"""
pyfiles = dict()
for file in os.listdir('./'):
if file.endswith('.py'):
with open(file, 'rb') as code_file:
code = code_file.read()
pyfiles[file] = code
code_files = filter(lambda x: x.endswith('.py'), os.listdir('./'))
to_distributed_files = list(code_files) + distributed_files
for file in to_distributed_files:
with open(file, 'rb') as code_file:
code = code_file.read()
pyfiles[file] = code
return cloudpickle.dumps(pyfiles)
def _create_sockets(self, master_address):
......@@ -274,7 +286,7 @@ class Client(object):
GLOBAL_CLIENT = None
def connect(master_address):
def connect(master_address, distributed_files=[]):
"""Create a global client which connects to the master node.
.. code-block:: python
......@@ -283,6 +295,9 @@ def connect(master_address):
Args:
master_address (str): The address of the Master node to connect to.
distributed_files (list): A list of files to be distributed at all
remote instances(e,g. the configuration
file for initialization) .
Raises:
Exception: An exception is raised if the master node is not started.
......@@ -293,10 +308,12 @@ def connect(master_address):
global GLOBAL_CLIENT
cur_process_id = os.getpid()
if GLOBAL_CLIENT is None:
GLOBAL_CLIENT = Client(master_address, cur_process_id)
GLOBAL_CLIENT = Client(master_address, cur_process_id,
distributed_files)
else:
if GLOBAL_CLIENT.process_id != cur_process_id:
GLOBAL_CLIENT = Client(master_address, cur_process_id)
GLOBAL_CLIENT = Client(master_address, cur_process_id,
distributed_files)
def get_global_client():
......
......@@ -263,14 +263,14 @@ class Job(object):
message = self.reply_socket.recv_multipart()
tag = message[0]
obj = None
if tag == remote_constants.INIT_OBJECT_TAG:
cls = cloudpickle.loads(message[1])
args, kwargs = cloudpickle.loads(message[2])
max_memory = to_str(message[3])
if max_memory != 'None':
self.max_memory = float(max_memory)
if tag == remote_constants.INIT_OBJECT_TAG:
try:
cls = cloudpickle.loads(message[1])
args, kwargs = cloudpickle.loads(message[2])
max_memory = to_str(message[3])
if max_memory != 'None':
self.max_memory = float(max_memory)
obj = cls(*args, **kwargs)
except Exception as e:
traceback_str = str(traceback.format_exc())
......@@ -282,7 +282,6 @@ class Job(object):
])
self.client_is_alive = False
return None
self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
else:
logger.error("Message from job {}".format(message))
......
......@@ -21,6 +21,7 @@ import re
import socket
import subprocess
import sys
import time
import threading
import warnings
import zmq
......@@ -83,6 +84,10 @@ def cli():
@click.command("start", short_help="Start a master node.")
@click.option("--port", help="The port to bind to.", type=str, required=True)
@click.option(
"--debug",
help="Start parl in debug mode to show all logs.",
default=False)
@click.option(
"--cpu_num",
type=int,
......@@ -90,7 +95,10 @@ def cli():
"cpus of this machine.")
@click.option(
"--monitor_port", help="The port to start a cluster monitor.", type=str)
def start_master(port, cpu_num, monitor_port):
def start_master(port, cpu_num, monitor_port, debug):
if debug:
os.environ['DEBUG'] = 'True'
if not is_port_available(port):
raise Exception(
"The master address localhost:{} is already in use.".format(port))
......@@ -103,52 +111,80 @@ def start_master(port, cpu_num, monitor_port):
cpu_num = cpu_num if cpu_num else multiprocessing.cpu_count()
start_file = __file__.replace('scripts.pyc', 'start.py')
start_file = start_file.replace('scripts.py', 'start.py')
command = [sys.executable, start_file, "--name", "master", "--port", port]
monitor_port = monitor_port if monitor_port else get_free_tcp_port()
p = subprocess.Popen(command)
command = [
master_command = [
sys.executable, start_file, "--name", "master", "--port", port
]
worker_command = [
sys.executable, start_file, "--name", "worker", "--address",
"localhost:" + str(port), "--cpu_num",
str(cpu_num)
]
# Redirect the output to DEVNULL to solve the warning log.
FNULL = open(os.devnull, 'w')
p = subprocess.Popen(command, stdout=FNULL, stderr=subprocess.STDOUT)
monitor_port = monitor_port if monitor_port else get_free_tcp_port()
command = [
monitor_command = [
sys.executable, '{}/monitor.py'.format(__file__[:__file__.rfind('/')]),
"--monitor_port",
str(monitor_port), "--address", "localhost:" + str(port)
]
p = subprocess.Popen(command, stdout=FNULL, stderr=subprocess.STDOUT)
FNULL = open(os.devnull, 'w')
# Redirect the output to DEVNULL to solve the warning log.
_ = subprocess.Popen(
master_command, stdout=FNULL, stderr=subprocess.STDOUT)
_ = subprocess.Popen(
worker_command, stdout=FNULL, stderr=subprocess.STDOUT)
_ = subprocess.Popen(
monitor_command, stdout=FNULL, stderr=subprocess.STDOUT)
FNULL.close()
master_ip = get_ip_address()
cluster_info = """
monitor_info = """
# The Parl cluster is started at localhost:{}.
# A local worker with {} CPUs is connected to the cluster.
# A local worker with {} CPUs is connected to the cluster.
# Starting the cluster monitor...""".format(
port,
cpu_num,
)
click.echo(monitor_info)
# check if monitor is started
cmd = r'ps -ef | grep remote/monitor.py\ --monitor_port\ {}\ --address\ localhost:{}'.format(
monitor_port, port)
monitor_is_started = False
for i in range(3):
check_monitor_is_started = os.popen(cmd).read().strip().split('\n')
if len(check_monitor_is_started) == 2:
monitor_is_started = True
break
time.sleep(3)
master_ip = get_ip_address()
if monitor_is_started:
start_info = """
## If you want to check cluster status, please view:
http://{}:{}
or call:
xparl status
xparl status""".format(master_ip, monitor_port)
else:
start_info = "# Fail to start the cluster monitor."
monitor_info = """
{}
## If you want to add more CPU resources, please call:
xparl connect --address {}:{}
## If you want to shutdown the cluster, please call:
xparl stop
""".format(port, cpu_num, master_ip, monitor_port, master_ip, port)
click.echo(cluster_info)
xparl stop
""".format(start_info, master_ip, port)
click.echo(monitor_info)
@click.command("connect", short_help="Start a worker node.")
......@@ -185,29 +221,50 @@ def stop():
@click.command("status")
def status():
cmd = r'ps -ef | grep remote/monitor.py\ --monitor_port'
content = os.popen(cmd).read()
pattern = re.compile('--monitor_port (.*?)\n', re.S)
monitors = pattern.findall(content)
if len(monitors) == 0:
cmd = r'ps -ef | grep remote/start.py\ --name\ worker\ --address'
content = os.popen(cmd).read().strip()
pattern = re.compile('--address (.*?) --cpu')
clusters = set(pattern.findall(content))
if len(clusters) == 0:
click.echo('No active cluster is found.')
else:
ctx = zmq.Context()
status = []
for monitor in monitors:
monitor_port, _, master_address = monitor.split(' ')
monitor_address = "{}:{}".format(get_ip_address(), monitor_port)
socket = ctx.socket(zmq.REQ)
socket.connect('tcp://{}'.format(master_address))
socket.send_multipart([STATUS_TAG])
cluster_info = to_str(socket.recv_multipart()[1])
msg = """
for cluster in clusters:
cmd = r'ps -ef | grep address\ {}'.format(cluster)
content = os.popen(cmd).read()
pattern = re.compile('--monitor_port (.*?)\n', re.S)
monitors = pattern.findall(content)
if len(monitors):
monitor_port, _, master_address = monitors[0].split(' ')
monitor_address = "{}:{}".format(get_ip_address(),
monitor_port)
socket = ctx.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 10000)
socket.connect('tcp://{}'.format(master_address))
try:
socket.send_multipart([STATUS_TAG])
monitor_info = to_str(socket.recv_multipart()[1])
except zmq.error.Again as e:
click.echo(
'Can not connect to cluster {}, please try later.'.
format(master_address))
socket.close(0)
continue
msg = """
# Cluster {} {}
# If you want to check cluster status, please view: http://{}
""".format(master_address, cluster_info, monitor_address)
status.append(msg)
socket.close(0)
""".format(master_address, monitor_info, monitor_address)
status.append(msg)
socket.close(0)
else:
msg = """
# Cluster {} fails to start the cluster monitor.
""".format(cluster)
status.append(msg)
for monitor_status in status:
click.echo(monitor_status)
......
......@@ -66,26 +66,35 @@ class TestCluster(unittest.TestCase):
master = Master(port=1235)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
time.sleep(3)
worker1 = Worker('localhost:1235', 1)
for _ in range(3):
if master.cpu_num == 1:
break
time.sleep(10)
self.assertEqual(1, master.cpu_num)
parl.connect('localhost:1235')
with self.assertRaises(exceptions.RemoteError):
actor = Actor(abcd='a bug')
actor2 = Actor()
for _ in range(3):
if master.cpu_num == 0:
break
time.sleep(10)
self.assertEqual(actor2.add_one(1), 2)
self.assertEqual(0, master.cpu_num)
master.exit()
worker1.exit()
@timeout_decorator.timeout(seconds=300)
@timeout_decorator.timeout(seconds=500)
def test_actor_exception(self):
master = Master(port=1236)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
time.sleep(3)
worker1 = Worker('localhost:1236', 1)
self.assertEqual(1, master.cpu_num)
parl.connect('localhost:1236')
......@@ -95,7 +104,10 @@ class TestCluster(unittest.TestCase):
except:
pass
actor2 = Actor()
time.sleep(30)
for _ in range(5):
if master.cpu_num == 0:
break
time.sleep(10)
self.assertEqual(actor2.add_one(1), 2)
self.assertEqual(0, master.cpu_num)
del actor
......@@ -108,16 +120,21 @@ class TestCluster(unittest.TestCase):
master = Master(port=1237)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
time.sleep(3)
worker1 = Worker('localhost:1237', 4)
parl.connect('localhost:1237')
for i in range(10):
for _ in range(10):
actor = Actor()
ret = actor.add_one(1)
self.assertEqual(ret, 2)
del actor
time.sleep(20)
for _ in range(10):
if master.cpu_num == 4:
break
time.sleep(10)
self.assertEqual(master.cpu_num, 4)
worker1.exit()
master.exit()
......@@ -127,13 +144,27 @@ class TestCluster(unittest.TestCase):
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:1234', 4)
for _ in range(3):
if master.cpu_num == 4:
break
time.sleep(10)
self.assertEqual(master.cpu_num, 4)
worker2 = Worker('localhost:1234', 4)
for _ in range(3):
if master.cpu_num == 8:
break
time.sleep(10)
self.assertEqual(master.cpu_num, 8)
worker2.exit()
time.sleep(50)
for _ in range(10):
if master.cpu_num == 4:
break
time.sleep(10)
self.assertEqual(master.cpu_num, 4)
master.exit()
......
......@@ -20,9 +20,9 @@ from parl.utils import logger
import subprocess
import time
import threading
import timeout_decorator
import subprocess
import sys
import timeout_decorator
@parl.remote_class
......@@ -62,22 +62,27 @@ class TestJob(unittest.TestCase):
def tearDown(self):
disconnect()
@timeout_decorator.timeout(seconds=300)
@timeout_decorator.timeout(seconds=600)
def test_acor_exit_exceptionally(self):
master = Master(port=1335)
port = 1337
master = Master(port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:1335', 1)
worker1 = Worker('localhost:{}'.format(port), 1)
file_path = __file__.replace('reset_job_test', 'simulate_client')
command = [sys.executable, file_path]
proc = subprocess.Popen(command)
time.sleep(20)
for _ in range(6):
if master.cpu_num == 0:
break
else:
time.sleep(10)
self.assertEqual(master.cpu_num, 0)
proc.kill()
parl.connect('localhost:1335')
parl.connect('localhost:{}'.format(port))
actor = Actor()
master.exit()
worker1.exit()
......
......@@ -23,7 +23,7 @@ class Actor(object):
def train():
parl.connect('localhost:1335')
parl.connect('localhost:1337')
actor = Actor()
actor.add_one(1)
time.sleep(100000)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.client import disconnect
import time
import threading
import sys
import numpy as np
import json
@parl.remote_class
class Actor(object):
def __init__(self, random_array, config_file):
self.random_array = random_array
self.config_file = config_file
def random_sum(self):
return np.load(self.random_array).sum()
def read_config(self):
with open(self.config_file, 'r') as f:
config_file = json.load(f)
return config_file['test']
class TestConfigfile(unittest.TestCase):
def tearDown(self):
disconnect()
def test_sync_config_file(self):
master = Master(port=1335)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker = Worker('localhost:1335', 1)
random_file = 'random.npy'
random_array = np.random.randn(3, 5)
np.save(random_file, random_array)
random_sum = random_array.sum()
with open('config.json', 'w') as f:
config_file = {'test': 1000}
json.dump(config_file, f)
parl.connect('localhost:1335', ['random.npy', 'config.json'])
actor = Actor('random.npy', 'config.json')
time.sleep(5)
remote_sum = actor.random_sum()
self.assertEqual(remote_sum, random_sum)
time.sleep(10)
remote_config = actor.read_config()
self.assertEqual(config_file['test'], remote_config)
del actor
worker.exit()
master.exit()
if __name__ == '__main__':
unittest.main()
......@@ -71,7 +71,6 @@ class Worker(object):
self.worker_is_alive = True
self.worker_status = None # initialized at `self._create_jobs`
self.lock = threading.Lock()
self._set_cpu_num(cpu_num)
self.job_buffer = queue.Queue(maxsize=self.cpu_num)
self._create_sockets()
......
......@@ -81,6 +81,13 @@ def _getlogger():
logger = logging.getLogger('PARL')
logger.propagate = False
logger.setLevel(logging.DEBUG)
if 'DEBUG' in os.environ:
handler = logging.FileHandler('parl_debug.log')
handler.setFormatter(_Formatter(datefmt='%m-%d %H:%M:%S'))
logger.addHandler(handler)
return logger
if 'XPARL' not in os.environ:
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(_Formatter(datefmt='%m-%d %H:%M:%S'))
......
......@@ -72,8 +72,8 @@ setup(
"cloudpickle==1.2.1",
"tensorboardX==1.8",
"tb-nightly==1.15.0a20190801",
"flask==1.0.4",
"click",
"flask",
"psutil",
],
classifiers=[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册