Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PARL
提交
3565b546
P
PARL
项目概览
PaddlePaddle
/
PARL
通知
68
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
18
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PARL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
18
Issue
18
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3565b546
编写于
8月 19, 2020
作者:
T
TomorrowIsAnOtherDay
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' into xparl_doc
上级
df53a62b
e228e234
变更
26
显示空白变更内容
内联
并排
Showing
26 changed file
with
819 addition
and
111 deletion
+819
-111
.teamcity/requirements.txt
.teamcity/requirements.txt
+2
-0
CMakeLists.txt
CMakeLists.txt
+14
-4
docs/zh_CN/Overview.md
docs/zh_CN/Overview.md
+7
-1
docs/zh_CN/xparl/introduction.md
docs/zh_CN/xparl/introduction.md
+1
-1
docs/zh_CN/xparl/tutorial.md
docs/zh_CN/xparl/tutorial.md
+8
-8
parl/core/fluid/model.py
parl/core/fluid/model.py
+1
-1
parl/remote/client.py
parl/remote/client.py
+63
-29
parl/remote/job.py
parl/remote/job.py
+42
-16
parl/remote/master.py
parl/remote/master.py
+10
-0
parl/remote/remote_constants.py
parl/remote/remote_constants.py
+3
-0
parl/remote/remote_decorator.py
parl/remote/remote_decorator.py
+87
-17
parl/remote/scripts.py
parl/remote/scripts.py
+11
-5
parl/remote/tests/get_set_attribute_test.py
parl/remote/tests/get_set_attribute_test.py
+169
-0
parl/remote/tests/log_server_test.py
parl/remote/tests/log_server_test.py
+5
-8
parl/remote/tests/reset_job_test.py
parl/remote/tests/reset_job_test.py
+1
-1
parl/remote/tests/support_RegExp_test.py
parl/remote/tests/support_RegExp_test.py
+99
-0
parl/remote/tests/test_import_module/Module2.py
parl/remote/tests/test_import_module/Module2.py
+20
-0
parl/remote/tests/test_import_module/main_abs_test.py
parl/remote/tests/test_import_module/main_abs_test.py
+46
-0
parl/remote/tests/test_import_module/main_test.py
parl/remote/tests/test_import_module/main_test.py
+82
-0
parl/remote/tests/test_import_module/subdir/Module.py
parl/remote/tests/test_import_module/subdir/Module.py
+20
-0
parl/remote/tests/test_import_module/subdir/__init__.py
parl/remote/tests/test_import_module/subdir/__init__.py
+13
-0
parl/remote/utils.py
parl/remote/utils.py
+52
-9
parl/remote/worker.py
parl/remote/worker.py
+26
-7
parl/utils/csv_logger.py
parl/utils/csv_logger.py
+19
-2
parl/utils/utils.py
parl/utils/utils.py
+17
-1
setup.py
setup.py
+1
-1
未找到文件。
.teamcity/requirements.txt
浏览文件 @
3565b546
# requirements for unittest
# requirements for unittest
rarfile==3.1
rarfile==3.1
opencv-python<=4.3.0.34;python_version>="3"
opencv-python==4.2.0.32;python_version<"3"
paddlepaddle-gpu==1.6.1.post97
paddlepaddle-gpu==1.6.1.post97
gym
gym
details
details
...
...
CMakeLists.txt
浏览文件 @
3565b546
...
@@ -30,10 +30,20 @@ function(py_test TARGET_NAME)
...
@@ -30,10 +30,20 @@ function(py_test TARGET_NAME)
set
(
oneValueArgs
""
)
set
(
oneValueArgs
""
)
set
(
multiValueArgs SRCS DEPS ARGS ENVS
)
set
(
multiValueArgs SRCS DEPS ARGS ENVS
)
cmake_parse_arguments
(
py_test
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
cmake_parse_arguments
(
py_test
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
add_test
(
NAME
${
TARGET_NAME
}
if
(
${
FILE_NAME
}
MATCHES
".*abs_test.py"
)
add_test
(
NAME
${
TARGET_NAME
}
"_with_abs_path"
COMMAND python -u
${
py_test_SRCS
}
${
py_test_ARGS
}
COMMAND python -u
${
py_test_SRCS
}
${
py_test_ARGS
}
WORKING_DIRECTORY
${
CMAKE_CURRENT_SOURCE_DIR
}
)
WORKING_DIRECTORY
${
CMAKE_CURRENT_SOURCE_DIR
}
)
set_tests_properties
(
${
TARGET_NAME
}
"_with_abs_path"
PROPERTIES TIMEOUT 300
)
else
()
get_filename_component
(
WORKING_DIR
${
py_test_SRCS
}
DIRECTORY
)
get_filename_component
(
FILE_NAME
${
py_test_SRCS
}
NAME
)
get_filename_component
(
COMBINED_PATH
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
WORKING_DIR
}
ABSOLUTE
)
add_test
(
NAME
${
TARGET_NAME
}
COMMAND python -u
${
FILE_NAME
}
${
py_test_ARGS
}
WORKING_DIRECTORY
${
COMBINED_PATH
}
)
set_tests_properties
(
${
TARGET_NAME
}
PROPERTIES TIMEOUT 300
)
set_tests_properties
(
${
TARGET_NAME
}
PROPERTIES TIMEOUT 300
)
endif
()
endfunction
()
endfunction
()
function
(
import_test TARGET_NAME
)
function
(
import_test TARGET_NAME
)
...
...
docs/zh_CN/Overview.md
浏览文件 @
3565b546
...
@@ -126,7 +126,13 @@ yapf -i modified_file.py
...
@@ -126,7 +126,13 @@ yapf -i modified_file.py
```
```
-
持续集成测试
<br>
-
持续集成测试
<br>
当增加代码时候,需要增加测试代码覆盖所添加的代码,测试代码得放在相关代码文件的
`tests`
文件夹下,以
`_test.py`
结尾(这样持续集成测试会自动拉取代码跑)。附:
[
测试代码示例
](
../../parl/tests/import_test.py
)
当增加代码时候,需要增加测试代码覆盖所添加的代码,测试代码得放在相关代码文件的
`tests`
文件夹下,以
`_test.py`
结尾(这样持续集成测试会自动拉取代码跑)。附:
[
测试代码示例
](
../../parl/tests/import_test.py
)
-
本地运行单元测试(非必要)
<br>
如果你希望在自己的机器运行单测代码,可先在本地机器上安装Docker,再按以下步骤执行单测任务。
```
cd PARL
docker build -t parl/parl-test:unittest .teamcity/
nvidia-docker run -i --rm -v $PWD:/work -w /work parl/parl-test:unittest .teamcity/build.sh test
```
## 反馈
## 反馈
-
在 GitHub 上
[
提交问题
](
https://github.com/PaddlePaddle/PARL/issues
)
-
在 GitHub 上
[
提交问题
](
https://github.com/PaddlePaddle/PARL/issues
)
docs/zh_CN/xparl/introduction.md
浏览文件 @
3565b546
...
@@ -24,4 +24,4 @@ PARL在实现底层的并行计算时,是通过端到端的这种网络传输
...
@@ -24,4 +24,4 @@ PARL在实现底层的并行计算时,是通过端到端的这种网络传输
## 自动分发本地文件
## 自动分发本地文件
市面上的并行框架大部分得要用户手动同步文件才可以跑起并行代码,比如配置文件得要手动或者通过命令分发到不同机器,parl可以自动分发当前目录下的代码文件,实现无缝的多机并行。
市面上的并行框架大部分得要用户手动同步文件才可以跑起并行代码,比如配置文件得要手动或者通过命令分发到不同机器,parl可以自动分发当前目录下的代码文件,实现无缝的多机并行。
<img
src=
"../../parallel_training/comparison.png"
width=
"
5
00"
/>
<img
src=
"../../parallel_training/comparison.png"
width=
"
10
00"
/>
docs/zh_CN/xparl/tutorial.md
浏览文件 @
3565b546
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
## 配置命令
## 配置命令
这个教程将会演示如何搭建一个集群。
这个教程将会演示如何搭建一个集群。
搭建一个PARL集群,可以通过执行下面
两个
`xparl`
命令:
搭建一个PARL集群,可以通过执行下面
的
`xparl`
命令:
### 启动集群
### 启动集群
```
bash
```
bash
...
@@ -12,17 +12,17 @@ xparl start --port 6006
...
@@ -12,17 +12,17 @@ xparl start --port 6006
这个命令会启动一个主节点(master)来管理集群的计算资源,同时会把本地机器的CPU资源加入到集群中。命令中的6006端口只是作为示例,你可以修改成任何有效的端口。
这个命令会启动一个主节点(master)来管理集群的计算资源,同时会把本地机器的CPU资源加入到集群中。命令中的6006端口只是作为示例,你可以修改成任何有效的端口。
### 加入其它机器资源
启动后可通过
`xparl status`
查看目前集群有多少CPU资源可用,你可以在
`xparl start`
的命令中加入选项
`--cpu_num [CPU_NUM]`
(例如:--cpu_num 10)指定本机加入集群的CPU数量。
> 注意:如果你只有单台机器,可以忽略这部分教程。
如果你想加入更多的CPU计算资源到集群中,可以在其他机器上运行下面命令:
### 加入更多CPU资源
启动集群后,就可以直接使用集群了,如果CPU资源不够用,你可以在任何时候和任何机器(包括本机或其他机器)上,通过执行
`xparl connect`
命令把更多CPU资源加入到集群中。
```
bash
```
bash
xparl connect
--address
[
MASTER_ADDRESS]:6006
xparl connect
--address
[
MASTER_ADDRESS]:6006
```
```
它会启动一个工作节点(worker),并把当前机器的CPU资源加入到
该master对应的集群。worker默认会把所有的CPU资源加入到集群中,如果你需要指定worker可使用的CPU数量,可以在上述命令上加入选项
`--cpu_num [CPU_NUM]`
(例如:----cpu_num 10)
。
它会启动一个工作节点(worker),并把当前机器的CPU资源加入到
`--address`
指定的master集群。worker默认会把当前机器所有的可用的CPU资源加入到集群中,如果你需要指定加入的CPU数量,也可以在上述命令上加入选项
`--cpu_num [CPU_NUM]`
。
注意:启动集群后,你可以在任何时候和任何机器上,通过执行
`xparl connect`
命令把更多CPU资源加入到集群中。
## 示例
## 示例
这里我们给出了一个示例来演示如何通过
`@parl.remote_class`
来进行并行计算。
这里我们给出了一个示例来演示如何通过
`@parl.remote_class`
来进行并行计算。
...
@@ -47,9 +47,9 @@ actor.add(1, 2) # 返回 3
...
@@ -47,9 +47,9 @@ actor.add(1, 2) # 返回 3
```
```
## 关闭集群
## 关闭集群
在master机器上运行
`xparl stop`
命令即可关闭集群程序。当master节点退出后,
运行在其他机器
的worker节点也会自动退出并结束相关程序。
在master机器上运行
`xparl stop`
命令即可关闭集群程序。当master节点退出后,
与之关联
的worker节点也会自动退出并结束相关程序。
## 扩展阅读
## 扩展阅读
我们现在已经知道了如何搭建一个集群,以及如何通过修饰符
`@parl.remote_class`
来使用集群。
我们现在已经知道了如何
通过终端命令
`xparl`
搭建一个集群,以及如何通过修饰符
`@parl.remote_class`
来使用集群。
在
[
下一个教程
](
./example.md
)
我们将会演示如何通过这个修饰符来打破Python的全局解释器锁(Global Interpreter Lock, GIL)限制,从而实现真正的多线程计算。
在
[
下一个教程
](
./example.md
)
我们将会演示如何通过这个修饰符来打破Python的全局解释器锁(Global Interpreter Lock, GIL)限制,从而实现真正的多线程计算。
parl/core/fluid/model.py
浏览文件 @
3565b546
...
@@ -53,7 +53,7 @@ class Model(ModelBase):
...
@@ -53,7 +53,7 @@ class Model(ModelBase):
copied_policy = copy.deepcopy(model)
copied_policy = copy.deepcopy(model)
Attributes:
Attributes:
model_id(str): each model instance has its uniqe model_id.
model_id(str): each model instance has its uniq
u
e model_id.
Public Functions:
Public Functions:
- ``sync_weights_to``: synchronize parameters of the current model to another model.
- ``sync_weights_to``: synchronize parameters of the current model to another model.
...
...
parl/remote/client.py
浏览文件 @
3565b546
...
@@ -19,9 +19,11 @@ import socket
...
@@ -19,9 +19,11 @@ import socket
import
sys
import
sys
import
threading
import
threading
import
zmq
import
zmq
from
parl.utils
import
to_str
,
to_byte
,
get_ip_address
,
logger
import
parl
from
parl.utils
import
to_str
,
to_byte
,
get_ip_address
,
logger
,
isnotebook
from
parl.remote
import
remote_constants
from
parl.remote
import
remote_constants
import
time
import
time
import
glob
class
Client
(
object
):
class
Client
(
object
):
...
@@ -50,7 +52,6 @@ class Client(object):
...
@@ -50,7 +52,6 @@ class Client(object):
distributed_files (list): A list of files to be distributed at all
distributed_files (list): A list of files to be distributed at all
remote instances(e,g. the configuration
remote instances(e,g. the configuration
file for initialization) .
file for initialization) .
"""
"""
self
.
master_address
=
master_address
self
.
master_address
=
master_address
self
.
process_id
=
process_id
self
.
process_id
=
process_id
...
@@ -66,6 +67,7 @@ class Client(object):
...
@@ -66,6 +67,7 @@ class Client(object):
self
.
actor_num
=
0
self
.
actor_num
=
0
self
.
_create_sockets
(
master_address
)
self
.
_create_sockets
(
master_address
)
self
.
check_version
()
self
.
pyfiles
=
self
.
read_local_files
(
distributed_files
)
self
.
pyfiles
=
self
.
read_local_files
(
distributed_files
)
def
get_executable_path
(
self
):
def
get_executable_path
(
self
):
...
@@ -85,44 +87,58 @@ class Client(object):
...
@@ -85,44 +87,58 @@ class Client(object):
Args:
Args:
distributed_files (list): A list of files to be distributed at all
distributed_files (list): A list of files to be distributed at all
remote instances(e,g. the configuration
remote instances(e,g. the configuration
file for initialization) .
file for initialization) . RegExp of file
names is supported.
e.g.
distributed_files = ['./*.npy', './test*']
Returns:
Returns:
A cloudpickled dictionary containing the python code in current
A cloudpickled dictionary containing the python code in current
working directory.
working directory.
"""
"""
parsed_distributed_files
=
set
()
for
distributed_file
in
distributed_files
:
parsed_list
=
glob
.
glob
(
distributed_file
)
if
not
parsed_list
:
raise
ValueError
(
"no local file is matched with '{}', please check your input"
.
format
(
distributed_file
))
# exclude the directiories
for
pathname
in
parsed_list
:
if
not
os
.
path
.
isdir
(
pathname
):
parsed_distributed_files
.
add
(
pathname
)
pyfiles
=
dict
()
pyfiles
=
dict
()
pyfiles
[
'python_files'
]
=
{}
pyfiles
[
'python_files'
]
=
{}
pyfiles
[
'other_files'
]
=
{}
pyfiles
[
'other_files'
]
=
{}
code_files
=
filter
(
lambda
x
:
x
.
endswith
(
'.py'
),
os
.
listdir
(
'./'
))
if
isnotebook
():
main_folder
=
'./'
try
:
else
:
for
file
in
code_files
:
main_file
=
sys
.
argv
[
0
]
assert
os
.
path
.
exists
(
file
)
main_folder
=
'./'
with
open
(
file
,
'rb'
)
as
code_file
:
sep
=
os
.
sep
if
sep
in
main_file
:
main_folder
=
sep
.
join
(
main_file
.
split
(
sep
)[:
-
1
])
code_files
=
filter
(
lambda
x
:
x
.
endswith
(
'.py'
),
os
.
listdir
(
main_folder
))
for
file_name
in
code_files
:
file_path
=
os
.
path
.
join
(
main_folder
,
file_name
)
assert
os
.
path
.
exists
(
file_path
)
with
open
(
file_path
,
'rb'
)
as
code_file
:
code
=
code_file
.
read
()
code
=
code_file
.
read
()
pyfiles
[
'python_files'
][
fil
e
]
=
code
pyfiles
[
'python_files'
][
file_nam
e
]
=
code
for
file
in
distributed_files
:
for
file_name
in
parsed_
distributed_files
:
assert
os
.
path
.
exists
(
fil
e
)
assert
os
.
path
.
exists
(
file_nam
e
)
assert
not
os
.
path
.
isabs
(
assert
not
os
.
path
.
isabs
(
fil
e
file_nam
e
),
"[XPARL] Please do not distribute a file with absolute path."
),
"[XPARL] Please do not distribute a file with absolute path."
with
open
(
fil
e
,
'rb'
)
as
f
:
with
open
(
file_nam
e
,
'rb'
)
as
f
:
content
=
f
.
read
()
content
=
f
.
read
()
pyfiles
[
'other_files'
][
file
]
=
content
pyfiles
[
'other_files'
][
file_name
]
=
content
# append entry file to code list
main_file
=
sys
.
argv
[
0
]
with
open
(
main_file
,
'rb'
)
as
code_file
:
code
=
code_file
.
read
()
# parl/remote/remote_decorator.py -> remote_decorator.py
file_name
=
main_file
.
split
(
os
.
sep
)[
-
1
]
pyfiles
[
'python_files'
][
file_name
]
=
code
except
AssertionError
as
e
:
raise
Exception
(
'Failed to create the client, the file {} does not exist.'
.
format
(
file
))
return
cloudpickle
.
dumps
(
pyfiles
)
return
cloudpickle
.
dumps
(
pyfiles
)
def
_create_sockets
(
self
,
master_address
):
def
_create_sockets
(
self
,
master_address
):
...
@@ -165,6 +181,24 @@ class Client(object):
...
@@ -165,6 +181,24 @@ class Client(object):
"check if master is started and ensure the input "
"check if master is started and ensure the input "
"address {} is correct."
.
format
(
master_address
))
"address {} is correct."
.
format
(
master_address
))
def
check_version
(
self
):
'''Verify that the parl & python version in 'client' process matches that of the 'master' process'''
self
.
submit_job_socket
.
send_multipart
(
[
remote_constants
.
CHECK_VERSION_TAG
])
message
=
self
.
submit_job_socket
.
recv_multipart
()
tag
=
message
[
0
]
if
tag
==
remote_constants
.
NORMAL_TAG
:
client_parl_version
=
parl
.
__version__
client_python_version
=
str
(
sys
.
version_info
.
major
)
assert
client_parl_version
==
to_str
(
message
[
1
])
and
client_python_version
==
to_str
(
message
[
2
]),
\
'''Version mismatch: the 'master' is of version 'parl={}, python={}'. However,
'parl={}, python={}'is provided in your environment.'''
.
format
(
to_str
(
message
[
1
]),
to_str
(
message
[
2
]),
client_parl_version
,
client_python_version
)
else
:
raise
NotImplementedError
def
_reply_heartbeat
(
self
):
def
_reply_heartbeat
(
self
):
"""Reply heartbeat signals to the master node."""
"""Reply heartbeat signals to the master node."""
...
...
parl/remote/job.py
浏览文件 @
3565b546
...
@@ -311,8 +311,6 @@ class Job(object):
...
@@ -311,8 +311,6 @@ class Job(object):
try
:
try
:
file_name
,
class_name
,
end_of_file
=
cloudpickle
.
loads
(
file_name
,
class_name
,
end_of_file
=
cloudpickle
.
loads
(
message
[
1
])
message
[
1
])
#/home/nlp-ol/Firework/baidu/nlp/evokit/python_api/es_agent -> es_agent
file_name
=
file_name
.
split
(
os
.
sep
)[
-
1
]
cls
=
load_remote_class
(
file_name
,
class_name
,
end_of_file
)
cls
=
load_remote_class
(
file_name
,
class_name
,
end_of_file
)
args
,
kwargs
=
cloudpickle
.
loads
(
message
[
2
])
args
,
kwargs
=
cloudpickle
.
loads
(
message
[
2
])
logfile_path
=
os
.
path
.
join
(
self
.
log_dir
,
'stdout.log'
)
logfile_path
=
os
.
path
.
join
(
self
.
log_dir
,
'stdout.log'
)
...
@@ -327,7 +325,10 @@ class Job(object):
...
@@ -327,7 +325,10 @@ class Job(object):
to_byte
(
error_str
+
"
\n
traceback:
\n
"
+
traceback_str
)
to_byte
(
error_str
+
"
\n
traceback:
\n
"
+
traceback_str
)
])
])
return
None
return
None
reply_socket
.
send_multipart
([
remote_constants
.
NORMAL_TAG
])
reply_socket
.
send_multipart
([
remote_constants
.
NORMAL_TAG
,
dumps_return
(
set
(
obj
.
__dict__
.
keys
()))
])
else
:
else
:
logger
.
error
(
"Message from job {}"
.
format
(
message
))
logger
.
error
(
"Message from job {}"
.
format
(
message
))
reply_socket
.
send_multipart
([
reply_socket
.
send_multipart
([
...
@@ -397,11 +398,14 @@ class Job(object):
...
@@ -397,11 +398,14 @@ class Job(object):
while
True
:
while
True
:
message
=
reply_socket
.
recv_multipart
()
message
=
reply_socket
.
recv_multipart
()
tag
=
message
[
0
]
tag
=
message
[
0
]
if
tag
in
[
if
tag
==
remote_constants
.
CALL_TAG
:
remote_constants
.
CALL_TAG
,
remote_constants
.
GET_ATTRIBUTE_TAG
,
remote_constants
.
SET_ATTRIBUTE_TAG
,
]:
try
:
try
:
if
tag
==
remote_constants
.
CALL_TAG
:
function_name
=
to_str
(
message
[
1
])
function_name
=
to_str
(
message
[
1
])
data
=
message
[
2
]
data
=
message
[
2
]
args
,
kwargs
=
loads_argument
(
data
)
args
,
kwargs
=
loads_argument
(
data
)
...
@@ -412,9 +416,31 @@ class Job(object):
...
@@ -412,9 +416,31 @@ class Job(object):
ret
=
getattr
(
obj
,
function_name
)(
*
args
,
**
kwargs
)
ret
=
getattr
(
obj
,
function_name
)(
*
args
,
**
kwargs
)
ret
=
dumps_return
(
ret
)
ret
=
dumps_return
(
ret
)
reply_socket
.
send_multipart
([
remote_constants
.
NORMAL_TAG
,
ret
,
dumps_return
(
set
(
obj
.
__dict__
.
keys
()))
])
elif
tag
==
remote_constants
.
GET_ATTRIBUTE_TAG
:
attribute_name
=
to_str
(
message
[
1
])
logfile_path
=
os
.
path
.
join
(
self
.
log_dir
,
'stdout.log'
)
with
redirect_stdout_to_file
(
logfile_path
):
ret
=
getattr
(
obj
,
attribute_name
)
ret
=
dumps_return
(
ret
)
reply_socket
.
send_multipart
(
reply_socket
.
send_multipart
(
[
remote_constants
.
NORMAL_TAG
,
ret
])
[
remote_constants
.
NORMAL_TAG
,
ret
])
elif
tag
==
remote_constants
.
SET_ATTRIBUTE_TAG
:
attribute_name
=
to_str
(
message
[
1
])
attribute_value
=
loads_return
(
message
[
2
])
logfile_path
=
os
.
path
.
join
(
self
.
log_dir
,
'stdout.log'
)
with
redirect_stdout_to_file
(
logfile_path
):
setattr
(
obj
,
attribute_name
,
attribute_value
)
reply_socket
.
send_multipart
([
remote_constants
.
NORMAL_TAG
,
dumps_return
(
set
(
obj
.
__dict__
.
keys
()))
])
else
:
pass
except
Exception
as
e
:
except
Exception
as
e
:
# reset the job
# reset the job
...
...
parl/remote/master.py
浏览文件 @
3565b546
...
@@ -18,6 +18,8 @@ import threading
...
@@ -18,6 +18,8 @@ import threading
import
time
import
time
import
zmq
import
zmq
from
collections
import
deque
,
defaultdict
from
collections
import
deque
,
defaultdict
import
parl
import
sys
from
parl.utils
import
to_str
,
to_byte
,
logger
,
get_ip_address
from
parl.utils
import
to_str
,
to_byte
,
logger
,
get_ip_address
from
parl.remote
import
remote_constants
from
parl.remote
import
remote_constants
from
parl.remote.job_center
import
JobCenter
from
parl.remote.job_center
import
JobCenter
...
@@ -208,6 +210,7 @@ class Master(object):
...
@@ -208,6 +210,7 @@ class Master(object):
elif
tag
==
remote_constants
.
CLIENT_CONNECT_TAG
:
elif
tag
==
remote_constants
.
CLIENT_CONNECT_TAG
:
# `client_heartbeat_address` is the
# `client_heartbeat_address` is the
# `reply_master_heartbeat_address` of the client
# `reply_master_heartbeat_address` of the client
client_heartbeat_address
=
to_str
(
message
[
1
])
client_heartbeat_address
=
to_str
(
message
[
1
])
client_hostname
=
to_str
(
message
[
2
])
client_hostname
=
to_str
(
message
[
2
])
client_id
=
to_str
(
message
[
3
])
client_id
=
to_str
(
message
[
3
])
...
@@ -225,6 +228,13 @@ class Master(object):
...
@@ -225,6 +228,13 @@ class Master(object):
[
remote_constants
.
NORMAL_TAG
,
[
remote_constants
.
NORMAL_TAG
,
to_byte
(
log_monitor_address
)])
to_byte
(
log_monitor_address
)])
elif
tag
==
remote_constants
.
CHECK_VERSION_TAG
:
self
.
client_socket
.
send_multipart
([
remote_constants
.
NORMAL_TAG
,
to_byte
(
parl
.
__version__
),
to_byte
(
str
(
sys
.
version_info
.
major
))
])
# a client submits a job to the master
# a client submits a job to the master
elif
tag
==
remote_constants
.
CLIENT_SUBMIT_TAG
:
elif
tag
==
remote_constants
.
CLIENT_SUBMIT_TAG
:
# check available CPU resources
# check available CPU resources
...
...
parl/remote/remote_constants.py
浏览文件 @
3565b546
...
@@ -27,8 +27,11 @@ SEND_FILE_TAG = b'[SEND_FILE]'
...
@@ -27,8 +27,11 @@ SEND_FILE_TAG = b'[SEND_FILE]'
SUBMIT_JOB_TAG
=
b
'[SUBMIT_JOB]'
SUBMIT_JOB_TAG
=
b
'[SUBMIT_JOB]'
NEW_JOB_TAG
=
b
'[NEW_JOB]'
NEW_JOB_TAG
=
b
'[NEW_JOB]'
CHECK_VERSION_TAG
=
b
'[CHECK_VERSION]'
INIT_OBJECT_TAG
=
b
'[INIT_OBJECT]'
INIT_OBJECT_TAG
=
b
'[INIT_OBJECT]'
CALL_TAG
=
b
'[CALL]'
CALL_TAG
=
b
'[CALL]'
GET_ATTRIBUTE_TAG
=
b
'[GET_ATTRIBUTE]'
SET_ATTRIBUTE_TAG
=
b
'[SET_ATTRIBUTE]'
EXCEPTION_TAG
=
b
'[EXCEPTION]'
EXCEPTION_TAG
=
b
'[EXCEPTION]'
ATTRIBUTE_EXCEPTION_TAG
=
b
'[ATTRIBUTE_EXCEPTION]'
ATTRIBUTE_EXCEPTION_TAG
=
b
'[ATTRIBUTE_EXCEPTION]'
...
...
parl/remote/remote_decorator.py
浏览文件 @
3565b546
...
@@ -19,6 +19,7 @@ import time
...
@@ -19,6 +19,7 @@ import time
import
zmq
import
zmq
import
numpy
as
np
import
numpy
as
np
import
inspect
import
inspect
import
sys
from
parl.utils
import
get_ip_address
,
logger
,
to_str
,
to_byte
from
parl.utils
import
get_ip_address
,
logger
,
to_str
,
to_byte
from
parl.utils.communication
import
loads_argument
,
loads_return
,
\
from
parl.utils.communication
import
loads_argument
,
loads_return
,
\
...
@@ -27,6 +28,7 @@ from parl.remote import remote_constants
...
@@ -27,6 +28,7 @@ from parl.remote import remote_constants
from
parl.remote.exceptions
import
RemoteError
,
RemoteAttributeError
,
\
from
parl.remote.exceptions
import
RemoteError
,
RemoteAttributeError
,
\
RemoteDeserializeError
,
RemoteSerializeError
,
ResourceError
RemoteDeserializeError
,
RemoteSerializeError
,
ResourceError
from
parl.remote.client
import
get_global_client
from
parl.remote.client
import
get_global_client
from
parl.remote.utils
import
locate_remote_file
def
remote_class
(
*
args
,
**
kwargs
):
def
remote_class
(
*
args
,
**
kwargs
):
...
@@ -93,7 +95,7 @@ def remote_class(*args, **kwargs):
...
@@ -93,7 +95,7 @@ def remote_class(*args, **kwargs):
class.
class.
"""
"""
self
.
GLOBAL_CLIENT
=
get_global_client
()
self
.
GLOBAL_CLIENT
=
get_global_client
()
self
.
remote_attribute_keys_set
=
set
()
self
.
ctx
=
self
.
GLOBAL_CLIENT
.
ctx
self
.
ctx
=
self
.
GLOBAL_CLIENT
.
ctx
# GLOBAL_CLIENT will set `master_is_alive` to False when hearbeat
# GLOBAL_CLIENT will set `master_is_alive` to False when hearbeat
...
@@ -120,21 +122,34 @@ def remote_class(*args, **kwargs):
...
@@ -120,21 +122,34 @@ def remote_class(*args, **kwargs):
self
.
job_shutdown
=
False
self
.
job_shutdown
=
False
self
.
send_file
(
self
.
job_socket
)
self
.
send_file
(
self
.
job_socket
)
file_name
=
inspect
.
getfile
(
cls
)[:
-
3
]
module_path
=
inspect
.
getfile
(
cls
)
if
module_path
.
endswith
(
'pyc'
):
module_path
=
module_path
[:
-
4
]
elif
module_path
.
endswith
(
'py'
):
module_path
=
module_path
[:
-
3
]
else
:
raise
FileNotFoundError
(
"cannot not find the module:{}"
.
format
(
module_path
))
res
=
inspect
.
getfile
(
cls
)
file_path
=
locate_remote_file
(
module_path
)
cls_source
=
inspect
.
getsourcelines
(
cls
)
cls_source
=
inspect
.
getsourcelines
(
cls
)
end_of_file
=
cls_source
[
1
]
+
len
(
cls_source
[
0
])
end_of_file
=
cls_source
[
1
]
+
len
(
cls_source
[
0
])
class_name
=
cls
.
__name__
class_name
=
cls
.
__name__
self
.
job_socket
.
send_multipart
([
self
.
job_socket
.
send_multipart
([
remote_constants
.
INIT_OBJECT_TAG
,
remote_constants
.
INIT_OBJECT_TAG
,
cloudpickle
.
dumps
([
file_
name
,
class_name
,
end_of_file
]),
cloudpickle
.
dumps
([
file_
path
,
class_name
,
end_of_file
]),
cloudpickle
.
dumps
([
args
,
kwargs
]),
cloudpickle
.
dumps
([
args
,
kwargs
]),
])
])
message
=
self
.
job_socket
.
recv_multipart
()
message
=
self
.
job_socket
.
recv_multipart
()
tag
=
message
[
0
]
tag
=
message
[
0
]
if
tag
==
remote_constants
.
EXCEPTION_TAG
:
if
tag
==
remote_constants
.
NORMAL_TAG
:
self
.
remote_attribute_keys_set
=
loads_return
(
message
[
1
])
elif
tag
==
remote_constants
.
EXCEPTION_TAG
:
traceback_str
=
to_str
(
message
[
1
])
traceback_str
=
to_str
(
message
[
1
])
self
.
job_shutdown
=
True
self
.
job_shutdown
=
True
raise
RemoteError
(
'__init__'
,
traceback_str
)
raise
RemoteError
(
'__init__'
,
traceback_str
)
else
:
pass
def
__del__
(
self
):
def
__del__
(
self
):
"""Delete the remote class object and release remote resources."""
"""Delete the remote class object and release remote resources."""
...
@@ -179,16 +194,41 @@ def remote_class(*args, **kwargs):
...
@@ -179,16 +194,41 @@ def remote_class(*args, **kwargs):
cnt
-=
1
cnt
-=
1
return
None
return
None
def
__getattr__
(
self
,
attr
):
def
set_remote_attr
(
self
,
attr
,
value
):
self
.
internal_lock
.
acquire
()
self
.
job_socket
.
send_multipart
([
remote_constants
.
SET_ATTRIBUTE_TAG
,
to_byte
(
attr
),
dumps_return
(
value
)
])
message
=
self
.
job_socket
.
recv_multipart
()
tag
=
message
[
0
]
if
tag
==
remote_constants
.
NORMAL_TAG
:
self
.
remote_attribute_keys_set
=
loads_return
(
message
[
1
])
self
.
internal_lock
.
release
()
else
:
self
.
job_shutdown
=
True
raise
NotImplementedError
()
return
def
get_remote_attr
(
self
,
attr
):
"""Call the function of the unwrapped class."""
"""Call the function of the unwrapped class."""
#check if attr is a attribute or a function
is_attribute
=
attr
in
self
.
remote_attribute_keys_set
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
self
.
internal_lock
.
acquire
()
if
is_attribute
:
self
.
job_socket
.
send_multipart
([
remote_constants
.
GET_ATTRIBUTE_TAG
,
to_byte
(
attr
)
])
else
:
if
self
.
job_shutdown
:
if
self
.
job_shutdown
:
raise
RemoteError
(
raise
RemoteError
(
attr
,
"This actor losts connection with the job."
)
attr
,
self
.
internal_lock
.
acquire
(
)
"This actor losts connection with the job."
)
data
=
dumps_argument
(
*
args
,
**
kwargs
)
data
=
dumps_argument
(
*
args
,
**
kwargs
)
self
.
job_socket
.
send_multipart
(
self
.
job_socket
.
send_multipart
(
[
remote_constants
.
CALL_TAG
,
[
remote_constants
.
CALL_TAG
,
to_byte
(
attr
),
data
])
to_byte
(
attr
),
data
])
...
@@ -198,6 +238,11 @@ def remote_class(*args, **kwargs):
...
@@ -198,6 +238,11 @@ def remote_class(*args, **kwargs):
if
tag
==
remote_constants
.
NORMAL_TAG
:
if
tag
==
remote_constants
.
NORMAL_TAG
:
ret
=
loads_return
(
message
[
1
])
ret
=
loads_return
(
message
[
1
])
if
not
is_attribute
:
self
.
remote_attribute_keys_set
=
loads_return
(
message
[
2
])
self
.
internal_lock
.
release
()
return
ret
elif
tag
==
remote_constants
.
EXCEPTION_TAG
:
elif
tag
==
remote_constants
.
EXCEPTION_TAG
:
error_str
=
to_str
(
message
[
1
])
error_str
=
to_str
(
message
[
1
])
...
@@ -223,13 +268,38 @@ def remote_class(*args, **kwargs):
...
@@ -223,13 +268,38 @@ def remote_class(*args, **kwargs):
self
.
job_shutdown
=
True
self
.
job_shutdown
=
True
raise
NotImplementedError
()
raise
NotImplementedError
()
self
.
internal_lock
.
release
()
return
wrapper
()
if
is_attribute
else
wrapper
return
ret
def
proxy_wrapper_func
(
remote_wrapper
):
'''
The 'proxy_wrapper_func' is defined on the top of class 'RemoteWrapper'
in order to set and get attributes of 'remoted_wrapper' and the corresponding
remote models individually.
With 'proxy_wrapper_func', it is allowed to define a attribute (or method) of
the same name in 'RemoteWrapper' and remote models.
'''
class
ProxyWrapper
(
object
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
xparl_remote_wrapper_obj
=
remote_wrapper
(
*
args
,
**
kwargs
)
def
__getattr__
(
self
,
attr
):
return
self
.
xparl_remote_wrapper_obj
.
get_remote_attr
(
attr
)
def
__setattr__
(
self
,
attr
,
value
):
if
attr
==
'xparl_remote_wrapper_obj'
:
super
(
ProxyWrapper
,
self
).
__setattr__
(
attr
,
value
)
else
:
self
.
xparl_remote_wrapper_obj
.
set_remote_attr
(
attr
,
value
)
return
w
rapper
return
ProxyW
rapper
RemoteWrapper
.
_original
=
cls
RemoteWrapper
.
_original
=
cls
return
RemoteWrapper
proxy_wrapper
=
proxy_wrapper_func
(
RemoteWrapper
)
return
proxy_wrapper
max_memory
=
kwargs
.
get
(
'max_memory'
)
max_memory
=
kwargs
.
get
(
'max_memory'
)
if
len
(
args
)
==
1
and
callable
(
args
[
0
]):
if
len
(
args
)
==
1
and
callable
(
args
[
0
]):
...
...
parl/remote/scripts.py
浏览文件 @
3565b546
...
@@ -171,22 +171,28 @@ def start_master(port, cpu_num, monitor_port, debug, log_server_port_range):
...
@@ -171,22 +171,28 @@ def start_master(port, cpu_num, monitor_port, debug, log_server_port_range):
# Redirect the output to DEVNULL to solve the warning log.
# Redirect the output to DEVNULL to solve the warning log.
_
=
subprocess
.
Popen
(
_
=
subprocess
.
Popen
(
master_command
,
stdout
=
FNULL
,
stderr
=
subprocess
.
STDOUT
)
master_command
,
stdout
=
FNULL
,
stderr
=
subprocess
.
STDOUT
,
close_fds
=
True
)
if
cpu_num
>
0
:
if
cpu_num
>
0
:
# Sleep 1s for master ready
# Sleep 1s for master ready
time
.
sleep
(
1
)
time
.
sleep
(
1
)
_
=
subprocess
.
Popen
(
_
=
subprocess
.
Popen
(
worker_command
,
stdout
=
FNULL
,
stderr
=
subprocess
.
STDOUT
)
worker_command
,
stdout
=
FNULL
,
stderr
=
subprocess
.
STDOUT
,
close_fds
=
True
)
if
_IS_WINDOWS
:
if
_IS_WINDOWS
:
# TODO(@zenghsh3) redirecting stdout of monitor subprocess to FNULL will cause occasional failure
# TODO(@zenghsh3) redirecting stdout of monitor subprocess to FNULL will cause occasional failure
tmp_file
=
tempfile
.
TemporaryFile
()
tmp_file
=
tempfile
.
TemporaryFile
()
_
=
subprocess
.
Popen
(
monitor_command
,
stdout
=
tmp_file
)
_
=
subprocess
.
Popen
(
monitor_command
,
stdout
=
tmp_file
,
close_fds
=
True
)
tmp_file
.
close
()
tmp_file
.
close
()
else
:
else
:
_
=
subprocess
.
Popen
(
_
=
subprocess
.
Popen
(
monitor_command
,
stdout
=
FNULL
,
stderr
=
subprocess
.
STDOUT
)
monitor_command
,
stdout
=
FNULL
,
stderr
=
subprocess
.
STDOUT
,
close_fds
=
True
)
FNULL
.
close
()
FNULL
.
close
()
if
cpu_num
>
0
:
if
cpu_num
>
0
:
...
@@ -285,7 +291,7 @@ def start_worker(address, cpu_num, log_server_port_range):
...
@@ -285,7 +291,7 @@ def start_worker(address, cpu_num, log_server_port_range):
str
(
cpu_num
),
"--log_server_port"
,
str
(
cpu_num
),
"--log_server_port"
,
str
(
log_server_port
)
str
(
log_server_port
)
]
]
p
=
subprocess
.
Popen
(
command
)
p
=
subprocess
.
Popen
(
command
,
close_fds
=
True
)
if
not
is_log_server_started
(
get_ip_address
(),
log_server_port
):
if
not
is_log_server_started
(
get_ip_address
(),
log_server_port
):
click
.
echo
(
"# Fail to start the log server."
)
click
.
echo
(
"# Fail to start the log server."
)
...
...
parl/remote/tests/get_set_attribute_test.py
0 → 100644
浏览文件 @
3565b546
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
parl
import
numpy
as
np
from
parl.remote.client
import
disconnect
from
parl.utils
import
logger
from
parl.remote.master
import
Master
from
parl.remote.worker
import
Worker
import
time
import
threading
import
random
@
parl
.
remote_class
class
Actor
(
object
):
def
__init__
(
self
,
arg1
,
arg2
,
arg3
,
arg4
):
self
.
arg1
=
arg1
self
.
arg2
=
arg2
self
.
arg3
=
arg3
self
.
GLOBAL_CLIENT
=
arg4
def
arg1
(
self
,
x
,
y
):
time
.
sleep
(
0.2
)
return
x
+
y
def
arg5
(
self
):
return
100
def
set_new_attr
(
self
):
self
.
new_attr_1
=
200
class
Test_get_and_set_attribute
(
unittest
.
TestCase
):
def
tearDown
(
self
):
disconnect
()
def
test_get_attribute
(
self
):
port1
=
random
.
randint
(
6100
,
6200
)
logger
.
info
(
"running:test_get_attirbute"
)
master
=
Master
(
port
=
port1
)
th
=
threading
.
Thread
(
target
=
master
.
run
)
th
.
start
()
time
.
sleep
(
3
)
worker1
=
Worker
(
'localhost:{}'
.
format
(
port1
),
1
)
arg1
=
np
.
random
.
randint
(
100
)
arg2
=
np
.
random
.
randn
()
arg3
=
np
.
random
.
randn
(
3
,
3
)
arg4
=
100
parl
.
connect
(
'localhost:{}'
.
format
(
port1
))
actor
=
Actor
(
arg1
,
arg2
,
arg3
,
arg4
)
self
.
assertTrue
(
arg1
==
actor
.
arg1
)
self
.
assertTrue
(
arg2
==
actor
.
arg2
)
self
.
assertTrue
((
arg3
==
actor
.
arg3
).
all
())
self
.
assertTrue
(
arg4
==
actor
.
GLOBAL_CLIENT
)
master
.
exit
()
worker1
.
exit
()
def
test_set_attribute
(
self
):
port2
=
random
.
randint
(
6200
,
6300
)
logger
.
info
(
"running:test_set_attirbute"
)
master
=
Master
(
port
=
port2
)
th
=
threading
.
Thread
(
target
=
master
.
run
)
th
.
start
()
time
.
sleep
(
3
)
worker1
=
Worker
(
'localhost:{}'
.
format
(
port2
),
1
)
arg1
=
3
arg2
=
3.5
arg3
=
np
.
random
.
randn
(
3
,
3
)
arg4
=
100
parl
.
connect
(
'localhost:{}'
.
format
(
port2
))
actor
=
Actor
(
arg1
,
arg2
,
arg3
,
arg4
)
actor
.
arg1
=
arg1
actor
.
arg2
=
arg2
actor
.
arg3
=
arg3
actor
.
GLOBAL_CLIENT
=
arg4
self
.
assertTrue
(
arg1
==
actor
.
arg1
)
self
.
assertTrue
(
arg2
==
actor
.
arg2
)
self
.
assertTrue
((
arg3
==
actor
.
arg3
).
all
())
self
.
assertTrue
(
arg4
==
actor
.
GLOBAL_CLIENT
)
master
.
exit
()
worker1
.
exit
()
def
test_create_new_attribute_same_with_wrapper
(
self
):
port3
=
random
.
randint
(
6400
,
6500
)
logger
.
info
(
"running:test_create_new_attribute_same_with_wrapper"
)
master
=
Master
(
port
=
port3
)
th
=
threading
.
Thread
(
target
=
master
.
run
)
th
.
start
()
time
.
sleep
(
3
)
worker1
=
Worker
(
'localhost:{}'
.
format
(
port3
),
1
)
arg1
=
np
.
random
.
randint
(
100
)
arg2
=
np
.
random
.
randn
()
arg3
=
np
.
random
.
randn
(
3
,
3
)
arg4
=
100
parl
.
connect
(
'localhost:{}'
.
format
(
port3
))
actor
=
Actor
(
arg1
,
arg2
,
arg3
,
arg4
)
actor
.
internal_lock
=
50
self
.
assertTrue
(
actor
.
internal_lock
==
50
)
master
.
exit
()
worker1
.
exit
()
def
test_same_name_of_attribute_and_method
(
self
):
port4
=
random
.
randint
(
6500
,
6600
)
logger
.
info
(
"running:test_same_name_of_attribute_and_method"
)
master
=
Master
(
port
=
port4
)
th
=
threading
.
Thread
(
target
=
master
.
run
)
th
.
start
()
time
.
sleep
(
3
)
worker1
=
Worker
(
'localhost:{}'
.
format
(
port4
),
1
)
arg1
=
np
.
random
.
randint
(
100
)
arg2
=
np
.
random
.
randn
()
arg3
=
np
.
random
.
randn
(
3
,
3
)
arg4
=
100
parl
.
connect
(
'localhost:{}'
.
format
(
port4
))
actor
=
Actor
(
arg1
,
arg2
,
arg3
,
arg4
)
self
.
assertEqual
(
arg1
,
actor
.
arg1
)
def
call_method
():
return
actor
.
arg1
(
1
,
2
)
self
.
assertRaises
(
TypeError
,
call_method
)
master
.
exit
()
worker1
.
exit
()
def
test_non_existing_attribute_same_with_existing_method
(
self
):
port5
=
random
.
randint
(
6600
,
6700
)
logger
.
info
(
"running:test_non_existing_attribute_same_with_existing_method"
)
master
=
Master
(
port
=
port5
)
th
=
threading
.
Thread
(
target
=
master
.
run
)
th
.
start
()
time
.
sleep
(
3
)
worker1
=
Worker
(
'localhost:{}'
.
format
(
port5
),
1
)
arg1
=
np
.
random
.
randint
(
100
)
arg2
=
np
.
random
.
randn
()
arg3
=
np
.
random
.
randn
(
3
,
3
)
arg4
=
100
parl
.
connect
(
'localhost:{}'
.
format
(
port5
))
actor
=
Actor
(
arg1
,
arg2
,
arg3
,
arg4
)
actor
.
new_attr_2
=
300
self
.
assertEqual
(
300
,
actor
.
new_attr_2
)
actor
.
set_new_attr
()
self
.
assertEqual
(
200
,
actor
.
new_attr_1
)
self
.
assertTrue
(
callable
(
actor
.
arg5
))
def
call_non_existing_method
():
return
actor
.
arg2
(
10
)
self
.
assertRaises
(
TypeError
,
call_non_existing_method
)
master
.
exit
()
worker1
.
exit
()
if
__name__
==
'__main__'
:
unittest
.
main
()
parl/remote/tests/log_server_test.py
浏览文件 @
3565b546
...
@@ -24,6 +24,7 @@ import time
...
@@ -24,6 +24,7 @@ import time
import
unittest
import
unittest
import
requests
import
requests
requests
.
adapters
.
DEFAULT_RETRIES
=
5
import
parl
import
parl
from
parl.remote.client
import
disconnect
,
get_global_client
from
parl.remote.client
import
disconnect
,
get_global_client
...
@@ -125,10 +126,9 @@ class TestLogServer(unittest.TestCase):
...
@@ -125,10 +126,9 @@ class TestLogServer(unittest.TestCase):
th
.
start
()
th
.
start
()
time
.
sleep
(
1
)
time
.
sleep
(
1
)
# start the cluster monitor
# start the cluster monitor
monitor_file
=
__file__
.
replace
(
monitor_file
=
__file__
.
replace
(
'log_server_test.pyc'
,
'../monitor.py'
)
os
.
path
.
join
(
'tests'
,
'log_server_test.pyc'
),
'monitor.py'
)
monitor_file
=
monitor_file
.
replace
(
'log_server_test.py'
,
monitor_file
=
monitor_file
.
replace
(
'../monitor.py'
)
os
.
path
.
join
(
'tests'
,
'log_server_test.py'
),
'monitor.py'
)
command
=
[
command
=
[
sys
.
executable
,
monitor_file
,
"--monitor_port"
,
sys
.
executable
,
monitor_file
,
"--monitor_port"
,
str
(
monitor_port
),
"--address"
,
"localhost:"
+
str
(
master_port
)
str
(
monitor_port
),
"--address"
,
"localhost:"
+
str
(
master_port
)
...
@@ -138,10 +138,7 @@ class TestLogServer(unittest.TestCase):
...
@@ -138,10 +138,7 @@ class TestLogServer(unittest.TestCase):
else
:
else
:
FNULL
=
open
(
os
.
devnull
,
'w'
)
FNULL
=
open
(
os
.
devnull
,
'w'
)
monitor_proc
=
subprocess
.
Popen
(
monitor_proc
=
subprocess
.
Popen
(
command
,
command
,
stdout
=
FNULL
,
stderr
=
subprocess
.
STDOUT
,
close_fds
=
True
)
stdout
=
FNULL
,
stderr
=
subprocess
.
STDOUT
,
)
# Start worker
# Start worker
cluster_addr
=
'localhost:{}'
.
format
(
master_port
)
cluster_addr
=
'localhost:{}'
.
format
(
master_port
)
...
...
parl/remote/tests/reset_job_test.py
浏览文件 @
3565b546
...
@@ -69,7 +69,7 @@ class TestJob(unittest.TestCase):
...
@@ -69,7 +69,7 @@ class TestJob(unittest.TestCase):
file_path
=
__file__
.
replace
(
'reset_job_test'
,
'simulate_client'
)
file_path
=
__file__
.
replace
(
'reset_job_test'
,
'simulate_client'
)
command
=
[
sys
.
executable
,
file_path
]
command
=
[
sys
.
executable
,
file_path
]
proc
=
subprocess
.
Popen
(
command
)
proc
=
subprocess
.
Popen
(
command
,
close_fds
=
True
)
for
_
in
range
(
6
):
for
_
in
range
(
6
):
if
master
.
cpu_num
==
0
:
if
master
.
cpu_num
==
0
:
break
break
...
...
parl/remote/tests/support_RegExp_test.py
0 → 100644
浏览文件 @
3565b546
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
os
import
shutil
import
parl
from
parl.remote.master
import
Master
from
parl.remote.worker
import
Worker
import
time
import
threading
from
parl.remote.client
import
disconnect
from
parl.remote
import
exceptions
from
parl.utils
import
logger
@
parl
.
remote_class
class
Actor
(
object
):
def
file_exists
(
self
,
filename
):
return
os
.
path
.
exists
(
filename
)
class
TestCluster
(
unittest
.
TestCase
):
def
tearDown
(
self
):
disconnect
()
def
test_distributed_files_with_RegExp
(
self
):
if
os
.
path
.
exists
(
'distribute_test_dir'
):
shutil
.
rmtree
(
'distribute_test_dir'
)
os
.
mkdir
(
'distribute_test_dir'
)
f
=
open
(
'distribute_test_dir/test1.txt'
,
'wb'
)
f
.
close
()
f
=
open
(
'distribute_test_dir/test2.txt'
,
'wb'
)
f
.
close
()
f
=
open
(
'distribute_test_dir/data1.npy'
,
'wb'
)
f
.
close
()
f
=
open
(
'distribute_test_dir/data2.npy'
,
'wb'
)
f
.
close
()
logger
.
info
(
"running:test_distributed_files_with_RegExp"
)
master
=
Master
(
port
=
8605
)
th
=
threading
.
Thread
(
target
=
master
.
run
)
th
.
start
()
time
.
sleep
(
3
)
worker1
=
Worker
(
'localhost:8605'
,
1
)
parl
.
connect
(
'localhost:8605'
,
distributed_files
=
[
'distribute_test_dir/test*'
,
'distribute_test_dir/*npy'
,
])
actor
=
Actor
()
self
.
assertTrue
(
actor
.
file_exists
(
'distribute_test_dir/test1.txt'
))
self
.
assertTrue
(
actor
.
file_exists
(
'distribute_test_dir/test2.txt'
))
self
.
assertTrue
(
actor
.
file_exists
(
'distribute_test_dir/data1.npy'
))
self
.
assertTrue
(
actor
.
file_exists
(
'distribute_test_dir/data2.npy'
))
self
.
assertFalse
(
actor
.
file_exists
(
'distribute_test_dir/data3.npy'
))
shutil
.
rmtree
(
'distribute_test_dir'
)
master
.
exit
()
worker1
.
exit
()
def
test_miss_match_case
(
self
):
if
os
.
path
.
exists
(
'distribute_test_dir_2'
):
shutil
.
rmtree
(
'distribute_test_dir_2'
)
os
.
mkdir
(
'distribute_test_dir_2'
)
f
=
open
(
'distribute_test_dir_2/test1.txt'
,
'wb'
)
f
.
close
()
f
=
open
(
'distribute_test_dir_2/data1.npy'
,
'wb'
)
f
.
close
()
logger
.
info
(
"running:test_distributed_files_with_RegExp_error_case"
)
master
=
Master
(
port
=
8606
)
th
=
threading
.
Thread
(
target
=
master
.
run
)
th
.
start
()
time
.
sleep
(
3
)
worker1
=
Worker
(
'localhost:8606'
,
1
)
def
connect_test
():
parl
.
connect
(
'localhost:8606'
,
distributed_files
=
[
'distribute_test_dir_2/miss_match*'
])
self
.
assertRaises
(
ValueError
,
connect_test
)
shutil
.
rmtree
(
'distribute_test_dir_2'
)
master
.
exit
()
worker1
.
exit
()
if
__name__
==
'__main__'
:
unittest
.
main
()
parl/remote/tests/test_import_module/Module2.py
0 → 100644
浏览文件 @
3565b546
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
parl
@
parl
.
remote_class
class
B
(
object
):
def
add_sum
(
self
,
a
,
b
):
return
a
+
b
parl/remote/tests/test_import_module/main_abs_test.py
0 → 100644
浏览文件 @
3565b546
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
unittest
import
parl
import
time
import
threading
from
parl.remote.master
import
Master
from
parl.remote.worker
import
Worker
from
parl.remote.client
import
disconnect
class
TestImport
(
unittest
.
TestCase
):
def
tearDown
(
self
):
disconnect
()
def
test_import_local_module
(
self
):
from
Module2
import
B
port
=
8448
master
=
Master
(
port
=
port
)
th
=
threading
.
Thread
(
target
=
master
.
run
)
th
.
start
()
time
.
sleep
(
1
)
worker
=
Worker
(
'localhost:{}'
.
format
(
port
),
1
)
time
.
sleep
(
10
)
parl
.
connect
(
"localhost:8448"
)
obj
=
B
()
res
=
obj
.
add_sum
(
10
,
5
)
self
.
assertEqual
(
res
,
15
)
worker
.
exit
()
master
.
exit
()
if
__name__
==
'__main__'
:
unittest
.
main
()
parl/remote/tests/test_import_module/main_test.py
0 → 100644
浏览文件 @
3565b546
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
unittest
import
parl
import
time
import
threading
from
parl.remote.master
import
Master
from
parl.remote.worker
import
Worker
from
parl.remote.client
import
disconnect
class
TestImport
(
unittest
.
TestCase
):
def
tearDown
(
self
):
disconnect
()
def
test_import_local_module
(
self
):
from
Module2
import
B
port
=
8442
master
=
Master
(
port
=
port
)
th
=
threading
.
Thread
(
target
=
master
.
run
)
th
.
start
()
time
.
sleep
(
1
)
worker
=
Worker
(
'localhost:{}'
.
format
(
port
),
1
)
time
.
sleep
(
10
)
parl
.
connect
(
"localhost:8442"
)
obj
=
B
()
res
=
obj
.
add_sum
(
10
,
5
)
self
.
assertEqual
(
res
,
15
)
worker
.
exit
()
master
.
exit
()
def
test_import_subdir_module_0
(
self
):
from
subdir
import
Module
port
=
8443
master
=
Master
(
port
=
port
)
th
=
threading
.
Thread
(
target
=
master
.
run
)
th
.
start
()
time
.
sleep
(
1
)
worker
=
Worker
(
'localhost:{}'
.
format
(
port
),
1
)
time
.
sleep
(
10
)
parl
.
connect
(
"localhost:8443"
,
distributed_files
=
[
'./subdir/Module.py'
,
'./subdir/__init__.py'
])
obj
=
Module
.
A
()
res
=
obj
.
add_sum
(
10
,
5
)
self
.
assertEqual
(
res
,
15
)
worker
.
exit
()
master
.
exit
()
def
test_import_subdir_module_1
(
self
):
from
subdir.Module
import
A
port
=
8444
master
=
Master
(
port
=
port
)
th
=
threading
.
Thread
(
target
=
master
.
run
)
th
.
start
()
time
.
sleep
(
1
)
worker
=
Worker
(
'localhost:{}'
.
format
(
port
),
1
)
time
.
sleep
(
10
)
parl
.
connect
(
"localhost:8444"
,
distributed_files
=
[
'./subdir/Module.py'
,
'./subdir/__init__.py'
])
obj
=
A
()
res
=
obj
.
add_sum
(
10
,
5
)
self
.
assertEqual
(
res
,
15
)
worker
.
exit
()
master
.
exit
()
if
__name__
==
'__main__'
:
unittest
.
main
()
parl/remote/tests/test_import_module/subdir/Module.py
0 → 100644
浏览文件 @
3565b546
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
parl
@
parl
.
remote_class
class
A
(
object
):
def
add_sum
(
self
,
a
,
b
):
return
a
+
b
parl/remote/tests/test_import_module/subdir/__init__.py
0 → 100644
浏览文件 @
3565b546
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
parl/remote/utils.py
浏览文件 @
3565b546
...
@@ -13,8 +13,12 @@
...
@@ -13,8 +13,12 @@
# limitations under the License.
# limitations under the License.
import
sys
import
sys
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
os
from
parl.utils
import
isnotebook
__all__
=
[
'load_remote_class'
,
'redirect_stdout_to_file'
]
__all__
=
[
'load_remote_class'
,
'redirect_stdout_to_file'
,
'locate_remote_file'
]
def
simplify_code
(
code
,
end_of_file
):
def
simplify_code
(
code
,
end_of_file
):
...
@@ -32,7 +36,7 @@ def simplify_code(code, end_of_file):
...
@@ -32,7 +36,7 @@ def simplify_code(code, end_of_file):
def data_process():
def data_process():
XXXX
XXXX
------------------>
------------------>
The last two lines of the above code block will be removed as they are not class
related.
The last two lines of the above code block will be removed as they are not class
-
related.
"""
"""
to_write_lines
=
[]
to_write_lines
=
[]
for
i
,
line
in
enumerate
(
code
):
for
i
,
line
in
enumerate
(
code
):
...
@@ -60,12 +64,18 @@ def load_remote_class(file_name, class_name, end_of_file):
...
@@ -60,12 +64,18 @@ def load_remote_class(file_name, class_name, end_of_file):
with
open
(
file_name
+
'.py'
)
as
t_file
:
with
open
(
file_name
+
'.py'
)
as
t_file
:
code
=
t_file
.
readlines
()
code
=
t_file
.
readlines
()
code
=
simplify_code
(
code
,
end_of_file
)
code
=
simplify_code
(
code
,
end_of_file
)
module_name
=
'xparl_'
+
file_name
#folder/xx.py -> folder/xparl_xx.py
tmp_file_name
=
'xparl_'
+
file_name
+
'.py'
file_name
=
file_name
.
split
(
os
.
sep
)
prefix
=
os
.
sep
.
join
(
file_name
[:
-
1
])
if
prefix
==
""
:
prefix
=
'.'
module_name
=
prefix
+
os
.
sep
+
'xparl_'
+
file_name
[
-
1
]
tmp_file_name
=
module_name
+
'.py'
with
open
(
tmp_file_name
,
'w'
)
as
t_file
:
with
open
(
tmp_file_name
,
'w'
)
as
t_file
:
for
line
in
code
:
for
line
in
code
:
t_file
.
write
(
line
)
t_file
.
write
(
line
)
mod
=
__import__
(
module_name
)
module_name
=
module_name
.
lstrip
(
'.'
+
os
.
sep
).
replace
(
os
.
sep
,
'.'
)
mod
=
__import__
(
module_name
,
globals
(),
locals
(),
[
class_name
],
0
)
cls
=
getattr
(
mod
,
class_name
)
cls
=
getattr
(
mod
,
class_name
)
return
cls
return
cls
...
@@ -74,6 +84,9 @@ def load_remote_class(file_name, class_name, end_of_file):
...
@@ -74,6 +84,9 @@ def load_remote_class(file_name, class_name, end_of_file):
def
redirect_stdout_to_file
(
file_path
):
def
redirect_stdout_to_file
(
file_path
):
"""Redirect stdout (e.g., `print`) to specified file.
"""Redirect stdout (e.g., `print`) to specified file.
Args:
file_path: Path of the file to output the stdout.
Example:
Example:
>>> print('test')
>>> print('test')
test
test
...
@@ -81,10 +94,6 @@ def redirect_stdout_to_file(file_path):
...
@@ -81,10 +94,6 @@ def redirect_stdout_to_file(file_path):
... print('test') # Output nothing, `test` is printed to `test.log`.
... print('test') # Output nothing, `test` is printed to `test.log`.
>>> print('test')
>>> print('test')
test
test
Args:
file_path: Path of the file to output the stdout.
"""
"""
tmp
=
sys
.
stdout
tmp
=
sys
.
stdout
f
=
open
(
file_path
,
'a'
)
f
=
open
(
file_path
,
'a'
)
...
@@ -94,3 +103,37 @@ def redirect_stdout_to_file(file_path):
...
@@ -94,3 +103,37 @@ def redirect_stdout_to_file(file_path):
finally
:
finally
:
sys
.
stdout
=
tmp
sys
.
stdout
=
tmp
f
.
close
()
f
.
close
()
def
locate_remote_file
(
module_path
):
"""xparl has to locate the file that has the class decorated by parl.remote_class.
This function returns the relative path between this file and the entry file.
Note that this function should support the jupyter-notebook environment.
Args:
module_path: Absolute path of the module.
Example:
module_path: /home/user/dir/subdir/my_module
entry_file: /home/user/dir/main.py
--------> relative_path: subdir/my_module
"""
if
isnotebook
():
entry_path
=
os
.
getcwd
()
else
:
entry_file
=
sys
.
argv
[
0
]
entry_file
=
entry_file
.
split
(
os
.
sep
)[
-
1
]
entry_path
=
None
for
path
in
sys
.
path
:
to_check_path
=
os
.
path
.
join
(
path
,
entry_file
)
if
os
.
path
.
isfile
(
to_check_path
):
entry_path
=
path
break
if
entry_path
is
None
or
\
(
module_path
.
startswith
(
os
.
sep
)
and
entry_path
!=
module_path
[:
len
(
entry_path
)]):
raise
FileNotFoundError
(
"cannot locate the remote file"
)
if
module_path
.
startswith
(
os
.
sep
):
relative_module_path
=
'.'
+
module_path
[
len
(
entry_path
):]
else
:
relative_module_path
=
module_path
return
relative_module_path
parl/remote/worker.py
浏览文件 @
3565b546
...
@@ -26,7 +26,7 @@ import threading
...
@@ -26,7 +26,7 @@ import threading
import
warnings
import
warnings
import
zmq
import
zmq
from
datetime
import
datetime
from
datetime
import
datetime
import
parl
from
parl.utils
import
get_ip_address
,
to_byte
,
to_str
,
logger
,
_IS_WINDOWS
,
kill_process
from
parl.utils
import
get_ip_address
,
to_byte
,
to_str
,
logger
,
_IS_WINDOWS
,
kill_process
from
parl.remote
import
remote_constants
from
parl.remote
import
remote_constants
from
parl.remote.message
import
InitializedWorker
from
parl.remote.message
import
InitializedWorker
...
@@ -72,10 +72,10 @@ class Worker(object):
...
@@ -72,10 +72,10 @@ class Worker(object):
self
.
master_is_alive
=
True
self
.
master_is_alive
=
True
self
.
worker_is_alive
=
True
self
.
worker_is_alive
=
True
self
.
worker_status
=
None
# initialized at `self._create_jobs`
self
.
worker_status
=
None
# initialized at `self._create_jobs`
self
.
lock
=
threading
.
Lock
()
self
.
_set_cpu_num
(
cpu_num
)
self
.
_set_cpu_num
(
cpu_num
)
self
.
job_buffer
=
queue
.
Queue
(
maxsize
=
self
.
cpu_num
)
self
.
job_buffer
=
queue
.
Queue
(
maxsize
=
self
.
cpu_num
)
self
.
_create_sockets
()
self
.
_create_sockets
()
self
.
check_version
()
# create log server
# create log server
self
.
log_server_proc
,
self
.
log_server_address
=
self
.
_create_log_server
(
self
.
log_server_proc
,
self
.
log_server_address
=
self
.
_create_log_server
(
port
=
log_server_port
)
port
=
log_server_port
)
...
@@ -102,6 +102,24 @@ class Worker(object):
...
@@ -102,6 +102,24 @@ class Worker(object):
else
:
else
:
self
.
cpu_num
=
multiprocessing
.
cpu_count
()
self
.
cpu_num
=
multiprocessing
.
cpu_count
()
def
check_version
(
self
):
'''Verify that the parl & python version in 'worker' process matches that of the 'master' process'''
self
.
request_master_socket
.
send_multipart
(
[
remote_constants
.
CHECK_VERSION_TAG
])
message
=
self
.
request_master_socket
.
recv_multipart
()
tag
=
message
[
0
]
if
tag
==
remote_constants
.
NORMAL_TAG
:
worker_parl_version
=
parl
.
__version__
worker_python_version
=
str
(
sys
.
version_info
.
major
)
assert
worker_parl_version
==
to_str
(
message
[
1
])
and
worker_python_version
==
to_str
(
message
[
2
]),
\
'''Version mismatch: the "master" is of version "parl={}, python={}". However,
"parl={}, python={}"is provided in your environment.'''
.
format
(
to_str
(
message
[
1
]),
to_str
(
message
[
2
]),
worker_parl_version
,
worker_python_version
)
else
:
raise
NotImplementedError
def
_create_sockets
(
self
):
def
_create_sockets
(
self
):
""" Each worker has three sockets at start:
""" Each worker has three sockets at start:
...
@@ -209,7 +227,11 @@ class Worker(object):
...
@@ -209,7 +227,11 @@ class Worker(object):
# Redirect the output to DEVNULL
# Redirect the output to DEVNULL
FNULL
=
open
(
os
.
devnull
,
'w'
)
FNULL
=
open
(
os
.
devnull
,
'w'
)
for
_
in
range
(
job_num
):
for
_
in
range
(
job_num
):
subprocess
.
Popen
(
command
,
stdout
=
FNULL
,
stderr
=
subprocess
.
STDOUT
)
subprocess
.
Popen
(
command
,
stdout
=
FNULL
,
stderr
=
subprocess
.
STDOUT
,
close_fds
=
True
)
FNULL
.
close
()
FNULL
.
close
()
new_jobs
=
[]
new_jobs
=
[]
...
@@ -384,10 +406,7 @@ class Worker(object):
...
@@ -384,10 +406,7 @@ class Worker(object):
else
:
else
:
FNULL
=
open
(
os
.
devnull
,
'w'
)
FNULL
=
open
(
os
.
devnull
,
'w'
)
log_server_proc
=
subprocess
.
Popen
(
log_server_proc
=
subprocess
.
Popen
(
command
,
command
,
stdout
=
FNULL
,
stderr
=
subprocess
.
STDOUT
,
close_fds
=
True
)
stdout
=
FNULL
,
stderr
=
subprocess
.
STDOUT
,
)
FNULL
.
close
()
FNULL
.
close
()
log_server_address
=
"{}:{}"
.
format
(
self
.
worker_ip
,
port
)
log_server_address
=
"{}:{}"
.
format
(
self
.
worker_ip
,
port
)
...
...
parl/utils/csv_logger.py
浏览文件 @
3565b546
...
@@ -19,12 +19,24 @@ __all__ = ['CSVLogger']
...
@@ -19,12 +19,24 @@ __all__ = ['CSVLogger']
class
CSVLogger
(
object
):
class
CSVLogger
(
object
):
def
__init__
(
self
,
output_file
):
def
__init__
(
self
,
output_file
):
"""CSV Logger which can write dict result to csv file
"""CSV Logger which can write dict result to csv file.
Args:
output_file(str): filename of the csv file.
"""
"""
self
.
output_file
=
open
(
output_file
,
"w"
)
self
.
output_file
=
open
(
output_file
,
"w"
)
self
.
csv_writer
=
None
self
.
csv_writer
=
None
def
log_dict
(
self
,
result
):
def
log_dict
(
self
,
result
):
"""Ouput result to the csv file.
Will create the header of the csv file automatically when the function is called for the first time.
Ususally, the keys of the result should be the same every time you call the function.
Args:
result(dict)
"""
assert
isinstance
(
result
,
dict
),
"the input should be a dict."
if
self
.
csv_writer
is
None
:
if
self
.
csv_writer
is
None
:
self
.
csv_writer
=
csv
.
DictWriter
(
self
.
output_file
,
result
.
keys
())
self
.
csv_writer
=
csv
.
DictWriter
(
self
.
output_file
,
result
.
keys
())
self
.
csv_writer
.
writeheader
()
self
.
csv_writer
.
writeheader
()
...
@@ -38,4 +50,9 @@ class CSVLogger(object):
...
@@ -38,4 +50,9 @@ class CSVLogger(object):
self
.
output_file
.
flush
()
self
.
output_file
.
flush
()
def
close
(
self
):
def
close
(
self
):
if
not
self
.
output_file
.
closed
:
self
.
output_file
.
close
()
def
__del__
(
self
):
if
not
self
.
output_file
.
closed
:
self
.
output_file
.
close
()
self
.
output_file
.
close
()
parl/utils/utils.py
浏览文件 @
3565b546
...
@@ -20,7 +20,7 @@ import numpy as np
...
@@ -20,7 +20,7 @@ import numpy as np
__all__
=
[
__all__
=
[
'has_func'
,
'to_str'
,
'to_byte'
,
'is_PY2'
,
'is_PY3'
,
'MAX_INT32'
,
'has_func'
,
'to_str'
,
'to_byte'
,
'is_PY2'
,
'is_PY3'
,
'MAX_INT32'
,
'_HAS_FLUID'
,
'_HAS_TORCH'
,
'_IS_WINDOWS'
,
'_IS_MAC'
,
'kill_process'
,
'_HAS_FLUID'
,
'_HAS_TORCH'
,
'_IS_WINDOWS'
,
'_IS_MAC'
,
'kill_process'
,
'get_fluid_version'
'get_fluid_version'
,
'isnotebook'
]
]
...
@@ -101,3 +101,19 @@ def kill_process(regex_pattern):
...
@@ -101,3 +101,19 @@ def kill_process(regex_pattern):
command
=
"ps aux | grep {} | awk '{{print $2}}' | xargs kill -9"
.
format
(
command
=
"ps aux | grep {} | awk '{{print $2}}' | xargs kill -9"
.
format
(
regex_pattern
)
regex_pattern
)
subprocess
.
call
([
command
],
shell
=
True
)
subprocess
.
call
([
command
],
shell
=
True
)
def
isnotebook
():
"""check if the code is excuted in the IPython notebook
Reference: https://stackoverflow.com/a/39662359
"""
try
:
shell
=
get_ipython
().
__class__
.
__name__
if
shell
==
'ZMQInteractiveShell'
:
return
True
# Jupyter notebook or qtconsole
elif
shell
==
'TerminalInteractiveShell'
:
return
False
# Terminal running IPython
else
:
return
False
# Other type (?)
except
NameError
:
return
False
# Probably standard Python interpreter
setup.py
浏览文件 @
3565b546
...
@@ -82,7 +82,7 @@ setup(
...
@@ -82,7 +82,7 @@ setup(
"click"
,
"click"
,
"psutil>=5.6.2"
,
"psutil>=5.6.2"
,
"flask_cors"
,
"flask_cors"
,
"visualdl>=2.0.0b;python_version>='3' and platform_system=='Linux'"
,
"visualdl>=2.0.0b;python_version>='3
.7
' and platform_system=='Linux'"
,
],
],
classifiers
=
[
classifiers
=
[
'Intended Audience :: Developers'
,
'Intended Audience :: Developers'
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录