diff --git a/.teamcity/requirements.txt b/.teamcity/requirements.txt
index 4cd77073614710e3fd3bd39efdcbf9b0c601e93a..1899a87dcba5505f854acd7b9c69a92a3ac7bb99 100644
--- a/.teamcity/requirements.txt
+++ b/.teamcity/requirements.txt
@@ -1,5 +1,7 @@
# requirements for unittest
rarfile==3.1
+opencv-python<=4.3.0.34;python_version>="3"
+opencv-python==4.2.0.32;python_version<"3"
paddlepaddle-gpu==1.6.1.post97
gym
details
diff --git a/CMakeLists.txt b/CMakeLists.txt
index a77bd684aecc364bf0053e36724fcf0fe880d2f0..435e27f2e0a3ed24964a639236a66de1f7a69f75 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -30,10 +30,20 @@ function(py_test TARGET_NAME)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS ARGS ENVS)
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
- add_test(NAME ${TARGET_NAME}
- COMMAND python -u ${py_test_SRCS} ${py_test_ARGS}
- WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
- set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 300)
+ if (${FILE_NAME} MATCHES ".*abs_test.py")
+ add_test(NAME ${TARGET_NAME}"_with_abs_path"
+ COMMAND python -u ${py_test_SRCS} ${py_test_ARGS}
+ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
+ set_tests_properties(${TARGET_NAME}"_with_abs_path" PROPERTIES TIMEOUT 300)
+ else()
+ get_filename_component(WORKING_DIR ${py_test_SRCS} DIRECTORY)
+ get_filename_component(FILE_NAME ${py_test_SRCS} NAME)
+ get_filename_component(COMBINED_PATH ${CMAKE_CURRENT_SOURCE_DIR}/${WORKING_DIR} ABSOLUTE)
+ add_test(NAME ${TARGET_NAME}
+ COMMAND python -u ${FILE_NAME} ${py_test_ARGS}
+ WORKING_DIRECTORY ${COMBINED_PATH})
+ set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 300)
+ endif()
endfunction()
function(import_test TARGET_NAME)
diff --git a/docs/zh_CN/Overview.md b/docs/zh_CN/Overview.md
index 2d795edf2a0432f01e4a90a8bc3cc33be7bf3917..52fd5e546108b86262163f17b6c4121c2417a441 100644
--- a/docs/zh_CN/Overview.md
+++ b/docs/zh_CN/Overview.md
@@ -126,7 +126,13 @@ yapf -i modified_file.py
```
- 持续集成测试
当增加代码时候,需要增加测试代码覆盖所添加的代码,测试代码得放在相关代码文件的`tests`文件夹下,以`_test.py`结尾(这样持续集成测试会自动拉取代码跑)。附:[测试代码示例](../../parl/tests/import_test.py)
-
+- 本地运行单元测试(非必要)
+如果你希望在自己的机器运行单测代码,可先在本地机器上安装Docker,再按以下步骤执行单测任务。
+```
+cd PARL
+docker build -t parl/parl-test:unittest .teamcity/
+nvidia-docker run -i --rm -v $PWD:/work -w /work parl/parl-test:unittest .teamcity/build.sh test
+```
## 反馈
- 在 GitHub 上[提交问题](https://github.com/PaddlePaddle/PARL/issues)
diff --git a/docs/zh_CN/xparl/introduction.md b/docs/zh_CN/xparl/introduction.md
index 38c749a7ca592569534145fd2030cc1b2eea46c4..5cfc5dc4064bd3db5066296ccd63fce1e3874441 100644
--- a/docs/zh_CN/xparl/introduction.md
+++ b/docs/zh_CN/xparl/introduction.md
@@ -24,4 +24,4 @@ PARL在实现底层的并行计算时,是通过端到端的这种网络传输
## 自动分发本地文件
市面上的并行框架大部分得要用户手动同步文件才可以跑起并行代码,比如配置文件得要手动或者通过命令分发到不同机器,parl可以自动分发当前目录下的代码文件,实现无缝的多机并行。
-
+
diff --git a/docs/zh_CN/xparl/tutorial.md b/docs/zh_CN/xparl/tutorial.md
index 066297f57576575180eed2e3fb05b459ff7c9575..8a0ef4087ffcd124caaa6060877608ab88f078cf 100644
--- a/docs/zh_CN/xparl/tutorial.md
+++ b/docs/zh_CN/xparl/tutorial.md
@@ -3,7 +3,7 @@
## 配置命令
这个教程将会演示如何搭建一个集群。
-搭建一个PARL集群,可以通过执行下面两个`xparl`命令:
+搭建一个PARL集群,可以通过执行下面的`xparl`命令:
### 启动集群
```bash
@@ -12,17 +12,17 @@ xparl start --port 6006
这个命令会启动一个主节点(master)来管理集群的计算资源,同时会把本地机器的CPU资源加入到集群中。命令中的6006端口只是作为示例,你可以修改成任何有效的端口。
-### 加入其它机器资源
-> 注意:如果你只有单台机器,可以忽略这部分教程。
+启动后可通过`xparl status`查看目前集群有多少CPU资源可用,你可以在`xparl start`的命令中加入选项`--cpu_num [CPU_NUM]` (例如:--cpu_num 10)指定本机加入集群的CPU数量。
-如果你想加入更多的CPU计算资源到集群中,可以在其他机器上运行下面命令:
+### 加入更多CPU资源
+
+启动集群后,就可以直接使用集群了,如果CPU资源不够用,你可以在任何时候和任何机器(包括本机或其他机器)上,通过执行`xparl connect`命令把更多CPU资源加入到集群中。
```bash
xparl connect --address [MASTER_ADDRESS]:6006
```
-它会启动一个工作节点(worker),并把当前机器的CPU资源加入到该master对应的集群。worker默认会把所有的CPU资源加入到集群中,如果你需要指定worker可使用的CPU数量,可以在上述命令上加入选项`--cpu_num [CPU_NUM]` (例如:----cpu_num 10)。
+它会启动一个工作节点(worker),并把当前机器的CPU资源加入到`--address`指定的master集群。worker默认会把当前机器所有的可用的CPU资源加入到集群中,如果你需要指定加入的CPU数量,也可以在上述命令上加入选项`--cpu_num [CPU_NUM]` 。
-注意:启动集群后,你可以在任何时候和任何机器上,通过执行`xparl connect`命令把更多CPU资源加入到集群中。
## 示例
这里我们给出了一个示例来演示如何通过`@parl.remote_class`来进行并行计算。
@@ -47,9 +47,9 @@ actor.add(1, 2) # 返回 3
```
## 关闭集群
-在master机器上运行`xparl stop`命令即可关闭集群程序。当master节点退出后,运行在其他机器的worker节点也会自动退出并结束相关程序。
+在master机器上运行`xparl stop`命令即可关闭集群程序。当master节点退出后,与之关联的worker节点也会自动退出并结束相关程序。
## 扩展阅读
-我们现在已经知道了如何搭建一个集群,以及如何通过修饰符`@parl.remote_class`来使用集群。
+我们现在已经知道了如何通过终端命令`xparl`搭建一个集群,以及如何通过修饰符`@parl.remote_class`来使用集群。
在[下一个教程](./example.md)我们将会演示如何通过这个修饰符来打破Python的全局解释器锁(Global Interpreter Lock, GIL)限制,从而实现真正的多线程计算。
diff --git a/parl/core/fluid/model.py b/parl/core/fluid/model.py
index bf7069a68c53748d870c1d9d21c2ec971fee05fe..80f748689532b7fb1af3aefdb4e2e592e71c3d9b 100644
--- a/parl/core/fluid/model.py
+++ b/parl/core/fluid/model.py
@@ -53,7 +53,7 @@ class Model(ModelBase):
copied_policy = copy.deepcopy(model)
Attributes:
- model_id(str): each model instance has its uniqe model_id.
+ model_id(str): each model instance has its unique model_id.
Public Functions:
- ``sync_weights_to``: synchronize parameters of the current model to another model.
diff --git a/parl/remote/client.py b/parl/remote/client.py
index 379459c5768914a012cf89182724f1233cbf1329..344debe2afccad628613e9b27b5f21e25e6f631e 100644
--- a/parl/remote/client.py
+++ b/parl/remote/client.py
@@ -19,9 +19,11 @@ import socket
import sys
import threading
import zmq
-from parl.utils import to_str, to_byte, get_ip_address, logger
+import parl
+from parl.utils import to_str, to_byte, get_ip_address, logger, isnotebook
from parl.remote import remote_constants
import time
+import glob
class Client(object):
@@ -50,7 +52,6 @@ class Client(object):
distributed_files (list): A list of files to be distributed at all
remote instances(e,g. the configuration
file for initialization) .
-
"""
self.master_address = master_address
self.process_id = process_id
@@ -66,6 +67,7 @@ class Client(object):
self.actor_num = 0
self._create_sockets(master_address)
+ self.check_version()
self.pyfiles = self.read_local_files(distributed_files)
def get_executable_path(self):
@@ -85,44 +87,58 @@ class Client(object):
Args:
distributed_files (list): A list of files to be distributed at all
remote instances(e,g. the configuration
- file for initialization) .
-
+ file for initialization) . RegExp of file
+ names is supported.
+ e.g.
+ distributed_files = ['./*.npy', './test*']
+
Returns:
A cloudpickled dictionary containing the python code in current
working directory.
"""
+
+ parsed_distributed_files = set()
+ for distributed_file in distributed_files:
+ parsed_list = glob.glob(distributed_file)
+ if not parsed_list:
+ raise ValueError(
+ "no local file is matched with '{}', please check your input"
+ .format(distributed_file))
+ # exclude the directiories
+ for pathname in parsed_list:
+ if not os.path.isdir(pathname):
+ parsed_distributed_files.add(pathname)
+
pyfiles = dict()
pyfiles['python_files'] = {}
pyfiles['other_files'] = {}
- code_files = filter(lambda x: x.endswith('.py'), os.listdir('./'))
-
- try:
- for file in code_files:
- assert os.path.exists(file)
- with open(file, 'rb') as code_file:
- code = code_file.read()
- pyfiles['python_files'][file] = code
-
- for file in distributed_files:
- assert os.path.exists(file)
- assert not os.path.isabs(
- file
- ), "[XPARL] Please do not distribute a file with absolute path."
- with open(file, 'rb') as f:
- content = f.read()
- pyfiles['other_files'][file] = content
- # append entry file to code list
+ if isnotebook():
+ main_folder = './'
+ else:
main_file = sys.argv[0]
- with open(main_file, 'rb') as code_file:
+ main_folder = './'
+ sep = os.sep
+ if sep in main_file:
+ main_folder = sep.join(main_file.split(sep)[:-1])
+ code_files = filter(lambda x: x.endswith('.py'),
+ os.listdir(main_folder))
+
+ for file_name in code_files:
+ file_path = os.path.join(main_folder, file_name)
+ assert os.path.exists(file_path)
+ with open(file_path, 'rb') as code_file:
code = code_file.read()
- # parl/remote/remote_decorator.py -> remote_decorator.py
- file_name = main_file.split(os.sep)[-1]
pyfiles['python_files'][file_name] = code
- except AssertionError as e:
- raise Exception(
- 'Failed to create the client, the file {} does not exist.'.
- format(file))
+
+ for file_name in parsed_distributed_files:
+ assert os.path.exists(file_name)
+ assert not os.path.isabs(
+ file_name
+ ), "[XPARL] Please do not distribute a file with absolute path."
+ with open(file_name, 'rb') as f:
+ content = f.read()
+ pyfiles['other_files'][file_name] = content
return cloudpickle.dumps(pyfiles)
def _create_sockets(self, master_address):
@@ -165,6 +181,24 @@ class Client(object):
"check if master is started and ensure the input "
"address {} is correct.".format(master_address))
+ def check_version(self):
+ '''Verify that the parl & python version in 'client' process matches that of the 'master' process'''
+ self.submit_job_socket.send_multipart(
+ [remote_constants.CHECK_VERSION_TAG])
+ message = self.submit_job_socket.recv_multipart()
+ tag = message[0]
+ if tag == remote_constants.NORMAL_TAG:
+ client_parl_version = parl.__version__
+ client_python_version = str(sys.version_info.major)
+ assert client_parl_version == to_str(message[1]) and client_python_version == to_str(message[2]),\
+ '''Version mismatch: the 'master' is of version 'parl={}, python={}'. However,
+ 'parl={}, python={}'is provided in your environment.'''.format(
+ to_str(message[1]), to_str(message[2]),
+ client_parl_version, client_python_version
+ )
+ else:
+ raise NotImplementedError
+
def _reply_heartbeat(self):
"""Reply heartbeat signals to the master node."""
diff --git a/parl/remote/job.py b/parl/remote/job.py
index d835e5389aa447bb69567b61f6f1c60b9cf99d58..aa677ebfb9e18842f60491e24929e626fe2730a6 100644
--- a/parl/remote/job.py
+++ b/parl/remote/job.py
@@ -311,8 +311,6 @@ class Job(object):
try:
file_name, class_name, end_of_file = cloudpickle.loads(
message[1])
- #/home/nlp-ol/Firework/baidu/nlp/evokit/python_api/es_agent -> es_agent
- file_name = file_name.split(os.sep)[-1]
cls = load_remote_class(file_name, class_name, end_of_file)
args, kwargs = cloudpickle.loads(message[2])
logfile_path = os.path.join(self.log_dir, 'stdout.log')
@@ -327,7 +325,10 @@ class Job(object):
to_byte(error_str + "\ntraceback:\n" + traceback_str)
])
return None
- reply_socket.send_multipart([remote_constants.NORMAL_TAG])
+ reply_socket.send_multipart([
+ remote_constants.NORMAL_TAG,
+ dumps_return(set(obj.__dict__.keys()))
+ ])
else:
logger.error("Message from job {}".format(message))
reply_socket.send_multipart([
@@ -397,24 +398,49 @@ class Job(object):
while True:
message = reply_socket.recv_multipart()
-
tag = message[0]
-
- if tag == remote_constants.CALL_TAG:
+ if tag in [
+ remote_constants.CALL_TAG,
+ remote_constants.GET_ATTRIBUTE_TAG,
+ remote_constants.SET_ATTRIBUTE_TAG,
+ ]:
try:
- function_name = to_str(message[1])
- data = message[2]
- args, kwargs = loads_argument(data)
+ if tag == remote_constants.CALL_TAG:
+ function_name = to_str(message[1])
+ data = message[2]
+ args, kwargs = loads_argument(data)
- # Redirect stdout to stdout.log temporarily
- logfile_path = os.path.join(self.log_dir, 'stdout.log')
- with redirect_stdout_to_file(logfile_path):
- ret = getattr(obj, function_name)(*args, **kwargs)
+ # Redirect stdout to stdout.log temporarily
+ logfile_path = os.path.join(self.log_dir, 'stdout.log')
+ with redirect_stdout_to_file(logfile_path):
+ ret = getattr(obj, function_name)(*args, **kwargs)
- ret = dumps_return(ret)
+ ret = dumps_return(ret)
+ reply_socket.send_multipart([
+ remote_constants.NORMAL_TAG, ret,
+ dumps_return(set(obj.__dict__.keys()))
+ ])
- reply_socket.send_multipart(
- [remote_constants.NORMAL_TAG, ret])
+ elif tag == remote_constants.GET_ATTRIBUTE_TAG:
+ attribute_name = to_str(message[1])
+ logfile_path = os.path.join(self.log_dir, 'stdout.log')
+ with redirect_stdout_to_file(logfile_path):
+ ret = getattr(obj, attribute_name)
+ ret = dumps_return(ret)
+ reply_socket.send_multipart(
+ [remote_constants.NORMAL_TAG, ret])
+ elif tag == remote_constants.SET_ATTRIBUTE_TAG:
+ attribute_name = to_str(message[1])
+ attribute_value = loads_return(message[2])
+ logfile_path = os.path.join(self.log_dir, 'stdout.log')
+ with redirect_stdout_to_file(logfile_path):
+ setattr(obj, attribute_name, attribute_value)
+ reply_socket.send_multipart([
+ remote_constants.NORMAL_TAG,
+ dumps_return(set(obj.__dict__.keys()))
+ ])
+ else:
+ pass
except Exception as e:
# reset the job
diff --git a/parl/remote/master.py b/parl/remote/master.py
index 8cca0290a7ad68407026f2e24c4613da83af56a3..7964c561c2ab4067d15574a931ceb41db9bbfe85 100644
--- a/parl/remote/master.py
+++ b/parl/remote/master.py
@@ -18,6 +18,8 @@ import threading
import time
import zmq
from collections import deque, defaultdict
+import parl
+import sys
from parl.utils import to_str, to_byte, logger, get_ip_address
from parl.remote import remote_constants
from parl.remote.job_center import JobCenter
@@ -208,6 +210,7 @@ class Master(object):
elif tag == remote_constants.CLIENT_CONNECT_TAG:
# `client_heartbeat_address` is the
# `reply_master_heartbeat_address` of the client
+
client_heartbeat_address = to_str(message[1])
client_hostname = to_str(message[2])
client_id = to_str(message[3])
@@ -225,6 +228,13 @@ class Master(object):
[remote_constants.NORMAL_TAG,
to_byte(log_monitor_address)])
+ elif tag == remote_constants.CHECK_VERSION_TAG:
+ self.client_socket.send_multipart([
+ remote_constants.NORMAL_TAG,
+ to_byte(parl.__version__),
+ to_byte(str(sys.version_info.major))
+ ])
+
# a client submits a job to the master
elif tag == remote_constants.CLIENT_SUBMIT_TAG:
# check available CPU resources
diff --git a/parl/remote/remote_constants.py b/parl/remote/remote_constants.py
index 8f49da56cc58f0414364b1230bdbb798679691dd..09f96caed5b84a437e219144214b8f98433e5990 100644
--- a/parl/remote/remote_constants.py
+++ b/parl/remote/remote_constants.py
@@ -27,8 +27,11 @@ SEND_FILE_TAG = b'[SEND_FILE]'
SUBMIT_JOB_TAG = b'[SUBMIT_JOB]'
NEW_JOB_TAG = b'[NEW_JOB]'
+CHECK_VERSION_TAG = b'[CHECK_VERSION]'
INIT_OBJECT_TAG = b'[INIT_OBJECT]'
CALL_TAG = b'[CALL]'
+GET_ATTRIBUTE_TAG = b'[GET_ATTRIBUTE]'
+SET_ATTRIBUTE_TAG = b'[SET_ATTRIBUTE]'
EXCEPTION_TAG = b'[EXCEPTION]'
ATTRIBUTE_EXCEPTION_TAG = b'[ATTRIBUTE_EXCEPTION]'
diff --git a/parl/remote/remote_decorator.py b/parl/remote/remote_decorator.py
index a066abc40832fdce00fd00d1784aa75c60925e00..7403f0313155d4c7e6a15eb531e16bd49952811f 100644
--- a/parl/remote/remote_decorator.py
+++ b/parl/remote/remote_decorator.py
@@ -19,6 +19,7 @@ import time
import zmq
import numpy as np
import inspect
+import sys
from parl.utils import get_ip_address, logger, to_str, to_byte
from parl.utils.communication import loads_argument, loads_return,\
@@ -27,6 +28,7 @@ from parl.remote import remote_constants
from parl.remote.exceptions import RemoteError, RemoteAttributeError,\
RemoteDeserializeError, RemoteSerializeError, ResourceError
from parl.remote.client import get_global_client
+from parl.remote.utils import locate_remote_file
def remote_class(*args, **kwargs):
@@ -93,7 +95,7 @@ def remote_class(*args, **kwargs):
class.
"""
self.GLOBAL_CLIENT = get_global_client()
-
+ self.remote_attribute_keys_set = set()
self.ctx = self.GLOBAL_CLIENT.ctx
# GLOBAL_CLIENT will set `master_is_alive` to False when hearbeat
@@ -120,21 +122,34 @@ def remote_class(*args, **kwargs):
self.job_shutdown = False
self.send_file(self.job_socket)
- file_name = inspect.getfile(cls)[:-3]
+ module_path = inspect.getfile(cls)
+ if module_path.endswith('pyc'):
+ module_path = module_path[:-4]
+ elif module_path.endswith('py'):
+ module_path = module_path[:-3]
+ else:
+ raise FileNotFoundError(
+ "cannot not find the module:{}".format(module_path))
+ res = inspect.getfile(cls)
+ file_path = locate_remote_file(module_path)
cls_source = inspect.getsourcelines(cls)
end_of_file = cls_source[1] + len(cls_source[0])
class_name = cls.__name__
self.job_socket.send_multipart([
remote_constants.INIT_OBJECT_TAG,
- cloudpickle.dumps([file_name, class_name, end_of_file]),
+ cloudpickle.dumps([file_path, class_name, end_of_file]),
cloudpickle.dumps([args, kwargs]),
])
message = self.job_socket.recv_multipart()
tag = message[0]
- if tag == remote_constants.EXCEPTION_TAG:
+ if tag == remote_constants.NORMAL_TAG:
+ self.remote_attribute_keys_set = loads_return(message[1])
+ elif tag == remote_constants.EXCEPTION_TAG:
traceback_str = to_str(message[1])
self.job_shutdown = True
raise RemoteError('__init__', traceback_str)
+ else:
+ pass
def __del__(self):
"""Delete the remote class object and release remote resources."""
@@ -179,25 +194,55 @@ def remote_class(*args, **kwargs):
cnt -= 1
return None
- def __getattr__(self, attr):
+ def set_remote_attr(self, attr, value):
+ self.internal_lock.acquire()
+ self.job_socket.send_multipart([
+ remote_constants.SET_ATTRIBUTE_TAG,
+ to_byte(attr),
+ dumps_return(value)
+ ])
+ message = self.job_socket.recv_multipart()
+ tag = message[0]
+ if tag == remote_constants.NORMAL_TAG:
+ self.remote_attribute_keys_set = loads_return(message[1])
+ self.internal_lock.release()
+ else:
+ self.job_shutdown = True
+ raise NotImplementedError()
+ return
+
+ def get_remote_attr(self, attr):
"""Call the function of the unwrapped class."""
+ #check if attr is a attribute or a function
+ is_attribute = attr in self.remote_attribute_keys_set
def wrapper(*args, **kwargs):
- if self.job_shutdown:
- raise RemoteError(
- attr, "This actor losts connection with the job.")
self.internal_lock.acquire()
- data = dumps_argument(*args, **kwargs)
-
- self.job_socket.send_multipart(
- [remote_constants.CALL_TAG,
- to_byte(attr), data])
+ if is_attribute:
+ self.job_socket.send_multipart([
+ remote_constants.GET_ATTRIBUTE_TAG,
+ to_byte(attr)
+ ])
+ else:
+ if self.job_shutdown:
+ raise RemoteError(
+ attr,
+ "This actor losts connection with the job.")
+ data = dumps_argument(*args, **kwargs)
+ self.job_socket.send_multipart(
+ [remote_constants.CALL_TAG,
+ to_byte(attr), data])
message = self.job_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.NORMAL_TAG:
ret = loads_return(message[1])
+ if not is_attribute:
+ self.remote_attribute_keys_set = loads_return(
+ message[2])
+ self.internal_lock.release()
+ return ret
elif tag == remote_constants.EXCEPTION_TAG:
error_str = to_str(message[1])
@@ -223,13 +268,38 @@ def remote_class(*args, **kwargs):
self.job_shutdown = True
raise NotImplementedError()
- self.internal_lock.release()
- return ret
+ return wrapper() if is_attribute else wrapper
+
+ def proxy_wrapper_func(remote_wrapper):
+ '''
+ The 'proxy_wrapper_func' is defined on the top of class 'RemoteWrapper'
+ in order to set and get attributes of 'remoted_wrapper' and the corresponding
+ remote models individually.
+
+ With 'proxy_wrapper_func', it is allowed to define a attribute (or method) of
+ the same name in 'RemoteWrapper' and remote models.
+ '''
+
+ class ProxyWrapper(object):
+ def __init__(self, *args, **kwargs):
+ self.xparl_remote_wrapper_obj = remote_wrapper(
+ *args, **kwargs)
+
+ def __getattr__(self, attr):
+ return self.xparl_remote_wrapper_obj.get_remote_attr(attr)
+
+ def __setattr__(self, attr, value):
+ if attr == 'xparl_remote_wrapper_obj':
+ super(ProxyWrapper, self).__setattr__(attr, value)
+ else:
+ self.xparl_remote_wrapper_obj.set_remote_attr(
+ attr, value)
- return wrapper
+ return ProxyWrapper
RemoteWrapper._original = cls
- return RemoteWrapper
+ proxy_wrapper = proxy_wrapper_func(RemoteWrapper)
+ return proxy_wrapper
max_memory = kwargs.get('max_memory')
if len(args) == 1 and callable(args[0]):
diff --git a/parl/remote/scripts.py b/parl/remote/scripts.py
index cbae1ccf2059c8cab6da934984db466e4f54ad88..e32b819d24bb9a598da8cd9278b83ac4fc3339ed 100644
--- a/parl/remote/scripts.py
+++ b/parl/remote/scripts.py
@@ -171,22 +171,28 @@ def start_master(port, cpu_num, monitor_port, debug, log_server_port_range):
# Redirect the output to DEVNULL to solve the warning log.
_ = subprocess.Popen(
- master_command, stdout=FNULL, stderr=subprocess.STDOUT)
+ master_command, stdout=FNULL, stderr=subprocess.STDOUT, close_fds=True)
if cpu_num > 0:
# Sleep 1s for master ready
time.sleep(1)
_ = subprocess.Popen(
- worker_command, stdout=FNULL, stderr=subprocess.STDOUT)
+ worker_command,
+ stdout=FNULL,
+ stderr=subprocess.STDOUT,
+ close_fds=True)
if _IS_WINDOWS:
# TODO(@zenghsh3) redirecting stdout of monitor subprocess to FNULL will cause occasional failure
tmp_file = tempfile.TemporaryFile()
- _ = subprocess.Popen(monitor_command, stdout=tmp_file)
+ _ = subprocess.Popen(monitor_command, stdout=tmp_file, close_fds=True)
tmp_file.close()
else:
_ = subprocess.Popen(
- monitor_command, stdout=FNULL, stderr=subprocess.STDOUT)
+ monitor_command,
+ stdout=FNULL,
+ stderr=subprocess.STDOUT,
+ close_fds=True)
FNULL.close()
if cpu_num > 0:
@@ -285,7 +291,7 @@ def start_worker(address, cpu_num, log_server_port_range):
str(cpu_num), "--log_server_port",
str(log_server_port)
]
- p = subprocess.Popen(command)
+ p = subprocess.Popen(command, close_fds=True)
if not is_log_server_started(get_ip_address(), log_server_port):
click.echo("# Fail to start the log server.")
diff --git a/parl/remote/tests/get_set_attribute_test.py b/parl/remote/tests/get_set_attribute_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..68eef3f9ef63322828fdcbe25c9137e877e781b1
--- /dev/null
+++ b/parl/remote/tests/get_set_attribute_test.py
@@ -0,0 +1,169 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+import parl
+import numpy as np
+from parl.remote.client import disconnect
+from parl.utils import logger
+from parl.remote.master import Master
+from parl.remote.worker import Worker
+import time
+import threading
+import random
+
+
+@parl.remote_class
+class Actor(object):
+ def __init__(self, arg1, arg2, arg3, arg4):
+ self.arg1 = arg1
+ self.arg2 = arg2
+ self.arg3 = arg3
+ self.GLOBAL_CLIENT = arg4
+
+ def arg1(self, x, y):
+ time.sleep(0.2)
+ return x + y
+
+ def arg5(self):
+ return 100
+
+ def set_new_attr(self):
+ self.new_attr_1 = 200
+
+
+class Test_get_and_set_attribute(unittest.TestCase):
+ def tearDown(self):
+ disconnect()
+
+ def test_get_attribute(self):
+ port1 = random.randint(6100, 6200)
+ logger.info("running:test_get_attirbute")
+ master = Master(port=port1)
+ th = threading.Thread(target=master.run)
+ th.start()
+ time.sleep(3)
+ worker1 = Worker('localhost:{}'.format(port1), 1)
+ arg1 = np.random.randint(100)
+ arg2 = np.random.randn()
+ arg3 = np.random.randn(3, 3)
+ arg4 = 100
+ parl.connect('localhost:{}'.format(port1))
+ actor = Actor(arg1, arg2, arg3, arg4)
+ self.assertTrue(arg1 == actor.arg1)
+ self.assertTrue(arg2 == actor.arg2)
+ self.assertTrue((arg3 == actor.arg3).all())
+ self.assertTrue(arg4 == actor.GLOBAL_CLIENT)
+ master.exit()
+ worker1.exit()
+
+ def test_set_attribute(self):
+ port2 = random.randint(6200, 6300)
+ logger.info("running:test_set_attirbute")
+ master = Master(port=port2)
+ th = threading.Thread(target=master.run)
+ th.start()
+ time.sleep(3)
+ worker1 = Worker('localhost:{}'.format(port2), 1)
+ arg1 = 3
+ arg2 = 3.5
+ arg3 = np.random.randn(3, 3)
+ arg4 = 100
+ parl.connect('localhost:{}'.format(port2))
+ actor = Actor(arg1, arg2, arg3, arg4)
+ actor.arg1 = arg1
+ actor.arg2 = arg2
+ actor.arg3 = arg3
+ actor.GLOBAL_CLIENT = arg4
+ self.assertTrue(arg1 == actor.arg1)
+ self.assertTrue(arg2 == actor.arg2)
+ self.assertTrue((arg3 == actor.arg3).all())
+ self.assertTrue(arg4 == actor.GLOBAL_CLIENT)
+ master.exit()
+ worker1.exit()
+
+ def test_create_new_attribute_same_with_wrapper(self):
+ port3 = random.randint(6400, 6500)
+ logger.info("running:test_create_new_attribute_same_with_wrapper")
+ master = Master(port=port3)
+ th = threading.Thread(target=master.run)
+ th.start()
+ time.sleep(3)
+ worker1 = Worker('localhost:{}'.format(port3), 1)
+ arg1 = np.random.randint(100)
+ arg2 = np.random.randn()
+ arg3 = np.random.randn(3, 3)
+ arg4 = 100
+ parl.connect('localhost:{}'.format(port3))
+ actor = Actor(arg1, arg2, arg3, arg4)
+
+ actor.internal_lock = 50
+ self.assertTrue(actor.internal_lock == 50)
+ master.exit()
+ worker1.exit()
+
+ def test_same_name_of_attribute_and_method(self):
+ port4 = random.randint(6500, 6600)
+ logger.info("running:test_same_name_of_attribute_and_method")
+ master = Master(port=port4)
+ th = threading.Thread(target=master.run)
+ th.start()
+ time.sleep(3)
+ worker1 = Worker('localhost:{}'.format(port4), 1)
+ arg1 = np.random.randint(100)
+ arg2 = np.random.randn()
+ arg3 = np.random.randn(3, 3)
+ arg4 = 100
+ parl.connect('localhost:{}'.format(port4))
+ actor = Actor(arg1, arg2, arg3, arg4)
+ self.assertEqual(arg1, actor.arg1)
+
+ def call_method():
+ return actor.arg1(1, 2)
+
+ self.assertRaises(TypeError, call_method)
+ master.exit()
+ worker1.exit()
+
+ def test_non_existing_attribute_same_with_existing_method(self):
+ port5 = random.randint(6600, 6700)
+ logger.info(
+ "running:test_non_existing_attribute_same_with_existing_method")
+ master = Master(port=port5)
+ th = threading.Thread(target=master.run)
+ th.start()
+ time.sleep(3)
+ worker1 = Worker('localhost:{}'.format(port5), 1)
+ arg1 = np.random.randint(100)
+ arg2 = np.random.randn()
+ arg3 = np.random.randn(3, 3)
+ arg4 = 100
+ parl.connect('localhost:{}'.format(port5))
+ actor = Actor(arg1, arg2, arg3, arg4)
+ actor.new_attr_2 = 300
+ self.assertEqual(300, actor.new_attr_2)
+ actor.set_new_attr()
+ self.assertEqual(200, actor.new_attr_1)
+ self.assertTrue(callable(actor.arg5))
+
+ def call_non_existing_method():
+ return actor.arg2(10)
+
+ self.assertRaises(TypeError, call_non_existing_method)
+ master.exit()
+ worker1.exit()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/parl/remote/tests/log_server_test.py b/parl/remote/tests/log_server_test.py
index 931fc29538df1bc1c960c57e2f97a54e4bb8e0aa..868f5032c296af457e407a9b4555a0c36c75c895 100644
--- a/parl/remote/tests/log_server_test.py
+++ b/parl/remote/tests/log_server_test.py
@@ -24,6 +24,7 @@ import time
import unittest
import requests
+requests.adapters.DEFAULT_RETRIES = 5
import parl
from parl.remote.client import disconnect, get_global_client
@@ -125,10 +126,9 @@ class TestLogServer(unittest.TestCase):
th.start()
time.sleep(1)
# start the cluster monitor
- monitor_file = __file__.replace(
- os.path.join('tests', 'log_server_test.pyc'), 'monitor.py')
- monitor_file = monitor_file.replace(
- os.path.join('tests', 'log_server_test.py'), 'monitor.py')
+ monitor_file = __file__.replace('log_server_test.pyc', '../monitor.py')
+ monitor_file = monitor_file.replace('log_server_test.py',
+ '../monitor.py')
command = [
sys.executable, monitor_file, "--monitor_port",
str(monitor_port), "--address", "localhost:" + str(master_port)
@@ -138,10 +138,7 @@ class TestLogServer(unittest.TestCase):
else:
FNULL = open(os.devnull, 'w')
monitor_proc = subprocess.Popen(
- command,
- stdout=FNULL,
- stderr=subprocess.STDOUT,
- )
+ command, stdout=FNULL, stderr=subprocess.STDOUT, close_fds=True)
# Start worker
cluster_addr = 'localhost:{}'.format(master_port)
diff --git a/parl/remote/tests/reset_job_test.py b/parl/remote/tests/reset_job_test.py
index c76a8dc87feca88bcd7e5f6b9335fcb4133dae4e..fb1f8dca5795c8fd4e1d65e722206672dee1a213 100644
--- a/parl/remote/tests/reset_job_test.py
+++ b/parl/remote/tests/reset_job_test.py
@@ -69,7 +69,7 @@ class TestJob(unittest.TestCase):
file_path = __file__.replace('reset_job_test', 'simulate_client')
command = [sys.executable, file_path]
- proc = subprocess.Popen(command)
+ proc = subprocess.Popen(command, close_fds=True)
for _ in range(6):
if master.cpu_num == 0:
break
diff --git a/parl/remote/tests/support_RegExp_test.py b/parl/remote/tests/support_RegExp_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c51c3526262c3bd281c94c9e4f2fd55ca4776bec
--- /dev/null
+++ b/parl/remote/tests/support_RegExp_test.py
@@ -0,0 +1,99 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+import os
+import shutil
+import parl
+from parl.remote.master import Master
+from parl.remote.worker import Worker
+import time
+import threading
+from parl.remote.client import disconnect
+from parl.remote import exceptions
+from parl.utils import logger
+
+
+@parl.remote_class
+class Actor(object):
+ def file_exists(self, filename):
+ return os.path.exists(filename)
+
+
+class TestCluster(unittest.TestCase):
+ def tearDown(self):
+ disconnect()
+
+ def test_distributed_files_with_RegExp(self):
+ if os.path.exists('distribute_test_dir'):
+ shutil.rmtree('distribute_test_dir')
+ os.mkdir('distribute_test_dir')
+ f = open('distribute_test_dir/test1.txt', 'wb')
+ f.close()
+ f = open('distribute_test_dir/test2.txt', 'wb')
+ f.close()
+ f = open('distribute_test_dir/data1.npy', 'wb')
+ f.close()
+ f = open('distribute_test_dir/data2.npy', 'wb')
+ f.close()
+ logger.info("running:test_distributed_files_with_RegExp")
+ master = Master(port=8605)
+ th = threading.Thread(target=master.run)
+ th.start()
+ time.sleep(3)
+ worker1 = Worker('localhost:8605', 1)
+ parl.connect(
+ 'localhost:8605',
+ distributed_files=[
+ 'distribute_test_dir/test*',
+ 'distribute_test_dir/*npy',
+ ])
+ actor = Actor()
+ self.assertTrue(actor.file_exists('distribute_test_dir/test1.txt'))
+ self.assertTrue(actor.file_exists('distribute_test_dir/test2.txt'))
+ self.assertTrue(actor.file_exists('distribute_test_dir/data1.npy'))
+ self.assertTrue(actor.file_exists('distribute_test_dir/data2.npy'))
+ self.assertFalse(actor.file_exists('distribute_test_dir/data3.npy'))
+ shutil.rmtree('distribute_test_dir')
+ master.exit()
+ worker1.exit()
+
+ def test_miss_match_case(self):
+ if os.path.exists('distribute_test_dir_2'):
+ shutil.rmtree('distribute_test_dir_2')
+ os.mkdir('distribute_test_dir_2')
+ f = open('distribute_test_dir_2/test1.txt', 'wb')
+ f.close()
+ f = open('distribute_test_dir_2/data1.npy', 'wb')
+ f.close()
+ logger.info("running:test_distributed_files_with_RegExp_error_case")
+ master = Master(port=8606)
+ th = threading.Thread(target=master.run)
+ th.start()
+ time.sleep(3)
+ worker1 = Worker('localhost:8606', 1)
+
+ def connect_test():
+ parl.connect(
+ 'localhost:8606',
+ distributed_files=['distribute_test_dir_2/miss_match*'])
+
+ self.assertRaises(ValueError, connect_test)
+ shutil.rmtree('distribute_test_dir_2')
+ master.exit()
+ worker1.exit()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/parl/remote/tests/test_import_module/Module2.py b/parl/remote/tests/test_import_module/Module2.py
new file mode 100644
index 0000000000000000000000000000000000000000..b04fc66a5ad2006df414480aa32bee364ecba375
--- /dev/null
+++ b/parl/remote/tests/test_import_module/Module2.py
@@ -0,0 +1,20 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import parl
+
+
+@parl.remote_class
+class B(object):
+ def add_sum(self, a, b):
+ return a + b
diff --git a/parl/remote/tests/test_import_module/main_abs_test.py b/parl/remote/tests/test_import_module/main_abs_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c77dd3a516f136e1031a68d835477dfa3bd40712
--- /dev/null
+++ b/parl/remote/tests/test_import_module/main_abs_test.py
@@ -0,0 +1,46 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import unittest
+import parl
+import time
+import threading
+from parl.remote.master import Master
+from parl.remote.worker import Worker
+from parl.remote.client import disconnect
+
+
+class TestImport(unittest.TestCase):
+ def tearDown(self):
+ disconnect()
+
+ def test_import_local_module(self):
+ from Module2 import B
+ port = 8448
+ master = Master(port=port)
+ th = threading.Thread(target=master.run)
+ th.start()
+ time.sleep(1)
+ worker = Worker('localhost:{}'.format(port), 1)
+ time.sleep(10)
+ parl.connect("localhost:8448")
+ obj = B()
+ res = obj.add_sum(10, 5)
+ self.assertEqual(res, 15)
+ worker.exit()
+ master.exit()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/parl/remote/tests/test_import_module/main_test.py b/parl/remote/tests/test_import_module/main_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..128c47c8ca5a406670c34486dd4fb7bd6d6b63a8
--- /dev/null
+++ b/parl/remote/tests/test_import_module/main_test.py
@@ -0,0 +1,82 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import unittest
+import parl
+import time
+import threading
+from parl.remote.master import Master
+from parl.remote.worker import Worker
+from parl.remote.client import disconnect
+
+
+class TestImport(unittest.TestCase):
+ def tearDown(self):
+ disconnect()
+
+ def test_import_local_module(self):
+ from Module2 import B
+ port = 8442
+ master = Master(port=port)
+ th = threading.Thread(target=master.run)
+ th.start()
+ time.sleep(1)
+ worker = Worker('localhost:{}'.format(port), 1)
+ time.sleep(10)
+ parl.connect("localhost:8442")
+ obj = B()
+ res = obj.add_sum(10, 5)
+ self.assertEqual(res, 15)
+ worker.exit()
+ master.exit()
+
+ def test_import_subdir_module_0(self):
+ from subdir import Module
+ port = 8443
+ master = Master(port=port)
+ th = threading.Thread(target=master.run)
+ th.start()
+ time.sleep(1)
+ worker = Worker('localhost:{}'.format(port), 1)
+ time.sleep(10)
+ parl.connect(
+ "localhost:8443",
+ distributed_files=['./subdir/Module.py', './subdir/__init__.py'])
+ obj = Module.A()
+ res = obj.add_sum(10, 5)
+ self.assertEqual(res, 15)
+ worker.exit()
+ master.exit()
+
+ def test_import_subdir_module_1(self):
+ from subdir.Module import A
+ port = 8444
+ master = Master(port=port)
+ th = threading.Thread(target=master.run)
+ th.start()
+ time.sleep(1)
+ worker = Worker('localhost:{}'.format(port), 1)
+ time.sleep(10)
+ parl.connect(
+ "localhost:8444",
+ distributed_files=['./subdir/Module.py', './subdir/__init__.py'])
+ obj = A()
+ res = obj.add_sum(10, 5)
+ self.assertEqual(res, 15)
+ worker.exit()
+ master.exit()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/parl/remote/tests/test_import_module/subdir/Module.py b/parl/remote/tests/test_import_module/subdir/Module.py
new file mode 100644
index 0000000000000000000000000000000000000000..c06ba3bfe46d28476ab2d6eb94d0f724cab63851
--- /dev/null
+++ b/parl/remote/tests/test_import_module/subdir/Module.py
@@ -0,0 +1,20 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import parl
+
+
+@parl.remote_class
+class A(object):
+ def add_sum(self, a, b):
+ return a + b
diff --git a/parl/remote/tests/test_import_module/subdir/__init__.py b/parl/remote/tests/test_import_module/subdir/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..847ddc47ac89114f2012bc6b9990a69abfe39fb3
--- /dev/null
+++ b/parl/remote/tests/test_import_module/subdir/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/parl/remote/utils.py b/parl/remote/utils.py
index 9a2ece8686ff7de73c8164565f34281e412aa4ee..5d6368e21d2f9d19a72f28eea0d214df8a913664 100644
--- a/parl/remote/utils.py
+++ b/parl/remote/utils.py
@@ -13,8 +13,12 @@
# limitations under the License.
import sys
from contextlib import contextmanager
+import os
+from parl.utils import isnotebook
-__all__ = ['load_remote_class', 'redirect_stdout_to_file']
+__all__ = [
+ 'load_remote_class', 'redirect_stdout_to_file', 'locate_remote_file'
+]
def simplify_code(code, end_of_file):
@@ -32,7 +36,7 @@ def simplify_code(code, end_of_file):
def data_process():
XXXX
------------------>
- The last two lines of the above code block will be removed as they are not class related.
+ The last two lines of the above code block will be removed as they are not class-related.
"""
to_write_lines = []
for i, line in enumerate(code):
@@ -60,12 +64,18 @@ def load_remote_class(file_name, class_name, end_of_file):
with open(file_name + '.py') as t_file:
code = t_file.readlines()
code = simplify_code(code, end_of_file)
- module_name = 'xparl_' + file_name
- tmp_file_name = 'xparl_' + file_name + '.py'
+ #folder/xx.py -> folder/xparl_xx.py
+ file_name = file_name.split(os.sep)
+ prefix = os.sep.join(file_name[:-1])
+ if prefix == "":
+ prefix = '.'
+ module_name = prefix + os.sep + 'xparl_' + file_name[-1]
+ tmp_file_name = module_name + '.py'
with open(tmp_file_name, 'w') as t_file:
for line in code:
t_file.write(line)
- mod = __import__(module_name)
+ module_name = module_name.lstrip('.' + os.sep).replace(os.sep, '.')
+ mod = __import__(module_name, globals(), locals(), [class_name], 0)
cls = getattr(mod, class_name)
return cls
@@ -74,6 +84,9 @@ def load_remote_class(file_name, class_name, end_of_file):
def redirect_stdout_to_file(file_path):
"""Redirect stdout (e.g., `print`) to specified file.
+ Args:
+ file_path: Path of the file to output the stdout.
+
Example:
>>> print('test')
test
@@ -81,10 +94,6 @@ def redirect_stdout_to_file(file_path):
... print('test') # Output nothing, `test` is printed to `test.log`.
>>> print('test')
test
-
- Args:
- file_path: Path of the file to output the stdout.
-
"""
tmp = sys.stdout
f = open(file_path, 'a')
@@ -94,3 +103,37 @@ def redirect_stdout_to_file(file_path):
finally:
sys.stdout = tmp
f.close()
+
+
+def locate_remote_file(module_path):
+ """xparl has to locate the file that has the class decorated by parl.remote_class.
+ This function returns the relative path between this file and the entry file.
+ Note that this function should support the jupyter-notebook environment.
+
+ Args:
+ module_path: Absolute path of the module.
+
+ Example:
+ module_path: /home/user/dir/subdir/my_module
+ entry_file: /home/user/dir/main.py
+ --------> relative_path: subdir/my_module
+ """
+ if isnotebook():
+ entry_path = os.getcwd()
+ else:
+ entry_file = sys.argv[0]
+ entry_file = entry_file.split(os.sep)[-1]
+ entry_path = None
+ for path in sys.path:
+ to_check_path = os.path.join(path, entry_file)
+ if os.path.isfile(to_check_path):
+ entry_path = path
+ break
+ if entry_path is None or \
+ (module_path.startswith(os.sep) and entry_path != module_path[:len(entry_path)]):
+ raise FileNotFoundError("cannot locate the remote file")
+ if module_path.startswith(os.sep):
+ relative_module_path = '.' + module_path[len(entry_path):]
+ else:
+ relative_module_path = module_path
+ return relative_module_path
diff --git a/parl/remote/worker.py b/parl/remote/worker.py
index eec5598c6d081ca054541657c61670ecffc70cee..e16e11d2345c80ff0c5f6f33a2296271c9a21c31 100644
--- a/parl/remote/worker.py
+++ b/parl/remote/worker.py
@@ -26,7 +26,7 @@ import threading
import warnings
import zmq
from datetime import datetime
-
+import parl
from parl.utils import get_ip_address, to_byte, to_str, logger, _IS_WINDOWS, kill_process
from parl.remote import remote_constants
from parl.remote.message import InitializedWorker
@@ -72,10 +72,10 @@ class Worker(object):
self.master_is_alive = True
self.worker_is_alive = True
self.worker_status = None # initialized at `self._create_jobs`
- self.lock = threading.Lock()
self._set_cpu_num(cpu_num)
self.job_buffer = queue.Queue(maxsize=self.cpu_num)
self._create_sockets()
+ self.check_version()
# create log server
self.log_server_proc, self.log_server_address = self._create_log_server(
port=log_server_port)
@@ -102,6 +102,24 @@ class Worker(object):
else:
self.cpu_num = multiprocessing.cpu_count()
+ def check_version(self):
+ '''Verify that the parl & python version in 'worker' process matches that of the 'master' process'''
+ self.request_master_socket.send_multipart(
+ [remote_constants.CHECK_VERSION_TAG])
+ message = self.request_master_socket.recv_multipart()
+ tag = message[0]
+ if tag == remote_constants.NORMAL_TAG:
+ worker_parl_version = parl.__version__
+ worker_python_version = str(sys.version_info.major)
+ assert worker_parl_version == to_str(message[1]) and worker_python_version == to_str(message[2]),\
+ '''Version mismatch: the "master" is of version "parl={}, python={}". However,
+ "parl={}, python={}"is provided in your environment.'''.format(
+ to_str(message[1]), to_str(message[2]),
+ worker_parl_version, worker_python_version
+ )
+ else:
+ raise NotImplementedError
+
def _create_sockets(self):
""" Each worker has three sockets at start:
@@ -209,7 +227,11 @@ class Worker(object):
# Redirect the output to DEVNULL
FNULL = open(os.devnull, 'w')
for _ in range(job_num):
- subprocess.Popen(command, stdout=FNULL, stderr=subprocess.STDOUT)
+ subprocess.Popen(
+ command,
+ stdout=FNULL,
+ stderr=subprocess.STDOUT,
+ close_fds=True)
FNULL.close()
new_jobs = []
@@ -384,10 +406,7 @@ class Worker(object):
else:
FNULL = open(os.devnull, 'w')
log_server_proc = subprocess.Popen(
- command,
- stdout=FNULL,
- stderr=subprocess.STDOUT,
- )
+ command, stdout=FNULL, stderr=subprocess.STDOUT, close_fds=True)
FNULL.close()
log_server_address = "{}:{}".format(self.worker_ip, port)
diff --git a/parl/utils/csv_logger.py b/parl/utils/csv_logger.py
index e5152e599b831fab0dbed3a43d3bb40b3412bcb5..8b9045658338e63a2f2d92371e96b755d1d74442 100644
--- a/parl/utils/csv_logger.py
+++ b/parl/utils/csv_logger.py
@@ -19,12 +19,24 @@ __all__ = ['CSVLogger']
class CSVLogger(object):
def __init__(self, output_file):
- """CSV Logger which can write dict result to csv file
+ """CSV Logger which can write dict result to csv file.
+
+ Args:
+ output_file(str): filename of the csv file.
"""
self.output_file = open(output_file, "w")
self.csv_writer = None
def log_dict(self, result):
+ """Ouput result to the csv file.
+
+ Will create the header of the csv file automatically when the function is called for the first time.
+ Ususally, the keys of the result should be the same every time you call the function.
+
+ Args:
+ result(dict)
+ """
+ assert isinstance(result, dict), "the input should be a dict."
if self.csv_writer is None:
self.csv_writer = csv.DictWriter(self.output_file, result.keys())
self.csv_writer.writeheader()
@@ -38,4 +50,9 @@ class CSVLogger(object):
self.output_file.flush()
def close(self):
- self.output_file.close()
+ if not self.output_file.closed:
+ self.output_file.close()
+
+ def __del__(self):
+ if not self.output_file.closed:
+ self.output_file.close()
diff --git a/parl/utils/utils.py b/parl/utils/utils.py
index effb572811a6f72b9c6e363901fa632bdc75a625..69af1511a437cce56a9dc80885c6eb7086463160 100644
--- a/parl/utils/utils.py
+++ b/parl/utils/utils.py
@@ -20,7 +20,7 @@ import numpy as np
__all__ = [
'has_func', 'to_str', 'to_byte', 'is_PY2', 'is_PY3', 'MAX_INT32',
'_HAS_FLUID', '_HAS_TORCH', '_IS_WINDOWS', '_IS_MAC', 'kill_process',
- 'get_fluid_version'
+ 'get_fluid_version', 'isnotebook'
]
@@ -101,3 +101,19 @@ def kill_process(regex_pattern):
command = "ps aux | grep {} | awk '{{print $2}}' | xargs kill -9".format(
regex_pattern)
subprocess.call([command], shell=True)
+
+
+def isnotebook():
+ """check if the code is excuted in the IPython notebook
+ Reference: https://stackoverflow.com/a/39662359
+ """
+ try:
+ shell = get_ipython().__class__.__name__
+ if shell == 'ZMQInteractiveShell':
+ return True # Jupyter notebook or qtconsole
+ elif shell == 'TerminalInteractiveShell':
+ return False # Terminal running IPython
+ else:
+ return False # Other type (?)
+ except NameError:
+ return False # Probably standard Python interpreter
diff --git a/setup.py b/setup.py
index 659af892019a5142ad122bf8213f910e3be14483..ddd61cfe4e829f6964be488b2cedc737f1c27bd0 100644
--- a/setup.py
+++ b/setup.py
@@ -82,7 +82,7 @@ setup(
"click",
"psutil>=5.6.2",
"flask_cors",
- "visualdl>=2.0.0b;python_version>='3' and platform_system=='Linux'",
+ "visualdl>=2.0.0b;python_version>='3.7' and platform_system=='Linux'",
],
classifiers=[
'Intended Audience :: Developers',