未验证 提交 2cb2c8bf 编写于 作者: H Hongsheng Zeng 提交者: GitHub

support windows system (#215)

* first version; unittest in parl/remote is passed

* add windows local unittest shell; fix some incompatible problem

* fix save api

* refine comments and remove log of xparl stop

* fix bug

* Update scripts.py

* Update utils.py

* Update agent_base_test.py
上级 117b1c38
......@@ -3,4 +3,3 @@ paddlepaddle-gpu==1.6.1.post97
gym
details
parameterized
timeout_decorator
......@@ -2,4 +2,3 @@
gym
details
parameterized
timeout_decorator
#!/usr/bin/env bash
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: You need install mingw-cmake.
function init() {
RED='\033[0;31m'
BLUE='\033[0;34m'
BOLD='\033[1m'
NONE='\033[0m'
REPO_ROOT=`pwd`
}
function abort(){
echo "Your change doesn't follow PaddlePaddle's code style." 1>&2
echo "Please use pre-commit to check what is wrong." 1>&2
exit 1
}
function run_test_with_cpu() {
export CUDA_VISIBLE_DEVICES="-1"
mkdir -p ${REPO_ROOT}/build
cd ${REPO_ROOT}/build
if [ $# -eq 1 ];then
cmake -G "MinGW Makefiles" ..
else
cmake -G "MinGW Makefiles" .. -$2=ON
fi
cat <<EOF
=====================================================
Running unit tests with CPU in the environment: $1
=====================================================
EOF
if [ $# -eq 1 ];then
ctest --output-on-failure -j10
else
ctest --output-on-failure
fi
cd ${REPO_ROOT}
rm -rf ${REPO_ROOT}/build
}
function main() {
set -e
local CMD=$1
init
env="unused_variable"
# run unittest in windows (used in local machine)
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple .
pip uninstall -y torch torchvision
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple paddlepaddle==1.6.1 gym details parameterized
run_test_with_cpu $env
run_test_with_cpu $env "DIS_TESTING_SERIALLY"
pip uninstall -y paddlepaddle
pip install torch==1.4.0+cpu torchvision==0.5.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
run_test_with_cpu $env "DIS_TESTING_TORCH"
}
main $@
......@@ -15,6 +15,7 @@
import warnings
warnings.simplefilter('default')
import os
import paddle.fluid as fluid
from parl.core.fluid import layers
from parl.core.agent_base import AgentBase
......@@ -152,8 +153,8 @@ class Agent(AgentBase):
"""
if program is None:
program = self.learn_program
dirname = '/'.join(save_path.split('/')[:-1])
filename = save_path.split('/')[-1]
dirname = os.sep.join(save_path.split(os.sep)[:-1])
filename = save_path.split(os.sep)[-1]
fluid.io.save_params(
executor=self.fluid_executor,
dirname=dirname,
......@@ -186,8 +187,8 @@ class Agent(AgentBase):
program = self.learn_program
if type(program) is fluid.compiler.CompiledProgram:
program = program._init_program
dirname = '/'.join(save_path.split('/')[:-1])
filename = save_path.split('/')[-1]
dirname = os.sep.join(save_path.split(os.sep)[:-1])
filename = save_path.split(os.sep)[-1]
fluid.io.load_params(
executor=self.fluid_executor,
dirname=dirname,
......
......@@ -45,7 +45,7 @@ class TestParamSharing(unittest.TestCase):
dict_size = 100
input_cx = np.random.uniform(0, 1, [batch_size, 100]).astype("float32")
input_x = np.random.randint(
dict_size, size=(batch_size, 1)).astype("int")
dict_size, size=(batch_size, 1)).astype("int64")
#################################
main_program1 = fluid.Program()
......@@ -59,7 +59,7 @@ class TestParamSharing(unittest.TestCase):
main_program2 = fluid.Program()
with fluid.program_guard(main_program2):
x_ = layers.data(name='x', shape=[1], dtype="int")
x_ = layers.data(name='x', shape=[1], dtype="int64")
cx_ = layers.cast(
x=layers.one_hot(input=x_, depth=dict_size), dtype="float32")
y1_ = net.fc1(input=cx_)
......
......@@ -46,8 +46,8 @@ class TestAlgorithm(parl.Algorithm):
class TestAgent(parl.Agent):
def __init__(self, algorithm, gpu_id=None):
super(TestAgent, self).__init__(algorithm, gpu_id)
def __init__(self, algorithm):
super(TestAgent, self).__init__(algorithm)
def build_program(self):
self.predict_program = fluid.Program()
......@@ -92,8 +92,8 @@ class AgentBaseTest(unittest.TestCase):
agent = TestAgent(self.algorithm)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
save_path1 = './model.ckpt'
save_path2 = './my_model/model-2.ckpt'
save_path1 = 'model.ckpt'
save_path2 = os.path.join('my_model', 'model-2.ckpt')
agent.save(save_path1)
agent.save(save_path2)
self.assertTrue(os.path.exists(save_path1))
......@@ -103,7 +103,7 @@ class AgentBaseTest(unittest.TestCase):
agent = TestAgent(self.algorithm)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
save_path1 = './model.ckpt'
save_path1 = 'model.ckpt'
previous_output = agent.predict(obs)
agent.save(save_path1)
agent.restore(save_path1)
......@@ -121,7 +121,7 @@ class AgentBaseTest(unittest.TestCase):
agent.learn_program = parl.compile(agent.learn_program)
obs = np.random.random([3, 10]).astype('float32')
previous_output = agent.predict(obs)
save_path1 = './model.ckpt'
save_path1 = 'model.ckpt'
agent.save(save_path1)
agent.restore(save_path1)
......
......@@ -113,8 +113,9 @@ class Agent(AgentBase):
"""
if model is None:
model = self.algorithm.model
dirname = '/'.join(save_path.split('/')[:-1])
if not os.path.exists(dirname):
sep = os.sep
dirname = sep.join(save_path.split(sep)[:-1])
if dirname != '' and not os.path.exists(dirname):
os.makedirs(dirname)
torch.save(model.state_dict(), save_path)
......
......@@ -77,8 +77,8 @@ class AgentBaseTest(unittest.TestCase):
def test_save(self):
agent = TestAgent(self.alg)
obs = torch.randn(3, 10)
save_path1 = './model.ckpt'
save_path2 = './my_model/model-2.ckpt'
save_path1 = 'model.ckpt'
save_path2 = os.path.join('my_model', 'model-2.ckpt')
agent.save(save_path1)
agent.save(save_path2)
self.assertTrue(os.path.exists(save_path1))
......@@ -88,7 +88,7 @@ class AgentBaseTest(unittest.TestCase):
agent = TestAgent(self.alg)
obs = torch.randn(3, 10)
output = agent.predict(obs)
save_path1 = './model.ckpt'
save_path1 = 'model.ckpt'
previous_output = agent.predict(obs).detach().cpu().numpy()
agent.save(save_path1)
agent.restore(save_path1)
......
......@@ -63,12 +63,15 @@ class Job(object):
self.worker_address = worker_address
self.job_ip = get_ip_address()
self.pid = os.getpid()
self.lock = threading.Lock()
self.run_job_process = Process(
target=self.run, args=(job_address_sender, ))
self.run_job_process.start()
"""
NOTE:
In Windows, it will raise errors when creating threading.Lock before starting multiprocess.Process.
"""
self.lock = threading.Lock()
self._create_sockets()
process = psutil.Process(self.pid)
......
......@@ -61,7 +61,7 @@ class Master(object):
self.ctx = zmq.Context()
self.master_ip = get_ip_address()
logger.set_dir(
os.path.expanduser('~/.parl_data/master/{}:{}'.format(
os.path.expanduser('~/.parl_data/master/{}_{}'.format(
self.master_ip, port)))
self.client_socket = self.ctx.socket(zmq.REP)
self.client_socket.bind("tcp://*:{}".format(port))
......
......@@ -42,7 +42,7 @@ class ClusterMonitor(object):
def __init__(self, master_address):
ctx = zmq.Context()
self.socket = ctx.socket(zmq.REQ)
self.socket.setsockopt(zmq.RCVTIMEO, 10000)
self.socket.setsockopt(zmq.RCVTIMEO, 30000)
self.socket.connect('tcp://{}'.format(master_address))
self.data = None
......@@ -100,6 +100,10 @@ def cluster():
if __name__ == "__main__":
import logging
log = logging.getLogger('werkzeug')
log.disabled = True
parser = argparse.ArgumentParser()
parser.add_argument('--monitor_port', default=1234, type=int)
parser.add_argument('--address', default='localhost:8010', type=str)
......
......@@ -23,10 +23,11 @@ import subprocess
import sys
import time
import threading
import tempfile
import warnings
import zmq
from multiprocessing import Process
from parl.utils import get_ip_address, to_str
from parl.utils import get_ip_address, to_str, _IS_WINDOWS
from parl.remote.remote_constants import STATUS_TAG
# A flag to mark if parl is started from a command line
......@@ -34,10 +35,12 @@ os.environ['XPARL'] = 'True'
# Solve `Click will abort further execution because Python 3 was configured
# to use ASCII as encoding for the environment` error.
try:
locale.setlocale(locale.LC_ALL, "en_US.UTF-8")
except:
pass
if not _IS_WINDOWS:
try:
locale.setlocale(locale.LC_ALL, "en_US.UTF-8")
except:
pass
#TODO: this line will cause error in python2/macOS
if sys.version_info.major == 3:
......@@ -115,6 +118,9 @@ def start_master(port, cpu_num, monitor_port, debug):
cpu_num) if cpu_num is not None else multiprocessing.cpu_count()
start_file = __file__.replace('scripts.pyc', 'start.py')
start_file = start_file.replace('scripts.py', 'start.py')
monitor_file = __file__.replace('scripts.pyc', 'monitor.py')
monitor_file = monitor_file.replace('scripts.py', 'monitor.py')
monitor_port = monitor_port if monitor_port else get_free_tcp_port()
master_command = [
......@@ -126,8 +132,7 @@ def start_master(port, cpu_num, monitor_port, debug):
str(cpu_num)
]
monitor_command = [
sys.executable, '{}/monitor.py'.format(__file__[:__file__.rfind('/')]),
"--monitor_port",
sys.executable, monitor_file, "--monitor_port",
str(monitor_port), "--address", "localhost:" + str(port)
]
......@@ -136,11 +141,21 @@ def start_master(port, cpu_num, monitor_port, debug):
# Redirect the output to DEVNULL to solve the warning log.
_ = subprocess.Popen(
master_command, stdout=FNULL, stderr=subprocess.STDOUT)
if cpu_num > 0:
# Sleep 1s for master ready
time.sleep(1)
_ = subprocess.Popen(
worker_command, stdout=FNULL, stderr=subprocess.STDOUT)
_ = subprocess.Popen(
monitor_command, stdout=FNULL, stderr=subprocess.STDOUT)
if _IS_WINDOWS:
# TODO(@zenghsh3) redirecting stdout of monitor subprocess to FNULL will cause occasional failure
tmp_file = tempfile.TemporaryFile()
_ = subprocess.Popen(monitor_command, stdout=tmp_file)
tmp_file.close()
else:
_ = subprocess.Popen(
monitor_command, stdout=FNULL, stderr=subprocess.STDOUT)
FNULL.close()
if cpu_num > 0:
......@@ -161,16 +176,20 @@ def start_master(port, cpu_num, monitor_port, debug):
click.echo(monitor_info)
# check if monitor is started
cmd = r'ps -ef | grep remote/monitor.py\ --monitor_port\ {}\ --address\ localhost:{}'.format(
monitor_port, port)
monitor_is_started = False
if _IS_WINDOWS:
cmd = r'''wmic process where "commandline like '%remote\\monitor.py --monitor_port {} --address localhost:{}%'" get commandline /format:list | findstr /V wmic | findstr CommandLine='''.format(
monitor_port, port)
else:
cmd = r'ps -ef | grep -v grep | grep remote/monitor.py\ --monitor_port\ {}\ --address\ localhost:{}'.format(
monitor_port, port)
for i in range(3):
check_monitor_is_started = os.popen(cmd).read().strip().split('\n')
if len(check_monitor_is_started) == 2:
check_monitor_is_started = os.popen(cmd).read()
if len(check_monitor_is_started) > 0:
monitor_is_started = True
break
time.sleep(3)
master_ip = get_ip_address()
if monitor_is_started:
start_info = """
......@@ -212,9 +231,12 @@ def start_worker(address, cpu_num):
"please check if the input address {} ".format(
address) + "is correct.")
cpu_num = str(cpu_num) if cpu_num else ''
start_file = __file__.replace('scripts.pyc', 'start.py')
start_file = start_file.replace('scripts.py', 'start.py')
command = [
sys.executable, "{}/start.py".format(__file__[:-11]), "--name",
"worker", "--address", address, "--cpu_num",
sys.executable, start_file, "--name", "worker", "--address", address,
"--cpu_num",
str(cpu_num)
]
p = subprocess.Popen(command)
......@@ -222,20 +244,35 @@ def start_worker(address, cpu_num):
@click.command("stop", help="Exit the cluster.")
def stop():
command = (
"ps aux | grep remote/start.py | awk '{print $2}' | xargs kill -9")
subprocess.call([command], shell=True)
command = (
"ps aux | grep remote/job.py | awk '{print $2}' | xargs kill -9")
subprocess.call([command], shell=True)
command = (
"ps aux | grep remote/monitor.py | awk '{print $2}' | xargs kill -9")
subprocess.call([command], shell=True)
if _IS_WINDOWS:
command = r'''for /F "skip=2 tokens=2 delims=," %a in ('wmic process where "commandline like '%remote\\job.py%'" get processid^,status /format:csv') do taskkill /F /T /pid %a'''
os.popen(command).read()
command = r'''for /F "skip=2 tokens=2 delims=," %a in ('wmic process where "commandline like '%remote\\start.py%'" get processid^,status /format:csv') do taskkill /F /pid %a'''
os.popen(command).read()
command = r'''for /F "skip=2 tokens=2 delims=," %a in ('wmic process where "commandline like '%remote\\monitor.py%'" get processid^,status /format:csv') do taskkill /F /pid %a'''
os.popen(command).read()
else:
command = (
"ps aux | grep remote/start.py | awk '{print $2}' | xargs kill -9")
subprocess.call([command], shell=True)
command = (
"ps aux | grep remote/job.py | awk '{print $2}' | xargs kill -9")
subprocess.call([command], shell=True)
command = (
"ps aux | grep remote/monitor.py | awk '{print $2}' | xargs kill -9"
)
subprocess.call([command], shell=True)
@click.command("status")
def status():
cmd = r'ps -ef | grep remote/start.py\ --name\ worker\ --address'
if _IS_WINDOWS:
cmd = r'''wmic process where "commandline like '%remote\\start.py --name worker --address%'" get commandline /format:list | findstr /V wmic | findstr CommandLine='''
else:
cmd = r'ps -ef | grep remote/start.py\ --name\ worker\ --address'
content = os.popen(cmd).read().strip()
pattern = re.compile('--address (.*?) --cpu')
clusters = set(pattern.findall(content))
......@@ -245,7 +282,11 @@ def status():
ctx = zmq.Context()
status = []
for cluster in clusters:
cmd = r'ps -ef | grep address\ {}'.format(cluster)
if _IS_WINDOWS:
cmd = r'''wmic process where "commandline like '%address {}%'" get commandline /format:list | findstr /V wmic | findstr CommandLine='''.format(
cluster)
else:
cmd = r'ps -ef | grep address\ {}'.format(cluster)
content = os.popen(cmd).read()
pattern = re.compile('--monitor_port (.*?)\n', re.S)
monitors = pattern.findall(content)
......
......@@ -45,7 +45,10 @@ class TestMaxMemory(unittest.TestCase):
def tearDown(self):
disconnect()
def actor(self):
#In windows, multiprocessing.Process cannot run the method of class, but static method is ok.
@staticmethod
def actor(cluster_addr):
parl.connect(cluster_addr)
actor1 = Actor()
time.sleep(10)
actor1.add_500mb()
......@@ -56,16 +59,17 @@ class TestMaxMemory(unittest.TestCase):
th = threading.Thread(target=master.run)
th.start()
time.sleep(5)
worker = Worker('localhost:{}'.format(port), 1)
cluster_monitor = ClusterMonitor('localhost:{}'.format(port))
cluster_addr = 'localhost:{}'.format(port)
worker = Worker(cluster_addr, 1)
cluster_monitor = ClusterMonitor(cluster_addr)
time.sleep(5)
parl.connect('localhost:{}'.format(port))
parl.connect(cluster_addr)
actor = Actor()
time.sleep(20)
self.assertEqual(1, cluster_monitor.data['clients'][0]['actor_num'])
del actor
time.sleep(10)
p = Process(target=self.actor)
p = Process(target=self.actor, args=(cluster_addr, ))
p.start()
for _ in range(6):
......
......@@ -22,7 +22,6 @@ import time
import threading
from parl.remote.client import disconnect
from parl.remote import exceptions
import timeout_decorator
import subprocess
......
......@@ -22,7 +22,6 @@ import time
import threading
from parl.remote.client import disconnect
from parl.remote import exceptions
import timeout_decorator
import subprocess
......
......@@ -22,7 +22,6 @@ import time
import threading
from parl.remote.client import disconnect
from parl.remote import exceptions
import timeout_decorator
import subprocess
......
......@@ -21,7 +21,6 @@ import time
import threading
from parl.remote.client import disconnect
from parl.remote import exceptions
import timeout_decorator
import subprocess
......@@ -89,7 +88,6 @@ class TestCluster(unittest.TestCase):
master.exit()
worker1.exit()
@timeout_decorator.timeout(seconds=800)
def test_actor_exception(self):
master = Master(port=1236)
th = threading.Thread(target=master.run)
......
......@@ -16,7 +16,6 @@ import unittest
import parl
import time
import threading
import timeout_decorator
import multiprocessing
from parl.remote.master import Master
......@@ -39,12 +38,14 @@ class TestCluster(unittest.TestCase):
def tearDown(self):
disconnect()
def _connect_and_create_actor(self, cluster_addr):
#In windows, multiprocessing.Process cannot run the method of class, but static method is ok.
@staticmethod
def _connect_and_create_actor(cluster_addr):
parl.connect(cluster_addr)
for _ in range(2):
actor = Actor()
ret = actor.add_one(1)
self.assertEqual(ret, 2)
assert ret == 2
disconnect()
def _create_actor(self):
......@@ -53,7 +54,6 @@ class TestCluster(unittest.TestCase):
ret = actor.add_one(1)
self.assertEqual(ret, 2)
@timeout_decorator.timeout(seconds=300)
def test_connect_and_create_actor_in_multiprocessing_with_connected_in_main_process(
self):
# start the master
......
......@@ -16,7 +16,6 @@ import unittest
import parl
import time
import threading
import timeout_decorator
import multiprocessing
from parl.remote.master import Master
......@@ -39,12 +38,14 @@ class TestCluster(unittest.TestCase):
def tearDown(self):
disconnect()
def _connect_and_create_actor(self, cluster_addr):
#In windows, multiprocessing.Process cannot run the method of class, but static method is ok.
@staticmethod
def _connect_and_create_actor(cluster_addr):
parl.connect(cluster_addr)
for _ in range(2):
actor = Actor()
ret = actor.add_one(1)
self.assertEqual(ret, 2)
assert ret == 2
disconnect()
def _create_actor(self):
......@@ -53,7 +54,6 @@ class TestCluster(unittest.TestCase):
ret = actor.add_one(1)
self.assertEqual(ret, 2)
@timeout_decorator.timeout(seconds=300)
def test_connect_and_create_actor_in_multiprocessing_without_connected_in_main_process(
self):
# start the master
......
......@@ -4,8 +4,7 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# http://www.apache.org/licenses/LICENSE-2.0 #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -16,12 +15,12 @@ import unittest
import parl
import time
import threading
import timeout_decorator
import multiprocessing
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.client import disconnect
from parl.utils import _IS_WINDOWS
@parl.remote_class
......@@ -39,21 +38,14 @@ class TestCluster(unittest.TestCase):
def tearDown(self):
disconnect()
def _connect_and_create_actor(self, cluster_addr):
parl.connect(cluster_addr)
for _ in range(2):
actor = Actor()
ret = actor.add_one(1)
self.assertEqual(ret, 2)
disconnect()
def _create_actor(self):
#In windows, multiprocessing.Process cannot run the method of class, but static method is ok.
@staticmethod
def _create_actor():
for _ in range(2):
actor = Actor()
ret = actor.add_one(1)
self.assertEqual(ret, 2)
assert ret == 2
@timeout_decorator.timeout(seconds=300)
def test_create_actor_in_multiprocessing(self):
# start the master
master = Master(port=8240)
......@@ -64,14 +56,15 @@ class TestCluster(unittest.TestCase):
worker1 = Worker('localhost:8240', 4)
parl.connect('localhost:8240')
proc1 = multiprocessing.Process(target=self._create_actor)
proc2 = multiprocessing.Process(target=self._create_actor)
proc1.start()
proc2.start()
if not _IS_WINDOWS: # In windows, fork process cannot access client created in main process.
proc1 = multiprocessing.Process(target=self._create_actor)
proc2 = multiprocessing.Process(target=self._create_actor)
proc1.start()
proc2.start()
proc1.join()
proc2.join()
print("[test_create_actor_in_multiprocessing] Join")
proc1.join()
proc2.join()
print("[test_create_actor_in_multiprocessing] Join")
# make sure that the client of the main process still works
self._create_actor()
......
......@@ -23,7 +23,6 @@ import time
import threading
import subprocess
import sys
import timeout_decorator
@parl.remote_class
......@@ -63,7 +62,6 @@ class TestJob(unittest.TestCase):
def tearDown(self):
disconnect()
@timeout_decorator.timeout(seconds=600)
def test_acor_exit_exceptionally(self):
port = 1337
master = Master(port)
......
......@@ -16,7 +16,8 @@ import parl
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.client import disconnect
from parl.utils import logger
from parl.utils import logger, _IS_WINDOWS
import os
import threading
import time
import subprocess
......@@ -70,9 +71,14 @@ class TestJobAlone(unittest.TestCase):
time.sleep(1)
self.assertEqual(master.cpu_num, 4)
print("We are going to kill all the jobs.")
command = (
"ps aux | grep remote/job.py | awk '{print $2}' | xargs kill -9")
subprocess.call([command], shell=True)
if _IS_WINDOWS:
command = r'''for /F "skip=2 tokens=2 delims=," %a in ('wmic process where "commandline like '%remote\\job.py%'" get processid^,status /format:csv') do taskkill /F /T /pid %a'''
print(os.popen(command).read())
else:
command = (
"ps aux | grep remote/job.py | awk '{print $2}' | xargs kill -9"
)
subprocess.call([command], shell=True)
parl.connect('localhost:1334')
actor = Actor()
self.assertEqual(actor.add_one(1), 2)
......
......@@ -21,6 +21,7 @@ import threading
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.client import disconnect
from parl.utils import _IS_WINDOWS
@parl.remote_class
......@@ -44,12 +45,15 @@ class TestSendFile(unittest.TestCase):
worker = Worker('localhost:{}'.format(port), 1)
time.sleep(2)
os.system('mkdir ./rom_files')
os.system('touch ./rom_files/pong.bin')
assert os.path.exists('./rom_files/pong.bin')
parl.connect(
'localhost:{}'.format(port),
distributed_files=['./rom_files/pong.bin'])
tmp_dir = 'rom_files'
tmp_file = os.path.join(tmp_dir, 'pong.bin')
os.system('mkdir {}'.format(tmp_dir))
if _IS_WINDOWS:
os.system('type NUL >> {}'.format(tmp_file))
else:
os.system('touch {}'.format(tmp_file))
assert os.path.exists(tmp_file)
parl.connect('localhost:{}'.format(port), distributed_files=[tmp_file])
time.sleep(5)
actor = Actor()
for _ in range(10):
......@@ -70,8 +74,9 @@ class TestSendFile(unittest.TestCase):
worker = Worker('localhost:{}'.format(port), 1)
time.sleep(2)
tmp_file = os.path.join('rom_files', 'no_pong.bin')
self.assertRaises(Exception, parl.connect, 'localhost:{}'.format(port),
['./rom_files/no_pong.bin'])
[tmp_file])
worker.exit()
master.exit()
......
......@@ -26,7 +26,7 @@ import warnings
import zmq
from datetime import datetime
from parl.utils import get_ip_address, to_byte, to_str, logger
from parl.utils import get_ip_address, to_byte, to_str, logger, _IS_WINDOWS
from parl.remote import remote_constants
from parl.remote.message import InitializedWorker
from parl.remote.status import WorkerStatus
......@@ -311,7 +311,10 @@ class Worker(object):
total_memory = round(virtual_memory[0] / (1024**3), 2)
used_memory = round(virtual_memory[3] / (1024**3), 2)
vacant_memory = round(total_memory - used_memory, 2)
load_average = round(os.getloadavg()[0], 2)
if _IS_WINDOWS:
load_average = round(psutil.getloadavg()[0], 2)
else:
load_average = round(os.getloadavg()[0], 2)
return (vacant_memory, used_memory, now, load_average)
def _reply_heartbeat(self, target):
......@@ -329,7 +332,7 @@ class Worker(object):
logger.set_dir(
os.path.expanduser('~/.parl_data/worker/{}'.format(
self.master_heartbeat_address)))
self.master_heartbeat_address.replace(':', '_'))))
self.heartbeat_socket_initialized.set()
logger.info("[Worker] Connect to the master node successfully. "
......
......@@ -14,9 +14,9 @@
import os
import platform
import socket
import subprocess
from parl.utils import logger
from parl.utils import utils
from parl.utils import logger, _HAS_FLUID, _IS_WINDOWS
__all__ = ['get_gpu_count', 'get_ip_address', 'is_gpu_available']
......@@ -25,29 +25,25 @@ def get_ip_address():
"""
get the IP address of the host.
"""
platform_sys = platform.system()
# Only support Linux and MacOS
if platform_sys != 'Linux' and platform_sys != 'Darwin':
logger.warning(
'get_ip_address only support Linux and MacOS, please set ip address manually.'
)
return None
local_ip = None
import socket
try:
# First way, tested in Ubuntu and MacOS
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
local_ip = s.getsockname()[0]
s.close()
except:
# Second way, tested in CentOS
# Windows
if _IS_WINDOWS:
local_ip = socket.gethostbyname(socket.gethostname())
else:
# Linux and MacOS
local_ip = None
try:
local_ip = socket.gethostbyname(socket.gethostname())
# First way, tested in Ubuntu and MacOS
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
local_ip = s.getsockname()[0]
s.close()
except:
pass
# Second way, tested in CentOS
try:
local_ip = socket.gethostbyname(socket.gethostname())
except:
pass
if local_ip == None or local_ip == '127.0.0.1' or local_ip == '127.0.1.1':
logger.warning(
......@@ -97,7 +93,7 @@ def is_gpu_available():
True if a gpu device can be found.
"""
ret = get_gpu_count() > 0
if utils._HAS_FLUID:
if _HAS_FLUID:
from paddle import fluid
if ret is True and not fluid.is_compiled_with_cuda():
logger.warning("Found non-empty CUDA_VISIBLE_DEVICES. \
......
......@@ -16,7 +16,7 @@ import sys
__all__ = [
'has_func', 'action_mapping', 'to_str', 'to_byte', 'is_PY2', 'is_PY3',
'MAX_INT32', '_HAS_FLUID', '_HAS_TORCH', '_IS_MAC'
'MAX_INT32', '_HAS_FLUID', '_HAS_TORCH', '_IS_WINDOWS', '_IS_MAC'
]
......@@ -93,4 +93,5 @@ try:
except ImportError:
_HAS_TORCH = False
_IS_WINDOWS = (sys.platform == 'win32')
_IS_MAC = (sys.platform == 'darwin')
......@@ -31,7 +31,12 @@ def _find_packages(prefix=''):
prefix = prefix
for root, _, files in os.walk(path):
if '__init__.py' in files:
packages.append(re.sub('^[^A-z0-9_]', '', root.replace('/', '.')))
if sys.platform == 'win32':
packages.append(
re.sub('^[^A-z0-9_]', '', root.replace('\\', '.')))
else:
packages.append(
re.sub('^[^A-z0-9_]', '', root.replace('/', '.')))
return packages
......@@ -74,7 +79,7 @@ setup(
"tb-nightly==1.15.0a20190801",
"flask==1.0.4",
"click",
"psutil",
"psutil>=5.6.2",
],
classifiers=[
'Intended Audience :: Developers',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册