diff --git a/parl/remote/client.py b/parl/remote/client.py index 11865f8ff2afe9368768c701a0c00d232a54c4e1..946493421b58e38f4e19ff543a820e03aedea3b2 100644 --- a/parl/remote/client.py +++ b/parl/remote/client.py @@ -19,7 +19,7 @@ import socket import sys import threading import zmq -from parl.utils import to_str, to_byte, get_ip_address, logger +from parl.utils import to_str, to_byte, get_ip_address, logger, isnotebook from parl.remote import remote_constants import time @@ -94,11 +94,14 @@ class Client(object): pyfiles['python_files'] = {} pyfiles['other_files'] = {} - main_file = sys.argv[0] - main_folder = './' - sep = os.sep - if sep in main_file: - main_folder = sep.join(main_file.split(sep)[:-1]) + if isnotebook(): + main_folder = './' + else: + main_file = sys.argv[0] + main_folder = './' + sep = os.sep + if sep in main_file: + main_folder = sep.join(main_file.split(sep)[:-1]) code_files = filter(lambda x: x.endswith('.py'), os.listdir(main_folder)) @@ -108,16 +111,6 @@ class Client(object): with open(file_path, 'rb') as code_file: code = code_file.read() pyfiles['python_files'][file_name] = code - # append entry file to code list - assert os.path.isfile( - main_file - ), "[xparl] error occurs when distributing files. cannot find the entry file:{} in current working directory: {}".format( - main_file, os.getcwd()) - with open(main_file, 'rb') as code_file: - code = code_file.read() - # parl/remote/remote_decorator.py -> remote_decorator.py - file_name = main_file.split(os.sep)[-1] - pyfiles['python_files'][file_name] = code for file_name in distributed_files: assert os.path.exists(file_name) diff --git a/parl/remote/utils.py b/parl/remote/utils.py index 63c94c1a8022256bec382a348a11466c93d0ecc8..5d6368e21d2f9d19a72f28eea0d214df8a913664 100644 --- a/parl/remote/utils.py +++ b/parl/remote/utils.py @@ -14,6 +14,7 @@ import sys from contextlib import contextmanager import os +from parl.utils import isnotebook __all__ = [ 'load_remote_class', 'redirect_stdout_to_file', 'locate_remote_file' @@ -107,6 +108,7 @@ def redirect_stdout_to_file(file_path): def locate_remote_file(module_path): """xparl has to locate the file that has the class decorated by parl.remote_class. This function returns the relative path between this file and the entry file. + Note that this function should support the jupyter-notebook environment. Args: module_path: Absolute path of the module. @@ -116,14 +118,17 @@ def locate_remote_file(module_path): entry_file: /home/user/dir/main.py --------> relative_path: subdir/my_module """ - entry_file = sys.argv[0] - entry_file = entry_file.split(os.sep)[-1] - entry_path = None - for path in sys.path: - to_check_path = os.path.join(path, entry_file) - if os.path.isfile(to_check_path): - entry_path = path - break + if isnotebook(): + entry_path = os.getcwd() + else: + entry_file = sys.argv[0] + entry_file = entry_file.split(os.sep)[-1] + entry_path = None + for path in sys.path: + to_check_path = os.path.join(path, entry_file) + if os.path.isfile(to_check_path): + entry_path = path + break if entry_path is None or \ (module_path.startswith(os.sep) and entry_path != module_path[:len(entry_path)]): raise FileNotFoundError("cannot locate the remote file") diff --git a/parl/utils/utils.py b/parl/utils/utils.py index effb572811a6f72b9c6e363901fa632bdc75a625..69af1511a437cce56a9dc80885c6eb7086463160 100644 --- a/parl/utils/utils.py +++ b/parl/utils/utils.py @@ -20,7 +20,7 @@ import numpy as np __all__ = [ 'has_func', 'to_str', 'to_byte', 'is_PY2', 'is_PY3', 'MAX_INT32', '_HAS_FLUID', '_HAS_TORCH', '_IS_WINDOWS', '_IS_MAC', 'kill_process', - 'get_fluid_version' + 'get_fluid_version', 'isnotebook' ] @@ -101,3 +101,19 @@ def kill_process(regex_pattern): command = "ps aux | grep {} | awk '{{print $2}}' | xargs kill -9".format( regex_pattern) subprocess.call([command], shell=True) + + +def isnotebook(): + """check if the code is excuted in the IPython notebook + Reference: https://stackoverflow.com/a/39662359 + """ + try: + shell = get_ipython().__class__.__name__ + if shell == 'ZMQInteractiveShell': + return True # Jupyter notebook or qtconsole + elif shell == 'TerminalInteractiveShell': + return False # Terminal running IPython + else: + return False # Other type (?) + except NameError: + return False # Probably standard Python interpreter