From 74d4facbb4c381ad4240f12d88e71c1aa35b410b Mon Sep 17 00:00:00 2001 From: Bo Zhou <2466956298@qq.com> Date: Mon, 17 Aug 2020 12:22:07 +0800 Subject: [PATCH] support jupyter (#385) * support jupyter * Update utils.py * Update utils.py --- parl/remote/client.py | 25 +++++++++---------------- parl/remote/utils.py | 21 +++++++++++++-------- parl/utils/utils.py | 18 +++++++++++++++++- 3 files changed, 39 insertions(+), 25 deletions(-) diff --git a/parl/remote/client.py b/parl/remote/client.py index 11865f8..9464934 100644 --- a/parl/remote/client.py +++ b/parl/remote/client.py @@ -19,7 +19,7 @@ import socket import sys import threading import zmq -from parl.utils import to_str, to_byte, get_ip_address, logger +from parl.utils import to_str, to_byte, get_ip_address, logger, isnotebook from parl.remote import remote_constants import time @@ -94,11 +94,14 @@ class Client(object): pyfiles['python_files'] = {} pyfiles['other_files'] = {} - main_file = sys.argv[0] - main_folder = './' - sep = os.sep - if sep in main_file: - main_folder = sep.join(main_file.split(sep)[:-1]) + if isnotebook(): + main_folder = './' + else: + main_file = sys.argv[0] + main_folder = './' + sep = os.sep + if sep in main_file: + main_folder = sep.join(main_file.split(sep)[:-1]) code_files = filter(lambda x: x.endswith('.py'), os.listdir(main_folder)) @@ -108,16 +111,6 @@ class Client(object): with open(file_path, 'rb') as code_file: code = code_file.read() pyfiles['python_files'][file_name] = code - # append entry file to code list - assert os.path.isfile( - main_file - ), "[xparl] error occurs when distributing files. cannot find the entry file:{} in current working directory: {}".format( - main_file, os.getcwd()) - with open(main_file, 'rb') as code_file: - code = code_file.read() - # parl/remote/remote_decorator.py -> remote_decorator.py - file_name = main_file.split(os.sep)[-1] - pyfiles['python_files'][file_name] = code for file_name in distributed_files: assert os.path.exists(file_name) diff --git a/parl/remote/utils.py b/parl/remote/utils.py index 63c94c1..5d6368e 100644 --- a/parl/remote/utils.py +++ b/parl/remote/utils.py @@ -14,6 +14,7 @@ import sys from contextlib import contextmanager import os +from parl.utils import isnotebook __all__ = [ 'load_remote_class', 'redirect_stdout_to_file', 'locate_remote_file' @@ -107,6 +108,7 @@ def redirect_stdout_to_file(file_path): def locate_remote_file(module_path): """xparl has to locate the file that has the class decorated by parl.remote_class. This function returns the relative path between this file and the entry file. + Note that this function should support the jupyter-notebook environment. Args: module_path: Absolute path of the module. @@ -116,14 +118,17 @@ def locate_remote_file(module_path): entry_file: /home/user/dir/main.py --------> relative_path: subdir/my_module """ - entry_file = sys.argv[0] - entry_file = entry_file.split(os.sep)[-1] - entry_path = None - for path in sys.path: - to_check_path = os.path.join(path, entry_file) - if os.path.isfile(to_check_path): - entry_path = path - break + if isnotebook(): + entry_path = os.getcwd() + else: + entry_file = sys.argv[0] + entry_file = entry_file.split(os.sep)[-1] + entry_path = None + for path in sys.path: + to_check_path = os.path.join(path, entry_file) + if os.path.isfile(to_check_path): + entry_path = path + break if entry_path is None or \ (module_path.startswith(os.sep) and entry_path != module_path[:len(entry_path)]): raise FileNotFoundError("cannot locate the remote file") diff --git a/parl/utils/utils.py b/parl/utils/utils.py index effb572..69af151 100644 --- a/parl/utils/utils.py +++ b/parl/utils/utils.py @@ -20,7 +20,7 @@ import numpy as np __all__ = [ 'has_func', 'to_str', 'to_byte', 'is_PY2', 'is_PY3', 'MAX_INT32', '_HAS_FLUID', '_HAS_TORCH', '_IS_WINDOWS', '_IS_MAC', 'kill_process', - 'get_fluid_version' + 'get_fluid_version', 'isnotebook' ] @@ -101,3 +101,19 @@ def kill_process(regex_pattern): command = "ps aux | grep {} | awk '{{print $2}}' | xargs kill -9".format( regex_pattern) subprocess.call([command], shell=True) + + +def isnotebook(): + """check if the code is excuted in the IPython notebook + Reference: https://stackoverflow.com/a/39662359 + """ + try: + shell = get_ipython().__class__.__name__ + if shell == 'ZMQInteractiveShell': + return True # Jupyter notebook or qtconsole + elif shell == 'TerminalInteractiveShell': + return False # Terminal running IPython + else: + return False # Other type (?) + except NameError: + return False # Probably standard Python interpreter -- GitLab