提交 3565b546 编写于 作者: T TomorrowIsAnOtherDay

Merge branch 'develop' into xparl_doc

# requirements for unittest # requirements for unittest
rarfile==3.1 rarfile==3.1
opencv-python<=4.3.0.34;python_version>="3"
opencv-python==4.2.0.32;python_version<"3"
paddlepaddle-gpu==1.6.1.post97 paddlepaddle-gpu==1.6.1.post97
gym gym
details details
......
...@@ -30,10 +30,20 @@ function(py_test TARGET_NAME) ...@@ -30,10 +30,20 @@ function(py_test TARGET_NAME)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS ARGS ENVS) set(multiValueArgs SRCS DEPS ARGS ENVS)
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME} if (${FILE_NAME} MATCHES ".*abs_test.py")
add_test(NAME ${TARGET_NAME}"_with_abs_path"
COMMAND python -u ${py_test_SRCS} ${py_test_ARGS} COMMAND python -u ${py_test_SRCS} ${py_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
set_tests_properties(${TARGET_NAME}"_with_abs_path" PROPERTIES TIMEOUT 300)
else()
get_filename_component(WORKING_DIR ${py_test_SRCS} DIRECTORY)
get_filename_component(FILE_NAME ${py_test_SRCS} NAME)
get_filename_component(COMBINED_PATH ${CMAKE_CURRENT_SOURCE_DIR}/${WORKING_DIR} ABSOLUTE)
add_test(NAME ${TARGET_NAME}
COMMAND python -u ${FILE_NAME} ${py_test_ARGS}
WORKING_DIRECTORY ${COMBINED_PATH})
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 300) set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 300)
endif()
endfunction() endfunction()
function(import_test TARGET_NAME) function(import_test TARGET_NAME)
......
...@@ -126,7 +126,13 @@ yapf -i modified_file.py ...@@ -126,7 +126,13 @@ yapf -i modified_file.py
``` ```
- 持续集成测试<br> - 持续集成测试<br>
当增加代码时候,需要增加测试代码覆盖所添加的代码,测试代码得放在相关代码文件的`tests`文件夹下,以`_test.py`结尾(这样持续集成测试会自动拉取代码跑)。附:[测试代码示例](../../parl/tests/import_test.py) 当增加代码时候,需要增加测试代码覆盖所添加的代码,测试代码得放在相关代码文件的`tests`文件夹下,以`_test.py`结尾(这样持续集成测试会自动拉取代码跑)。附:[测试代码示例](../../parl/tests/import_test.py)
- 本地运行单元测试(非必要)<br>
如果你希望在自己的机器运行单测代码,可先在本地机器上安装Docker,再按以下步骤执行单测任务。
```
cd PARL
docker build -t parl/parl-test:unittest .teamcity/
nvidia-docker run -i --rm -v $PWD:/work -w /work parl/parl-test:unittest .teamcity/build.sh test
```
## 反馈 ## 反馈
- 在 GitHub 上[提交问题](https://github.com/PaddlePaddle/PARL/issues) - 在 GitHub 上[提交问题](https://github.com/PaddlePaddle/PARL/issues)
...@@ -24,4 +24,4 @@ PARL在实现底层的并行计算时,是通过端到端的这种网络传输 ...@@ -24,4 +24,4 @@ PARL在实现底层的并行计算时,是通过端到端的这种网络传输
## 自动分发本地文件 ## 自动分发本地文件
市面上的并行框架大部分得要用户手动同步文件才可以跑起并行代码,比如配置文件得要手动或者通过命令分发到不同机器,parl可以自动分发当前目录下的代码文件,实现无缝的多机并行。 市面上的并行框架大部分得要用户手动同步文件才可以跑起并行代码,比如配置文件得要手动或者通过命令分发到不同机器,parl可以自动分发当前目录下的代码文件,实现无缝的多机并行。
<img src="../../parallel_training/comparison.png" width="500"/> <img src="../../parallel_training/comparison.png" width="1000"/>
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
## 配置命令 ## 配置命令
这个教程将会演示如何搭建一个集群。 这个教程将会演示如何搭建一个集群。
搭建一个PARL集群,可以通过执行下面两个`xparl`命令: 搭建一个PARL集群,可以通过执行下面`xparl`命令:
### 启动集群 ### 启动集群
```bash ```bash
...@@ -12,17 +12,17 @@ xparl start --port 6006 ...@@ -12,17 +12,17 @@ xparl start --port 6006
这个命令会启动一个主节点(master)来管理集群的计算资源,同时会把本地机器的CPU资源加入到集群中。命令中的6006端口只是作为示例,你可以修改成任何有效的端口。 这个命令会启动一个主节点(master)来管理集群的计算资源,同时会把本地机器的CPU资源加入到集群中。命令中的6006端口只是作为示例,你可以修改成任何有效的端口。
### 加入其它机器资源 启动后可通过`xparl status`查看目前集群有多少CPU资源可用,你可以在`xparl start`的命令中加入选项`--cpu_num [CPU_NUM]` (例如:--cpu_num 10)指定本机加入集群的CPU数量。
> 注意:如果你只有单台机器,可以忽略这部分教程。
如果你想加入更多的CPU计算资源到集群中,可以在其他机器上运行下面命令: ### 加入更多CPU资源
启动集群后,就可以直接使用集群了,如果CPU资源不够用,你可以在任何时候和任何机器(包括本机或其他机器)上,通过执行`xparl connect`命令把更多CPU资源加入到集群中。
```bash ```bash
xparl connect --address [MASTER_ADDRESS]:6006 xparl connect --address [MASTER_ADDRESS]:6006
``` ```
它会启动一个工作节点(worker),并把当前机器的CPU资源加入到该master对应的集群。worker默认会把所有的CPU资源加入到集群中,如果你需要指定worker可使用的CPU数量,可以在上述命令上加入选项`--cpu_num [CPU_NUM]` (例如:----cpu_num 10) 它会启动一个工作节点(worker),并把当前机器的CPU资源加入到`--address`指定的master集群。worker默认会把当前机器所有的可用的CPU资源加入到集群中,如果你需要指定加入的CPU数量,也可以在上述命令上加入选项`--cpu_num [CPU_NUM]`
注意:启动集群后,你可以在任何时候和任何机器上,通过执行`xparl connect`命令把更多CPU资源加入到集群中。
## 示例 ## 示例
这里我们给出了一个示例来演示如何通过`@parl.remote_class`来进行并行计算。 这里我们给出了一个示例来演示如何通过`@parl.remote_class`来进行并行计算。
...@@ -47,9 +47,9 @@ actor.add(1, 2) # 返回 3 ...@@ -47,9 +47,9 @@ actor.add(1, 2) # 返回 3
``` ```
## 关闭集群 ## 关闭集群
在master机器上运行`xparl stop`命令即可关闭集群程序。当master节点退出后,运行在其他机器的worker节点也会自动退出并结束相关程序。 在master机器上运行`xparl stop`命令即可关闭集群程序。当master节点退出后,与之关联的worker节点也会自动退出并结束相关程序。
## 扩展阅读 ## 扩展阅读
我们现在已经知道了如何搭建一个集群,以及如何通过修饰符`@parl.remote_class`来使用集群。 我们现在已经知道了如何通过终端命令`xparl`搭建一个集群,以及如何通过修饰符`@parl.remote_class`来使用集群。
[下一个教程](./example.md)我们将会演示如何通过这个修饰符来打破Python的全局解释器锁(Global Interpreter Lock, GIL)限制,从而实现真正的多线程计算。 [下一个教程](./example.md)我们将会演示如何通过这个修饰符来打破Python的全局解释器锁(Global Interpreter Lock, GIL)限制,从而实现真正的多线程计算。
...@@ -53,7 +53,7 @@ class Model(ModelBase): ...@@ -53,7 +53,7 @@ class Model(ModelBase):
copied_policy = copy.deepcopy(model) copied_policy = copy.deepcopy(model)
Attributes: Attributes:
model_id(str): each model instance has its uniqe model_id. model_id(str): each model instance has its unique model_id.
Public Functions: Public Functions:
- ``sync_weights_to``: synchronize parameters of the current model to another model. - ``sync_weights_to``: synchronize parameters of the current model to another model.
......
...@@ -19,9 +19,11 @@ import socket ...@@ -19,9 +19,11 @@ import socket
import sys import sys
import threading import threading
import zmq import zmq
from parl.utils import to_str, to_byte, get_ip_address, logger import parl
from parl.utils import to_str, to_byte, get_ip_address, logger, isnotebook
from parl.remote import remote_constants from parl.remote import remote_constants
import time import time
import glob
class Client(object): class Client(object):
...@@ -50,7 +52,6 @@ class Client(object): ...@@ -50,7 +52,6 @@ class Client(object):
distributed_files (list): A list of files to be distributed at all distributed_files (list): A list of files to be distributed at all
remote instances(e,g. the configuration remote instances(e,g. the configuration
file for initialization) . file for initialization) .
""" """
self.master_address = master_address self.master_address = master_address
self.process_id = process_id self.process_id = process_id
...@@ -66,6 +67,7 @@ class Client(object): ...@@ -66,6 +67,7 @@ class Client(object):
self.actor_num = 0 self.actor_num = 0
self._create_sockets(master_address) self._create_sockets(master_address)
self.check_version()
self.pyfiles = self.read_local_files(distributed_files) self.pyfiles = self.read_local_files(distributed_files)
def get_executable_path(self): def get_executable_path(self):
...@@ -85,44 +87,58 @@ class Client(object): ...@@ -85,44 +87,58 @@ class Client(object):
Args: Args:
distributed_files (list): A list of files to be distributed at all distributed_files (list): A list of files to be distributed at all
remote instances(e,g. the configuration remote instances(e,g. the configuration
file for initialization) . file for initialization) . RegExp of file
names is supported.
e.g.
distributed_files = ['./*.npy', './test*']
Returns: Returns:
A cloudpickled dictionary containing the python code in current A cloudpickled dictionary containing the python code in current
working directory. working directory.
""" """
parsed_distributed_files = set()
for distributed_file in distributed_files:
parsed_list = glob.glob(distributed_file)
if not parsed_list:
raise ValueError(
"no local file is matched with '{}', please check your input"
.format(distributed_file))
# exclude the directiories
for pathname in parsed_list:
if not os.path.isdir(pathname):
parsed_distributed_files.add(pathname)
pyfiles = dict() pyfiles = dict()
pyfiles['python_files'] = {} pyfiles['python_files'] = {}
pyfiles['other_files'] = {} pyfiles['other_files'] = {}
code_files = filter(lambda x: x.endswith('.py'), os.listdir('./')) if isnotebook():
main_folder = './'
try: else:
for file in code_files: main_file = sys.argv[0]
assert os.path.exists(file) main_folder = './'
with open(file, 'rb') as code_file: sep = os.sep
if sep in main_file:
main_folder = sep.join(main_file.split(sep)[:-1])
code_files = filter(lambda x: x.endswith('.py'),
os.listdir(main_folder))
for file_name in code_files:
file_path = os.path.join(main_folder, file_name)
assert os.path.exists(file_path)
with open(file_path, 'rb') as code_file:
code = code_file.read() code = code_file.read()
pyfiles['python_files'][file] = code pyfiles['python_files'][file_name] = code
for file in distributed_files: for file_name in parsed_distributed_files:
assert os.path.exists(file) assert os.path.exists(file_name)
assert not os.path.isabs( assert not os.path.isabs(
file file_name
), "[XPARL] Please do not distribute a file with absolute path." ), "[XPARL] Please do not distribute a file with absolute path."
with open(file, 'rb') as f: with open(file_name, 'rb') as f:
content = f.read() content = f.read()
pyfiles['other_files'][file] = content pyfiles['other_files'][file_name] = content
# append entry file to code list
main_file = sys.argv[0]
with open(main_file, 'rb') as code_file:
code = code_file.read()
# parl/remote/remote_decorator.py -> remote_decorator.py
file_name = main_file.split(os.sep)[-1]
pyfiles['python_files'][file_name] = code
except AssertionError as e:
raise Exception(
'Failed to create the client, the file {} does not exist.'.
format(file))
return cloudpickle.dumps(pyfiles) return cloudpickle.dumps(pyfiles)
def _create_sockets(self, master_address): def _create_sockets(self, master_address):
...@@ -165,6 +181,24 @@ class Client(object): ...@@ -165,6 +181,24 @@ class Client(object):
"check if master is started and ensure the input " "check if master is started and ensure the input "
"address {} is correct.".format(master_address)) "address {} is correct.".format(master_address))
def check_version(self):
'''Verify that the parl & python version in 'client' process matches that of the 'master' process'''
self.submit_job_socket.send_multipart(
[remote_constants.CHECK_VERSION_TAG])
message = self.submit_job_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.NORMAL_TAG:
client_parl_version = parl.__version__
client_python_version = str(sys.version_info.major)
assert client_parl_version == to_str(message[1]) and client_python_version == to_str(message[2]),\
'''Version mismatch: the 'master' is of version 'parl={}, python={}'. However,
'parl={}, python={}'is provided in your environment.'''.format(
to_str(message[1]), to_str(message[2]),
client_parl_version, client_python_version
)
else:
raise NotImplementedError
def _reply_heartbeat(self): def _reply_heartbeat(self):
"""Reply heartbeat signals to the master node.""" """Reply heartbeat signals to the master node."""
......
...@@ -311,8 +311,6 @@ class Job(object): ...@@ -311,8 +311,6 @@ class Job(object):
try: try:
file_name, class_name, end_of_file = cloudpickle.loads( file_name, class_name, end_of_file = cloudpickle.loads(
message[1]) message[1])
#/home/nlp-ol/Firework/baidu/nlp/evokit/python_api/es_agent -> es_agent
file_name = file_name.split(os.sep)[-1]
cls = load_remote_class(file_name, class_name, end_of_file) cls = load_remote_class(file_name, class_name, end_of_file)
args, kwargs = cloudpickle.loads(message[2]) args, kwargs = cloudpickle.loads(message[2])
logfile_path = os.path.join(self.log_dir, 'stdout.log') logfile_path = os.path.join(self.log_dir, 'stdout.log')
...@@ -327,7 +325,10 @@ class Job(object): ...@@ -327,7 +325,10 @@ class Job(object):
to_byte(error_str + "\ntraceback:\n" + traceback_str) to_byte(error_str + "\ntraceback:\n" + traceback_str)
]) ])
return None return None
reply_socket.send_multipart([remote_constants.NORMAL_TAG]) reply_socket.send_multipart([
remote_constants.NORMAL_TAG,
dumps_return(set(obj.__dict__.keys()))
])
else: else:
logger.error("Message from job {}".format(message)) logger.error("Message from job {}".format(message))
reply_socket.send_multipart([ reply_socket.send_multipart([
...@@ -397,11 +398,14 @@ class Job(object): ...@@ -397,11 +398,14 @@ class Job(object):
while True: while True:
message = reply_socket.recv_multipart() message = reply_socket.recv_multipart()
tag = message[0] tag = message[0]
if tag in [
if tag == remote_constants.CALL_TAG: remote_constants.CALL_TAG,
remote_constants.GET_ATTRIBUTE_TAG,
remote_constants.SET_ATTRIBUTE_TAG,
]:
try: try:
if tag == remote_constants.CALL_TAG:
function_name = to_str(message[1]) function_name = to_str(message[1])
data = message[2] data = message[2]
args, kwargs = loads_argument(data) args, kwargs = loads_argument(data)
...@@ -412,9 +416,31 @@ class Job(object): ...@@ -412,9 +416,31 @@ class Job(object):
ret = getattr(obj, function_name)(*args, **kwargs) ret = getattr(obj, function_name)(*args, **kwargs)
ret = dumps_return(ret) ret = dumps_return(ret)
reply_socket.send_multipart([
remote_constants.NORMAL_TAG, ret,
dumps_return(set(obj.__dict__.keys()))
])
elif tag == remote_constants.GET_ATTRIBUTE_TAG:
attribute_name = to_str(message[1])
logfile_path = os.path.join(self.log_dir, 'stdout.log')
with redirect_stdout_to_file(logfile_path):
ret = getattr(obj, attribute_name)
ret = dumps_return(ret)
reply_socket.send_multipart( reply_socket.send_multipart(
[remote_constants.NORMAL_TAG, ret]) [remote_constants.NORMAL_TAG, ret])
elif tag == remote_constants.SET_ATTRIBUTE_TAG:
attribute_name = to_str(message[1])
attribute_value = loads_return(message[2])
logfile_path = os.path.join(self.log_dir, 'stdout.log')
with redirect_stdout_to_file(logfile_path):
setattr(obj, attribute_name, attribute_value)
reply_socket.send_multipart([
remote_constants.NORMAL_TAG,
dumps_return(set(obj.__dict__.keys()))
])
else:
pass
except Exception as e: except Exception as e:
# reset the job # reset the job
......
...@@ -18,6 +18,8 @@ import threading ...@@ -18,6 +18,8 @@ import threading
import time import time
import zmq import zmq
from collections import deque, defaultdict from collections import deque, defaultdict
import parl
import sys
from parl.utils import to_str, to_byte, logger, get_ip_address from parl.utils import to_str, to_byte, logger, get_ip_address
from parl.remote import remote_constants from parl.remote import remote_constants
from parl.remote.job_center import JobCenter from parl.remote.job_center import JobCenter
...@@ -208,6 +210,7 @@ class Master(object): ...@@ -208,6 +210,7 @@ class Master(object):
elif tag == remote_constants.CLIENT_CONNECT_TAG: elif tag == remote_constants.CLIENT_CONNECT_TAG:
# `client_heartbeat_address` is the # `client_heartbeat_address` is the
# `reply_master_heartbeat_address` of the client # `reply_master_heartbeat_address` of the client
client_heartbeat_address = to_str(message[1]) client_heartbeat_address = to_str(message[1])
client_hostname = to_str(message[2]) client_hostname = to_str(message[2])
client_id = to_str(message[3]) client_id = to_str(message[3])
...@@ -225,6 +228,13 @@ class Master(object): ...@@ -225,6 +228,13 @@ class Master(object):
[remote_constants.NORMAL_TAG, [remote_constants.NORMAL_TAG,
to_byte(log_monitor_address)]) to_byte(log_monitor_address)])
elif tag == remote_constants.CHECK_VERSION_TAG:
self.client_socket.send_multipart([
remote_constants.NORMAL_TAG,
to_byte(parl.__version__),
to_byte(str(sys.version_info.major))
])
# a client submits a job to the master # a client submits a job to the master
elif tag == remote_constants.CLIENT_SUBMIT_TAG: elif tag == remote_constants.CLIENT_SUBMIT_TAG:
# check available CPU resources # check available CPU resources
......
...@@ -27,8 +27,11 @@ SEND_FILE_TAG = b'[SEND_FILE]' ...@@ -27,8 +27,11 @@ SEND_FILE_TAG = b'[SEND_FILE]'
SUBMIT_JOB_TAG = b'[SUBMIT_JOB]' SUBMIT_JOB_TAG = b'[SUBMIT_JOB]'
NEW_JOB_TAG = b'[NEW_JOB]' NEW_JOB_TAG = b'[NEW_JOB]'
CHECK_VERSION_TAG = b'[CHECK_VERSION]'
INIT_OBJECT_TAG = b'[INIT_OBJECT]' INIT_OBJECT_TAG = b'[INIT_OBJECT]'
CALL_TAG = b'[CALL]' CALL_TAG = b'[CALL]'
GET_ATTRIBUTE_TAG = b'[GET_ATTRIBUTE]'
SET_ATTRIBUTE_TAG = b'[SET_ATTRIBUTE]'
EXCEPTION_TAG = b'[EXCEPTION]' EXCEPTION_TAG = b'[EXCEPTION]'
ATTRIBUTE_EXCEPTION_TAG = b'[ATTRIBUTE_EXCEPTION]' ATTRIBUTE_EXCEPTION_TAG = b'[ATTRIBUTE_EXCEPTION]'
......
...@@ -19,6 +19,7 @@ import time ...@@ -19,6 +19,7 @@ import time
import zmq import zmq
import numpy as np import numpy as np
import inspect import inspect
import sys
from parl.utils import get_ip_address, logger, to_str, to_byte from parl.utils import get_ip_address, logger, to_str, to_byte
from parl.utils.communication import loads_argument, loads_return,\ from parl.utils.communication import loads_argument, loads_return,\
...@@ -27,6 +28,7 @@ from parl.remote import remote_constants ...@@ -27,6 +28,7 @@ from parl.remote import remote_constants
from parl.remote.exceptions import RemoteError, RemoteAttributeError,\ from parl.remote.exceptions import RemoteError, RemoteAttributeError,\
RemoteDeserializeError, RemoteSerializeError, ResourceError RemoteDeserializeError, RemoteSerializeError, ResourceError
from parl.remote.client import get_global_client from parl.remote.client import get_global_client
from parl.remote.utils import locate_remote_file
def remote_class(*args, **kwargs): def remote_class(*args, **kwargs):
...@@ -93,7 +95,7 @@ def remote_class(*args, **kwargs): ...@@ -93,7 +95,7 @@ def remote_class(*args, **kwargs):
class. class.
""" """
self.GLOBAL_CLIENT = get_global_client() self.GLOBAL_CLIENT = get_global_client()
self.remote_attribute_keys_set = set()
self.ctx = self.GLOBAL_CLIENT.ctx self.ctx = self.GLOBAL_CLIENT.ctx
# GLOBAL_CLIENT will set `master_is_alive` to False when hearbeat # GLOBAL_CLIENT will set `master_is_alive` to False when hearbeat
...@@ -120,21 +122,34 @@ def remote_class(*args, **kwargs): ...@@ -120,21 +122,34 @@ def remote_class(*args, **kwargs):
self.job_shutdown = False self.job_shutdown = False
self.send_file(self.job_socket) self.send_file(self.job_socket)
file_name = inspect.getfile(cls)[:-3] module_path = inspect.getfile(cls)
if module_path.endswith('pyc'):
module_path = module_path[:-4]
elif module_path.endswith('py'):
module_path = module_path[:-3]
else:
raise FileNotFoundError(
"cannot not find the module:{}".format(module_path))
res = inspect.getfile(cls)
file_path = locate_remote_file(module_path)
cls_source = inspect.getsourcelines(cls) cls_source = inspect.getsourcelines(cls)
end_of_file = cls_source[1] + len(cls_source[0]) end_of_file = cls_source[1] + len(cls_source[0])
class_name = cls.__name__ class_name = cls.__name__
self.job_socket.send_multipart([ self.job_socket.send_multipart([
remote_constants.INIT_OBJECT_TAG, remote_constants.INIT_OBJECT_TAG,
cloudpickle.dumps([file_name, class_name, end_of_file]), cloudpickle.dumps([file_path, class_name, end_of_file]),
cloudpickle.dumps([args, kwargs]), cloudpickle.dumps([args, kwargs]),
]) ])
message = self.job_socket.recv_multipart() message = self.job_socket.recv_multipart()
tag = message[0] tag = message[0]
if tag == remote_constants.EXCEPTION_TAG: if tag == remote_constants.NORMAL_TAG:
self.remote_attribute_keys_set = loads_return(message[1])
elif tag == remote_constants.EXCEPTION_TAG:
traceback_str = to_str(message[1]) traceback_str = to_str(message[1])
self.job_shutdown = True self.job_shutdown = True
raise RemoteError('__init__', traceback_str) raise RemoteError('__init__', traceback_str)
else:
pass
def __del__(self): def __del__(self):
"""Delete the remote class object and release remote resources.""" """Delete the remote class object and release remote resources."""
...@@ -179,16 +194,41 @@ def remote_class(*args, **kwargs): ...@@ -179,16 +194,41 @@ def remote_class(*args, **kwargs):
cnt -= 1 cnt -= 1
return None return None
def __getattr__(self, attr): def set_remote_attr(self, attr, value):
self.internal_lock.acquire()
self.job_socket.send_multipart([
remote_constants.SET_ATTRIBUTE_TAG,
to_byte(attr),
dumps_return(value)
])
message = self.job_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.NORMAL_TAG:
self.remote_attribute_keys_set = loads_return(message[1])
self.internal_lock.release()
else:
self.job_shutdown = True
raise NotImplementedError()
return
def get_remote_attr(self, attr):
"""Call the function of the unwrapped class.""" """Call the function of the unwrapped class."""
#check if attr is a attribute or a function
is_attribute = attr in self.remote_attribute_keys_set
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
self.internal_lock.acquire()
if is_attribute:
self.job_socket.send_multipart([
remote_constants.GET_ATTRIBUTE_TAG,
to_byte(attr)
])
else:
if self.job_shutdown: if self.job_shutdown:
raise RemoteError( raise RemoteError(
attr, "This actor losts connection with the job.") attr,
self.internal_lock.acquire() "This actor losts connection with the job.")
data = dumps_argument(*args, **kwargs) data = dumps_argument(*args, **kwargs)
self.job_socket.send_multipart( self.job_socket.send_multipart(
[remote_constants.CALL_TAG, [remote_constants.CALL_TAG,
to_byte(attr), data]) to_byte(attr), data])
...@@ -198,6 +238,11 @@ def remote_class(*args, **kwargs): ...@@ -198,6 +238,11 @@ def remote_class(*args, **kwargs):
if tag == remote_constants.NORMAL_TAG: if tag == remote_constants.NORMAL_TAG:
ret = loads_return(message[1]) ret = loads_return(message[1])
if not is_attribute:
self.remote_attribute_keys_set = loads_return(
message[2])
self.internal_lock.release()
return ret
elif tag == remote_constants.EXCEPTION_TAG: elif tag == remote_constants.EXCEPTION_TAG:
error_str = to_str(message[1]) error_str = to_str(message[1])
...@@ -223,13 +268,38 @@ def remote_class(*args, **kwargs): ...@@ -223,13 +268,38 @@ def remote_class(*args, **kwargs):
self.job_shutdown = True self.job_shutdown = True
raise NotImplementedError() raise NotImplementedError()
self.internal_lock.release() return wrapper() if is_attribute else wrapper
return ret
def proxy_wrapper_func(remote_wrapper):
'''
The 'proxy_wrapper_func' is defined on the top of class 'RemoteWrapper'
in order to set and get attributes of 'remoted_wrapper' and the corresponding
remote models individually.
With 'proxy_wrapper_func', it is allowed to define a attribute (or method) of
the same name in 'RemoteWrapper' and remote models.
'''
class ProxyWrapper(object):
def __init__(self, *args, **kwargs):
self.xparl_remote_wrapper_obj = remote_wrapper(
*args, **kwargs)
def __getattr__(self, attr):
return self.xparl_remote_wrapper_obj.get_remote_attr(attr)
def __setattr__(self, attr, value):
if attr == 'xparl_remote_wrapper_obj':
super(ProxyWrapper, self).__setattr__(attr, value)
else:
self.xparl_remote_wrapper_obj.set_remote_attr(
attr, value)
return wrapper return ProxyWrapper
RemoteWrapper._original = cls RemoteWrapper._original = cls
return RemoteWrapper proxy_wrapper = proxy_wrapper_func(RemoteWrapper)
return proxy_wrapper
max_memory = kwargs.get('max_memory') max_memory = kwargs.get('max_memory')
if len(args) == 1 and callable(args[0]): if len(args) == 1 and callable(args[0]):
......
...@@ -171,22 +171,28 @@ def start_master(port, cpu_num, monitor_port, debug, log_server_port_range): ...@@ -171,22 +171,28 @@ def start_master(port, cpu_num, monitor_port, debug, log_server_port_range):
# Redirect the output to DEVNULL to solve the warning log. # Redirect the output to DEVNULL to solve the warning log.
_ = subprocess.Popen( _ = subprocess.Popen(
master_command, stdout=FNULL, stderr=subprocess.STDOUT) master_command, stdout=FNULL, stderr=subprocess.STDOUT, close_fds=True)
if cpu_num > 0: if cpu_num > 0:
# Sleep 1s for master ready # Sleep 1s for master ready
time.sleep(1) time.sleep(1)
_ = subprocess.Popen( _ = subprocess.Popen(
worker_command, stdout=FNULL, stderr=subprocess.STDOUT) worker_command,
stdout=FNULL,
stderr=subprocess.STDOUT,
close_fds=True)
if _IS_WINDOWS: if _IS_WINDOWS:
# TODO(@zenghsh3) redirecting stdout of monitor subprocess to FNULL will cause occasional failure # TODO(@zenghsh3) redirecting stdout of monitor subprocess to FNULL will cause occasional failure
tmp_file = tempfile.TemporaryFile() tmp_file = tempfile.TemporaryFile()
_ = subprocess.Popen(monitor_command, stdout=tmp_file) _ = subprocess.Popen(monitor_command, stdout=tmp_file, close_fds=True)
tmp_file.close() tmp_file.close()
else: else:
_ = subprocess.Popen( _ = subprocess.Popen(
monitor_command, stdout=FNULL, stderr=subprocess.STDOUT) monitor_command,
stdout=FNULL,
stderr=subprocess.STDOUT,
close_fds=True)
FNULL.close() FNULL.close()
if cpu_num > 0: if cpu_num > 0:
...@@ -285,7 +291,7 @@ def start_worker(address, cpu_num, log_server_port_range): ...@@ -285,7 +291,7 @@ def start_worker(address, cpu_num, log_server_port_range):
str(cpu_num), "--log_server_port", str(cpu_num), "--log_server_port",
str(log_server_port) str(log_server_port)
] ]
p = subprocess.Popen(command) p = subprocess.Popen(command, close_fds=True)
if not is_log_server_started(get_ip_address(), log_server_port): if not is_log_server_started(get_ip_address(), log_server_port):
click.echo("# Fail to start the log server.") click.echo("# Fail to start the log server.")
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
import numpy as np
from parl.remote.client import disconnect
from parl.utils import logger
from parl.remote.master import Master
from parl.remote.worker import Worker
import time
import threading
import random
@parl.remote_class
class Actor(object):
def __init__(self, arg1, arg2, arg3, arg4):
self.arg1 = arg1
self.arg2 = arg2
self.arg3 = arg3
self.GLOBAL_CLIENT = arg4
def arg1(self, x, y):
time.sleep(0.2)
return x + y
def arg5(self):
return 100
def set_new_attr(self):
self.new_attr_1 = 200
class Test_get_and_set_attribute(unittest.TestCase):
def tearDown(self):
disconnect()
def test_get_attribute(self):
port1 = random.randint(6100, 6200)
logger.info("running:test_get_attirbute")
master = Master(port=port1)
th = threading.Thread(target=master.run)
th.start()
time.sleep(3)
worker1 = Worker('localhost:{}'.format(port1), 1)
arg1 = np.random.randint(100)
arg2 = np.random.randn()
arg3 = np.random.randn(3, 3)
arg4 = 100
parl.connect('localhost:{}'.format(port1))
actor = Actor(arg1, arg2, arg3, arg4)
self.assertTrue(arg1 == actor.arg1)
self.assertTrue(arg2 == actor.arg2)
self.assertTrue((arg3 == actor.arg3).all())
self.assertTrue(arg4 == actor.GLOBAL_CLIENT)
master.exit()
worker1.exit()
def test_set_attribute(self):
port2 = random.randint(6200, 6300)
logger.info("running:test_set_attirbute")
master = Master(port=port2)
th = threading.Thread(target=master.run)
th.start()
time.sleep(3)
worker1 = Worker('localhost:{}'.format(port2), 1)
arg1 = 3
arg2 = 3.5
arg3 = np.random.randn(3, 3)
arg4 = 100
parl.connect('localhost:{}'.format(port2))
actor = Actor(arg1, arg2, arg3, arg4)
actor.arg1 = arg1
actor.arg2 = arg2
actor.arg3 = arg3
actor.GLOBAL_CLIENT = arg4
self.assertTrue(arg1 == actor.arg1)
self.assertTrue(arg2 == actor.arg2)
self.assertTrue((arg3 == actor.arg3).all())
self.assertTrue(arg4 == actor.GLOBAL_CLIENT)
master.exit()
worker1.exit()
def test_create_new_attribute_same_with_wrapper(self):
port3 = random.randint(6400, 6500)
logger.info("running:test_create_new_attribute_same_with_wrapper")
master = Master(port=port3)
th = threading.Thread(target=master.run)
th.start()
time.sleep(3)
worker1 = Worker('localhost:{}'.format(port3), 1)
arg1 = np.random.randint(100)
arg2 = np.random.randn()
arg3 = np.random.randn(3, 3)
arg4 = 100
parl.connect('localhost:{}'.format(port3))
actor = Actor(arg1, arg2, arg3, arg4)
actor.internal_lock = 50
self.assertTrue(actor.internal_lock == 50)
master.exit()
worker1.exit()
def test_same_name_of_attribute_and_method(self):
port4 = random.randint(6500, 6600)
logger.info("running:test_same_name_of_attribute_and_method")
master = Master(port=port4)
th = threading.Thread(target=master.run)
th.start()
time.sleep(3)
worker1 = Worker('localhost:{}'.format(port4), 1)
arg1 = np.random.randint(100)
arg2 = np.random.randn()
arg3 = np.random.randn(3, 3)
arg4 = 100
parl.connect('localhost:{}'.format(port4))
actor = Actor(arg1, arg2, arg3, arg4)
self.assertEqual(arg1, actor.arg1)
def call_method():
return actor.arg1(1, 2)
self.assertRaises(TypeError, call_method)
master.exit()
worker1.exit()
def test_non_existing_attribute_same_with_existing_method(self):
port5 = random.randint(6600, 6700)
logger.info(
"running:test_non_existing_attribute_same_with_existing_method")
master = Master(port=port5)
th = threading.Thread(target=master.run)
th.start()
time.sleep(3)
worker1 = Worker('localhost:{}'.format(port5), 1)
arg1 = np.random.randint(100)
arg2 = np.random.randn()
arg3 = np.random.randn(3, 3)
arg4 = 100
parl.connect('localhost:{}'.format(port5))
actor = Actor(arg1, arg2, arg3, arg4)
actor.new_attr_2 = 300
self.assertEqual(300, actor.new_attr_2)
actor.set_new_attr()
self.assertEqual(200, actor.new_attr_1)
self.assertTrue(callable(actor.arg5))
def call_non_existing_method():
return actor.arg2(10)
self.assertRaises(TypeError, call_non_existing_method)
master.exit()
worker1.exit()
if __name__ == '__main__':
unittest.main()
...@@ -24,6 +24,7 @@ import time ...@@ -24,6 +24,7 @@ import time
import unittest import unittest
import requests import requests
requests.adapters.DEFAULT_RETRIES = 5
import parl import parl
from parl.remote.client import disconnect, get_global_client from parl.remote.client import disconnect, get_global_client
...@@ -125,10 +126,9 @@ class TestLogServer(unittest.TestCase): ...@@ -125,10 +126,9 @@ class TestLogServer(unittest.TestCase):
th.start() th.start()
time.sleep(1) time.sleep(1)
# start the cluster monitor # start the cluster monitor
monitor_file = __file__.replace( monitor_file = __file__.replace('log_server_test.pyc', '../monitor.py')
os.path.join('tests', 'log_server_test.pyc'), 'monitor.py') monitor_file = monitor_file.replace('log_server_test.py',
monitor_file = monitor_file.replace( '../monitor.py')
os.path.join('tests', 'log_server_test.py'), 'monitor.py')
command = [ command = [
sys.executable, monitor_file, "--monitor_port", sys.executable, monitor_file, "--monitor_port",
str(monitor_port), "--address", "localhost:" + str(master_port) str(monitor_port), "--address", "localhost:" + str(master_port)
...@@ -138,10 +138,7 @@ class TestLogServer(unittest.TestCase): ...@@ -138,10 +138,7 @@ class TestLogServer(unittest.TestCase):
else: else:
FNULL = open(os.devnull, 'w') FNULL = open(os.devnull, 'w')
monitor_proc = subprocess.Popen( monitor_proc = subprocess.Popen(
command, command, stdout=FNULL, stderr=subprocess.STDOUT, close_fds=True)
stdout=FNULL,
stderr=subprocess.STDOUT,
)
# Start worker # Start worker
cluster_addr = 'localhost:{}'.format(master_port) cluster_addr = 'localhost:{}'.format(master_port)
......
...@@ -69,7 +69,7 @@ class TestJob(unittest.TestCase): ...@@ -69,7 +69,7 @@ class TestJob(unittest.TestCase):
file_path = __file__.replace('reset_job_test', 'simulate_client') file_path = __file__.replace('reset_job_test', 'simulate_client')
command = [sys.executable, file_path] command = [sys.executable, file_path]
proc = subprocess.Popen(command) proc = subprocess.Popen(command, close_fds=True)
for _ in range(6): for _ in range(6):
if master.cpu_num == 0: if master.cpu_num == 0:
break break
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import shutil
import parl
from parl.remote.master import Master
from parl.remote.worker import Worker
import time
import threading
from parl.remote.client import disconnect
from parl.remote import exceptions
from parl.utils import logger
@parl.remote_class
class Actor(object):
def file_exists(self, filename):
return os.path.exists(filename)
class TestCluster(unittest.TestCase):
def tearDown(self):
disconnect()
def test_distributed_files_with_RegExp(self):
if os.path.exists('distribute_test_dir'):
shutil.rmtree('distribute_test_dir')
os.mkdir('distribute_test_dir')
f = open('distribute_test_dir/test1.txt', 'wb')
f.close()
f = open('distribute_test_dir/test2.txt', 'wb')
f.close()
f = open('distribute_test_dir/data1.npy', 'wb')
f.close()
f = open('distribute_test_dir/data2.npy', 'wb')
f.close()
logger.info("running:test_distributed_files_with_RegExp")
master = Master(port=8605)
th = threading.Thread(target=master.run)
th.start()
time.sleep(3)
worker1 = Worker('localhost:8605', 1)
parl.connect(
'localhost:8605',
distributed_files=[
'distribute_test_dir/test*',
'distribute_test_dir/*npy',
])
actor = Actor()
self.assertTrue(actor.file_exists('distribute_test_dir/test1.txt'))
self.assertTrue(actor.file_exists('distribute_test_dir/test2.txt'))
self.assertTrue(actor.file_exists('distribute_test_dir/data1.npy'))
self.assertTrue(actor.file_exists('distribute_test_dir/data2.npy'))
self.assertFalse(actor.file_exists('distribute_test_dir/data3.npy'))
shutil.rmtree('distribute_test_dir')
master.exit()
worker1.exit()
def test_miss_match_case(self):
if os.path.exists('distribute_test_dir_2'):
shutil.rmtree('distribute_test_dir_2')
os.mkdir('distribute_test_dir_2')
f = open('distribute_test_dir_2/test1.txt', 'wb')
f.close()
f = open('distribute_test_dir_2/data1.npy', 'wb')
f.close()
logger.info("running:test_distributed_files_with_RegExp_error_case")
master = Master(port=8606)
th = threading.Thread(target=master.run)
th.start()
time.sleep(3)
worker1 = Worker('localhost:8606', 1)
def connect_test():
parl.connect(
'localhost:8606',
distributed_files=['distribute_test_dir_2/miss_match*'])
self.assertRaises(ValueError, connect_test)
shutil.rmtree('distribute_test_dir_2')
master.exit()
worker1.exit()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import parl
@parl.remote_class
class B(object):
def add_sum(self, a, b):
return a + b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import parl
import time
import threading
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.client import disconnect
class TestImport(unittest.TestCase):
def tearDown(self):
disconnect()
def test_import_local_module(self):
from Module2 import B
port = 8448
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker = Worker('localhost:{}'.format(port), 1)
time.sleep(10)
parl.connect("localhost:8448")
obj = B()
res = obj.add_sum(10, 5)
self.assertEqual(res, 15)
worker.exit()
master.exit()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import parl
import time
import threading
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.client import disconnect
class TestImport(unittest.TestCase):
def tearDown(self):
disconnect()
def test_import_local_module(self):
from Module2 import B
port = 8442
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker = Worker('localhost:{}'.format(port), 1)
time.sleep(10)
parl.connect("localhost:8442")
obj = B()
res = obj.add_sum(10, 5)
self.assertEqual(res, 15)
worker.exit()
master.exit()
def test_import_subdir_module_0(self):
from subdir import Module
port = 8443
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker = Worker('localhost:{}'.format(port), 1)
time.sleep(10)
parl.connect(
"localhost:8443",
distributed_files=['./subdir/Module.py', './subdir/__init__.py'])
obj = Module.A()
res = obj.add_sum(10, 5)
self.assertEqual(res, 15)
worker.exit()
master.exit()
def test_import_subdir_module_1(self):
from subdir.Module import A
port = 8444
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker = Worker('localhost:{}'.format(port), 1)
time.sleep(10)
parl.connect(
"localhost:8444",
distributed_files=['./subdir/Module.py', './subdir/__init__.py'])
obj = A()
res = obj.add_sum(10, 5)
self.assertEqual(res, 15)
worker.exit()
master.exit()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import parl
@parl.remote_class
class A(object):
def add_sum(self, a, b):
return a + b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -13,8 +13,12 @@ ...@@ -13,8 +13,12 @@
# limitations under the License. # limitations under the License.
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
import os
from parl.utils import isnotebook
__all__ = ['load_remote_class', 'redirect_stdout_to_file'] __all__ = [
'load_remote_class', 'redirect_stdout_to_file', 'locate_remote_file'
]
def simplify_code(code, end_of_file): def simplify_code(code, end_of_file):
...@@ -32,7 +36,7 @@ def simplify_code(code, end_of_file): ...@@ -32,7 +36,7 @@ def simplify_code(code, end_of_file):
def data_process(): def data_process():
XXXX XXXX
------------------> ------------------>
The last two lines of the above code block will be removed as they are not class related. The last two lines of the above code block will be removed as they are not class-related.
""" """
to_write_lines = [] to_write_lines = []
for i, line in enumerate(code): for i, line in enumerate(code):
...@@ -60,12 +64,18 @@ def load_remote_class(file_name, class_name, end_of_file): ...@@ -60,12 +64,18 @@ def load_remote_class(file_name, class_name, end_of_file):
with open(file_name + '.py') as t_file: with open(file_name + '.py') as t_file:
code = t_file.readlines() code = t_file.readlines()
code = simplify_code(code, end_of_file) code = simplify_code(code, end_of_file)
module_name = 'xparl_' + file_name #folder/xx.py -> folder/xparl_xx.py
tmp_file_name = 'xparl_' + file_name + '.py' file_name = file_name.split(os.sep)
prefix = os.sep.join(file_name[:-1])
if prefix == "":
prefix = '.'
module_name = prefix + os.sep + 'xparl_' + file_name[-1]
tmp_file_name = module_name + '.py'
with open(tmp_file_name, 'w') as t_file: with open(tmp_file_name, 'w') as t_file:
for line in code: for line in code:
t_file.write(line) t_file.write(line)
mod = __import__(module_name) module_name = module_name.lstrip('.' + os.sep).replace(os.sep, '.')
mod = __import__(module_name, globals(), locals(), [class_name], 0)
cls = getattr(mod, class_name) cls = getattr(mod, class_name)
return cls return cls
...@@ -74,6 +84,9 @@ def load_remote_class(file_name, class_name, end_of_file): ...@@ -74,6 +84,9 @@ def load_remote_class(file_name, class_name, end_of_file):
def redirect_stdout_to_file(file_path): def redirect_stdout_to_file(file_path):
"""Redirect stdout (e.g., `print`) to specified file. """Redirect stdout (e.g., `print`) to specified file.
Args:
file_path: Path of the file to output the stdout.
Example: Example:
>>> print('test') >>> print('test')
test test
...@@ -81,10 +94,6 @@ def redirect_stdout_to_file(file_path): ...@@ -81,10 +94,6 @@ def redirect_stdout_to_file(file_path):
... print('test') # Output nothing, `test` is printed to `test.log`. ... print('test') # Output nothing, `test` is printed to `test.log`.
>>> print('test') >>> print('test')
test test
Args:
file_path: Path of the file to output the stdout.
""" """
tmp = sys.stdout tmp = sys.stdout
f = open(file_path, 'a') f = open(file_path, 'a')
...@@ -94,3 +103,37 @@ def redirect_stdout_to_file(file_path): ...@@ -94,3 +103,37 @@ def redirect_stdout_to_file(file_path):
finally: finally:
sys.stdout = tmp sys.stdout = tmp
f.close() f.close()
def locate_remote_file(module_path):
"""xparl has to locate the file that has the class decorated by parl.remote_class.
This function returns the relative path between this file and the entry file.
Note that this function should support the jupyter-notebook environment.
Args:
module_path: Absolute path of the module.
Example:
module_path: /home/user/dir/subdir/my_module
entry_file: /home/user/dir/main.py
--------> relative_path: subdir/my_module
"""
if isnotebook():
entry_path = os.getcwd()
else:
entry_file = sys.argv[0]
entry_file = entry_file.split(os.sep)[-1]
entry_path = None
for path in sys.path:
to_check_path = os.path.join(path, entry_file)
if os.path.isfile(to_check_path):
entry_path = path
break
if entry_path is None or \
(module_path.startswith(os.sep) and entry_path != module_path[:len(entry_path)]):
raise FileNotFoundError("cannot locate the remote file")
if module_path.startswith(os.sep):
relative_module_path = '.' + module_path[len(entry_path):]
else:
relative_module_path = module_path
return relative_module_path
...@@ -26,7 +26,7 @@ import threading ...@@ -26,7 +26,7 @@ import threading
import warnings import warnings
import zmq import zmq
from datetime import datetime from datetime import datetime
import parl
from parl.utils import get_ip_address, to_byte, to_str, logger, _IS_WINDOWS, kill_process from parl.utils import get_ip_address, to_byte, to_str, logger, _IS_WINDOWS, kill_process
from parl.remote import remote_constants from parl.remote import remote_constants
from parl.remote.message import InitializedWorker from parl.remote.message import InitializedWorker
...@@ -72,10 +72,10 @@ class Worker(object): ...@@ -72,10 +72,10 @@ class Worker(object):
self.master_is_alive = True self.master_is_alive = True
self.worker_is_alive = True self.worker_is_alive = True
self.worker_status = None # initialized at `self._create_jobs` self.worker_status = None # initialized at `self._create_jobs`
self.lock = threading.Lock()
self._set_cpu_num(cpu_num) self._set_cpu_num(cpu_num)
self.job_buffer = queue.Queue(maxsize=self.cpu_num) self.job_buffer = queue.Queue(maxsize=self.cpu_num)
self._create_sockets() self._create_sockets()
self.check_version()
# create log server # create log server
self.log_server_proc, self.log_server_address = self._create_log_server( self.log_server_proc, self.log_server_address = self._create_log_server(
port=log_server_port) port=log_server_port)
...@@ -102,6 +102,24 @@ class Worker(object): ...@@ -102,6 +102,24 @@ class Worker(object):
else: else:
self.cpu_num = multiprocessing.cpu_count() self.cpu_num = multiprocessing.cpu_count()
def check_version(self):
'''Verify that the parl & python version in 'worker' process matches that of the 'master' process'''
self.request_master_socket.send_multipart(
[remote_constants.CHECK_VERSION_TAG])
message = self.request_master_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.NORMAL_TAG:
worker_parl_version = parl.__version__
worker_python_version = str(sys.version_info.major)
assert worker_parl_version == to_str(message[1]) and worker_python_version == to_str(message[2]),\
'''Version mismatch: the "master" is of version "parl={}, python={}". However,
"parl={}, python={}"is provided in your environment.'''.format(
to_str(message[1]), to_str(message[2]),
worker_parl_version, worker_python_version
)
else:
raise NotImplementedError
def _create_sockets(self): def _create_sockets(self):
""" Each worker has three sockets at start: """ Each worker has three sockets at start:
...@@ -209,7 +227,11 @@ class Worker(object): ...@@ -209,7 +227,11 @@ class Worker(object):
# Redirect the output to DEVNULL # Redirect the output to DEVNULL
FNULL = open(os.devnull, 'w') FNULL = open(os.devnull, 'w')
for _ in range(job_num): for _ in range(job_num):
subprocess.Popen(command, stdout=FNULL, stderr=subprocess.STDOUT) subprocess.Popen(
command,
stdout=FNULL,
stderr=subprocess.STDOUT,
close_fds=True)
FNULL.close() FNULL.close()
new_jobs = [] new_jobs = []
...@@ -384,10 +406,7 @@ class Worker(object): ...@@ -384,10 +406,7 @@ class Worker(object):
else: else:
FNULL = open(os.devnull, 'w') FNULL = open(os.devnull, 'w')
log_server_proc = subprocess.Popen( log_server_proc = subprocess.Popen(
command, command, stdout=FNULL, stderr=subprocess.STDOUT, close_fds=True)
stdout=FNULL,
stderr=subprocess.STDOUT,
)
FNULL.close() FNULL.close()
log_server_address = "{}:{}".format(self.worker_ip, port) log_server_address = "{}:{}".format(self.worker_ip, port)
......
...@@ -19,12 +19,24 @@ __all__ = ['CSVLogger'] ...@@ -19,12 +19,24 @@ __all__ = ['CSVLogger']
class CSVLogger(object): class CSVLogger(object):
def __init__(self, output_file): def __init__(self, output_file):
"""CSV Logger which can write dict result to csv file """CSV Logger which can write dict result to csv file.
Args:
output_file(str): filename of the csv file.
""" """
self.output_file = open(output_file, "w") self.output_file = open(output_file, "w")
self.csv_writer = None self.csv_writer = None
def log_dict(self, result): def log_dict(self, result):
"""Ouput result to the csv file.
Will create the header of the csv file automatically when the function is called for the first time.
Ususally, the keys of the result should be the same every time you call the function.
Args:
result(dict)
"""
assert isinstance(result, dict), "the input should be a dict."
if self.csv_writer is None: if self.csv_writer is None:
self.csv_writer = csv.DictWriter(self.output_file, result.keys()) self.csv_writer = csv.DictWriter(self.output_file, result.keys())
self.csv_writer.writeheader() self.csv_writer.writeheader()
...@@ -38,4 +50,9 @@ class CSVLogger(object): ...@@ -38,4 +50,9 @@ class CSVLogger(object):
self.output_file.flush() self.output_file.flush()
def close(self): def close(self):
if not self.output_file.closed:
self.output_file.close()
def __del__(self):
if not self.output_file.closed:
self.output_file.close() self.output_file.close()
...@@ -20,7 +20,7 @@ import numpy as np ...@@ -20,7 +20,7 @@ import numpy as np
__all__ = [ __all__ = [
'has_func', 'to_str', 'to_byte', 'is_PY2', 'is_PY3', 'MAX_INT32', 'has_func', 'to_str', 'to_byte', 'is_PY2', 'is_PY3', 'MAX_INT32',
'_HAS_FLUID', '_HAS_TORCH', '_IS_WINDOWS', '_IS_MAC', 'kill_process', '_HAS_FLUID', '_HAS_TORCH', '_IS_WINDOWS', '_IS_MAC', 'kill_process',
'get_fluid_version' 'get_fluid_version', 'isnotebook'
] ]
...@@ -101,3 +101,19 @@ def kill_process(regex_pattern): ...@@ -101,3 +101,19 @@ def kill_process(regex_pattern):
command = "ps aux | grep {} | awk '{{print $2}}' | xargs kill -9".format( command = "ps aux | grep {} | awk '{{print $2}}' | xargs kill -9".format(
regex_pattern) regex_pattern)
subprocess.call([command], shell=True) subprocess.call([command], shell=True)
def isnotebook():
"""check if the code is excuted in the IPython notebook
Reference: https://stackoverflow.com/a/39662359
"""
try:
shell = get_ipython().__class__.__name__
if shell == 'ZMQInteractiveShell':
return True # Jupyter notebook or qtconsole
elif shell == 'TerminalInteractiveShell':
return False # Terminal running IPython
else:
return False # Other type (?)
except NameError:
return False # Probably standard Python interpreter
...@@ -82,7 +82,7 @@ setup( ...@@ -82,7 +82,7 @@ setup(
"click", "click",
"psutil>=5.6.2", "psutil>=5.6.2",
"flask_cors", "flask_cors",
"visualdl>=2.0.0b;python_version>='3' and platform_system=='Linux'", "visualdl>=2.0.0b;python_version>='3.7' and platform_system=='Linux'",
], ],
classifiers=[ classifiers=[
'Intended Audience :: Developers', 'Intended Audience :: Developers',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册