diff --git a/parl/remote/client.py b/parl/remote/client.py index 946493421b58e38f4e19ff543a820e03aedea3b2..9d8f30b48adb2c28eaceb4d65a14bac76b1f0a1f 100644 --- a/parl/remote/client.py +++ b/parl/remote/client.py @@ -22,6 +22,7 @@ import zmq from parl.utils import to_str, to_byte, get_ip_address, logger, isnotebook from parl.remote import remote_constants import time +import glob class Client(object): @@ -84,12 +85,28 @@ class Client(object): Args: distributed_files (list): A list of files to be distributed at all remote instances(e,g. the configuration - file for initialization) . - + file for initialization) . RegExp of file + names is supported. + e.g. + distributed_files = ['./*.npy', './test*'] + Returns: A cloudpickled dictionary containing the python code in current working directory. """ + + parsed_distributed_files = set() + for distributed_file in distributed_files: + parsed_list = glob.glob(distributed_file) + if not parsed_list: + raise ValueError( + "no local file is matched with '{}', please check your input" + .format(distributed_file)) + # exclude the directiories + for pathname in parsed_list: + if not os.path.isdir(pathname): + parsed_distributed_files.add(pathname) + pyfiles = dict() pyfiles['python_files'] = {} pyfiles['other_files'] = {} @@ -112,7 +129,7 @@ class Client(object): code = code_file.read() pyfiles['python_files'][file_name] = code - for file_name in distributed_files: + for file_name in parsed_distributed_files: assert os.path.exists(file_name) assert not os.path.isabs( file_name diff --git a/parl/remote/tests/support_RegExp_test.py b/parl/remote/tests/support_RegExp_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c51c3526262c3bd281c94c9e4f2fd55ca4776bec --- /dev/null +++ b/parl/remote/tests/support_RegExp_test.py @@ -0,0 +1,99 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import shutil +import parl +from parl.remote.master import Master +from parl.remote.worker import Worker +import time +import threading +from parl.remote.client import disconnect +from parl.remote import exceptions +from parl.utils import logger + + +@parl.remote_class +class Actor(object): + def file_exists(self, filename): + return os.path.exists(filename) + + +class TestCluster(unittest.TestCase): + def tearDown(self): + disconnect() + + def test_distributed_files_with_RegExp(self): + if os.path.exists('distribute_test_dir'): + shutil.rmtree('distribute_test_dir') + os.mkdir('distribute_test_dir') + f = open('distribute_test_dir/test1.txt', 'wb') + f.close() + f = open('distribute_test_dir/test2.txt', 'wb') + f.close() + f = open('distribute_test_dir/data1.npy', 'wb') + f.close() + f = open('distribute_test_dir/data2.npy', 'wb') + f.close() + logger.info("running:test_distributed_files_with_RegExp") + master = Master(port=8605) + th = threading.Thread(target=master.run) + th.start() + time.sleep(3) + worker1 = Worker('localhost:8605', 1) + parl.connect( + 'localhost:8605', + distributed_files=[ + 'distribute_test_dir/test*', + 'distribute_test_dir/*npy', + ]) + actor = Actor() + self.assertTrue(actor.file_exists('distribute_test_dir/test1.txt')) + self.assertTrue(actor.file_exists('distribute_test_dir/test2.txt')) + self.assertTrue(actor.file_exists('distribute_test_dir/data1.npy')) + self.assertTrue(actor.file_exists('distribute_test_dir/data2.npy')) + self.assertFalse(actor.file_exists('distribute_test_dir/data3.npy')) + shutil.rmtree('distribute_test_dir') + master.exit() + worker1.exit() + + def test_miss_match_case(self): + if os.path.exists('distribute_test_dir_2'): + shutil.rmtree('distribute_test_dir_2') + os.mkdir('distribute_test_dir_2') + f = open('distribute_test_dir_2/test1.txt', 'wb') + f.close() + f = open('distribute_test_dir_2/data1.npy', 'wb') + f.close() + logger.info("running:test_distributed_files_with_RegExp_error_case") + master = Master(port=8606) + th = threading.Thread(target=master.run) + th.start() + time.sleep(3) + worker1 = Worker('localhost:8606', 1) + + def connect_test(): + parl.connect( + 'localhost:8606', + distributed_files=['distribute_test_dir_2/miss_match*']) + + self.assertRaises(ValueError, connect_test) + shutil.rmtree('distribute_test_dir_2') + master.exit() + worker1.exit() + + +if __name__ == '__main__': + unittest.main()