未验证 提交 74d4facb 编写于 作者: B Bo Zhou 提交者: GitHub

support jupyter (#385)

* support jupyter

* Update utils.py

* Update utils.py
上级 d6e82f01
...@@ -19,7 +19,7 @@ import socket ...@@ -19,7 +19,7 @@ import socket
import sys import sys
import threading import threading
import zmq import zmq
from parl.utils import to_str, to_byte, get_ip_address, logger from parl.utils import to_str, to_byte, get_ip_address, logger, isnotebook
from parl.remote import remote_constants from parl.remote import remote_constants
import time import time
...@@ -94,11 +94,14 @@ class Client(object): ...@@ -94,11 +94,14 @@ class Client(object):
pyfiles['python_files'] = {} pyfiles['python_files'] = {}
pyfiles['other_files'] = {} pyfiles['other_files'] = {}
main_file = sys.argv[0] if isnotebook():
main_folder = './' main_folder = './'
sep = os.sep else:
if sep in main_file: main_file = sys.argv[0]
main_folder = sep.join(main_file.split(sep)[:-1]) main_folder = './'
sep = os.sep
if sep in main_file:
main_folder = sep.join(main_file.split(sep)[:-1])
code_files = filter(lambda x: x.endswith('.py'), code_files = filter(lambda x: x.endswith('.py'),
os.listdir(main_folder)) os.listdir(main_folder))
...@@ -108,16 +111,6 @@ class Client(object): ...@@ -108,16 +111,6 @@ class Client(object):
with open(file_path, 'rb') as code_file: with open(file_path, 'rb') as code_file:
code = code_file.read() code = code_file.read()
pyfiles['python_files'][file_name] = code pyfiles['python_files'][file_name] = code
# append entry file to code list
assert os.path.isfile(
main_file
), "[xparl] error occurs when distributing files. cannot find the entry file:{} in current working directory: {}".format(
main_file, os.getcwd())
with open(main_file, 'rb') as code_file:
code = code_file.read()
# parl/remote/remote_decorator.py -> remote_decorator.py
file_name = main_file.split(os.sep)[-1]
pyfiles['python_files'][file_name] = code
for file_name in distributed_files: for file_name in distributed_files:
assert os.path.exists(file_name) assert os.path.exists(file_name)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
import os import os
from parl.utils import isnotebook
__all__ = [ __all__ = [
'load_remote_class', 'redirect_stdout_to_file', 'locate_remote_file' 'load_remote_class', 'redirect_stdout_to_file', 'locate_remote_file'
...@@ -107,6 +108,7 @@ def redirect_stdout_to_file(file_path): ...@@ -107,6 +108,7 @@ def redirect_stdout_to_file(file_path):
def locate_remote_file(module_path): def locate_remote_file(module_path):
"""xparl has to locate the file that has the class decorated by parl.remote_class. """xparl has to locate the file that has the class decorated by parl.remote_class.
This function returns the relative path between this file and the entry file. This function returns the relative path between this file and the entry file.
Note that this function should support the jupyter-notebook environment.
Args: Args:
module_path: Absolute path of the module. module_path: Absolute path of the module.
...@@ -116,14 +118,17 @@ def locate_remote_file(module_path): ...@@ -116,14 +118,17 @@ def locate_remote_file(module_path):
entry_file: /home/user/dir/main.py entry_file: /home/user/dir/main.py
--------> relative_path: subdir/my_module --------> relative_path: subdir/my_module
""" """
entry_file = sys.argv[0] if isnotebook():
entry_file = entry_file.split(os.sep)[-1] entry_path = os.getcwd()
entry_path = None else:
for path in sys.path: entry_file = sys.argv[0]
to_check_path = os.path.join(path, entry_file) entry_file = entry_file.split(os.sep)[-1]
if os.path.isfile(to_check_path): entry_path = None
entry_path = path for path in sys.path:
break to_check_path = os.path.join(path, entry_file)
if os.path.isfile(to_check_path):
entry_path = path
break
if entry_path is None or \ if entry_path is None or \
(module_path.startswith(os.sep) and entry_path != module_path[:len(entry_path)]): (module_path.startswith(os.sep) and entry_path != module_path[:len(entry_path)]):
raise FileNotFoundError("cannot locate the remote file") raise FileNotFoundError("cannot locate the remote file")
......
...@@ -20,7 +20,7 @@ import numpy as np ...@@ -20,7 +20,7 @@ import numpy as np
__all__ = [ __all__ = [
'has_func', 'to_str', 'to_byte', 'is_PY2', 'is_PY3', 'MAX_INT32', 'has_func', 'to_str', 'to_byte', 'is_PY2', 'is_PY3', 'MAX_INT32',
'_HAS_FLUID', '_HAS_TORCH', '_IS_WINDOWS', '_IS_MAC', 'kill_process', '_HAS_FLUID', '_HAS_TORCH', '_IS_WINDOWS', '_IS_MAC', 'kill_process',
'get_fluid_version' 'get_fluid_version', 'isnotebook'
] ]
...@@ -101,3 +101,19 @@ def kill_process(regex_pattern): ...@@ -101,3 +101,19 @@ def kill_process(regex_pattern):
command = "ps aux | grep {} | awk '{{print $2}}' | xargs kill -9".format( command = "ps aux | grep {} | awk '{{print $2}}' | xargs kill -9".format(
regex_pattern) regex_pattern)
subprocess.call([command], shell=True) subprocess.call([command], shell=True)
def isnotebook():
"""check if the code is excuted in the IPython notebook
Reference: https://stackoverflow.com/a/39662359
"""
try:
shell = get_ipython().__class__.__name__
if shell == 'ZMQInteractiveShell':
return True # Jupyter notebook or qtconsole
elif shell == 'TerminalInteractiveShell':
return False # Terminal running IPython
else:
return False # Other type (?)
except NameError:
return False # Probably standard Python interpreter
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册