未验证 提交 ec77a4ba 编写于 作者: Y Yuecheng Liu 提交者: GitHub

support RegExp input of distributed_files list (#377)

* add support of RegExp input of distributed_files[]

* add support of RegExp input of distributed_files[]

* add RegExp input of distributed files (together with unittest)

* support RegExp input of distributed files (failing case unittest added)

* support RegExp input of distributed files, update setup.py

* support RegExp input of distributed files

* add RegExp support for distributed files input

* add more test cases
上级 74d4facb
......@@ -22,6 +22,7 @@ import zmq
from parl.utils import to_str, to_byte, get_ip_address, logger, isnotebook
from parl.remote import remote_constants
import time
import glob
class Client(object):
......@@ -84,12 +85,28 @@ class Client(object):
Args:
distributed_files (list): A list of files to be distributed at all
remote instances(e,g. the configuration
file for initialization) .
file for initialization) . RegExp of file
names is supported.
e.g.
distributed_files = ['./*.npy', './test*']
Returns:
A cloudpickled dictionary containing the python code in current
working directory.
"""
parsed_distributed_files = set()
for distributed_file in distributed_files:
parsed_list = glob.glob(distributed_file)
if not parsed_list:
raise ValueError(
"no local file is matched with '{}', please check your input"
.format(distributed_file))
# exclude the directiories
for pathname in parsed_list:
if not os.path.isdir(pathname):
parsed_distributed_files.add(pathname)
pyfiles = dict()
pyfiles['python_files'] = {}
pyfiles['other_files'] = {}
......@@ -112,7 +129,7 @@ class Client(object):
code = code_file.read()
pyfiles['python_files'][file_name] = code
for file_name in distributed_files:
for file_name in parsed_distributed_files:
assert os.path.exists(file_name)
assert not os.path.isabs(
file_name
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import shutil
import parl
from parl.remote.master import Master
from parl.remote.worker import Worker
import time
import threading
from parl.remote.client import disconnect
from parl.remote import exceptions
from parl.utils import logger
@parl.remote_class
class Actor(object):
def file_exists(self, filename):
return os.path.exists(filename)
class TestCluster(unittest.TestCase):
def tearDown(self):
disconnect()
def test_distributed_files_with_RegExp(self):
if os.path.exists('distribute_test_dir'):
shutil.rmtree('distribute_test_dir')
os.mkdir('distribute_test_dir')
f = open('distribute_test_dir/test1.txt', 'wb')
f.close()
f = open('distribute_test_dir/test2.txt', 'wb')
f.close()
f = open('distribute_test_dir/data1.npy', 'wb')
f.close()
f = open('distribute_test_dir/data2.npy', 'wb')
f.close()
logger.info("running:test_distributed_files_with_RegExp")
master = Master(port=8605)
th = threading.Thread(target=master.run)
th.start()
time.sleep(3)
worker1 = Worker('localhost:8605', 1)
parl.connect(
'localhost:8605',
distributed_files=[
'distribute_test_dir/test*',
'distribute_test_dir/*npy',
])
actor = Actor()
self.assertTrue(actor.file_exists('distribute_test_dir/test1.txt'))
self.assertTrue(actor.file_exists('distribute_test_dir/test2.txt'))
self.assertTrue(actor.file_exists('distribute_test_dir/data1.npy'))
self.assertTrue(actor.file_exists('distribute_test_dir/data2.npy'))
self.assertFalse(actor.file_exists('distribute_test_dir/data3.npy'))
shutil.rmtree('distribute_test_dir')
master.exit()
worker1.exit()
def test_miss_match_case(self):
if os.path.exists('distribute_test_dir_2'):
shutil.rmtree('distribute_test_dir_2')
os.mkdir('distribute_test_dir_2')
f = open('distribute_test_dir_2/test1.txt', 'wb')
f.close()
f = open('distribute_test_dir_2/data1.npy', 'wb')
f.close()
logger.info("running:test_distributed_files_with_RegExp_error_case")
master = Master(port=8606)
th = threading.Thread(target=master.run)
th.start()
time.sleep(3)
worker1 = Worker('localhost:8606', 1)
def connect_test():
parl.connect(
'localhost:8606',
distributed_files=['distribute_test_dir_2/miss_match*'])
self.assertRaises(ValueError, connect_test)
shutil.rmtree('distribute_test_dir_2')
master.exit()
worker1.exit()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册