Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
67c6ddff
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
67c6ddff
编写于
3月 15, 2022
作者:
K
kuizhiqing
提交者:
GitHub
3月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
New design for launch/run (#40086)
上级
464f65b1
变更
25
隐藏空白更改
内联
并排
Showing
25 changed file
with
2546 addition
and
0 deletion
+2546
-0
python/paddle/distributed/run/__init__.py
python/paddle/distributed/run/__init__.py
+86
-0
python/paddle/distributed/run/__main__.py
python/paddle/distributed/run/__main__.py
+28
-0
python/paddle/distributed/run/context/__init__.py
python/paddle/distributed/run/context/__init__.py
+219
-0
python/paddle/distributed/run/context/device.py
python/paddle/distributed/run/context/device.py
+88
-0
python/paddle/distributed/run/context/event.py
python/paddle/distributed/run/context/event.py
+20
-0
python/paddle/distributed/run/context/node.py
python/paddle/distributed/run/context/node.py
+64
-0
python/paddle/distributed/run/context/resource.py
python/paddle/distributed/run/context/resource.py
+18
-0
python/paddle/distributed/run/context/status.py
python/paddle/distributed/run/context/status.py
+58
-0
python/paddle/distributed/run/controllers/__init__.py
python/paddle/distributed/run/controllers/__init__.py
+32
-0
python/paddle/distributed/run/controllers/collective.py
python/paddle/distributed/run/controllers/collective.py
+185
-0
python/paddle/distributed/run/controllers/controller.py
python/paddle/distributed/run/controllers/controller.py
+192
-0
python/paddle/distributed/run/controllers/master.py
python/paddle/distributed/run/controllers/master.py
+289
-0
python/paddle/distributed/run/controllers/ps.py
python/paddle/distributed/run/controllers/ps.py
+221
-0
python/paddle/distributed/run/job/__init__.py
python/paddle/distributed/run/job/__init__.py
+25
-0
python/paddle/distributed/run/job/container.py
python/paddle/distributed/run/job/container.py
+179
-0
python/paddle/distributed/run/job/job.py
python/paddle/distributed/run/job/job.py
+80
-0
python/paddle/distributed/run/job/pod.py
python/paddle/distributed/run/job/pod.py
+185
-0
python/paddle/distributed/run/job/status.py
python/paddle/distributed/run/job/status.py
+24
-0
python/paddle/distributed/run/plugins/__init__.py
python/paddle/distributed/run/plugins/__init__.py
+50
-0
python/paddle/distributed/run/plugins/ip.py
python/paddle/distributed/run/plugins/ip.py
+30
-0
python/paddle/distributed/run/utils/kv_client.py
python/paddle/distributed/run/utils/kv_client.py
+94
-0
python/paddle/distributed/run/utils/kv_server.py
python/paddle/distributed/run/utils/kv_server.py
+121
-0
python/paddle/distributed/run/utils/process_context.py
python/paddle/distributed/run/utils/process_context.py
+83
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/test_run.py
python/paddle/fluid/tests/unittests/test_run.py
+174
-0
未找到文件。
python/paddle/distributed/run/__init__.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.job.container
import
Container
from
.job.pod
import
Pod
from
.job.job
import
Job
from
.
import
plugins
#__all__ = [Container, Pod, Job]
'''
Paddle distribution training entry ``python -m paddle.distributed.run``.
Help
# for arg usage and explanation, try the following command
# python -m paddle.distributed.run -h
Collective Mode
Case 1: 1 node
use all visible devices
# python -m paddle.distributed.run train.py
use specified devices
# python -m paddle.distributed.run --devices=0,1,2,3 train.py
Case 2: multi-node, auto detect ip/port
# python -m paddle.distributed.run --np 2 train.py
# auto print following command
# python -m paddle.distributed.run --master 10.0.0.1:13538 --np 2 demo.py
# then copy and paste above command to other nodes
Case 3: multi-node, specified master/rendezvous server
# python -m paddle.distributed.run --np 2 --master 10.0.0.1:2379 train.py
# the master ip must be one of the node and the port must available
Parameter Server Mode
Case 1.1: 1 node, 1 ps, 1 trainer
# python -m paddle.distributed.run --mode ps train.py
# python -m paddle.distributed.run --server_num=1 --trainer_num=1 train.py
Case 1.2: 1 node, 2 ps, 2 trainer
# python -m paddle.distributed.run --server_num=2 --trainer_num=2 train.py
Case 2: 2 node, 2 ps, 2 trainer per node
# python -m paddle.distributed.run --server_num=2 --trainer_num=2 --np 2 train.py
# auto print following command
# python -m paddle.distributed.run --master 10.0.0.1:13538 --server_num=2 --trainer_num=2 --np 2 train.py
# then copy and paste above command to other nodes
Case 3: multi-node, specified master/rendezvous server
# python -m paddle.distributed.run --master 10.0.0.1:13538 --server_num=2 --trainer_num=2 --np 2 train.py
# the master ip must be one of the node and the port must available
Case 4: specified servers and trainers in each node
python -m paddle.distributed.run --servers 127.0.0.1:8900,127.0.0.1:8901 --trainers 127.0.0.1:8902,127.0.0.1:8903 train.py
Elastic Mode
# run following command in 3 node to run immediately, or in 2 node to run after elastic_timeout
# python -m paddle.distributed.run --master etcd://10.0.0.1:2379 --np 2:3 train.py
# once the peer number changes between 2:3, the strategy holds
'''
python/paddle/distributed/run/__main__.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.context
import
Context
from
.
import
controllers
# initialize the context to run
ctx
=
Context
()
# initialize the selected controller
c
=
controllers
.
init
(
ctx
)
# run the pods
c
.
run
()
# manager or just wait pod
c
.
finalize
()
python/paddle/distributed/run/context/__init__.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
argparse
import
ArgumentParser
,
REMAINDER
import
os
,
copy
from
paddle.distributed.run
import
plugins
from
.node
import
Node
from
.status
import
Status
import
logging
class
Context
(
object
):
def
__init__
(
self
,
enable_plugin
=
True
):
os
.
environ
.
pop
(
'http_proxy'
,
None
)
os
.
environ
.
pop
(
'https_proxy'
,
None
)
self
.
args
=
self
.
parse_args
()
self
.
envs
=
self
.
fetch_envs
()
self
.
logger
=
self
.
get_logger
()
self
.
node
=
Node
()
self
.
status
=
Status
()
self
.
set_env_in_args
()
# design for event queue, later
self
.
events
=
[]
if
enable_plugin
:
self
.
_enable_plugin
()
def
get_envs
(
self
):
return
self
.
envs
.
copy
()
def
_enable_plugin
(
self
):
for
pl
in
plugins
.
enabled_plugins
:
pl
(
self
)
def
parse_args
(
self
):
parser
=
ArgumentParser
()
base_group
=
parser
.
add_argument_group
(
"Base Parameters"
)
base_group
.
add_argument
(
"--master"
,
type
=
str
,
default
=
None
,
help
=
"the master/rendezvous server, ip:port"
)
base_group
.
add_argument
(
"--rank"
,
type
=
int
,
default
=-
1
,
help
=
"the peer rank"
)
base_group
.
add_argument
(
"--log"
,
type
=
str
,
default
=
"INFO"
,
help
=
"log level. Default INFO"
)
base_group
.
add_argument
(
"--np"
,
type
=
str
,
default
=
"1"
,
help
=
"the number of peers, i.e. pod/node number"
)
base_group
.
add_argument
(
"--nproc_per_node"
,
type
=
int
,
default
=
None
,
help
=
"the number of processes in a pod"
)
base_group
.
add_argument
(
"--log_dir"
,
type
=
str
,
default
=
"log"
,
help
=
"the path for each process's log. Default ./log"
)
base_group
.
add_argument
(
"--mode"
,
type
=
str
,
default
=
"collective"
,
help
=
"run mode of the job, collective/ps/ps-heter"
)
base_group
.
add_argument
(
"--id"
,
type
=
str
,
default
=
"default"
,
help
=
"unique id of the job. Default default"
)
base_group
.
add_argument
(
"--devices"
,
type
=
str
,
default
=
None
,
help
=
"accelerate devices. as --gpus,npus,xps"
)
base_group
.
add_argument
(
"--host"
,
type
=
str
,
default
=
None
,
help
=
"host ip"
)
base_group
.
add_argument
(
"training_script"
,
type
=
str
,
help
=
"the full path of py script,"
"followed by arguments for the "
"training script"
)
base_group
.
add_argument
(
'training_script_args'
,
nargs
=
REMAINDER
)
ps_group
=
parser
.
add_argument_group
(
"Parameter-Server Parameters"
)
# for parameter server
ps_group
.
add_argument
(
"--servers"
,
type
=
str
,
default
=
''
,
help
=
"servers endpoints full list"
)
ps_group
.
add_argument
(
"--trainers"
,
type
=
str
,
default
=
''
,
help
=
"trainers endpoints full list"
)
ps_group
.
add_argument
(
"--trainer_num"
,
type
=
int
,
default
=
None
,
help
=
"number of trainers"
)
ps_group
.
add_argument
(
"--server_num"
,
type
=
int
,
default
=
None
,
help
=
"number of servers"
)
ps_group
.
add_argument
(
"--gloo_port"
,
type
=
int
,
default
=
6767
,
help
=
"gloo http port"
)
ps_group
.
add_argument
(
"--with_gloo"
,
type
=
str
,
default
=
"0"
,
help
=
"use gloo or not"
)
# parameter elastic mode
elastic_group
=
parser
.
add_argument_group
(
"Elastic Parameters"
)
elastic_group
.
add_argument
(
"--max_restart"
,
type
=
int
,
default
=
3
,
help
=
"the times can restart. Default 3"
)
elastic_group
.
add_argument
(
"--elastic_level"
,
type
=
int
,
default
=-
1
,
help
=
"elastic level: -1 disable, 0 failed exit, peers hold, 1 internal restart"
)
elastic_group
.
add_argument
(
"--elastic_timeout"
,
type
=
int
,
default
=
30
,
help
=
"seconds to wait before elastic perform training"
)
return
parser
.
parse_args
()
def
_valide_env
(
self
,
key
):
if
key
in
[
'POD_IP'
]:
return
True
if
key
.
endswith
(
'_VISIBLE_DEVICES'
):
return
True
if
key
.
startswith
(
'PADDLE_'
):
return
True
return
False
def
fetch_envs
(
self
):
ge
=
os
.
environ
.
copy
()
black_env_list
=
[
'http_proxy'
,
'https_proxy'
]
for
key
in
black_env_list
:
ge
.
pop
(
key
,
None
)
return
ge
'''
# use black list instead white list
return {k: ge[k] for k in ge if self._valide_env(k)}
'''
def
get_logger
(
self
,
level
=
logging
.
INFO
):
logger
=
logging
.
getLogger
(
"PADDLERUN"
)
logger
.
setLevel
(
self
.
args
.
log
.
upper
()
or
level
)
formatter
=
logging
.
Formatter
(
fmt
=
'%(name)s %(levelname)s %(asctime)s %(message)s'
)
ch
=
logging
.
StreamHandler
()
ch
.
setFormatter
(
formatter
)
logger
.
addHandler
(
ch
)
return
logger
def
set_env_in_args
(
self
):
env_args
=
{
'POD_IP'
:
'host'
,
'PADDLE_MASTER'
:
'master'
,
'PADDLE_DEVICES'
:
'devices'
,
'PADDLE_NP'
:
'np'
,
'PADDLE_MODE'
:
'mode'
,
'PADDLE_LOG'
:
'log'
,
'PADDLE_NPROC_PER_NODE'
:
'nproc_per_node'
,
'PADDLE_JOB_ID'
:
'id'
,
'PADDLE_RANK'
:
'rank'
,
'PADDLE_LOG_DIR'
:
'log_dir'
,
'PADDLE_MAX_RESTlRT'
:
'max_restart'
,
'PADDLE_ELASTIC_LEVEL'
:
'elastic_level'
,
'PADDLE_ELASTIC_TIMEOUT'
:
'elastic_timeout'
,
'PADDLE_SERVER_NUM'
:
'server_num'
,
'PADDLE_TRAINER_NUM'
:
'trainer_num'
,
'PADDLE_SERVERS_ENDPOINTS'
:
'servers'
,
'PADDLE_TRAINERS_ENDPOINTS'
:
'trainers'
,
'PADDLE_GLOO_PORT'
:
'gloo_port'
,
'PADDLE_WITH_GLOO'
:
'with_gloo'
,
}
for
k
,
v
in
env_args
.
items
():
if
k
in
self
.
envs
:
setattr
(
self
.
args
,
v
,
self
.
envs
[
k
])
python/paddle/distributed/run/context/device.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
class
DeviceType
:
CPU
=
'cpu'
GPU
=
'gpu'
XPU
=
'xpu'
NPU
=
'npu'
class
Device
(
object
):
def
__init__
(
self
,
dtype
=
None
,
count
=
1
,
memory
=
""
,
labels
=
""
):
self
.
dtype
=
dtype
self
.
count
=
count
self
.
memory
=
memory
self
.
labels
=
labels
def
__str__
(
self
):
return
","
.
join
(
self
.
labels
)
@
classmethod
def
parse_device
(
self
):
dev
=
Device
()
visible_devices
=
None
if
'CUDA_VISIBLE_DEVICES'
in
os
.
environ
or
'NVIDIA_VISIBLE_DEVICES'
in
os
.
environ
:
dev
.
dtype
=
DeviceType
.
GPU
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
)
or
os
.
getenv
(
"NVIDIA_VISIBLE_DEVICES"
)
elif
'XPU_VISIBLE_DEVICES'
in
os
.
environ
:
dev
.
dtype
=
DeviceType
.
XPU
visible_devices
=
os
.
getenv
(
"XPU_VISIBLE_DEVICES"
)
elif
'ASCEND_VISIBLE_DEVICES'
in
os
.
environ
:
dev
.
dtype
=
DeviceType
.
NPU
visible_devices
=
os
.
getenv
(
"ASCEND_VISIBLE_DEVICES"
)
if
visible_devices
and
visible_devices
!=
'all'
:
dev
.
labels
=
visible_devices
.
split
(
','
)
dev
.
count
=
len
(
dev
.
labels
)
else
:
return
self
.
detect_device
()
return
dev
@
classmethod
def
detect_device
(
self
):
import
paddle.fluid
as
fluid
dev
=
Device
()
num
=
0
visible_devices
=
None
if
fluid
.
core
.
is_compiled_with_cuda
():
dev
.
dtype
=
DeviceType
.
GPU
num
=
fluid
.
core
.
get_cuda_device_count
()
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
)
or
os
.
getenv
(
"NVIDIA_VISIBLE_DEVICES"
)
elif
fluid
.
core
.
is_compiled_with_xpu
():
dev
.
dtype
=
DeviceType
.
XPU
num
=
fluid
.
core
.
get_xpu_device_count
()
visible_devices
=
os
.
getenv
(
"XPU_VISIBLE_DEVICES"
)
elif
fluid
.
core
.
is_compiled_with_npu
():
dev
.
dtype
=
DeviceType
.
NPU
num
=
fluid
.
core
.
get_npu_device_count
()
visible_devices
=
os
.
getenv
(
"ASCEND_VISIBLE_DEVICES"
)
if
num
==
0
:
dev
.
dtype
=
DeviceType
.
CPU
elif
visible_devices
is
None
or
visible_devices
==
"all"
or
visible_devices
==
""
:
dev
.
labels
=
[
str
(
x
)
for
x
in
range
(
0
,
num
)]
dev
.
count
=
num
else
:
dev
.
labels
=
visible_devices
.
split
(
','
)
dev
.
count
=
len
(
dev
.
labels
)
return
dev
python/paddle/distributed/run/context/event.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
Event
(
object
):
def
__init__
(
self
,
kind
=
"status"
,
message
=
""
,
fatal
=
False
):
self
.
kind
=
kind
self
.
message
=
message
self
.
fatal
=
fatal
python/paddle/distributed/run/context/node.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.device
import
Device
import
socket
import
struct
from
contextlib
import
closing
class
Node
(
object
):
def
__init__
(
self
):
# self.device = Device.detect_device()
self
.
device
=
Device
.
parse_device
()
self
.
ip
=
self
.
get_host_ip
()
self
.
free_ports
=
[]
def
get_host_ip
(
self
):
try
:
self
.
hostname
=
socket
.
gethostname
()
self
.
ip
=
socket
.
gethostbyname
(
socket
.
getfqdn
(
self
.
hostname
))
return
self
.
ip
except
:
return
'127.0.0.1'
def
get_free_ports
(
self
,
n
=
1
):
free_ports
=
[
self
.
get_free_port
()
for
i
in
range
(
n
)]
self
.
free_ports
+=
free_ports
return
free_ports
def
get_ports_occupied
(
self
):
return
self
.
free_ports
@
classmethod
def
get_free_port
(
self
):
with
closing
(
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
))
as
s
:
s
.
setsockopt
(
socket
.
SOL_SOCKET
,
socket
.
SO_LINGER
,
struct
.
pack
(
'ii'
,
1
,
0
))
s
.
bind
((
''
,
0
))
return
s
.
getsockname
()[
1
]
@
classmethod
def
is_server_ready
(
self
,
ip
,
port
):
with
closing
(
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
))
as
sock
:
#sock.settimeout(0.01)
#sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if
hasattr
(
socket
,
'SO_REUSEPORT'
):
sock
.
setsockopt
(
socket
.
SOL_SOCKET
,
socket
.
SO_REUSEPORT
,
1
)
result
=
sock
.
connect_ex
((
ip
,
int
(
port
)))
if
result
==
0
:
return
True
else
:
return
False
python/paddle/distributed/run/context/resource.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
Resource
(
object
):
def
__init__
(
self
):
self
.
devices
=
[]
python/paddle/distributed/run/context/status.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
Status
(
object
):
UNINIT
=
"uninit"
READY
=
"ready"
RUNNING
=
"running"
FAILED
=
"failed"
TERMINATING
=
"terminating"
RESTARTING
=
"restarting"
UNKNOWN
=
"unknown"
COMPLETED
=
"completed"
DONE
=
"done"
# should exit whatever status
def
__init__
(
self
):
self
.
_current_status
=
None
def
current
(
self
):
return
self
.
_current_status
def
is_running
(
self
):
return
self
.
_current_status
==
self
.
RUNNING
def
is_restarting
(
self
):
return
self
.
_current_status
==
self
.
RESTARTING
def
is_done
(
self
):
if
self
.
_current_status
in
[
self
.
DONE
,
self
.
COMPLETED
,
self
.
FAILED
]:
return
True
else
:
return
False
def
run
(
self
):
self
.
_current_status
=
self
.
RUNNING
def
fail
(
self
):
self
.
_current_status
=
self
.
FAILED
def
complete
(
self
):
self
.
_current_status
=
self
.
COMPLETED
def
restart
(
self
):
self
.
_current_status
=
self
.
RESTARTING
def
done
(
self
):
self
.
_current_status
=
self
.
DONE
python/paddle/distributed/run/controllers/__init__.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__
=
[
"init"
]
from
.collective
import
CollectiveController
from
.collective
import
CollectiveElasticController
from
.ps
import
PSController
# the order is extremely important
_controllers
=
[
CollectiveElasticController
,
PSController
,
CollectiveController
,
]
def
init
(
ctx
):
for
c
in
_controllers
:
if
c
.
enable
(
ctx
):
return
c
(
ctx
)
python/paddle/distributed/run/controllers/collective.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.controller
import
Controller
import
json
import
os
import
six
import
time
class
CollectiveController
(
Controller
):
@
classmethod
def
enable
(
cls
,
ctx
):
if
ctx
:
ctx
.
logger
.
debug
(
"{} enabled"
.
format
(
cls
.
__name__
))
return
True
else
:
return
False
def
build_pod
(
self
):
self
.
pod
.
replicas
=
self
.
pod_replicas
()
# rank will be reset when restart
self
.
pod
.
rank
=
self
.
ctx
.
args
.
rank
port
=
self
.
ctx
.
node
.
get_free_port
()
# compatible
endpoints
=
[
"{}:{}"
.
format
(
self
.
ctx
.
node
.
ip
,
p
)
for
p
in
self
.
ctx
.
node
.
get_free_ports
(
self
.
pod
.
replicas
)
]
data
=
json
.
dumps
({
'name'
:
self
.
pod
.
name
,
'rank'
:
self
.
pod
.
rank
,
'replicas'
:
self
.
pod
.
replicas
,
'dtype'
:
self
.
ctx
.
node
.
device
.
dtype
,
'candidate'
:
'{}:{}'
.
format
(
self
.
ctx
.
node
.
ip
,
port
),
'endpoints'
:
","
.
join
(
endpoints
),
})
peer_list
,
rank
=
self
.
master
.
sync_peers
(
'/{}/info'
.
format
(
self
.
job
.
id
),
self
.
pod
.
name
,
data
,
self
.
job
.
replicas
,
self
.
pod
.
rank
)
self
.
pod
.
rank
=
rank
if
len
(
peer_list
)
<
1
:
return
False
peer_list
=
[
json
.
loads
(
i
)
for
i
in
peer_list
]
self
.
ctx
.
logger
.
debug
(
"sync peers done {}"
.
format
(
peer_list
))
self
.
save_pod_log
(
peer_list
)
global_size
=
sum
([
i
[
'replicas'
]
for
i
in
peer_list
])
rank_offset
=
sum
([
i
[
'replicas'
]
for
i
in
peer_list
[:
rank
]])
'''
The new designed collective need nothing but a master endpoint
'''
collective_master
=
peer_list
[
0
][
'candidate'
]
job_endpoints
=
[
i
[
'endpoints'
]
for
i
in
peer_list
]
self
.
pod
.
reset
()
for
i
in
range
(
self
.
pod
.
replicas
):
e
=
{
"PADDLE_MASTER"
:
collective_master
,
"PADDLE_GLOBAL_SIZE"
:
"{}"
.
format
(
global_size
),
"PADDLE_LOCAL_SIZE"
:
"{}"
.
format
(
self
.
pod
.
replicas
),
"PADDLE_GLOBAL_RANK"
:
"{}"
.
format
(
i
+
rank_offset
),
"PADDLE_LOCAL_RANK"
:
"{}"
.
format
(
i
),
## compatible env
"PADDLE_TRAINER_ENDPOINTS"
:
","
.
join
(
job_endpoints
),
"PADDLE_CURRENT_ENDPOINT"
:
endpoints
[
i
],
"PADDLE_TRAINER_ID"
:
"{}"
.
format
(
i
+
rank_offset
),
"PADDLE_TRAINERS_NUM"
:
"{}"
.
format
(
global_size
),
"PADDLE_RANK_IN_NODE"
:
str
(
i
),
}
self
.
add_container
(
envs
=
e
,
log_tag
=
i
)
return
True
class
CollectiveElasticController
(
CollectiveController
):
@
classmethod
def
enable
(
cls
,
ctx
):
if
ctx
.
args
.
master
and
ctx
.
args
.
master
.
startswith
(
"etcd://"
):
ctx
.
logger
.
debug
(
"{} enabled"
.
format
(
cls
.
__name__
))
return
True
else
:
return
False
def
register
(
self
):
if
self
.
job
.
id
==
'default'
:
self
.
ctx
.
logger
.
warning
(
'Using default job name may cause conflict, add --id in args'
)
self
.
master
.
register_heartbeat
(
self
.
job
.
id
,
self
.
pod
.
name
)
def
watch
(
self
)
->
bool
:
'''
watch self and peer status, return true to exit
'''
while
not
self
.
ctx
.
status
.
is_done
():
# self status
status
=
self
.
pod
.
watch
(
timeout
=
2
)
self
.
ctx
.
logger
.
debug
(
"Pod status {}, Ctx status {}"
.
format
(
status
,
self
.
ctx
.
status
.
current
()))
# completed
if
status
==
self
.
ctx
.
status
.
COMPLETED
:
self
.
master
.
set_status
(
status
)
self
.
ctx
.
status
.
complete
()
self
.
ctx
.
logger
.
info
(
"Pod complete {}"
.
format
(
status
))
return
True
# self failure
elif
status
==
self
.
ctx
.
status
.
FAILED
:
self
.
master
.
set_status
(
status
)
self
.
master
.
restart_peer
()
self
.
ctx
.
logger
.
info
(
"Pod failed {}"
.
format
(
status
))
self
.
pod
.
stop
()
if
self
.
ctx
.
args
.
elastic_level
<=
0
:
return
True
else
:
return
False
# peer failure
if
self
.
ctx
.
status
.
is_restarting
()
and
self
.
master
.
get_status
(
)
!=
self
.
ctx
.
status
.
COMPLETED
:
self
.
pod
.
stop
()
return
False
#peers = self.master.fetch_peer_alive()
#print("peers {}".format(peers))
def
run
(
self
):
timeout
=
self
.
ctx
.
args
.
elastic_timeout
if
self
.
job
.
elastic
else
self
.
ctx
.
args
.
elastic_timeout
*
10
self
.
register
()
while
self
.
pod
.
restart
<=
self
.
ctx
.
args
.
max_restart
:
self
.
build_job
()
ok
,
replicas
=
self
.
master
.
wait_peer_ready
(
self
.
job
.
replicas_min
,
self
.
job
.
replicas_max
,
timeout
)
if
ok
:
self
.
job
.
replicas
=
replicas
else
:
self
.
ctx
.
logger
.
warnning
(
"peer not ready {}"
.
format
(
self
.
job
))
break
self
.
ctx
.
logger
.
debug
(
"Run {}"
.
format
(
self
.
job
))
if
not
self
.
build_pod
():
continue
self
.
master
.
set_status
(
self
.
ctx
.
status
.
RUNNING
)
self
.
ctx
.
status
.
run
()
assert
len
(
self
.
pod
.
containers
)
>
0
,
"No container in the pod"
self
.
ctx
.
logger
.
debug
(
"Run {}"
.
format
(
self
.
pod
))
self
.
ctx
.
logger
.
debug
(
"Run {}"
.
format
(
self
.
pod
.
containers
[
0
]))
self
.
pod
.
deploy
()
if
self
.
watch
():
break
self
.
ctx
.
logger
.
debug
(
"Job done {}"
.
format
(
self
.
job
))
python/paddle/distributed/run/controllers/controller.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
os
import
signal
from
paddle.distributed.run.job
import
Job
from
paddle.distributed.run.job
import
Pod
from
paddle.distributed.run.job
import
Container
from
.master
import
Master
import
time
class
ControleMode
:
COLLECTIVE
=
"collective"
PS
=
"ps"
class
ControllerBase
(
object
):
def
__init__
(
self
,
ctx
):
signal
.
signal
(
signal
.
SIGTERM
,
self
.
signal_handler
)
signal
.
signal
(
signal
.
SIGABRT
,
self
.
signal_handler
)
signal
.
signal
(
signal
.
SIGINT
,
self
.
signal_handler
)
self
.
ctx
=
ctx
self
.
master
=
Master
.
factory
(
self
.
ctx
)
self
.
job
=
Job
(
np
=
self
.
ctx
.
args
.
np
,
mode
=
self
.
ctx
.
args
.
mode
,
id
=
self
.
ctx
.
args
.
id
)
self
.
pod
=
Pod
()
self
.
join_server
=
None
def
run
(
self
):
self
.
build_job
()
self
.
build_pod
()
if
len
(
self
.
pod
.
containers
)
<
1
:
self
.
ctx
.
logger
.
error
(
"No container in the pod {}"
.
format
(
self
.
pod
))
return
self
.
ctx
.
logger
.
info
(
"Run {}"
.
format
(
self
.
pod
))
self
.
ctx
.
logger
.
debug
(
self
.
pod
.
containers
[
0
])
self
.
pod
.
deploy
()
self
.
watch
()
def
watch
(
self
)
->
bool
:
status
=
self
.
pod
.
watch
()
if
status
==
self
.
ctx
.
status
.
COMPLETED
:
self
.
ctx
.
logger
.
info
(
"Pod {}"
.
format
(
status
))
elif
status
==
self
.
ctx
.
status
.
FAILED
:
self
.
ctx
.
logger
.
info
(
"Pod {}"
.
format
(
status
))
self
.
ctx
.
logger
.
error
(
"Container failed !!!
\n
{}"
.
format
(
self
.
pod
.
failed_container
()))
self
.
pod
.
tail
()
self
.
pod
.
stop
()
def
stop
(
self
,
sigint
=
None
):
self
.
ctx
.
logger
.
debug
(
"Controller stop"
)
self
.
master
.
stop
()
self
.
pod
.
stop
(
sigint
)
def
finalize
(
self
):
self
.
pod
.
join
()
self
.
master
.
stop
()
self
.
ctx
.
logger
.
info
(
"Exit code {}"
.
format
(
self
.
pod
.
exit_code
))
sys
.
exit
(
self
.
pod
.
exit_code
)
def
signal_handler
(
self
,
sigint
,
frame
):
self
.
ctx
.
logger
.
info
(
"Terminating with signal {}"
.
format
(
sigint
))
if
hasattr
(
self
,
'sigint'
):
time
.
sleep
(
5
)
sys
.
exit
(
sigint
)
self
.
sigint
=
sigint
self
.
ctx
.
status
.
done
()
self
.
stop
(
sigint
)
time
.
sleep
(
1
)
self
.
ctx
.
logger
.
debug
(
"Exit with signal {}"
.
format
(
sigint
))
sys
.
exit
(
sigint
)
class
Controller
(
ControllerBase
):
'''
Controller API for customization
'''
def
build_job
(
self
):
'''
build job fill the job info.
'''
self
.
ctx
.
logger
.
info
(
self
.
job
)
def
build_pod
(
self
)
->
bool
:
'''
build pod includes creating containers etc.
Return True if succeed
'''
raise
NotImplementedError
def
_get_entrypoint
(
self
):
entrypoint
=
[
sys
.
executable
,
"-u"
,
self
.
ctx
.
args
.
training_script
]
entrypoint
.
extend
(
self
.
ctx
.
args
.
training_script_args
)
return
entrypoint
def
_get_out_err_file
(
self
,
out
=
None
,
err
=
None
):
if
out
and
self
.
ctx
.
args
.
log_dir
!=
""
:
out
=
os
.
path
.
join
(
self
.
ctx
.
args
.
log_dir
,
out
)
if
err
and
self
.
ctx
.
args
.
log_dir
!=
""
:
err
=
os
.
path
.
join
(
self
.
ctx
.
args
.
log_dir
,
err
)
return
out
,
(
err
or
out
)
def
new_container
(
self
,
entrypoint
=
None
,
envs
=
{},
use_ctx_env
=
True
,
out
=
None
,
err
=
None
):
c
=
Container
(
entrypoint
=
(
entrypoint
or
self
.
_get_entrypoint
()),
env
=
(
self
.
ctx
.
get_envs
()
if
use_ctx_env
else
{}),
)
c
.
outfile
,
c
.
errfile
=
self
.
_get_out_err_file
(
out
,
err
)
c
.
update_env
(
envs
)
return
c
def
add_container
(
self
,
container
=
None
,
entrypoint
=
None
,
envs
=
{},
log_tag
=
None
,
is_init
=
False
):
if
not
is_init
and
log_tag
is
not
None
:
log_file
=
"{}.{}.{}.log"
.
format
(
self
.
job
.
id
,
self
.
pod
.
name
,
log_tag
)
else
:
log_file
=
None
if
not
container
:
container
=
self
.
new_container
(
entrypoint
=
entrypoint
,
envs
=
envs
,
out
=
log_file
,
err
=
log_file
)
if
is_init
:
self
.
pod
.
add_init_container
(
container
)
else
:
self
.
pod
.
add_container
(
container
)
def
pod_replicas
(
self
):
'''
how many process/container should be run in pod
'''
if
self
.
ctx
.
args
.
nproc_per_node
:
return
int
(
self
.
ctx
.
args
.
nproc_per_node
)
else
:
return
self
.
ctx
.
node
.
device
.
count
def
save_pod_log
(
self
,
info
):
'''
save_pod_log append *info* to the log file of pod.name
'''
if
not
self
.
ctx
.
args
.
log_dir
:
return
f
=
os
.
path
.
join
(
self
.
ctx
.
args
.
log_dir
,
'{}.{}.log'
.
format
(
self
.
job
.
id
,
self
.
pod
.
name
))
try
:
os
.
makedirs
(
os
.
path
.
dirname
(
f
),
exist_ok
=
True
)
with
open
(
f
,
'a+'
)
as
fd
:
fd
.
write
(
str
(
info
))
except
Exception
as
e
:
self
.
ctx
.
logger
.
error
(
"save log failed because {}"
.
format
(
e
))
python/paddle/distributed/run/controllers/master.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.distributed.run.utils.kv_client
import
KVClient
from
paddle.distributed.run.utils.kv_server
import
KVServer
import
time
import
sys
import
six
import
threading
import
copy
import
random
ETCD_PROTOCAL
=
'etcd://'
class
Master
(
object
):
'''
Master is a distributed store design to exchange info among nodes
'''
MAIN
=
"main"
STANDBY
=
"standby"
PATICIPANT
=
"participant"
def
__init__
(
self
,
ctx
):
self
.
ctx
=
ctx
self
.
server
=
None
self
.
initialized
=
False
self
.
endpoint
=
None
def
stop
(
self
):
raise
NotImplementedError
def
sync_peers
(
self
,
prefix
,
key
,
value
,
size
,
rank
=-
1
)
->
(
list
,
int
):
raise
NotImplementedError
@
classmethod
def
factory
(
cls
,
ctx
):
if
ctx
.
args
.
master
and
ctx
.
args
.
master
.
startswith
(
ETCD_PROTOCAL
):
return
ETCDMaster
(
ctx
)
else
:
return
HTTPMaster
(
ctx
)
class
HTTPMaster
(
Master
):
def
lazy_init
(
self
):
if
self
.
initialized
:
return
self
.
role
=
Master
.
PATICIPANT
if
self
.
ctx
.
args
.
master
:
self
.
endpoint
=
self
.
ctx
.
args
.
master
ip
,
port
=
self
.
endpoint
.
split
(
':'
)
if
ip
in
[
'127.0.0.1'
,
self
.
ctx
.
node
.
ip
]:
time
.
sleep
(
2
*
random
.
random
())
while
not
self
.
ctx
.
node
.
is_server_ready
(
ip
,
int
(
port
)):
try
:
self
.
server
=
KVServer
(
int
(
port
))
self
.
role
=
Master
.
MAIN
break
except
Exception
as
e
:
self
.
ctx
.
logger
.
warning
(
"start master failed {}"
.
format
(
e
))
time
.
sleep
(
0.1
)
continue
else
:
port
=
self
.
ctx
.
node
.
get_free_port
()
self
.
endpoint
=
"{}:{}"
.
format
(
self
.
ctx
.
node
.
ip
,
port
)
self
.
server
=
KVServer
(
port
)
self
.
role
=
Master
.
MAIN
print
(
"Copy the following command to other nodes to run."
)
cmd
=
[
sys
.
executable
.
split
(
'/'
)[
-
1
],
"-m"
,
"paddle.distributed.run"
]
cmd
.
extend
([
"--master"
,
self
.
endpoint
])
cmd
.
extend
(
sys
.
argv
[
1
:])
print
(
"-"
*
80
)
print
(
" "
.
join
(
cmd
))
print
(
"-"
*
80
)
if
self
.
ctx
.
args
.
rank
>=
0
:
self
.
ctx
.
logger
.
warning
(
"--rank set in the command may not compatible in auto mode"
)
if
'127.0.0.1'
in
self
.
endpoint
:
self
.
endpoint
=
self
.
endpoint
.
replace
(
'127.0.0.1'
,
self
.
ctx
.
node
.
ip
)
self
.
client
=
KVClient
(
self
.
endpoint
)
self
.
initialized
=
True
self
.
_start_server
()
def
_start_server
(
self
):
if
self
.
server
and
not
self
.
server
.
started
:
self
.
server
.
start
()
self
.
ctx
.
logger
.
debug
(
"KV server start at {}"
.
format
(
self
.
endpoint
))
def
_stop_server
(
self
):
if
self
.
server
and
not
self
.
server
.
stopped
:
self
.
server
.
stop
()
self
.
ctx
.
logger
.
debug
(
"KV server stopped"
)
def
stop
(
self
):
self
.
_stop_server
()
def
sync_peers
(
self
,
prefix
,
key
,
value
,
size
,
rank
=-
1
)
->
(
list
,
int
):
if
size
<
2
:
return
[
value
],
0
self
.
lazy_init
()
while
not
self
.
ctx
.
status
.
is_done
():
if
self
.
client
.
wait_server_ready
(
timeout
=
5
):
break
else
:
self
.
ctx
.
logger
.
warning
(
"master not ready"
)
time
.
sleep
(
0.1
)
# 'aaaaaa' make suer main pod (master server) as rank 0
ky
=
'aaaaaa'
if
rank
<
0
and
self
.
role
==
Master
.
MAIN
else
key
k
=
"{}/{}/{}"
.
format
(
prefix
,
ky
,
rank
)
while
not
self
.
ctx
.
status
.
is_done
():
if
not
self
.
client
.
put
(
k
,
value
):
self
.
ctx
.
logger
.
warning
(
"put value failed"
)
time
.
sleep
(
0.1
)
continue
rjson
=
self
.
client
.
get_prefix
(
prefix
)
self
.
ctx
.
logger
.
debug
(
"sync peers {}"
.
format
(
rjson
))
if
rjson
and
len
(
rjson
)
==
size
:
if
rank
<
0
:
keys
=
list
(
rjson
.
keys
())
keys
.
sort
()
ret
=
[
rjson
[
k
]
for
k
in
keys
]
idx
=
ret
.
index
(
value
)
return
ret
,
idx
else
:
ret
=
[
None
]
*
size
for
k
,
v
in
rjson
.
items
():
ret
[
int
(
k
.
split
(
'/'
)[
-
1
])]
=
v
return
ret
,
rank
else
:
time
.
sleep
(
0.5
)
return
[],
0
class
ETCDMaster
(
Master
):
def
__init__
(
self
,
ctx
):
super
().
__init__
(
ctx
)
if
self
.
ctx
.
args
.
master
:
# etcd://localhost:2379
self
.
endpoint
=
self
.
ctx
.
args
.
master
.
strip
(
"etcd://"
)
import
etcd3
host
,
port
=
self
.
endpoint
.
split
(
':'
)
self
.
client
=
etcd3
.
client
(
host
=
host
,
port
=
port
)
def
sync_peers
(
self
,
prefix
,
key
,
value
,
size
,
rank
=-
1
)
->
(
list
,
int
):
'''
sync_peers gather all value for key under scope prefix
result always be sorted either by rank or alphabet of pod.name
'''
path
=
"{}/{}/{}"
.
format
(
prefix
,
key
,
rank
)
self
.
client
.
delete_prefix
(
prefix
)
self
.
ctx
.
logger
.
debug
(
"sync path {} value {}"
.
format
(
path
,
value
))
while
not
self
.
ctx
.
status
.
is_done
():
self
.
client
.
put
(
path
,
six
.
b
(
value
))
result
=
[
i
for
i
in
self
.
client
.
get_prefix
(
prefix
)]
result
=
copy
.
deepcopy
(
result
)
self
.
ctx
.
logger
.
debug
(
"sync peers {}"
.
format
(
result
))
if
len
(
result
)
==
size
:
if
rank
<
0
:
keys
=
[
six
.
ensure_str
(
i
[
1
].
key
)
for
i
in
result
]
sorted_keys
=
[
six
.
ensure_str
(
i
[
1
].
key
)
for
i
in
result
]
sorted_keys
.
sort
()
values
=
[
six
.
ensure_str
(
i
[
0
])
for
i
in
result
]
ret
=
[
values
[
keys
.
index
(
k
)]
for
k
in
sorted_keys
]
idx
=
ret
.
index
(
value
)
return
ret
,
idx
else
:
ret
=
[
None
]
*
size
for
v
,
k
in
result
:
ii
=
int
(
six
.
ensure_str
(
k
.
key
).
split
(
'/'
)[
-
1
])
if
ii
<
0
:
self
.
ctx
.
logger
.
error
(
"rank {} error in sync"
.
format
(
ii
))
ret
[
ii
]
=
six
.
ensure_str
(
v
)
return
ret
,
rank
else
:
time
.
sleep
(
0.5
)
def
register_heartbeat
(
self
,
job_id
,
pod_id
,
ttl
=
10
):
if
hasattr
(
self
,
'heartbeat_prefix'
):
self
.
ctx
.
logger
.
warning
(
"Heartbeat already done"
)
return
self
.
job_prefix
=
'/paddle/{}'
.
format
(
job_id
)
self
.
heartbeat_prefix
=
'{}/heartbeat'
.
format
(
self
.
job_prefix
)
lease
=
self
.
client
.
lease
(
ttl
)
#self.client.delete_prefix(self.job_prefix)
beat_path
=
"{}/{}"
.
format
(
self
.
heartbeat_prefix
,
pod_id
)
self
.
client
.
put
(
beat_path
,
six
.
b
(
pod_id
),
lease
=
lease
)
def
_beat_watch
(
event
):
self
.
ctx
.
status
.
restart
()
beat_watch
=
self
.
client
.
add_watch_prefix_callback
(
self
.
heartbeat_prefix
,
_beat_watch
)
def
_heartbeat
():
while
not
self
.
ctx
.
status
.
is_done
():
try
:
lease
.
refresh
()
if
pod_id
not
in
self
.
fetch_peer_alive
():
self
.
client
.
put
(
beat_path
,
six
.
b
(
pod_id
),
lease
=
lease
)
self
.
ctx
.
logger
.
debug
(
"Heartbeat register again"
)
except
Exception
as
e
:
self
.
ctx
.
logger
.
error
(
"Heartbeat error {}"
.
format
(
e
))
time
.
sleep
(
ttl
/
2
)
self
.
ctx
.
logger
.
debug
(
"Heartbeat done"
)
self
.
client
.
cancel_watch
(
beat_watch
)
self
.
beat_thread
=
threading
.
Thread
(
name
=
'heartbeat'
,
target
=
_heartbeat
,
daemon
=
True
)
self
.
beat_thread
.
start
()
def
fetch_peer_alive
(
self
):
peer_alive
=
[
six
.
ensure_str
(
i
[
0
])
for
i
in
self
.
client
.
get_prefix
(
self
.
heartbeat_prefix
)
]
self
.
ctx
.
logger
.
debug
(
"peer alive {}"
.
format
(
peer_alive
))
return
peer_alive
def
wait_peer_ready
(
self
,
replicas_min
,
replicas_max
,
timeout
):
end
=
time
.
time
()
+
timeout
while
not
self
.
ctx
.
status
.
is_done
()
and
time
.
time
()
<
end
:
if
len
(
self
.
fetch_peer_alive
())
==
replicas_max
:
return
(
True
,
replicas_max
)
else
:
time
.
sleep
(
0.5
)
np
=
len
(
self
.
fetch_peer_alive
())
if
np
>=
replicas_min
and
np
<=
replicas_max
:
return
(
True
,
np
)
else
:
return
(
False
,
np
)
def
restart_peer
(
self
):
self
.
client
.
delete_prefix
(
self
.
heartbeat_prefix
)
def
set_status
(
self
,
status
):
assert
self
.
client
.
put
(
self
.
job_prefix
,
six
.
b
(
status
),
lease
=
self
.
client
.
lease
(
600
)),
"set status failed {}"
.
format
(
status
)
def
get_status
(
self
):
return
six
.
ensure_str
(
self
.
client
.
get
(
self
.
job_prefix
)[
0
]
or
''
)
def
stop
(
self
):
if
hasattr
(
self
,
'beat_thread'
):
self
.
ctx
.
status
.
done
()
# TODO(kuizhiqing) thread should exit
#self.beat_thread.join()
python/paddle/distributed/run/controllers/ps.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.controller
import
Controller
,
ControleMode
import
json
import
os
,
shutil
class
PSController
(
Controller
):
@
classmethod
def
enable
(
cls
,
ctx
):
if
ctx
.
args
.
mode
==
ControleMode
.
PS
or
ctx
.
args
.
server_num
or
len
(
ctx
.
args
.
servers
)
>
0
:
ctx
.
logger
.
debug
(
"{} enabled"
.
format
(
cls
.
__name__
))
ctx
.
args
.
mode
=
ControleMode
.
PS
return
True
else
:
return
False
def
build_pod
(
self
):
if
self
.
ctx
.
args
.
servers
and
self
.
ctx
.
args
.
trainers
:
self
.
_build_pod_with_args
()
else
:
self
.
_build_pod_with_master
()
def
_build_pod_with_args
(
self
):
if
'127.0.0.1'
in
self
.
ctx
.
args
.
servers
:
host
=
'127.0.0.1'
else
:
host
=
self
.
ctx
.
node
.
ip
server_endpoints
=
[
s
for
s
in
self
.
ctx
.
args
.
servers
.
split
(
","
)]
trainer_endpoints
=
[
s
for
s
in
self
.
ctx
.
args
.
trainers
.
split
(
","
)]
servers
=
[
s
for
s
in
self
.
ctx
.
args
.
servers
.
split
(
","
)
if
s
.
startswith
(
host
)
]
trainers
=
[
s
for
s
in
self
.
ctx
.
args
.
trainers
.
split
(
","
)
if
s
.
startswith
(
host
)
]
server_num
=
len
(
servers
)
trainer_num
=
len
(
trainers
)
self
.
pod
.
replicas
=
server_num
+
trainer_num
self
.
save_pod_log
([
server_endpoints
,
trainer_endpoints
])
import
tempfile
gloo_rendezvous_dir
=
tempfile
.
mkdtemp
()
if
os
.
path
.
exists
(
gloo_rendezvous_dir
):
shutil
.
rmtree
(
gloo_rendezvous_dir
)
gloo_port
=
self
.
ctx
.
args
.
gloo_port
gloo_http
=
"{}:{}"
.
format
(
server_endpoints
[
0
].
split
(
":"
)[
0
],
gloo_port
)
_gloo_envs
=
{
"PADDLE_GLOO_RENDEZVOUS"
:
"3"
,
"PADDLE_GLOO_FS_PATH"
:
gloo_rendezvous_dir
,
"PADDLE_GLOO_HTTP_ENDPOINT"
:
gloo_http
,
"PADDLE_WITH_GLOO"
:
self
.
ctx
.
args
.
with_gloo
}
for
i
in
range
(
server_num
):
e
=
{
"PADDLE_PSERVERS_IP_PORT_LIST"
:
self
.
ctx
.
args
.
servers
,
"PADDLE_TRAINER_ENDPOINTS"
:
self
.
ctx
.
args
.
trainers
,
"PADDLE_PORT"
:
servers
[
i
].
split
(
":"
)[
1
],
"PADDLE_ROLE"
:
"PSERVER"
,
"TRAINING_ROLE"
:
"PSERVER"
,
"PADDLE_TRAINERS_NUM"
:
"{}"
.
format
(
len
(
trainer_endpoints
)),
"POD_IP"
:
self
.
ctx
.
node
.
ip
,
}
e
.
update
(
_gloo_envs
)
log_tag
=
"ps.{}"
.
format
(
i
)
self
.
add_container
(
envs
=
e
,
log_tag
=
log_tag
)
trainer_rank_offset
=
0
for
s
in
trainer_endpoints
:
if
s
.
startswith
(
host
):
break
else
:
trainer_rank_offset
+=
1
for
i
in
range
(
trainer_num
):
e
=
{
"PADDLE_PSERVERS_IP_PORT_LIST"
:
","
.
join
(
server_endpoints
),
"PADDLE_TRAINER_ENDPOINTS"
:
","
.
join
(
trainer_endpoints
),
"PADDLE_PORT"
:
trainers
[
i
].
split
(
":"
)[
1
],
"PADDLE_ROLE"
:
"TRAINER"
,
"TRAINING_ROLE"
:
"TRAINER"
,
"PADDLE_TRAINER_ID"
:
"{}"
.
format
(
i
+
trainer_rank_offset
),
"PADDLE_TRAINERS_NUM"
:
"{}"
.
format
(
len
(
trainer_endpoints
)),
"POD_IP"
:
self
.
ctx
.
node
.
ip
,
}
e
.
update
(
_gloo_envs
)
log_tag
=
"trainer.{}"
.
format
(
i
)
self
.
add_container
(
envs
=
e
,
log_tag
=
log_tag
)
def
_build_pod_with_master
(
self
):
self
.
pod
.
rank
=
self
.
ctx
.
args
.
rank
server_num
=
self
.
ctx
.
args
.
server_num
or
1
servers
=
[
"{}:{}"
.
format
(
self
.
ctx
.
node
.
ip
,
p
)
for
p
in
self
.
ctx
.
node
.
get_free_ports
(
server_num
)
]
trainer_num
=
self
.
ctx
.
args
.
trainer_num
or
1
trainers
=
[
"{}:{}"
.
format
(
self
.
ctx
.
node
.
ip
,
p
)
for
p
in
self
.
ctx
.
node
.
get_free_ports
(
trainer_num
)
]
data
=
json
.
dumps
({
'name'
:
self
.
pod
.
name
,
'rank'
:
self
.
pod
.
rank
,
'servers'
:
servers
,
'trainers'
:
trainers
,
'dtype'
:
self
.
ctx
.
node
.
device
.
dtype
,
'gloo_port'
:
self
.
ctx
.
node
.
get_free_port
(),
})
peer_list
,
rank
=
self
.
master
.
sync_peers
(
'/{}/info'
.
format
(
self
.
job
.
id
),
self
.
pod
.
name
,
data
,
self
.
job
.
replicas
,
self
.
pod
.
rank
)
self
.
ctx
.
logger
.
debug
(
"sync peers done {}"
.
format
(
peer_list
))
peer_list
=
[
json
.
loads
(
i
)
for
i
in
peer_list
]
self
.
save_pod_log
(
peer_list
)
server_endpoints
=
[
j
for
i
in
peer_list
for
j
in
i
[
'servers'
]]
trainer_endpoints
=
[
j
for
i
in
peer_list
for
j
in
i
[
'trainers'
]]
#rank_offset = sum([i['replicas'] for i in peer_list[:rank]])
server_rank_offset
=
sum
([
len
(
i
[
'servers'
])
for
i
in
peer_list
[:
rank
]])
trainer_rank_offset
=
sum
(
[
len
(
i
[
'trainers'
])
for
i
in
peer_list
[:
rank
]])
self
.
pod
.
rank
=
rank
self
.
pod
.
replicas
=
server_num
+
trainer_num
import
tempfile
gloo_rendezvous_dir
=
tempfile
.
mkdtemp
()
if
os
.
path
.
exists
(
gloo_rendezvous_dir
):
shutil
.
rmtree
(
gloo_rendezvous_dir
)
gloo_port
=
peer_list
[
0
][
'gloo_port'
]
gloo_http
=
"{}:{}"
.
format
(
server_endpoints
[
0
].
split
(
":"
)[
0
],
gloo_port
)
_gloo_envs
=
{
"PADDLE_GLOO_RENDEZVOUS"
:
"3"
,
"PADDLE_GLOO_FS_PATH"
:
gloo_rendezvous_dir
,
"PADDLE_GLOO_HTTP_ENDPOINT"
:
gloo_http
,
"PADDLE_WITH_GLOO"
:
self
.
ctx
.
args
.
with_gloo
}
for
i
in
range
(
server_num
):
e
=
{
"PADDLE_PSERVERS_IP_PORT_LIST"
:
","
.
join
(
server_endpoints
),
"PADDLE_TRAINER_ENDPOINTS"
:
","
.
join
(
trainer_endpoints
),
"PADDLE_PORT"
:
server_endpoints
[
i
+
server_rank_offset
].
split
(
":"
)[
1
],
"PADDLE_ROLE"
:
"PSERVER"
,
"TRAINING_ROLE"
:
"PSERVER"
,
"PADDLE_TRAINERS_NUM"
:
"{}"
.
format
(
len
(
trainer_endpoints
)),
"POD_IP"
:
self
.
ctx
.
node
.
ip
,
}
e
.
update
(
_gloo_envs
)
log_tag
=
"ps.{}"
.
format
(
i
)
self
.
add_container
(
envs
=
e
,
log_tag
=
log_tag
)
for
i
in
range
(
trainer_num
):
e
=
{
"PADDLE_PSERVERS_IP_PORT_LIST"
:
","
.
join
(
server_endpoints
),
"PADDLE_TRAINER_ENDPOINTS"
:
","
.
join
(
trainer_endpoints
),
"PADDLE_PORT"
:
trainer_endpoints
[
i
+
trainer_rank_offset
].
split
(
":"
)[
1
],
"PADDLE_ROLE"
:
"TRAINER"
,
"TRAINING_ROLE"
:
"TRAINER"
,
"PADDLE_TRAINER_ID"
:
"{}"
.
format
(
i
+
trainer_rank_offset
),
"PADDLE_TRAINERS_NUM"
:
"{}"
.
format
(
len
(
trainer_endpoints
)),
"POD_IP"
:
self
.
ctx
.
node
.
ip
,
}
e
.
update
(
_gloo_envs
)
log_tag
=
"trainer.{}"
.
format
(
i
)
self
.
add_container
(
envs
=
e
,
log_tag
=
log_tag
)
''' NEW VERSION
for i in range(server_num):
e = {
"PADDLE_PSERVER_ENDPOINTS": ",".join(server_endpoints),
"PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints),
"PADDLE_ROLE": "PSERVER",
"PADDLE_RANK": "{}".format(i + server_rank_offset),
}
log_tag = "ps.{}".format(i)
self.add_container(envs=e, log_tag=log_tag)
for i in range(trainer_num):
e = {
"PADDLE_PSERVER_ENDPOINTS": ",".join(server_endpoints),
"PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints),
"PADDLE_ROLE": "TRAINER_CPU",
"PADDLE_RANK": "{}".format(i + trainer_rank_offset),
}
log_tag = "trainer.{}".format(i)
self.add_container(envs=e, log_tag=log_tag)
'''
python/paddle/distributed/run/job/__init__.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.pod
import
Pod
from
.job
import
Job
from
.container
import
Container
from
.status
import
Status
__all__
=
[
'Pod'
,
'Job'
,
'Container'
,
'Status'
,
]
python/paddle/distributed/run/job/container.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
OrderedDict
from
paddle.distributed.run.utils.process_context
import
ProcessContext
from
.status
import
Status
import
os
,
copy
,
sys
import
time
class
Container
(
object
):
'''
TODO(kuizhiqing) A container can be run by process/thread or just a callable function
'''
def
__init__
(
self
,
entrypoint
=
[],
rank
=-
1
,
env
=
{}):
self
.
_entrypoint
=
entrypoint
self
.
_rank
=
rank
self
.
_out
=
None
self
.
_err
=
None
self
.
_env
=
env
self
.
_proc
=
None
self
.
_retry
:
int
=
3
self
.
_grace_period
=
10
self
.
_log_handler
=
None
@
property
def
entrypoint
(
self
):
return
self
.
_entrypoint
@
entrypoint
.
setter
def
entrypoint
(
self
,
entry
):
self
.
_entrypoint
=
entry
@
property
def
rank
(
self
):
return
self
.
_rank
@
rank
.
setter
def
rank
(
self
,
r
):
self
.
_rank
=
r
@
property
def
outfile
(
self
):
return
self
.
_out
@
outfile
.
setter
def
outfile
(
self
,
out
):
self
.
_out
=
out
@
property
def
errfile
(
self
):
return
self
.
_err
@
errfile
.
setter
def
errfile
(
self
,
err
):
self
.
_err
=
err
def
update_env
(
self
,
env
=
{},
**
kwargs
):
env
=
{
k
:
v
for
k
,
v
in
env
.
items
()
if
isinstance
(
v
,
str
)}
self
.
_env
.
update
(
env
)
kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
isinstance
(
v
,
str
)}
self
.
_env
.
update
(
kwargs
)
def
_get_fd
(
self
,
pth
):
if
not
pth
:
return
None
try
:
d
=
os
.
path
.
dirname
(
pth
)
if
not
os
.
path
.
isdir
(
d
):
os
.
makedirs
(
d
,
exist_ok
=
True
)
return
open
(
pth
,
'w'
)
except
:
return
None
def
start
(
self
,
timeout
=-
1
):
end
=
time
.
time
()
+
timeout
if
self
.
_proc
and
self
.
_proc
.
alive
():
return
True
self
.
_stdout
=
self
.
_get_fd
(
self
.
_out
)
or
sys
.
stdout
if
self
.
_out
==
self
.
_err
:
self
.
_stderr
=
self
.
_stdout
elif
self
.
_err
:
self
.
_stderr
=
self
.
_get_fd
(
self
.
_err
)
or
sys
.
stderr
self
.
_proc
=
ProcessContext
(
self
.
_entrypoint
,
env
=
self
.
_env
,
out
=
self
.
_stdout
,
err
=
self
.
_stderr
)
self
.
_proc
.
start
()
while
timeout
>
0
and
time
.
time
()
<
end
:
if
self
.
_proc
.
alive
():
time
.
sleep
(
0.1
)
continue
if
self
.
_proc
.
exit_code
()
==
0
:
return
True
return
False
def
terminate
(
self
,
force
=
False
):
if
self
.
_log_handler
:
self
.
_log_handler
.
close
()
self
.
_log_handler
=
None
if
self
.
_proc
and
self
.
_proc
.
alive
():
return
self
.
_proc
.
terminate
(
force
)
def
wait
(
self
,
timeout
=
None
):
self
.
_proc
.
wait
(
timeout
)
def
exit_code
(
self
):
return
self
.
_proc
.
exit_code
()
if
self
.
_proc
else
-
1
def
status
(
self
):
if
not
self
.
_proc
:
return
Status
.
UNINIT
if
self
.
_proc
.
alive
():
return
Status
.
RUNNING
elif
self
.
_proc
.
exit_code
()
==
0
:
return
Status
.
COMPLETED
else
:
return
Status
.
FAILED
def
__str__
(
self
):
return
'Container rank {} status {} cmd {} code {} log {}
\n
env {}'
.
format
(
self
.
_rank
,
self
.
status
(),
self
.
_entrypoint
,
self
.
exit_code
(),
self
.
errfile
,
self
.
_env
,
)
def
logs
(
self
,
fn
=
None
,
offset
=
0
,
whence
=
1
,
lines
=
1000
):
if
not
self
.
_log_handler
:
self
.
_log_handler
=
open
(
self
.
_out
)
if
fn
is
None
:
fn
=
sys
.
stdout
self
.
_log_handler
.
seek
(
offset
,
whence
)
try
:
idx
=
0
for
line
in
self
.
_log_handler
:
fn
.
write
(
line
)
idx
+=
1
if
idx
>
lines
:
break
finally
:
return
self
.
_log_handler
.
tell
()
def
tail
(
self
,
length
=
3000
):
if
not
self
.
_log_handler
:
self
.
_log_handler
=
open
(
self
.
_out
)
self
.
_log_handler
.
seek
(
0
,
2
)
ed
=
self
.
_log_handler
.
tell
()
if
ed
>
length
:
self
.
logs
(
offset
=
ed
-
length
,
whence
=
0
)
else
:
self
.
logs
(
offset
=
0
,
whence
=
0
)
python/paddle/distributed/run/job/job.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
JobMode
:
COLLECTIVE
=
'collective'
PS
=
'ps'
HETER
=
'heter'
class
Job
(
object
):
def
__init__
(
self
,
id
=
'default'
,
mode
=
JobMode
.
COLLECTIVE
,
np
=
"1"
):
self
.
_mode
=
mode
self
.
_id
=
id
self
.
_replicas
=
0
self
.
_replicas_min
=
self
.
_replicas
self
.
_replicas_max
=
self
.
_replicas
self
.
_elastic
=
False
self
.
set_replicas
(
str
(
np
))
def
__str__
(
self
):
return
"Job: {}, mode {}, replicas {}[{}:{}], elastic {}"
.
format
(
self
.
id
,
self
.
mode
,
self
.
_replicas
,
self
.
_replicas_min
,
self
.
_replicas_max
,
self
.
elastic
)
@
property
def
mode
(
self
):
return
self
.
_mode
@
property
def
id
(
self
):
return
self
.
_id
@
property
def
elastic
(
self
):
return
self
.
_elastic
@
property
def
replicas
(
self
):
return
self
.
_replicas
@
property
def
replicas_min
(
self
):
return
self
.
_replicas_min
@
property
def
replicas_max
(
self
):
return
self
.
_replicas_max
@
replicas
.
setter
def
replicas
(
self
,
replicas
):
self
.
_replicas
=
replicas
def
set_replicas
(
self
,
np
:
str
):
np
=
str
(
np
)
if
np
else
'1'
if
':'
in
np
:
nps
=
np
.
split
(
':'
)
self
.
_replicas_min
,
self
.
_replicas_max
=
int
(
nps
[
0
]),
int
(
nps
[
1
])
self
.
_replicas
=
self
.
_replicas_max
# default to max
self
.
_elastic
=
True
else
:
self
.
_replicas
=
int
(
np
)
self
.
_replicas_min
,
self
.
_replicas_max
=
self
.
_replicas
,
self
.
_replicas
self
.
_elastic
=
False
python/paddle/distributed/run/job/pod.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
OrderedDict
from
.container
import
Container
from
.status
import
Status
import
random
import
time
class
PodSepc
(
object
):
def
__init__
(
self
):
self
.
_name
=
''
.
join
(
random
.
choice
(
'abcdefghijklmnopqrstuvwxyz'
)
for
_
in
range
(
6
))
# by controller
self
.
_init_containers
:
List
[
Container
]
=
[]
self
.
_containers
:
List
[
Container
]
=
[]
#self.resource: Resource = None
#self.status: Status = None
self
.
_rank
=
-
1
self
.
_init_timeout
=
120
# 2 min timeout for each init container
self
.
_restart
=
-
1
self
.
_replicas
=
0
# number of containers
self
.
_exit_code
=
0
class
Pod
(
PodSepc
):
def
__init__
(
self
):
super
().
__init__
()
def
__str__
(
self
):
return
"Pod: {}, replicas {}, status {}"
.
format
(
self
.
name
,
self
.
replicas
,
self
.
status
())
def
failed_container
(
self
):
for
c
in
self
.
_containers
:
if
c
.
status
()
==
Status
.
FAILED
:
return
c
return
None
@
property
def
name
(
self
):
return
self
.
_name
@
property
def
replicas
(
self
):
return
self
.
_replicas
@
replicas
.
setter
def
replicas
(
self
,
r
):
self
.
_replicas
=
r
@
property
def
rank
(
self
):
return
self
.
_rank
@
rank
.
setter
def
rank
(
self
,
r
):
self
.
_rank
=
r
@
property
def
restart
(
self
):
return
self
.
_restart
@
property
def
containers
(
self
):
return
self
.
_containers
def
add_container
(
self
,
c
):
c
.
rank
=
len
(
self
.
_containers
)
self
.
_containers
.
append
(
c
)
@
property
def
init_containers
(
self
):
return
self
.
_init_containers
def
add_init_container
(
self
,
c
):
c
.
rank
=
len
(
self
.
_init_containers
)
self
.
_init_containers
.
append
(
c
)
@
property
def
exit_code
(
self
):
for
c
in
self
.
_containers
:
if
c
.
exit_code
()
!=
0
:
return
c
.
exit_code
()
return
0
def
deploy
(
self
):
for
i
in
self
.
_init_containers
:
i
.
start
(
self
.
_init_timeout
)
for
c
in
self
.
_containers
:
c
.
start
()
self
.
_restart
+=
1
def
stop
(
self
,
sigint
=
0
):
for
c
in
self
.
_containers
:
force
=
True
if
sigint
==
9
else
False
c
.
terminate
(
force
)
def
join
(
self
):
for
c
in
self
.
_containers
:
c
.
wait
(
None
)
def
status
(
self
):
if
self
.
is_failed
():
return
Status
.
FAILED
if
self
.
is_completed
():
return
Status
.
COMPLETED
return
Status
.
READY
def
reset
(
self
):
self
.
_init_containers
=
[]
self
.
_containers
=
[]
def
is_failed
(
self
):
for
c
in
self
.
_containers
:
if
c
.
status
()
==
Status
.
FAILED
:
return
True
return
False
def
is_completed
(
self
):
for
c
in
self
.
_containers
:
if
c
.
status
()
!=
Status
.
COMPLETED
:
return
False
return
True
def
logs
(
self
,
idx
=
None
):
if
idx
is
None
:
if
self
.
failed_container
():
self
.
failed_container
().
logs
()
else
:
self
.
_containers
[
0
].
logs
()
else
:
self
.
_containers
[
idx
].
logs
()
def
tail
(
self
,
idx
=
None
):
if
idx
is
None
:
if
self
.
failed_container
():
self
.
failed_container
().
tail
()
else
:
self
.
_containers
[
0
].
tail
()
else
:
self
.
_containers
[
idx
].
tail
()
def
watch
(
self
,
all_list
=
[
Status
.
COMPLETED
],
any_list
=
[
Status
.
FAILED
],
interval
=
1
,
timeout
=-
1
):
'''
watch return if any container status in any_list
or all container status in all_list
'''
end
=
time
.
time
()
+
timeout
while
timeout
<
0
or
time
.
time
()
<
end
:
for
c
in
self
.
_containers
:
if
c
.
status
()
in
any_list
:
return
c
.
status
()
s
=
[
c
.
status
()
for
c
in
self
.
_containers
]
if
len
(
set
(
s
))
==
1
and
s
[
0
]
in
all_list
:
return
s
[
0
]
time
.
sleep
(
interval
)
python/paddle/distributed/run/job/status.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
Status
(
object
):
UNINIT
=
"uninit"
READY
=
"ready"
RUNNING
=
"running"
FAILED
=
"failed"
TERMINATING
=
"terminating"
RESTARTING
=
"restarting"
UNKNOWN
=
"unknown"
COMPLETED
=
"completed"
python/paddle/distributed/run/plugins/__init__.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
six
__all__
=
[]
def
log
(
ctx
):
ctx
.
logger
.
info
(
"----------- Configuration ----------------------"
)
for
arg
,
value
in
sorted
(
six
.
iteritems
(
vars
(
ctx
.
args
))):
ctx
.
logger
.
info
(
"%s: %s"
%
(
arg
,
value
))
ctx
.
logger
.
info
(
"--------------------------------------------------"
)
def
process_args
(
ctx
):
# reset device by args
#argdev = ctx.args.gpus or ctx.args.xpus or ctx.args.npus
argdev
=
ctx
.
args
.
devices
if
argdev
:
ctx
.
node
.
device
.
labels
=
argdev
.
split
(
','
)
ctx
.
node
.
device
.
count
=
len
(
ctx
.
node
.
device
.
labels
)
ctx
.
logger
.
debug
(
'Device reset by args {}'
.
format
(
argdev
))
def
collective_compatible
(
ctx
):
if
'PADDLE_TRAINER_ENDPOINTS'
in
ctx
.
envs
:
ctx
.
master
=
ctx
.
envs
[
'PADDLE_TRAINER_ENDPOINTS'
].
split
(
','
)[
0
]
if
'DISTRIBUTED_TRAINER_ENDPOINTS'
in
ctx
.
envs
:
ctx
.
master
=
ctx
.
envs
[
'DISTRIBUTED_TRAINER_ENDPOINTS'
].
split
(
','
)[
0
]
def
rewrite_host_ip
(
ctx
):
if
ctx
.
args
.
host
is
not
None
and
"."
in
ctx
.
args
.
host
:
ctx
.
logger
.
warning
(
'Host ip reset to {}'
.
format
(
ctx
.
args
.
host
))
ctx
.
node
.
ip
=
ctx
.
args
.
host
enabled_plugins
=
[
collective_compatible
,
rewrite_host_ip
,
process_args
,
log
]
python/paddle/distributed/run/plugins/ip.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
socket
def
get_local_ip
(
ctx
):
_
,
ip
=
_get_host_name_ip
()
ctx
.
args
.
host
=
ip
ctx
.
envs
[
"POD_IP"
]
=
ip
def
_get_host_name_ip
():
try
:
host_name
=
socket
.
gethostname
()
host_ip
=
socket
.
gethostbyname
(
host_name
)
return
host_name
,
host_ip
except
:
return
None
python/paddle/distributed/run/utils/kv_client.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
requests
import
time
class
KVClient
(
object
):
def
__init__
(
self
,
endpoint
=
'localhost:2379'
):
self
.
endpoint
=
endpoint
if
endpoint
.
startswith
(
"http://"
)
else
"http://{}"
.
format
(
endpoint
)
def
put
(
self
,
key
,
value
):
key
=
key
if
key
.
startswith
(
'/'
)
else
"/{}"
.
format
(
key
)
u
=
"{}{}"
.
format
(
self
.
endpoint
,
key
)
try
:
r
=
requests
.
post
(
u
,
data
=
value
,
timeout
=
3
)
if
r
.
status_code
==
200
:
return
True
else
:
return
False
except
:
return
False
def
get
(
self
,
key
):
key
=
key
if
key
.
startswith
(
'/'
)
else
"/{}"
.
format
(
key
)
u
=
"{}{}"
.
format
(
self
.
endpoint
,
key
)
try
:
r
=
requests
.
get
(
u
,
timeout
=
3
)
if
r
.
status_code
==
200
:
ret
=
r
.
json
()
return
ret
.
get
(
key
,
''
)
else
:
return
"error"
except
:
return
""
def
get_prefix
(
self
,
key
):
key
=
key
if
key
.
startswith
(
'/'
)
else
"/{}"
.
format
(
key
)
u
=
"{}{}"
.
format
(
self
.
endpoint
,
key
)
try
:
r
=
requests
.
get
(
u
,
timeout
=
3
)
if
r
.
status_code
==
200
:
return
r
.
json
()
except
:
return
""
def
delete
(
self
,
key
):
key
=
key
if
key
.
startswith
(
'/'
)
else
"/{}"
.
format
(
key
)
u
=
"{}{}"
.
format
(
self
.
endpoint
,
key
)
try
:
r
=
requests
.
delete
(
u
,
timeout
=
3
)
if
r
.
status_code
==
200
:
return
True
else
:
return
False
except
:
return
False
def
wait_server_ready
(
self
,
timeout
=
3
):
end
=
time
.
time
()
+
timeout
while
time
.
time
()
<
end
:
if
self
.
get
(
"/healthy"
)
==
"ok"
:
return
True
if
__name__
==
'__main__'
:
cli
=
PKVClient
(
"http://localhost:8090"
)
data
=
{
"/workers/1"
:
"rank1"
,
"/workers/2"
:
"rank2"
}
for
k
,
v
in
data
.
items
():
cli
.
put
(
k
,
v
)
x
=
cli
.
get_prefix
(
"/workers"
)
print
(
x
)
for
k
,
v
in
data
.
items
():
assert
x
[
k
]
==
v
cli
.
put
(
"key"
,
"value"
)
print
(
cli
.
get
(
"key"
))
assert
cli
.
get
(
"key"
)
==
"value"
cli
.
delete
(
"key"
)
print
(
cli
.
get
(
"/key"
))
print
(
cli
.
get
(
"/healthy"
))
assert
cli
.
get
(
"/healthy"
)
==
"ok"
python/paddle/distributed/run/utils/kv_server.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
http.server
import
HTTPServer
import
http.server
as
SimpleHTTPServer
from
multiprocessing
import
Process
import
threading
import
json
class
KVHandler
(
SimpleHTTPServer
.
SimpleHTTPRequestHandler
):
def
do_GET
(
self
):
with
self
.
server
.
kv_lock
:
ret
=
{}
for
k
,
v
in
self
.
server
.
kv
.
items
():
if
k
.
startswith
(
self
.
path
):
ret
[
k
]
=
v
.
decode
(
encoding
=
"utf-8"
)
if
ret
:
self
.
output
(
200
,
json
.
dumps
(
ret
).
encode
(
"utf-8"
))
else
:
self
.
output
(
404
)
def
do_PUT
(
self
):
self
.
do_POST
()
def
do_POST
(
self
):
content_length
=
int
(
self
.
headers
[
'Content-Length'
]
or
0
)
try
:
value
=
self
.
rfile
.
read
(
content_length
)
with
self
.
server
.
kv_lock
:
self
.
server
.
kv
[
self
.
path
]
=
value
self
.
output
(
200
)
return
except
:
self
.
output
(
500
)
def
do_DELETE
(
self
):
with
self
.
server
.
kv_lock
:
if
self
.
path
in
self
.
server
.
kv
:
del
self
.
server
.
kv
[
self
.
path
]
self
.
output
(
200
)
else
:
self
.
output
(
404
)
def
output
(
self
,
code
,
value
=
''
):
self
.
send_response
(
code
)
self
.
send_header
(
"Content-Length"
,
len
(
value
))
self
.
send_header
(
"Content-Type"
,
"application/json; charset=utf8"
)
self
.
end_headers
()
if
value
:
self
.
wfile
.
write
(
value
)
def
log_message
(
self
,
format
,
*
args
):
return
class
KVServer
(
HTTPServer
,
object
):
def
__init__
(
self
,
port
):
super
(
KVServer
,
self
).
__init__
((
''
,
port
),
KVHandler
)
self
.
kv_lock
=
threading
.
Lock
()
self
.
kv
=
{
'/healthy'
:
b
'ok'
}
self
.
port
=
port
self
.
stopped
=
False
self
.
started
=
False
def
start
(
self
):
self
.
listen_thread
=
threading
.
Thread
(
target
=
self
.
serve_forever
)
self
.
listen_thread
.
start
()
self
.
started
=
True
def
stop
(
self
):
self
.
shutdown
()
self
.
listen_thread
.
join
()
self
.
server_close
()
self
.
stopped
=
True
class
PKVServer
():
def
__init__
(
self
,
port
):
self
.
_server
=
KVServer
(
port
)
def
start
(
self
):
self
.
proc
=
Process
(
target
=
self
.
_server
.
start
)
self
.
proc
.
daemon
=
True
self
.
proc
.
start
()
def
stop
(
self
):
self
.
_server
.
stop
()
self
.
proc
.
join
()
@
property
def
started
(
self
):
return
self
.
_server
.
started
@
property
def
stopped
(
self
):
return
self
.
_server
.
stopped
if
__name__
==
'__main__'
:
#kv = PKVServer(8090)
kv
=
KVServer
(
8090
)
kv
.
start
()
import
time
#print("serve at 8090 for 600 s")
time
.
sleep
(
600
)
python/paddle/distributed/run/utils/process_context.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
subprocess
import
os
,
sys
,
signal
,
time
class
ProcessContext
(
object
):
def
__init__
(
self
,
cmd
,
env
=
os
.
environ
,
out
=
sys
.
stdout
,
err
=
sys
.
stderr
,
group
=
True
,
preexec_fn
=
None
):
self
.
_cmd
=
cmd
self
.
_env
=
env
self
.
_preexec_fn
=
preexec_fn
self
.
_stdout
=
out
self
.
_stderr
=
err
self
.
_group
=
group
if
os
.
name
!=
'nt'
else
False
self
.
_proc
=
None
self
.
_code
=
None
def
_start
(
self
):
pre_fn
=
os
.
setsid
if
self
.
_group
else
None
self
.
_proc
=
subprocess
.
Popen
(
self
.
_cmd
,
env
=
self
.
_env
,
stdout
=
self
.
_stdout
,
stderr
=
self
.
_stderr
,
preexec_fn
=
self
.
_preexec_fn
or
pre_fn
)
def
_close_std
(
self
):
try
:
if
not
self
.
_stdout
.
isatty
():
self
.
_stdout
.
close
()
if
not
self
.
_stderr
.
isatty
():
self
.
_stderr
.
close
()
except
:
pass
def
alive
(
self
):
return
self
.
_proc
and
self
.
_proc
.
poll
()
is
None
def
exit_code
(
self
):
return
self
.
_proc
.
poll
()
if
self
.
_proc
else
None
def
start
(
self
):
self
.
_start
()
def
terminate
(
self
,
force
=
False
,
max_retry
=
3
):
for
i
in
range
(
max_retry
):
if
self
.
alive
():
if
self
.
_group
:
os
.
killpg
(
os
.
getpgid
(
self
.
_proc
.
pid
),
signal
.
SIGTERM
)
else
:
self
.
_proc
.
terminate
()
time
.
sleep
(
0.2
)
else
:
break
if
force
and
self
.
alive
():
self
.
_proc
.
kill
()
self
.
_close_std
()
return
self
.
alive
()
def
wait
(
self
,
timeout
=
None
):
self
.
_proc
.
wait
(
timeout
)
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
67c6ddff
...
@@ -949,6 +949,7 @@ if (WITH_DISTRIBUTE AND NOT APPLE)
...
@@ -949,6 +949,7 @@ if (WITH_DISTRIBUTE AND NOT APPLE)
endif
()
endif
()
# setting timeout value as 15S
# setting timeout value as 15S
set_tests_properties
(
test_run PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_sync_batch_norm_op PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_sync_batch_norm_op PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_cross_op PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_cross_op PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_imperative_lod_tensor_to_selected_rows PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_imperative_lod_tensor_to_selected_rows PROPERTIES TIMEOUT 200
)
...
...
python/paddle/fluid/tests/unittests/test_run.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
subprocess
import
sys
,
os
import
json
import
shutil
import
random
from
os
import
listdir
from
os.path
import
isfile
,
join
pyname
=
'train.py'
colpyfile
=
'''# train.py for unitest
import os
env = os.environ.copy()
assert "PADDLE_MASTER" in env
assert "PADDLE_GLOBAL_SIZE" in env
assert "PADDLE_LOCAL_SIZE" in env
assert "PADDLE_GLOBAL_RANK" in env
assert "PADDLE_LOCAL_RANK" in env
'''
pspyfile
=
'''# train.py for unitest
import os
env = os.environ.copy()
assert "PADDLE_PSERVERS_IP_PORT_LIST" in env
assert "PADDLE_TRAINER_ENDPOINTS" in env
#assert "PADDLE_PSERVER_ENDPOINTS" in env
#assert "PADDLE_TRAINER_ENDPOINTS" in env
#assert "PADDLE_ROLE" in env
#assert "PADDLE_RANK" in env
'''
def
write_file
(
name
,
ct
):
with
open
(
name
,
"w"
)
as
f
:
f
.
write
(
ct
)
def
get_files
(
pth
,
prefix
):
return
[
f
for
f
in
listdir
(
pth
)
if
isfile
(
join
(
pth
,
f
))
and
f
.
startswith
(
prefix
)
]
class
Collective_Test
(
unittest
.
TestCase
):
def
setUp
(
self
):
write_file
(
pyname
,
colpyfile
)
def
pdrun
(
self
,
args
,
env
=
None
):
cmd
=
[
sys
.
executable
.
split
(
'/'
)[
-
1
],
"-m"
,
"paddle.distributed.run"
]
if
args
:
cmd
.
extend
(
args
.
split
(
" "
))
cmd
.
extend
([
pyname
])
proc
=
subprocess
.
Popen
(
cmd
,
env
)
return
proc
'''
def test_collective_1(self):
args = "--id test1"
p = self.pdrun(args)
p.wait()
self.assertTrue(p.poll() == 0)
'''
def
test_collective_2
(
self
):
if
os
.
path
.
exists
(
'./log'
):
shutil
.
rmtree
(
'./log'
)
args
=
"--id test2 --devices 0,1,2"
p
=
self
.
pdrun
(
args
)
p
.
wait
()
self
.
assertTrue
(
p
.
poll
()
==
0
)
c
=
get_files
(
'log'
,
'test2'
)
self
.
assertTrue
(
len
(
c
)
==
4
)
def
test_collective_3
(
self
):
if
os
.
path
.
exists
(
'./log'
):
shutil
.
rmtree
(
'./log'
)
port
=
random
.
randrange
(
6000
,
8000
)
args
=
"--id test3 --devices 0,1 --master 127.0.0.1:{} --np 2"
.
format
(
port
)
p1
=
self
.
pdrun
(
args
)
p2
=
self
.
pdrun
(
args
)
p1
.
wait
()
p2
.
wait
()
self
.
assertTrue
(
p1
.
poll
()
==
0
)
self
.
assertTrue
(
p2
.
poll
()
==
0
)
c
=
get_files
(
'log'
,
'test3'
)
self
.
assertTrue
(
len
(
c
)
==
6
)
class
PS_Test
(
unittest
.
TestCase
):
def
setUp
(
self
):
write_file
(
pyname
,
pspyfile
)
def
pdrun
(
self
,
args
,
env
=
None
):
cmd
=
[
sys
.
executable
.
split
(
'/'
)[
-
1
],
"-m"
,
"paddle.distributed.run"
]
if
args
:
cmd
.
extend
(
args
.
split
(
" "
))
cmd
.
extend
([
pyname
])
proc
=
subprocess
.
Popen
(
cmd
,
env
)
return
proc
'''
def test_ps_1(self):
args = "--mode ps"
p = self.pdrun(args)
p.wait()
self.assertTrue(p.poll() == 0)
def test_ps_2(self):
if os.path.exists('./log'):
shutil.rmtree('./log')
args = "--id ps2 --server_num=2 --trainer_num=2"
p = self.pdrun(args)
p.wait()
self.assertTrue(p.poll() == 0)
c = get_files('log', 'ps2')
self.assertTrue(len(c) == 5)
'''
def
test_ps_3
(
self
):
if
os
.
path
.
exists
(
'./log'
):
shutil
.
rmtree
(
'./log'
)
port
=
random
.
randrange
(
6000
,
8000
)
args
=
"--id ps3 --master 127.0.0.1:{} --np 2 --server_num=1 --trainer_num=1"
.
format
(
port
)
p1
=
self
.
pdrun
(
args
)
p2
=
self
.
pdrun
(
args
)
p1
.
wait
()
p2
.
wait
()
self
.
assertTrue
(
p1
.
poll
()
==
0
)
self
.
assertTrue
(
p2
.
poll
()
==
0
)
c
=
get_files
(
'log'
,
'ps3'
)
self
.
assertTrue
(
len
(
c
)
==
6
)
def
test_ps_4
(
self
):
if
os
.
path
.
exists
(
'./log'
):
shutil
.
rmtree
(
'./log'
)
args
=
"--id ps4 --servers 127.0.0.1:8900,127.0.0.1:8901 --trainers 127.0.0.1:8902,127.0.0.1:8903"
p1
=
self
.
pdrun
(
args
)
p1
.
wait
()
self
.
assertTrue
(
p1
.
poll
()
==
0
)
c
=
get_files
(
'log'
,
'ps4'
)
self
.
assertTrue
(
len
(
c
)
==
5
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录