未验证 提交 73672709 编写于 作者: Z Zheyue Tan 提交者: GitHub

redirect stdout to log file while initializing the remote actor instance (#294)

* redirect stdout to log file while initializing the remote actor instance

* add test for catching output in  `Actor.__init__`
上级 8c7f1922
...@@ -36,7 +36,7 @@ from parl.utils.communication import loads_argument, loads_return,\ ...@@ -36,7 +36,7 @@ from parl.utils.communication import loads_argument, loads_return,\
from parl.remote import remote_constants from parl.remote import remote_constants
from parl.utils.exceptions import SerializeError, DeserializeError from parl.utils.exceptions import SerializeError, DeserializeError
from parl.remote.message import InitializedJob from parl.remote.message import InitializedJob
from parl.remote.utils import load_remote_class from parl.remote.utils import load_remote_class, redirect_stdout_to_file
class Job(object): class Job(object):
...@@ -315,7 +315,9 @@ class Job(object): ...@@ -315,7 +315,9 @@ class Job(object):
file_name = file_name.split(os.sep)[-1] file_name = file_name.split(os.sep)[-1]
cls = load_remote_class(file_name, class_name, end_of_file) cls = load_remote_class(file_name, class_name, end_of_file)
args, kwargs = cloudpickle.loads(message[2]) args, kwargs = cloudpickle.loads(message[2])
obj = cls(*args, **kwargs) logfile_path = os.path.join(self.log_dir, 'stdout.log')
with redirect_stdout_to_file(logfile_path):
obj = cls(*args, **kwargs)
except Exception as e: except Exception as e:
traceback_str = str(traceback.format_exc()) traceback_str = str(traceback.format_exc())
error_str = str(e) error_str = str(e)
...@@ -406,11 +408,8 @@ class Job(object): ...@@ -406,11 +408,8 @@ class Job(object):
# Redirect stdout to stdout.log temporarily # Redirect stdout to stdout.log temporarily
logfile_path = os.path.join(self.log_dir, 'stdout.log') logfile_path = os.path.join(self.log_dir, 'stdout.log')
with open(logfile_path, 'a') as f: with redirect_stdout_to_file(logfile_path):
tmp = sys.stdout
sys.stdout = f
ret = getattr(obj, function_name)(*args, **kwargs) ret = getattr(obj, function_name)(*args, **kwargs)
sys.stdout = tmp
ret = dumps_return(ret) ret = dumps_return(ret)
......
...@@ -29,7 +29,7 @@ import parl ...@@ -29,7 +29,7 @@ import parl
from parl.remote.client import disconnect, get_global_client from parl.remote.client import disconnect, get_global_client
from parl.remote.master import Master from parl.remote.master import Master
from parl.remote.worker import Worker from parl.remote.worker import Worker
from parl.utils import _IS_WINDOWS, get_free_tcp_port from parl.utils import _IS_WINDOWS
@parl.remote_class @parl.remote_class
...@@ -38,6 +38,8 @@ class Actor(object): ...@@ -38,6 +38,8 @@ class Actor(object):
self.number = number self.number = number
self.arg1 = arg1 self.arg1 = arg1
self.arg2 = arg2 self.arg2 = arg2
print("Init actor...")
self.init_output = "Init actor...\n"
def sim_output(self, start, end): def sim_output(self, start, end):
output = "" output = ""
...@@ -48,7 +50,7 @@ class Actor(object): ...@@ -48,7 +50,7 @@ class Actor(object):
print(i) print(i)
output += str(i) output += str(i)
output += "\n" output += "\n"
return output return self.init_output + output
class TestLogServer(unittest.TestCase): class TestLogServer(unittest.TestCase):
......
...@@ -11,8 +11,10 @@ ...@@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
from contextlib import contextmanager
__all__ = ['load_remote_class'] __all__ = ['load_remote_class', 'redirect_stdout_to_file']
def simplify_code(code, end_of_file): def simplify_code(code, end_of_file):
...@@ -66,3 +68,29 @@ def load_remote_class(file_name, class_name, end_of_file): ...@@ -66,3 +68,29 @@ def load_remote_class(file_name, class_name, end_of_file):
mod = __import__(module_name) mod = __import__(module_name)
cls = getattr(mod, class_name) cls = getattr(mod, class_name)
return cls return cls
@contextmanager
def redirect_stdout_to_file(file_path):
"""Redirect stdout (e.g., `print`) to specified file.
Example:
>>> print('test')
test
>>> with redirect_stdout_to_file('test.log'):
... print('test') # Output nothing, `test` is printed to `test.log`.
>>> print('test')
test
Args:
file_path: Path of the file to output the stdout.
"""
tmp = sys.stdout
f = open(file_path, 'a')
sys.stdout = f
try:
yield
finally:
sys.stdout = tmp
f.close()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册