未验证 提交 dd812118 编写于 作者: Y Yuecheng Liu 提交者: GitHub

support to distribute a folder alongwith its subfiles/subfolders (#398)

* support to distribute a folder alongwith its subfiles/subfolders

* job_shutshown check

* remove outdated annotation

* len(dirs)

* yapf

* add get_subfiles_recursively function to parl.utils

* hmm...

* modify a little

* rewrite

* modify
上级 1cbcfb15
...@@ -21,6 +21,7 @@ import threading ...@@ -21,6 +21,7 @@ import threading
import zmq import zmq
import parl import parl
from parl.utils import to_str, to_byte, get_ip_address, logger, isnotebook from parl.utils import to_str, to_byte, get_ip_address, logger, isnotebook
from parl.remote.utils import get_subfiles_recursively
from parl.remote import remote_constants from parl.remote import remote_constants
import time import time
import glob import glob
...@@ -96,22 +97,28 @@ class Client(object): ...@@ -96,22 +97,28 @@ class Client(object):
A cloudpickled dictionary containing the python code in current A cloudpickled dictionary containing the python code in current
working directory. working directory.
""" """
pyfiles = dict()
pyfiles['python_files'] = {}
pyfiles['other_files'] = {}
user_files = []
user_empty_subfolders = []
parsed_distributed_files = set()
for distributed_file in distributed_files: for distributed_file in distributed_files:
parsed_list = glob.glob(distributed_file) parsed_list = glob.glob(distributed_file)
if not parsed_list: if not parsed_list:
raise ValueError( raise ValueError(
"no local file is matched with '{}', please check your input" "no local file is matched with '{}', please check your input"
.format(distributed_file)) .format(distributed_file))
# exclude the directiories
for pathname in parsed_list: for pathname in parsed_list:
if not os.path.isdir(pathname): if os.path.isdir(pathname):
parsed_distributed_files.add(pathname) pythonfiles, otherfiles, emptysubfolders = get_subfiles_recursively(
pathname)
pyfiles = dict() user_files.extend(pythonfiles)
pyfiles['python_files'] = {} user_files.extend(otherfiles)
pyfiles['other_files'] = {} user_empty_subfolders.extend(emptysubfolders)
else:
user_files.append(pathname)
if isnotebook(): if isnotebook():
main_folder = './' main_folder = './'
...@@ -131,7 +138,7 @@ class Client(object): ...@@ -131,7 +138,7 @@ class Client(object):
code = code_file.read() code = code_file.read()
pyfiles['python_files'][file_name] = code pyfiles['python_files'][file_name] = code
for file_name in parsed_distributed_files: for file_name in set(user_files):
assert os.path.exists(file_name) assert os.path.exists(file_name)
assert not os.path.isabs( assert not os.path.isabs(
file_name file_name
...@@ -139,6 +146,8 @@ class Client(object): ...@@ -139,6 +146,8 @@ class Client(object):
with open(file_name, 'rb') as f: with open(file_name, 'rb') as f:
content = f.read() content = f.read()
pyfiles['other_files'][file_name] = content pyfiles['other_files'][file_name] = content
pyfiles['empty_subfolders'] = set(user_empty_subfolders)
return cloudpickle.dumps(pyfiles) return cloudpickle.dumps(pyfiles)
def _create_sockets(self, master_address): def _create_sockets(self, master_address):
......
...@@ -259,8 +259,14 @@ class Job(object): ...@@ -259,8 +259,14 @@ class Job(object):
tag = message[0] tag = message[0]
if tag == remote_constants.SEND_FILE_TAG: if tag == remote_constants.SEND_FILE_TAG:
pyfiles = pickle.loads(message[1]) pyfiles = pickle.loads(message[1])
# save python files to temporary directory
envdir = tempfile.mkdtemp() envdir = tempfile.mkdtemp()
for empty_subfolder in pyfiles['empty_subfolders']:
empty_subfolder_path = os.path.join(envdir, empty_subfolder)
if not os.path.exists(empty_subfolder_path):
os.makedirs(empty_subfolder_path)
# save python files to temporary directory
for file, code in pyfiles['python_files'].items(): for file, code in pyfiles['python_files'].items():
file = os.path.join(envdir, file) file = os.path.join(envdir, file)
with open(file, 'wb') as code_file: with open(file, 'wb') as code_file:
......
...@@ -217,6 +217,9 @@ def remote_class(*args, **kwargs): ...@@ -217,6 +217,9 @@ def remote_class(*args, **kwargs):
is_attribute = attr in self.remote_attribute_keys_set is_attribute = attr in self.remote_attribute_keys_set
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if self.job_shutdown:
raise RemoteError(
attr, "This actor losts connection with the job.")
self.internal_lock.acquire() self.internal_lock.acquire()
if is_attribute: if is_attribute:
self.job_socket.send_multipart([ self.job_socket.send_multipart([
...@@ -224,10 +227,6 @@ def remote_class(*args, **kwargs): ...@@ -224,10 +227,6 @@ def remote_class(*args, **kwargs):
to_byte(attr) to_byte(attr)
]) ])
else: else:
if self.job_shutdown:
raise RemoteError(
attr,
"This actor losts connection with the job.")
data = dumps_argument(*args, **kwargs) data = dumps_argument(*args, **kwargs)
self.job_socket.send_multipart( self.job_socket.send_multipart(
[remote_constants.CALL_TAG, [remote_constants.CALL_TAG,
......
...@@ -48,13 +48,13 @@ class TestCluster(unittest.TestCase): ...@@ -48,13 +48,13 @@ class TestCluster(unittest.TestCase):
f = open('distribute_test_dir/data2.npy', 'wb') f = open('distribute_test_dir/data2.npy', 'wb')
f.close() f.close()
logger.info("running:test_distributed_files_with_RegExp") logger.info("running:test_distributed_files_with_RegExp")
master = Master(port=8605) master = Master(port=8435)
th = threading.Thread(target=master.run) th = threading.Thread(target=master.run)
th.start() th.start()
time.sleep(3) time.sleep(3)
worker1 = Worker('localhost:8605', 1) worker1 = Worker('localhost:8435', 1)
parl.connect( parl.connect(
'localhost:8605', 'localhost:8435',
distributed_files=[ distributed_files=[
'distribute_test_dir/test*', 'distribute_test_dir/test*',
'distribute_test_dir/*npy', 'distribute_test_dir/*npy',
...@@ -78,15 +78,15 @@ class TestCluster(unittest.TestCase): ...@@ -78,15 +78,15 @@ class TestCluster(unittest.TestCase):
f = open('distribute_test_dir_2/data1.npy', 'wb') f = open('distribute_test_dir_2/data1.npy', 'wb')
f.close() f.close()
logger.info("running:test_distributed_files_with_RegExp_error_case") logger.info("running:test_distributed_files_with_RegExp_error_case")
master = Master(port=8606) master = Master(port=8436)
th = threading.Thread(target=master.run) th = threading.Thread(target=master.run)
th.start() th.start()
time.sleep(3) time.sleep(3)
worker1 = Worker('localhost:8606', 1) worker1 = Worker('localhost:8436', 1)
def connect_test(): def connect_test():
parl.connect( parl.connect(
'localhost:8606', 'localhost:8436',
distributed_files=['distribute_test_dir_2/miss_match*']) distributed_files=['distribute_test_dir_2/miss_match*'])
self.assertRaises(ValueError, connect_test) self.assertRaises(ValueError, connect_test)
...@@ -94,6 +94,39 @@ class TestCluster(unittest.TestCase): ...@@ -94,6 +94,39 @@ class TestCluster(unittest.TestCase):
master.exit() master.exit()
worker1.exit() worker1.exit()
def test_distribute_folder(self):
if os.path.exists('distribute_test_dir_3'):
shutil.rmtree('distribute_test_dir_3')
os.mkdir('distribute_test_dir_3')
os.mkdir('distribute_test_dir_3/subfolder_test')
os.mkdir('distribute_test_dir_3/empty_folder')
f = open('distribute_test_dir_3/subfolder_test/test1.txt', 'wb')
f.close()
f = open('distribute_test_dir_3/subfolder_test/data1.npy', 'wb')
f.close()
logger.info("running:test_distributed_folder")
master = Master(port=8437)
th = threading.Thread(target=master.run)
th.start()
time.sleep(3)
worker1 = Worker('localhost:8437', 1)
parl.connect(
'localhost:8437', distributed_files=[
'distribute_test_dir_3',
])
actor = Actor()
self.assertTrue(
actor.file_exists(
'distribute_test_dir_3/subfolder_test/test1.txt'))
self.assertTrue(
actor.file_exists(
'distribute_test_dir_3/subfolder_test/data1.npy'))
self.assertTrue(
actor.file_exists('distribute_test_dir_3/empty_folder'))
shutil.rmtree('distribute_test_dir_3')
master.exit()
worker1.exit()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -17,7 +17,8 @@ import os ...@@ -17,7 +17,8 @@ import os
from parl.utils import isnotebook from parl.utils import isnotebook
__all__ = [ __all__ = [
'load_remote_class', 'redirect_stdout_to_file', 'locate_remote_file' 'load_remote_class', 'redirect_stdout_to_file', 'locate_remote_file',
'get_subfiles_recursively'
] ]
...@@ -140,3 +141,36 @@ def locate_remote_file(module_path): ...@@ -140,3 +141,36 @@ def locate_remote_file(module_path):
else: else:
relative_module_path = module_path relative_module_path = module_path
return relative_module_path return relative_module_path
def get_subfiles_recursively(folder_path):
'''
Get subfiles under 'folder_path' recursively
Args:
folder_path: A folder(dir) whose subfiles/subfolders will be returned.
Returns:
python_files: A list including subfiles endwith '.py'.
other_files: A list including subfiles not endwith '.py'.
empty_subfolders: A list including empty subfolders.
'''
if not os.path.exists(folder_path):
raise ValueError("Path '{}' don't exist.".format(folder_path))
elif not os.path.isdir(folder_path):
raise ValueError('Input should be a folder, not a file.')
else:
python_files = []
other_files = []
empty_subfolders = []
for root, dirs, files in os.walk(folder_path):
if files:
for sub_file in files:
if sub_file.endswith('.py'):
python_files.append(
os.path.normpath(os.path.join(root, sub_file)))
else:
other_files.append(
os.path.normpath(os.path.join(root, sub_file)))
elif len(dirs) == 0:
empty_subfolders.append(os.path.normpath(root))
return python_files, other_files, empty_subfolders
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册