From b818e5ed5279b7d5783da0c20161109b3f9089bb Mon Sep 17 00:00:00 2001 From: TomorrowIsAnOtherDay <2466956298@qq.com> Date: Tue, 9 Jun 2020 19:49:20 +0800 Subject: [PATCH] fix the bug of distributing files --- parl/remote/job.py | 8 +++++--- parl/remote/tests/recursive_actor_test.py | 9 +++++---- parl/remote/tests/sync_config_file_test.py | 7 +++---- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/parl/remote/job.py b/parl/remote/job.py index e5c6965..17a280f 100644 --- a/parl/remote/job.py +++ b/parl/remote/job.py @@ -269,13 +269,15 @@ class Job(object): # create directory (i.e. ./rom_files/) if '/' in file: try: - os.makedirs(os.path.join(*file.rsplit('/')[:-1])) + sep = os.sep + recursive_dirs = os.path.join(*(file.split(sep)[:-1])) + recursive_dirs = os.path.join(envdir, recursive_dirs) + os.makedirs(recursive_dirs) except OSError as e: pass file = os.path.join(envdir, file) with open(file, 'wb') as f: f.write(content) - logger.info('[job] reply') reply_socket.send_multipart([remote_constants.NORMAL_TAG]) return envdir else: @@ -358,7 +360,7 @@ class Job(object): self.single_task(obj, reply_socket, job_address) except Exception as e: logger.error( - "Error occurs when running a single task. We will reset this job. Reason:{}" + "Error occurs when running a single task. We will reset this job. \nReason:{}" .format(e)) traceback_str = str(traceback.format_exc()) logger.error("traceback:\n{}".format(traceback_str)) diff --git a/parl/remote/tests/recursive_actor_test.py b/parl/remote/tests/recursive_actor_test.py index 2d0665c..5e9613b 100644 --- a/parl/remote/tests/recursive_actor_test.py +++ b/parl/remote/tests/recursive_actor_test.py @@ -22,10 +22,11 @@ import threading c = 10 port = 3002 -master = Master(port=port) -th = threading.Thread(target=master.run) -th.setDaemon(True) -th.start() +if __name__ == '__main__': + master = Master(port=port) + th = threading.Thread(target=master.run) + th.setDaemon(True) + th.start() time.sleep(5) cluster_addr = 'localhost:{}'.format(port) parl.connect(cluster_addr) diff --git a/parl/remote/tests/sync_config_file_test.py b/parl/remote/tests/sync_config_file_test.py index a4d131d..c8be194 100644 --- a/parl/remote/tests/sync_config_file_test.py +++ b/parl/remote/tests/sync_config_file_test.py @@ -17,12 +17,10 @@ import parl from parl.remote.master import Master from parl.remote.worker import Worker from parl.remote.client import disconnect - +import os import time import threading - import sys - import numpy as np import json @@ -65,7 +63,8 @@ class TestConfigfile(unittest.TestCase): parl.connect('localhost:1335', ['random.npy', 'config.json']) actor = Actor('random.npy', 'config.json') time.sleep(5) - + os.remove('./random.npy') + os.remove('./config.json') remote_sum = actor.random_sum() self.assertEqual(remote_sum, random_sum) time.sleep(10) -- GitLab