Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PARL
提交
50f3bd31
P
PARL
项目概览
PaddlePaddle
/
PARL
通知
67
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
18
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PARL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
18
Issue
18
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
50f3bd31
编写于
8月 03, 2019
作者:
B
Bo Zhou
提交者:
GitHub
8月 03, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Compatibility (#118)
* fix the vital issue on compatibility * resolve the warning log * yapf * yapf
上级
a13dcce5
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
116 addition
and
95 deletion
+116
-95
parl/remote/client.py
parl/remote/client.py
+6
-2
parl/remote/job.py
parl/remote/job.py
+19
-19
parl/remote/master.py
parl/remote/master.py
+8
-7
parl/remote/remote_decorator.py
parl/remote/remote_decorator.py
+0
-2
parl/remote/scripts.py
parl/remote/scripts.py
+22
-14
parl/remote/worker.py
parl/remote/worker.py
+56
-40
parl/utils/logger.py
parl/utils/logger.py
+5
-11
未找到文件。
parl/remote/client.py
浏览文件 @
50f3bd31
...
...
@@ -18,6 +18,7 @@ import threading
import
zmq
from
parl.utils
import
to_str
,
to_byte
,
get_ip_address
,
logger
from
parl.remote
import
remote_constants
import
time
class
Client
(
object
):
...
...
@@ -78,7 +79,8 @@ class Client(object):
zmq
.
RCVTIMEO
,
remote_constants
.
HEARTBEAT_TIMEOUT_S
*
1000
)
self
.
submit_job_socket
.
connect
(
"tcp://{}"
.
format
(
master_address
))
thread
=
threading
.
Thread
(
target
=
self
.
_reply_heartbeat
,
daemon
=
True
)
thread
=
threading
.
Thread
(
target
=
self
.
_reply_heartbeat
)
thread
.
setDaemon
(
True
)
thread
.
start
()
self
.
heartbeat_socket_initialized
.
wait
()
...
...
@@ -127,7 +129,7 @@ class Client(object):
When a `@parl.remote_class` object is created, the global client
sends a job to the master node. Then the master node will allocate
a vacant job from its job pool to the remote object.
a vacant job from its job pool to the remote object.
Returns:
IP address of the job.
...
...
@@ -151,6 +153,8 @@ class Client(object):
# no vacant CPU resources, can not submit a new job
elif
tag
==
remote_constants
.
CPU_TAG
:
job_address
=
None
# wait 1 second to avoid requesting in a high frequency.
time
.
sleep
(
1
)
else
:
raise
NotImplementedError
else
:
...
...
parl/remote/job.py
浏览文件 @
50f3bd31
...
...
@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
''
os
.
environ
[
'XPARL'
]
=
'True'
import
argparse
import
cloudpickle
import
pickle
...
...
@@ -27,9 +30,6 @@ from parl.utils.communication import loads_argument, loads_return,\
from
parl.remote
import
remote_constants
from
parl.utils.exceptions
import
SerializeError
,
DeserializeError
import
os
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
''
class
Job
(
object
):
"""Base class for the job.
...
...
@@ -41,7 +41,6 @@ class Job(object):
def
__init__
(
self
,
worker_address
):
self
.
job_is_alive
=
True
self
.
heartbeat_socket_initialized
=
threading
.
Event
()
self
.
worker_address
=
worker_address
self
.
_create_sockets
()
...
...
@@ -68,19 +67,9 @@ class Job(object):
reply_thread
=
threading
.
Thread
(
target
=
self
.
_reply_heartbeat
,
args
=
(
"worker {}"
.
format
(
self
.
worker_address
),
)
,
daemon
=
True
)
args
=
(
"worker {}"
.
format
(
self
.
worker_address
),
)
)
reply_thread
.
setDaemon
(
True
)
reply_thread
.
start
()
self
.
heartbeat_socket_initialized
.
wait
()
# job_socket: sends job_address and heartbeat_address to worker
self
.
job_socket
=
self
.
ctx
.
socket
(
zmq
.
REQ
)
self
.
job_socket
.
connect
(
"tcp://{}"
.
format
(
self
.
worker_address
))
self
.
job_socket
.
send_multipart
([
remote_constants
.
NORMAL_TAG
,
to_byte
(
self
.
job_address
),
to_byte
(
self
.
heartbeat_worker_address
)
])
_
=
self
.
job_socket
.
recv_multipart
()
def
_reply_heartbeat
(
self
,
target
):
"""reply heartbeat signals to the target"""
...
...
@@ -90,9 +79,20 @@ class Job(object):
remote_constants
.
HEARTBEAT_RCVTIMEO_S
*
1000
)
socket
.
linger
=
0
heartbeat_worker_port
=
socket
.
bind_to_random_port
(
addr
=
"tcp://*"
)
self
.
heartbeat_worker_address
=
"{}:{}"
.
format
(
self
.
job_ip
,
heartbeat_worker_port
)
self
.
heartbeat_socket_initialized
.
set
()
heartbeat_worker_address
=
"{}:{}"
.
format
(
self
.
job_ip
,
heartbeat_worker_port
)
# job_socket: sends job_address and heartbeat_address to worker
job_socket
=
self
.
ctx
.
socket
(
zmq
.
REQ
)
job_socket
.
connect
(
"tcp://{}"
.
format
(
self
.
worker_address
))
job_socket
.
send_multipart
([
remote_constants
.
NORMAL_TAG
,
to_byte
(
self
.
job_address
),
to_byte
(
heartbeat_worker_address
),
to_byte
(
str
(
os
.
getpid
()))
])
_
=
job_socket
.
recv_multipart
()
# a flag to decide when to exit heartbeat loop
self
.
worker_is_alive
=
True
while
self
.
worker_is_alive
and
self
.
job_is_alive
:
...
...
parl/remote/master.py
浏览文件 @
50f3bd31
...
...
@@ -41,7 +41,7 @@ class Master(object):
Attributes:
worker_pool (dict): A dict to store connected workers.
job_pool (list): A list to store the job address of vacant cpu, when
job_pool (list): A list to store the job address of vacant cpu, when
this number is 0, the master will refuse to create
new remote object.
client_job_dict (dict): A dict of list to record the job submitted by
...
...
@@ -201,8 +201,9 @@ class Master(object):
self
.
worker_pool
[
worker
.
address
]
=
worker
self
.
worker_locks
[
worker
.
address
]
=
threading
.
Lock
()
logger
.
info
(
"A new worker {} is added, "
.
format
(
worker
.
address
)
+
"cluster has {} CPUs.
\n
"
.
format
(
len
(
self
.
job_pool
)))
logger
.
info
(
"A new worker {} is added, "
.
format
(
worker
.
address
)
+
"the cluster has {} CPUs.
\n
"
.
format
(
len
(
self
.
job_pool
)))
# a thread for sending heartbeat signals to `worker.address`
thread
=
threading
.
Thread
(
...
...
@@ -210,8 +211,8 @@ class Master(object):
args
=
(
worker_heartbeat_address
,
worker
.
address
,
)
,
daemon
=
True
)
)
)
thread
.
setDaemon
(
True
)
thread
.
start
()
self
.
client_socket
.
send_multipart
([
remote_constants
.
NORMAL_TAG
])
...
...
@@ -224,8 +225,8 @@ class Master(object):
thread
=
threading
.
Thread
(
target
=
self
.
_create_client_monitor
,
args
=
(
client_heartbeat_address
,
)
,
daemon
=
True
)
args
=
(
client_heartbeat_address
,
)
)
thread
.
setDaemon
(
True
)
thread
.
start
()
self
.
client_socket
.
send_multipart
([
remote_constants
.
NORMAL_TAG
])
...
...
parl/remote/remote_decorator.py
浏览文件 @
50f3bd31
...
...
@@ -110,7 +110,6 @@ def remote_class(cls):
_
=
self
.
job_socket
.
recv_multipart
()
except
zmq
.
error
.
Again
as
e
:
logger
.
error
(
"Job socket failed."
)
logger
.
info
(
"[connect_job] job_address:{}"
.
format
(
job_address
))
def
__del__
(
self
):
"""Delete the remote class object and release remote resources."""
...
...
@@ -138,7 +137,6 @@ def remote_class(cls):
logger
.
warning
(
"No vacant cpu resources at present, "
"will try {} times later."
.
format
(
cnt
))
cnt
-=
1
time
.
sleep
(
1
)
return
None
def
__getattr__
(
self
,
attr
):
...
...
parl/remote/scripts.py
浏览文件 @
50f3bd31
...
...
@@ -14,11 +14,13 @@
import
click
import
locale
import
sys
import
os
import
subprocess
import
threading
import
warnings
from
multiprocessing
import
Process
from
parl.utils
import
logger
# A flag to mark if parl is started from a command line
os
.
environ
[
'XPARL'
]
=
'True'
...
...
@@ -27,18 +29,22 @@ os.environ['XPARL'] = 'True'
# to use ASCII as encoding for the environment` error.
locale
.
setlocale
(
locale
.
LC_ALL
,
"en_US.UTF-8"
)
warnings
.
simplefilter
(
"ignore"
,
ResourceWarning
)
#TODO: this line will cause error in python2/macOS
if
sys
.
version_info
.
major
==
3
:
warnings
.
simplefilter
(
"ignore"
,
ResourceWarning
)
def
is_port_
in_us
e
(
port
):
def
is_port_
availabl
e
(
port
):
""" Check if a port is used.
True if the port is not available. Otherwise, this port can be used for
connection.
True if the port is available for connection.
"""
port
=
int
(
port
)
import
socket
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
return
s
.
connect_ex
((
'localhost'
,
int
(
port
)))
==
0
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
available
=
sock
.
connect_ex
((
'localhost'
,
port
))
sock
.
close
()
return
available
def
is_master_started
(
address
):
...
...
@@ -71,22 +77,24 @@ def cli():
help
=
"Set number of cpu manually. If not set, it will use all "
"cpus of this machine."
)
def
start_master
(
port
,
cpu_num
):
if
is_port_in_us
e
(
port
):
if
not
is_port_availabl
e
(
port
):
raise
Exception
(
"The master address localhost:{} already in use."
.
format
(
port
))
cpu_num
=
str
(
cpu_num
)
if
cpu_num
else
''
command
=
[
"python"
,
"{}/start.py"
.
format
(
__file__
[:
-
11
]),
"--name"
,
"master"
,
"--port"
,
port
]
start_file
=
__file__
.
replace
(
'scripts.pyc'
,
'start.py'
)
start_file
=
start_file
.
replace
(
'scripts.py'
,
'start.py'
)
command
=
[
"python"
,
start_file
,
"--name"
,
"master"
,
"--port"
,
port
]
p
=
subprocess
.
Popen
(
command
)
command
=
[
"python"
,
"{}/start.py"
.
format
(
__file__
[:
-
11
]),
"--name"
,
"worker
"
,
"
--address"
,
"
localhost:"
+
str
(
port
),
"--cpu_num"
,
"python"
,
start_file
,
"--name"
,
"worker"
,
"--address
"
,
"localhost:"
+
str
(
port
),
"--cpu_num"
,
str
(
cpu_num
)
]
p
=
subprocess
.
Popen
(
command
)
# Redirect the output to DEVNULL to solve the warning log.
FNULL
=
open
(
os
.
devnull
,
'w'
)
p
=
subprocess
.
Popen
(
command
,
stdout
=
FNULL
,
stderr
=
subprocess
.
STDOUT
)
FNULL
.
close
()
@
click
.
command
(
"connect"
,
short_help
=
"Start a worker node."
)
...
...
parl/remote/worker.py
浏览文件 @
50f3bd31
...
...
@@ -15,15 +15,20 @@
import
cloudpickle
import
multiprocessing
import
os
import
signal
import
subprocess
import
sys
import
time
import
threading
import
warnings
import
zmq
from
parl.utils
import
get_ip_address
,
to_byte
,
to_str
,
logger
from
parl.remote
import
remote_constants
if
sys
.
version_info
.
major
==
3
:
warnings
.
simplefilter
(
"ignore"
,
ResourceWarning
)
class
WorkerInfo
(
object
):
"""A WorkerInfo object records the computation resources of a worker.
...
...
@@ -138,7 +143,7 @@ class Worker(object):
self
.
master_is_alive
=
False
return
self
.
_init_jobs
()
self
.
_init_jobs
(
job_num
=
self
.
cpu_num
)
self
.
request_master_socket
.
setsockopt
(
zmq
.
RCVTIMEO
,
remote_constants
.
HEARTBEAT_TIMEOUT_S
*
1000
)
...
...
@@ -146,8 +151,8 @@ class Worker(object):
list
(
self
.
job_pid
.
keys
()))
reply_thread
=
threading
.
Thread
(
target
=
self
.
_reply_heartbeat
,
args
=
(
"master {}"
.
format
(
self
.
master_address
),
)
,
daemon
=
True
)
args
=
(
"master {}"
.
format
(
self
.
master_address
),
)
)
reply_thread
.
setDaemon
(
True
)
reply_thread
.
start
()
self
.
heartbeat_socket_initialized
.
wait
()
...
...
@@ -158,55 +163,66 @@ class Worker(object):
])
_
=
self
.
request_master_socket
.
recv_multipart
()
def
_init_job
(
self
):
"""Create one job."""
def
_init_jobs
(
self
,
job_num
):
"""Create jobs.
Args:
job_num(int): the number of jobs to create.
"""
job_file
=
__file__
.
replace
(
'worker.pyc'
,
'job.py'
)
job_file
=
job_file
.
replace
(
'worker.py'
,
'job.py'
)
command
=
[
"python"
,
"{}/job.py"
.
format
(
__file__
[:
-
10
]),
"--worker_address"
,
self
.
reply_job_address
"python"
,
job_file
,
"--worker_address"
,
self
.
reply_job_address
]
with
open
(
os
.
devnull
,
"w"
)
as
null
:
pid
=
subprocess
.
Popen
(
command
,
stdout
=
null
,
stderr
=
null
)
self
.
lock
.
acquire
()
job_message
=
self
.
reply_job_socket
.
recv_multipart
(
)
self
.
reply_job_socket
.
send_multipart
([
remote_constants
.
NORMAL_TAG
]
)
job_address
=
to_str
(
job_message
[
1
])
heartbeat_job_address
=
to_str
(
job_message
[
2
])
self
.
job_pid
[
job_address
]
=
pid
self
.
lock
.
release
()
# a thread for sending heartbeat signals to job
thread
=
threading
.
Thread
(
target
=
self
.
_create_job_monitor
,
args
=
(
job_address
,
heartbeat_job_address
,
),
daemon
=
True
)
thread
.
start
()
return
job_address
def
_init_jobs
(
self
):
"""Create cpu_num jobs when the worker is created."""
job_threads
=
[]
for
_
in
range
(
self
.
cpu_num
):
t
=
threading
.
Thread
(
target
=
self
.
_init_job
,
daemon
=
True
)
t
.
start
()
job_threads
.
append
(
t
)
for
th
in
job_threads
:
th
.
join
()
# Redirect the output to DEVNULL
FNULL
=
open
(
os
.
devnull
,
'w'
)
for
_
in
range
(
job_num
):
pid
=
subprocess
.
Popen
(
command
,
stdout
=
FNULL
,
stderr
=
subprocess
.
STDOUT
)
FNULL
.
close
(
)
new_job_address
=
[]
for
_
in
range
(
job_num
):
job_message
=
self
.
reply_job_socket
.
recv_multipart
()
self
.
reply_job_socket
.
send_multipart
([
remote_constants
.
NORMAL_TAG
])
job_address
=
to_str
(
job_message
[
1
])
new_job_address
.
append
(
job_address
)
heartbeat_job_address
=
to_str
(
job_message
[
2
])
pid
=
to_str
(
job_message
[
3
])
self
.
job_pid
[
job_address
]
=
int
(
pid
)
# a thread for sending heartbeat signals to job
thread
=
threading
.
Thread
(
target
=
self
.
_create_job_monitor
,
args
=
(
job_address
,
heartbeat_job_address
,
))
thread
.
setDaemon
(
True
)
thread
.
start
()
assert
len
(
new_job_address
)
>
0
,
"init jobs failed"
if
len
(
new_job_address
)
>
1
:
return
new_job_address
else
:
return
new_job_address
[
0
]
def
_kill_job
(
self
,
job_address
):
"""kill problematic job process and update worker information"""
if
job_address
in
self
.
job_pid
:
self
.
job_pid
[
job_address
].
kill
()
self
.
lock
.
acquire
()
pid
=
self
.
job_pid
[
job_address
]
try
:
os
.
kill
(
pid
,
signal
.
SIGTERM
)
except
OSError
:
logger
.
warn
(
"job:{} has been killed before"
.
format
(
pid
))
self
.
job_pid
.
pop
(
job_address
)
logger
.
warning
(
"Worker kills job process {},"
.
format
(
job_address
))
self
.
lock
.
release
()
# When a old job is killed, the worker will create a new job.
if
self
.
master_is_alive
:
new_job_address
=
self
.
_init_job
(
)
new_job_address
=
self
.
_init_job
s
(
job_num
=
1
)
self
.
lock
.
acquire
()
self
.
request_master_socket
.
send_multipart
([
...
...
parl/utils/logger.py
浏览文件 @
50f3bd31
...
...
@@ -18,7 +18,6 @@ import os
import
os.path
import
sys
from
termcolor
import
colored
import
shutil
__all__
=
[
'set_dir'
,
'get_dir'
,
'set_level'
]
...
...
@@ -86,16 +85,10 @@ def _getlogger():
def
create_file_after_first_call
(
func_name
):
def
call
(
*
args
,
**
kwargs
):
global
_logger
if
LOG_DIR
is
None
:
if
LOG_DIR
is
None
and
hasattr
(
mod
,
'__file__'
):
basename
=
os
.
path
.
basename
(
mod
.
__file__
)
if
basename
.
rfind
(
'.'
)
==
-
1
:
basename
=
basename
else
:
basename
=
basename
[:
basename
.
rfind
(
'.'
)]
auto_dirname
=
os
.
path
.
join
(
'log_dir'
,
basename
)
shutil
.
rmtree
(
auto_dirname
,
ignore_errors
=
True
)
auto_dirname
=
os
.
path
.
join
(
'log_dir'
,
basename
[:
basename
.
rfind
(
'.'
)])
set_dir
(
auto_dirname
)
func
=
getattr
(
_logger
,
func_name
)
...
...
@@ -165,4 +158,5 @@ def get_dir():
# Will save log to log_dir/main_file_name/log.log by default
mod
=
sys
.
modules
[
'__main__'
]
_logger
.
info
(
"Argv: "
+
' '
.
join
(
sys
.
argv
))
if
hasattr
(
mod
,
'__file__'
)
and
'XPARL'
not
in
os
.
environ
:
_logger
.
info
(
"Argv: "
+
' '
.
join
(
sys
.
argv
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录