未验证 提交 2cb2c8bf 编写于 作者: H Hongsheng Zeng 提交者: GitHub

support windows system (#215)

* first version; unittest in parl/remote is passed

* add windows local unittest shell; fix some incompatible problem

* fix save api

* refine comments and remove log of xparl stop

* fix bug

* Update scripts.py

* Update utils.py

* Update agent_base_test.py
上级 117b1c38
...@@ -3,4 +3,3 @@ paddlepaddle-gpu==1.6.1.post97 ...@@ -3,4 +3,3 @@ paddlepaddle-gpu==1.6.1.post97
gym gym
details details
parameterized parameterized
timeout_decorator
...@@ -2,4 +2,3 @@ ...@@ -2,4 +2,3 @@
gym gym
details details
parameterized parameterized
timeout_decorator
#!/usr/bin/env bash
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: You need install mingw-cmake.
function init() {
RED='\033[0;31m'
BLUE='\033[0;34m'
BOLD='\033[1m'
NONE='\033[0m'
REPO_ROOT=`pwd`
}
function abort(){
echo "Your change doesn't follow PaddlePaddle's code style." 1>&2
echo "Please use pre-commit to check what is wrong." 1>&2
exit 1
}
function run_test_with_cpu() {
export CUDA_VISIBLE_DEVICES="-1"
mkdir -p ${REPO_ROOT}/build
cd ${REPO_ROOT}/build
if [ $# -eq 1 ];then
cmake -G "MinGW Makefiles" ..
else
cmake -G "MinGW Makefiles" .. -$2=ON
fi
cat <<EOF
=====================================================
Running unit tests with CPU in the environment: $1
=====================================================
EOF
if [ $# -eq 1 ];then
ctest --output-on-failure -j10
else
ctest --output-on-failure
fi
cd ${REPO_ROOT}
rm -rf ${REPO_ROOT}/build
}
function main() {
set -e
local CMD=$1
init
env="unused_variable"
# run unittest in windows (used in local machine)
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple .
pip uninstall -y torch torchvision
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple paddlepaddle==1.6.1 gym details parameterized
run_test_with_cpu $env
run_test_with_cpu $env "DIS_TESTING_SERIALLY"
pip uninstall -y paddlepaddle
pip install torch==1.4.0+cpu torchvision==0.5.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
run_test_with_cpu $env "DIS_TESTING_TORCH"
}
main $@
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import warnings import warnings
warnings.simplefilter('default') warnings.simplefilter('default')
import os
import paddle.fluid as fluid import paddle.fluid as fluid
from parl.core.fluid import layers from parl.core.fluid import layers
from parl.core.agent_base import AgentBase from parl.core.agent_base import AgentBase
...@@ -152,8 +153,8 @@ class Agent(AgentBase): ...@@ -152,8 +153,8 @@ class Agent(AgentBase):
""" """
if program is None: if program is None:
program = self.learn_program program = self.learn_program
dirname = '/'.join(save_path.split('/')[:-1]) dirname = os.sep.join(save_path.split(os.sep)[:-1])
filename = save_path.split('/')[-1] filename = save_path.split(os.sep)[-1]
fluid.io.save_params( fluid.io.save_params(
executor=self.fluid_executor, executor=self.fluid_executor,
dirname=dirname, dirname=dirname,
...@@ -186,8 +187,8 @@ class Agent(AgentBase): ...@@ -186,8 +187,8 @@ class Agent(AgentBase):
program = self.learn_program program = self.learn_program
if type(program) is fluid.compiler.CompiledProgram: if type(program) is fluid.compiler.CompiledProgram:
program = program._init_program program = program._init_program
dirname = '/'.join(save_path.split('/')[:-1]) dirname = os.sep.join(save_path.split(os.sep)[:-1])
filename = save_path.split('/')[-1] filename = save_path.split(os.sep)[-1]
fluid.io.load_params( fluid.io.load_params(
executor=self.fluid_executor, executor=self.fluid_executor,
dirname=dirname, dirname=dirname,
......
...@@ -45,7 +45,7 @@ class TestParamSharing(unittest.TestCase): ...@@ -45,7 +45,7 @@ class TestParamSharing(unittest.TestCase):
dict_size = 100 dict_size = 100
input_cx = np.random.uniform(0, 1, [batch_size, 100]).astype("float32") input_cx = np.random.uniform(0, 1, [batch_size, 100]).astype("float32")
input_x = np.random.randint( input_x = np.random.randint(
dict_size, size=(batch_size, 1)).astype("int") dict_size, size=(batch_size, 1)).astype("int64")
################################# #################################
main_program1 = fluid.Program() main_program1 = fluid.Program()
...@@ -59,7 +59,7 @@ class TestParamSharing(unittest.TestCase): ...@@ -59,7 +59,7 @@ class TestParamSharing(unittest.TestCase):
main_program2 = fluid.Program() main_program2 = fluid.Program()
with fluid.program_guard(main_program2): with fluid.program_guard(main_program2):
x_ = layers.data(name='x', shape=[1], dtype="int") x_ = layers.data(name='x', shape=[1], dtype="int64")
cx_ = layers.cast( cx_ = layers.cast(
x=layers.one_hot(input=x_, depth=dict_size), dtype="float32") x=layers.one_hot(input=x_, depth=dict_size), dtype="float32")
y1_ = net.fc1(input=cx_) y1_ = net.fc1(input=cx_)
......
...@@ -46,8 +46,8 @@ class TestAlgorithm(parl.Algorithm): ...@@ -46,8 +46,8 @@ class TestAlgorithm(parl.Algorithm):
class TestAgent(parl.Agent): class TestAgent(parl.Agent):
def __init__(self, algorithm, gpu_id=None): def __init__(self, algorithm):
super(TestAgent, self).__init__(algorithm, gpu_id) super(TestAgent, self).__init__(algorithm)
def build_program(self): def build_program(self):
self.predict_program = fluid.Program() self.predict_program = fluid.Program()
...@@ -92,8 +92,8 @@ class AgentBaseTest(unittest.TestCase): ...@@ -92,8 +92,8 @@ class AgentBaseTest(unittest.TestCase):
agent = TestAgent(self.algorithm) agent = TestAgent(self.algorithm)
obs = np.random.random([3, 10]).astype('float32') obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs) output_np = agent.predict(obs)
save_path1 = './model.ckpt' save_path1 = 'model.ckpt'
save_path2 = './my_model/model-2.ckpt' save_path2 = os.path.join('my_model', 'model-2.ckpt')
agent.save(save_path1) agent.save(save_path1)
agent.save(save_path2) agent.save(save_path2)
self.assertTrue(os.path.exists(save_path1)) self.assertTrue(os.path.exists(save_path1))
...@@ -103,7 +103,7 @@ class AgentBaseTest(unittest.TestCase): ...@@ -103,7 +103,7 @@ class AgentBaseTest(unittest.TestCase):
agent = TestAgent(self.algorithm) agent = TestAgent(self.algorithm)
obs = np.random.random([3, 10]).astype('float32') obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs) output_np = agent.predict(obs)
save_path1 = './model.ckpt' save_path1 = 'model.ckpt'
previous_output = agent.predict(obs) previous_output = agent.predict(obs)
agent.save(save_path1) agent.save(save_path1)
agent.restore(save_path1) agent.restore(save_path1)
...@@ -121,7 +121,7 @@ class AgentBaseTest(unittest.TestCase): ...@@ -121,7 +121,7 @@ class AgentBaseTest(unittest.TestCase):
agent.learn_program = parl.compile(agent.learn_program) agent.learn_program = parl.compile(agent.learn_program)
obs = np.random.random([3, 10]).astype('float32') obs = np.random.random([3, 10]).astype('float32')
previous_output = agent.predict(obs) previous_output = agent.predict(obs)
save_path1 = './model.ckpt' save_path1 = 'model.ckpt'
agent.save(save_path1) agent.save(save_path1)
agent.restore(save_path1) agent.restore(save_path1)
......
...@@ -113,8 +113,9 @@ class Agent(AgentBase): ...@@ -113,8 +113,9 @@ class Agent(AgentBase):
""" """
if model is None: if model is None:
model = self.algorithm.model model = self.algorithm.model
dirname = '/'.join(save_path.split('/')[:-1]) sep = os.sep
if not os.path.exists(dirname): dirname = sep.join(save_path.split(sep)[:-1])
if dirname != '' and not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
torch.save(model.state_dict(), save_path) torch.save(model.state_dict(), save_path)
......
...@@ -77,8 +77,8 @@ class AgentBaseTest(unittest.TestCase): ...@@ -77,8 +77,8 @@ class AgentBaseTest(unittest.TestCase):
def test_save(self): def test_save(self):
agent = TestAgent(self.alg) agent = TestAgent(self.alg)
obs = torch.randn(3, 10) obs = torch.randn(3, 10)
save_path1 = './model.ckpt' save_path1 = 'model.ckpt'
save_path2 = './my_model/model-2.ckpt' save_path2 = os.path.join('my_model', 'model-2.ckpt')
agent.save(save_path1) agent.save(save_path1)
agent.save(save_path2) agent.save(save_path2)
self.assertTrue(os.path.exists(save_path1)) self.assertTrue(os.path.exists(save_path1))
...@@ -88,7 +88,7 @@ class AgentBaseTest(unittest.TestCase): ...@@ -88,7 +88,7 @@ class AgentBaseTest(unittest.TestCase):
agent = TestAgent(self.alg) agent = TestAgent(self.alg)
obs = torch.randn(3, 10) obs = torch.randn(3, 10)
output = agent.predict(obs) output = agent.predict(obs)
save_path1 = './model.ckpt' save_path1 = 'model.ckpt'
previous_output = agent.predict(obs).detach().cpu().numpy() previous_output = agent.predict(obs).detach().cpu().numpy()
agent.save(save_path1) agent.save(save_path1)
agent.restore(save_path1) agent.restore(save_path1)
......
...@@ -63,12 +63,15 @@ class Job(object): ...@@ -63,12 +63,15 @@ class Job(object):
self.worker_address = worker_address self.worker_address = worker_address
self.job_ip = get_ip_address() self.job_ip = get_ip_address()
self.pid = os.getpid() self.pid = os.getpid()
self.lock = threading.Lock()
self.run_job_process = Process( self.run_job_process = Process(
target=self.run, args=(job_address_sender, )) target=self.run, args=(job_address_sender, ))
self.run_job_process.start() self.run_job_process.start()
"""
NOTE:
In Windows, it will raise errors when creating threading.Lock before starting multiprocess.Process.
"""
self.lock = threading.Lock()
self._create_sockets() self._create_sockets()
process = psutil.Process(self.pid) process = psutil.Process(self.pid)
......
...@@ -61,7 +61,7 @@ class Master(object): ...@@ -61,7 +61,7 @@ class Master(object):
self.ctx = zmq.Context() self.ctx = zmq.Context()
self.master_ip = get_ip_address() self.master_ip = get_ip_address()
logger.set_dir( logger.set_dir(
os.path.expanduser('~/.parl_data/master/{}:{}'.format( os.path.expanduser('~/.parl_data/master/{}_{}'.format(
self.master_ip, port))) self.master_ip, port)))
self.client_socket = self.ctx.socket(zmq.REP) self.client_socket = self.ctx.socket(zmq.REP)
self.client_socket.bind("tcp://*:{}".format(port)) self.client_socket.bind("tcp://*:{}".format(port))
......
...@@ -42,7 +42,7 @@ class ClusterMonitor(object): ...@@ -42,7 +42,7 @@ class ClusterMonitor(object):
def __init__(self, master_address): def __init__(self, master_address):
ctx = zmq.Context() ctx = zmq.Context()
self.socket = ctx.socket(zmq.REQ) self.socket = ctx.socket(zmq.REQ)
self.socket.setsockopt(zmq.RCVTIMEO, 10000) self.socket.setsockopt(zmq.RCVTIMEO, 30000)
self.socket.connect('tcp://{}'.format(master_address)) self.socket.connect('tcp://{}'.format(master_address))
self.data = None self.data = None
...@@ -100,6 +100,10 @@ def cluster(): ...@@ -100,6 +100,10 @@ def cluster():
if __name__ == "__main__": if __name__ == "__main__":
import logging
log = logging.getLogger('werkzeug')
log.disabled = True
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--monitor_port', default=1234, type=int) parser.add_argument('--monitor_port', default=1234, type=int)
parser.add_argument('--address', default='localhost:8010', type=str) parser.add_argument('--address', default='localhost:8010', type=str)
......
...@@ -23,10 +23,11 @@ import subprocess ...@@ -23,10 +23,11 @@ import subprocess
import sys import sys
import time import time
import threading import threading
import tempfile
import warnings import warnings
import zmq import zmq
from multiprocessing import Process from multiprocessing import Process
from parl.utils import get_ip_address, to_str from parl.utils import get_ip_address, to_str, _IS_WINDOWS
from parl.remote.remote_constants import STATUS_TAG from parl.remote.remote_constants import STATUS_TAG
# A flag to mark if parl is started from a command line # A flag to mark if parl is started from a command line
...@@ -34,9 +35,11 @@ os.environ['XPARL'] = 'True' ...@@ -34,9 +35,11 @@ os.environ['XPARL'] = 'True'
# Solve `Click will abort further execution because Python 3 was configured # Solve `Click will abort further execution because Python 3 was configured
# to use ASCII as encoding for the environment` error. # to use ASCII as encoding for the environment` error.
try:
if not _IS_WINDOWS:
try:
locale.setlocale(locale.LC_ALL, "en_US.UTF-8") locale.setlocale(locale.LC_ALL, "en_US.UTF-8")
except: except:
pass pass
#TODO: this line will cause error in python2/macOS #TODO: this line will cause error in python2/macOS
...@@ -115,6 +118,9 @@ def start_master(port, cpu_num, monitor_port, debug): ...@@ -115,6 +118,9 @@ def start_master(port, cpu_num, monitor_port, debug):
cpu_num) if cpu_num is not None else multiprocessing.cpu_count() cpu_num) if cpu_num is not None else multiprocessing.cpu_count()
start_file = __file__.replace('scripts.pyc', 'start.py') start_file = __file__.replace('scripts.pyc', 'start.py')
start_file = start_file.replace('scripts.py', 'start.py') start_file = start_file.replace('scripts.py', 'start.py')
monitor_file = __file__.replace('scripts.pyc', 'monitor.py')
monitor_file = monitor_file.replace('scripts.py', 'monitor.py')
monitor_port = monitor_port if monitor_port else get_free_tcp_port() monitor_port = monitor_port if monitor_port else get_free_tcp_port()
master_command = [ master_command = [
...@@ -126,8 +132,7 @@ def start_master(port, cpu_num, monitor_port, debug): ...@@ -126,8 +132,7 @@ def start_master(port, cpu_num, monitor_port, debug):
str(cpu_num) str(cpu_num)
] ]
monitor_command = [ monitor_command = [
sys.executable, '{}/monitor.py'.format(__file__[:__file__.rfind('/')]), sys.executable, monitor_file, "--monitor_port",
"--monitor_port",
str(monitor_port), "--address", "localhost:" + str(port) str(monitor_port), "--address", "localhost:" + str(port)
] ]
...@@ -136,9 +141,19 @@ def start_master(port, cpu_num, monitor_port, debug): ...@@ -136,9 +141,19 @@ def start_master(port, cpu_num, monitor_port, debug):
# Redirect the output to DEVNULL to solve the warning log. # Redirect the output to DEVNULL to solve the warning log.
_ = subprocess.Popen( _ = subprocess.Popen(
master_command, stdout=FNULL, stderr=subprocess.STDOUT) master_command, stdout=FNULL, stderr=subprocess.STDOUT)
if cpu_num > 0: if cpu_num > 0:
# Sleep 1s for master ready
time.sleep(1)
_ = subprocess.Popen( _ = subprocess.Popen(
worker_command, stdout=FNULL, stderr=subprocess.STDOUT) worker_command, stdout=FNULL, stderr=subprocess.STDOUT)
if _IS_WINDOWS:
# TODO(@zenghsh3) redirecting stdout of monitor subprocess to FNULL will cause occasional failure
tmp_file = tempfile.TemporaryFile()
_ = subprocess.Popen(monitor_command, stdout=tmp_file)
tmp_file.close()
else:
_ = subprocess.Popen( _ = subprocess.Popen(
monitor_command, stdout=FNULL, stderr=subprocess.STDOUT) monitor_command, stdout=FNULL, stderr=subprocess.STDOUT)
FNULL.close() FNULL.close()
...@@ -161,16 +176,20 @@ def start_master(port, cpu_num, monitor_port, debug): ...@@ -161,16 +176,20 @@ def start_master(port, cpu_num, monitor_port, debug):
click.echo(monitor_info) click.echo(monitor_info)
# check if monitor is started # check if monitor is started
cmd = r'ps -ef | grep remote/monitor.py\ --monitor_port\ {}\ --address\ localhost:{}'.format(
monitor_port, port)
monitor_is_started = False monitor_is_started = False
if _IS_WINDOWS:
cmd = r'''wmic process where "commandline like '%remote\\monitor.py --monitor_port {} --address localhost:{}%'" get commandline /format:list | findstr /V wmic | findstr CommandLine='''.format(
monitor_port, port)
else:
cmd = r'ps -ef | grep -v grep | grep remote/monitor.py\ --monitor_port\ {}\ --address\ localhost:{}'.format(
monitor_port, port)
for i in range(3): for i in range(3):
check_monitor_is_started = os.popen(cmd).read().strip().split('\n') check_monitor_is_started = os.popen(cmd).read()
if len(check_monitor_is_started) == 2: if len(check_monitor_is_started) > 0:
monitor_is_started = True monitor_is_started = True
break break
time.sleep(3) time.sleep(3)
master_ip = get_ip_address() master_ip = get_ip_address()
if monitor_is_started: if monitor_is_started:
start_info = """ start_info = """
...@@ -212,9 +231,12 @@ def start_worker(address, cpu_num): ...@@ -212,9 +231,12 @@ def start_worker(address, cpu_num):
"please check if the input address {} ".format( "please check if the input address {} ".format(
address) + "is correct.") address) + "is correct.")
cpu_num = str(cpu_num) if cpu_num else '' cpu_num = str(cpu_num) if cpu_num else ''
start_file = __file__.replace('scripts.pyc', 'start.py')
start_file = start_file.replace('scripts.py', 'start.py')
command = [ command = [
sys.executable, "{}/start.py".format(__file__[:-11]), "--name", sys.executable, start_file, "--name", "worker", "--address", address,
"worker", "--address", address, "--cpu_num", "--cpu_num",
str(cpu_num) str(cpu_num)
] ]
p = subprocess.Popen(command) p = subprocess.Popen(command)
...@@ -222,6 +244,16 @@ def start_worker(address, cpu_num): ...@@ -222,6 +244,16 @@ def start_worker(address, cpu_num):
@click.command("stop", help="Exit the cluster.") @click.command("stop", help="Exit the cluster.")
def stop(): def stop():
if _IS_WINDOWS:
command = r'''for /F "skip=2 tokens=2 delims=," %a in ('wmic process where "commandline like '%remote\\job.py%'" get processid^,status /format:csv') do taskkill /F /T /pid %a'''
os.popen(command).read()
command = r'''for /F "skip=2 tokens=2 delims=," %a in ('wmic process where "commandline like '%remote\\start.py%'" get processid^,status /format:csv') do taskkill /F /pid %a'''
os.popen(command).read()
command = r'''for /F "skip=2 tokens=2 delims=," %a in ('wmic process where "commandline like '%remote\\monitor.py%'" get processid^,status /format:csv') do taskkill /F /pid %a'''
os.popen(command).read()
else:
command = ( command = (
"ps aux | grep remote/start.py | awk '{print $2}' | xargs kill -9") "ps aux | grep remote/start.py | awk '{print $2}' | xargs kill -9")
subprocess.call([command], shell=True) subprocess.call([command], shell=True)
...@@ -229,13 +261,18 @@ def stop(): ...@@ -229,13 +261,18 @@ def stop():
"ps aux | grep remote/job.py | awk '{print $2}' | xargs kill -9") "ps aux | grep remote/job.py | awk '{print $2}' | xargs kill -9")
subprocess.call([command], shell=True) subprocess.call([command], shell=True)
command = ( command = (
"ps aux | grep remote/monitor.py | awk '{print $2}' | xargs kill -9") "ps aux | grep remote/monitor.py | awk '{print $2}' | xargs kill -9"
)
subprocess.call([command], shell=True) subprocess.call([command], shell=True)
@click.command("status") @click.command("status")
def status(): def status():
if _IS_WINDOWS:
cmd = r'''wmic process where "commandline like '%remote\\start.py --name worker --address%'" get commandline /format:list | findstr /V wmic | findstr CommandLine='''
else:
cmd = r'ps -ef | grep remote/start.py\ --name\ worker\ --address' cmd = r'ps -ef | grep remote/start.py\ --name\ worker\ --address'
content = os.popen(cmd).read().strip() content = os.popen(cmd).read().strip()
pattern = re.compile('--address (.*?) --cpu') pattern = re.compile('--address (.*?) --cpu')
clusters = set(pattern.findall(content)) clusters = set(pattern.findall(content))
...@@ -245,6 +282,10 @@ def status(): ...@@ -245,6 +282,10 @@ def status():
ctx = zmq.Context() ctx = zmq.Context()
status = [] status = []
for cluster in clusters: for cluster in clusters:
if _IS_WINDOWS:
cmd = r'''wmic process where "commandline like '%address {}%'" get commandline /format:list | findstr /V wmic | findstr CommandLine='''.format(
cluster)
else:
cmd = r'ps -ef | grep address\ {}'.format(cluster) cmd = r'ps -ef | grep address\ {}'.format(cluster)
content = os.popen(cmd).read() content = os.popen(cmd).read()
pattern = re.compile('--monitor_port (.*?)\n', re.S) pattern = re.compile('--monitor_port (.*?)\n', re.S)
......
...@@ -45,7 +45,10 @@ class TestMaxMemory(unittest.TestCase): ...@@ -45,7 +45,10 @@ class TestMaxMemory(unittest.TestCase):
def tearDown(self): def tearDown(self):
disconnect() disconnect()
def actor(self): #In windows, multiprocessing.Process cannot run the method of class, but static method is ok.
@staticmethod
def actor(cluster_addr):
parl.connect(cluster_addr)
actor1 = Actor() actor1 = Actor()
time.sleep(10) time.sleep(10)
actor1.add_500mb() actor1.add_500mb()
...@@ -56,16 +59,17 @@ class TestMaxMemory(unittest.TestCase): ...@@ -56,16 +59,17 @@ class TestMaxMemory(unittest.TestCase):
th = threading.Thread(target=master.run) th = threading.Thread(target=master.run)
th.start() th.start()
time.sleep(5) time.sleep(5)
worker = Worker('localhost:{}'.format(port), 1) cluster_addr = 'localhost:{}'.format(port)
cluster_monitor = ClusterMonitor('localhost:{}'.format(port)) worker = Worker(cluster_addr, 1)
cluster_monitor = ClusterMonitor(cluster_addr)
time.sleep(5) time.sleep(5)
parl.connect('localhost:{}'.format(port)) parl.connect(cluster_addr)
actor = Actor() actor = Actor()
time.sleep(20) time.sleep(20)
self.assertEqual(1, cluster_monitor.data['clients'][0]['actor_num']) self.assertEqual(1, cluster_monitor.data['clients'][0]['actor_num'])
del actor del actor
time.sleep(10) time.sleep(10)
p = Process(target=self.actor) p = Process(target=self.actor, args=(cluster_addr, ))
p.start() p.start()
for _ in range(6): for _ in range(6):
......
...@@ -22,7 +22,6 @@ import time ...@@ -22,7 +22,6 @@ import time
import threading import threading
from parl.remote.client import disconnect from parl.remote.client import disconnect
from parl.remote import exceptions from parl.remote import exceptions
import timeout_decorator
import subprocess import subprocess
......
...@@ -22,7 +22,6 @@ import time ...@@ -22,7 +22,6 @@ import time
import threading import threading
from parl.remote.client import disconnect from parl.remote.client import disconnect
from parl.remote import exceptions from parl.remote import exceptions
import timeout_decorator
import subprocess import subprocess
......
...@@ -22,7 +22,6 @@ import time ...@@ -22,7 +22,6 @@ import time
import threading import threading
from parl.remote.client import disconnect from parl.remote.client import disconnect
from parl.remote import exceptions from parl.remote import exceptions
import timeout_decorator
import subprocess import subprocess
......
...@@ -21,7 +21,6 @@ import time ...@@ -21,7 +21,6 @@ import time
import threading import threading
from parl.remote.client import disconnect from parl.remote.client import disconnect
from parl.remote import exceptions from parl.remote import exceptions
import timeout_decorator
import subprocess import subprocess
...@@ -89,7 +88,6 @@ class TestCluster(unittest.TestCase): ...@@ -89,7 +88,6 @@ class TestCluster(unittest.TestCase):
master.exit() master.exit()
worker1.exit() worker1.exit()
@timeout_decorator.timeout(seconds=800)
def test_actor_exception(self): def test_actor_exception(self):
master = Master(port=1236) master = Master(port=1236)
th = threading.Thread(target=master.run) th = threading.Thread(target=master.run)
......
...@@ -16,7 +16,6 @@ import unittest ...@@ -16,7 +16,6 @@ import unittest
import parl import parl
import time import time
import threading import threading
import timeout_decorator
import multiprocessing import multiprocessing
from parl.remote.master import Master from parl.remote.master import Master
...@@ -39,12 +38,14 @@ class TestCluster(unittest.TestCase): ...@@ -39,12 +38,14 @@ class TestCluster(unittest.TestCase):
def tearDown(self): def tearDown(self):
disconnect() disconnect()
def _connect_and_create_actor(self, cluster_addr): #In windows, multiprocessing.Process cannot run the method of class, but static method is ok.
@staticmethod
def _connect_and_create_actor(cluster_addr):
parl.connect(cluster_addr) parl.connect(cluster_addr)
for _ in range(2): for _ in range(2):
actor = Actor() actor = Actor()
ret = actor.add_one(1) ret = actor.add_one(1)
self.assertEqual(ret, 2) assert ret == 2
disconnect() disconnect()
def _create_actor(self): def _create_actor(self):
...@@ -53,7 +54,6 @@ class TestCluster(unittest.TestCase): ...@@ -53,7 +54,6 @@ class TestCluster(unittest.TestCase):
ret = actor.add_one(1) ret = actor.add_one(1)
self.assertEqual(ret, 2) self.assertEqual(ret, 2)
@timeout_decorator.timeout(seconds=300)
def test_connect_and_create_actor_in_multiprocessing_with_connected_in_main_process( def test_connect_and_create_actor_in_multiprocessing_with_connected_in_main_process(
self): self):
# start the master # start the master
......
...@@ -16,7 +16,6 @@ import unittest ...@@ -16,7 +16,6 @@ import unittest
import parl import parl
import time import time
import threading import threading
import timeout_decorator
import multiprocessing import multiprocessing
from parl.remote.master import Master from parl.remote.master import Master
...@@ -39,12 +38,14 @@ class TestCluster(unittest.TestCase): ...@@ -39,12 +38,14 @@ class TestCluster(unittest.TestCase):
def tearDown(self): def tearDown(self):
disconnect() disconnect()
def _connect_and_create_actor(self, cluster_addr): #In windows, multiprocessing.Process cannot run the method of class, but static method is ok.
@staticmethod
def _connect_and_create_actor(cluster_addr):
parl.connect(cluster_addr) parl.connect(cluster_addr)
for _ in range(2): for _ in range(2):
actor = Actor() actor = Actor()
ret = actor.add_one(1) ret = actor.add_one(1)
self.assertEqual(ret, 2) assert ret == 2
disconnect() disconnect()
def _create_actor(self): def _create_actor(self):
...@@ -53,7 +54,6 @@ class TestCluster(unittest.TestCase): ...@@ -53,7 +54,6 @@ class TestCluster(unittest.TestCase):
ret = actor.add_one(1) ret = actor.add_one(1)
self.assertEqual(ret, 2) self.assertEqual(ret, 2)
@timeout_decorator.timeout(seconds=300)
def test_connect_and_create_actor_in_multiprocessing_without_connected_in_main_process( def test_connect_and_create_actor_in_multiprocessing_without_connected_in_main_process(
self): self):
# start the master # start the master
......
...@@ -4,8 +4,7 @@ ...@@ -4,8 +4,7 @@
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0 #
#
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -16,12 +15,12 @@ import unittest ...@@ -16,12 +15,12 @@ import unittest
import parl import parl
import time import time
import threading import threading
import timeout_decorator
import multiprocessing import multiprocessing
from parl.remote.master import Master from parl.remote.master import Master
from parl.remote.worker import Worker from parl.remote.worker import Worker
from parl.remote.client import disconnect from parl.remote.client import disconnect
from parl.utils import _IS_WINDOWS
@parl.remote_class @parl.remote_class
...@@ -39,21 +38,14 @@ class TestCluster(unittest.TestCase): ...@@ -39,21 +38,14 @@ class TestCluster(unittest.TestCase):
def tearDown(self): def tearDown(self):
disconnect() disconnect()
def _connect_and_create_actor(self, cluster_addr): #In windows, multiprocessing.Process cannot run the method of class, but static method is ok.
parl.connect(cluster_addr) @staticmethod
for _ in range(2): def _create_actor():
actor = Actor()
ret = actor.add_one(1)
self.assertEqual(ret, 2)
disconnect()
def _create_actor(self):
for _ in range(2): for _ in range(2):
actor = Actor() actor = Actor()
ret = actor.add_one(1) ret = actor.add_one(1)
self.assertEqual(ret, 2) assert ret == 2
@timeout_decorator.timeout(seconds=300)
def test_create_actor_in_multiprocessing(self): def test_create_actor_in_multiprocessing(self):
# start the master # start the master
master = Master(port=8240) master = Master(port=8240)
...@@ -64,6 +56,7 @@ class TestCluster(unittest.TestCase): ...@@ -64,6 +56,7 @@ class TestCluster(unittest.TestCase):
worker1 = Worker('localhost:8240', 4) worker1 = Worker('localhost:8240', 4)
parl.connect('localhost:8240') parl.connect('localhost:8240')
if not _IS_WINDOWS: # In windows, fork process cannot access client created in main process.
proc1 = multiprocessing.Process(target=self._create_actor) proc1 = multiprocessing.Process(target=self._create_actor)
proc2 = multiprocessing.Process(target=self._create_actor) proc2 = multiprocessing.Process(target=self._create_actor)
proc1.start() proc1.start()
......
...@@ -23,7 +23,6 @@ import time ...@@ -23,7 +23,6 @@ import time
import threading import threading
import subprocess import subprocess
import sys import sys
import timeout_decorator
@parl.remote_class @parl.remote_class
...@@ -63,7 +62,6 @@ class TestJob(unittest.TestCase): ...@@ -63,7 +62,6 @@ class TestJob(unittest.TestCase):
def tearDown(self): def tearDown(self):
disconnect() disconnect()
@timeout_decorator.timeout(seconds=600)
def test_acor_exit_exceptionally(self): def test_acor_exit_exceptionally(self):
port = 1337 port = 1337
master = Master(port) master = Master(port)
......
...@@ -16,7 +16,8 @@ import parl ...@@ -16,7 +16,8 @@ import parl
from parl.remote.master import Master from parl.remote.master import Master
from parl.remote.worker import Worker from parl.remote.worker import Worker
from parl.remote.client import disconnect from parl.remote.client import disconnect
from parl.utils import logger from parl.utils import logger, _IS_WINDOWS
import os
import threading import threading
import time import time
import subprocess import subprocess
...@@ -70,8 +71,13 @@ class TestJobAlone(unittest.TestCase): ...@@ -70,8 +71,13 @@ class TestJobAlone(unittest.TestCase):
time.sleep(1) time.sleep(1)
self.assertEqual(master.cpu_num, 4) self.assertEqual(master.cpu_num, 4)
print("We are going to kill all the jobs.") print("We are going to kill all the jobs.")
if _IS_WINDOWS:
command = r'''for /F "skip=2 tokens=2 delims=," %a in ('wmic process where "commandline like '%remote\\job.py%'" get processid^,status /format:csv') do taskkill /F /T /pid %a'''
print(os.popen(command).read())
else:
command = ( command = (
"ps aux | grep remote/job.py | awk '{print $2}' | xargs kill -9") "ps aux | grep remote/job.py | awk '{print $2}' | xargs kill -9"
)
subprocess.call([command], shell=True) subprocess.call([command], shell=True)
parl.connect('localhost:1334') parl.connect('localhost:1334')
actor = Actor() actor = Actor()
......
...@@ -21,6 +21,7 @@ import threading ...@@ -21,6 +21,7 @@ import threading
from parl.remote.master import Master from parl.remote.master import Master
from parl.remote.worker import Worker from parl.remote.worker import Worker
from parl.remote.client import disconnect from parl.remote.client import disconnect
from parl.utils import _IS_WINDOWS
@parl.remote_class @parl.remote_class
...@@ -44,12 +45,15 @@ class TestSendFile(unittest.TestCase): ...@@ -44,12 +45,15 @@ class TestSendFile(unittest.TestCase):
worker = Worker('localhost:{}'.format(port), 1) worker = Worker('localhost:{}'.format(port), 1)
time.sleep(2) time.sleep(2)
os.system('mkdir ./rom_files') tmp_dir = 'rom_files'
os.system('touch ./rom_files/pong.bin') tmp_file = os.path.join(tmp_dir, 'pong.bin')
assert os.path.exists('./rom_files/pong.bin') os.system('mkdir {}'.format(tmp_dir))
parl.connect( if _IS_WINDOWS:
'localhost:{}'.format(port), os.system('type NUL >> {}'.format(tmp_file))
distributed_files=['./rom_files/pong.bin']) else:
os.system('touch {}'.format(tmp_file))
assert os.path.exists(tmp_file)
parl.connect('localhost:{}'.format(port), distributed_files=[tmp_file])
time.sleep(5) time.sleep(5)
actor = Actor() actor = Actor()
for _ in range(10): for _ in range(10):
...@@ -70,8 +74,9 @@ class TestSendFile(unittest.TestCase): ...@@ -70,8 +74,9 @@ class TestSendFile(unittest.TestCase):
worker = Worker('localhost:{}'.format(port), 1) worker = Worker('localhost:{}'.format(port), 1)
time.sleep(2) time.sleep(2)
tmp_file = os.path.join('rom_files', 'no_pong.bin')
self.assertRaises(Exception, parl.connect, 'localhost:{}'.format(port), self.assertRaises(Exception, parl.connect, 'localhost:{}'.format(port),
['./rom_files/no_pong.bin']) [tmp_file])
worker.exit() worker.exit()
master.exit() master.exit()
......
...@@ -26,7 +26,7 @@ import warnings ...@@ -26,7 +26,7 @@ import warnings
import zmq import zmq
from datetime import datetime from datetime import datetime
from parl.utils import get_ip_address, to_byte, to_str, logger from parl.utils import get_ip_address, to_byte, to_str, logger, _IS_WINDOWS
from parl.remote import remote_constants from parl.remote import remote_constants
from parl.remote.message import InitializedWorker from parl.remote.message import InitializedWorker
from parl.remote.status import WorkerStatus from parl.remote.status import WorkerStatus
...@@ -311,6 +311,9 @@ class Worker(object): ...@@ -311,6 +311,9 @@ class Worker(object):
total_memory = round(virtual_memory[0] / (1024**3), 2) total_memory = round(virtual_memory[0] / (1024**3), 2)
used_memory = round(virtual_memory[3] / (1024**3), 2) used_memory = round(virtual_memory[3] / (1024**3), 2)
vacant_memory = round(total_memory - used_memory, 2) vacant_memory = round(total_memory - used_memory, 2)
if _IS_WINDOWS:
load_average = round(psutil.getloadavg()[0], 2)
else:
load_average = round(os.getloadavg()[0], 2) load_average = round(os.getloadavg()[0], 2)
return (vacant_memory, used_memory, now, load_average) return (vacant_memory, used_memory, now, load_average)
...@@ -329,7 +332,7 @@ class Worker(object): ...@@ -329,7 +332,7 @@ class Worker(object):
logger.set_dir( logger.set_dir(
os.path.expanduser('~/.parl_data/worker/{}'.format( os.path.expanduser('~/.parl_data/worker/{}'.format(
self.master_heartbeat_address))) self.master_heartbeat_address.replace(':', '_'))))
self.heartbeat_socket_initialized.set() self.heartbeat_socket_initialized.set()
logger.info("[Worker] Connect to the master node successfully. " logger.info("[Worker] Connect to the master node successfully. "
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
import os import os
import platform import platform
import socket
import subprocess import subprocess
from parl.utils import logger from parl.utils import logger, _HAS_FLUID, _IS_WINDOWS
from parl.utils import utils
__all__ = ['get_gpu_count', 'get_ip_address', 'is_gpu_available'] __all__ = ['get_gpu_count', 'get_ip_address', 'is_gpu_available']
...@@ -25,17 +25,13 @@ def get_ip_address(): ...@@ -25,17 +25,13 @@ def get_ip_address():
""" """
get the IP address of the host. get the IP address of the host.
""" """
platform_sys = platform.system()
# Only support Linux and MacOS
if platform_sys != 'Linux' and platform_sys != 'Darwin':
logger.warning(
'get_ip_address only support Linux and MacOS, please set ip address manually.'
)
return None
# Windows
if _IS_WINDOWS:
local_ip = socket.gethostbyname(socket.gethostname())
else:
# Linux and MacOS
local_ip = None local_ip = None
import socket
try: try:
# First way, tested in Ubuntu and MacOS # First way, tested in Ubuntu and MacOS
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
...@@ -97,7 +93,7 @@ def is_gpu_available(): ...@@ -97,7 +93,7 @@ def is_gpu_available():
True if a gpu device can be found. True if a gpu device can be found.
""" """
ret = get_gpu_count() > 0 ret = get_gpu_count() > 0
if utils._HAS_FLUID: if _HAS_FLUID:
from paddle import fluid from paddle import fluid
if ret is True and not fluid.is_compiled_with_cuda(): if ret is True and not fluid.is_compiled_with_cuda():
logger.warning("Found non-empty CUDA_VISIBLE_DEVICES. \ logger.warning("Found non-empty CUDA_VISIBLE_DEVICES. \
......
...@@ -16,7 +16,7 @@ import sys ...@@ -16,7 +16,7 @@ import sys
__all__ = [ __all__ = [
'has_func', 'action_mapping', 'to_str', 'to_byte', 'is_PY2', 'is_PY3', 'has_func', 'action_mapping', 'to_str', 'to_byte', 'is_PY2', 'is_PY3',
'MAX_INT32', '_HAS_FLUID', '_HAS_TORCH', '_IS_MAC' 'MAX_INT32', '_HAS_FLUID', '_HAS_TORCH', '_IS_WINDOWS', '_IS_MAC'
] ]
...@@ -93,4 +93,5 @@ try: ...@@ -93,4 +93,5 @@ try:
except ImportError: except ImportError:
_HAS_TORCH = False _HAS_TORCH = False
_IS_WINDOWS = (sys.platform == 'win32')
_IS_MAC = (sys.platform == 'darwin') _IS_MAC = (sys.platform == 'darwin')
...@@ -31,7 +31,12 @@ def _find_packages(prefix=''): ...@@ -31,7 +31,12 @@ def _find_packages(prefix=''):
prefix = prefix prefix = prefix
for root, _, files in os.walk(path): for root, _, files in os.walk(path):
if '__init__.py' in files: if '__init__.py' in files:
packages.append(re.sub('^[^A-z0-9_]', '', root.replace('/', '.'))) if sys.platform == 'win32':
packages.append(
re.sub('^[^A-z0-9_]', '', root.replace('\\', '.')))
else:
packages.append(
re.sub('^[^A-z0-9_]', '', root.replace('/', '.')))
return packages return packages
...@@ -74,7 +79,7 @@ setup( ...@@ -74,7 +79,7 @@ setup(
"tb-nightly==1.15.0a20190801", "tb-nightly==1.15.0a20190801",
"flask==1.0.4", "flask==1.0.4",
"click", "click",
"psutil", "psutil>=5.6.2",
], ],
classifiers=[ classifiers=[
'Intended Audience :: Developers', 'Intended Audience :: Developers',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册