From 736727099056fe6a385fe0e5a38831f0b85a3896 Mon Sep 17 00:00:00 2001 From: Zheyue Tan Date: Thu, 11 Jun 2020 13:08:55 +0800 Subject: [PATCH] redirect stdout to log file while initializing the remote actor instance (#294) * redirect stdout to log file while initializing the remote actor instance * add test for catching output in `Actor.__init__` --- parl/remote/job.py | 11 +++++----- parl/remote/tests/log_server_test.py | 6 ++++-- parl/remote/utils.py | 30 +++++++++++++++++++++++++++- 3 files changed, 38 insertions(+), 9 deletions(-) diff --git a/parl/remote/job.py b/parl/remote/job.py index 794496d..d835e53 100644 --- a/parl/remote/job.py +++ b/parl/remote/job.py @@ -36,7 +36,7 @@ from parl.utils.communication import loads_argument, loads_return,\ from parl.remote import remote_constants from parl.utils.exceptions import SerializeError, DeserializeError from parl.remote.message import InitializedJob -from parl.remote.utils import load_remote_class +from parl.remote.utils import load_remote_class, redirect_stdout_to_file class Job(object): @@ -315,7 +315,9 @@ class Job(object): file_name = file_name.split(os.sep)[-1] cls = load_remote_class(file_name, class_name, end_of_file) args, kwargs = cloudpickle.loads(message[2]) - obj = cls(*args, **kwargs) + logfile_path = os.path.join(self.log_dir, 'stdout.log') + with redirect_stdout_to_file(logfile_path): + obj = cls(*args, **kwargs) except Exception as e: traceback_str = str(traceback.format_exc()) error_str = str(e) @@ -406,11 +408,8 @@ class Job(object): # Redirect stdout to stdout.log temporarily logfile_path = os.path.join(self.log_dir, 'stdout.log') - with open(logfile_path, 'a') as f: - tmp = sys.stdout - sys.stdout = f + with redirect_stdout_to_file(logfile_path): ret = getattr(obj, function_name)(*args, **kwargs) - sys.stdout = tmp ret = dumps_return(ret) diff --git a/parl/remote/tests/log_server_test.py b/parl/remote/tests/log_server_test.py index 6b6aad4..931fc29 100644 --- a/parl/remote/tests/log_server_test.py +++ b/parl/remote/tests/log_server_test.py @@ -29,7 +29,7 @@ import parl from parl.remote.client import disconnect, get_global_client from parl.remote.master import Master from parl.remote.worker import Worker -from parl.utils import _IS_WINDOWS, get_free_tcp_port +from parl.utils import _IS_WINDOWS @parl.remote_class @@ -38,6 +38,8 @@ class Actor(object): self.number = number self.arg1 = arg1 self.arg2 = arg2 + print("Init actor...") + self.init_output = "Init actor...\n" def sim_output(self, start, end): output = "" @@ -48,7 +50,7 @@ class Actor(object): print(i) output += str(i) output += "\n" - return output + return self.init_output + output class TestLogServer(unittest.TestCase): diff --git a/parl/remote/utils.py b/parl/remote/utils.py index 2cd36e5..9a2ece8 100644 --- a/parl/remote/utils.py +++ b/parl/remote/utils.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys +from contextlib import contextmanager -__all__ = ['load_remote_class'] +__all__ = ['load_remote_class', 'redirect_stdout_to_file'] def simplify_code(code, end_of_file): @@ -66,3 +68,29 @@ def load_remote_class(file_name, class_name, end_of_file): mod = __import__(module_name) cls = getattr(mod, class_name) return cls + + +@contextmanager +def redirect_stdout_to_file(file_path): + """Redirect stdout (e.g., `print`) to specified file. + + Example: + >>> print('test') + test + >>> with redirect_stdout_to_file('test.log'): + ... print('test') # Output nothing, `test` is printed to `test.log`. + >>> print('test') + test + + Args: + file_path: Path of the file to output the stdout. + + """ + tmp = sys.stdout + f = open(file_path, 'a') + sys.stdout = f + try: + yield + finally: + sys.stdout = tmp + f.close() -- GitLab