Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
67c6ddff
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
67c6ddff
编写于
3月 15, 2022
作者:
K
kuizhiqing
提交者:
GitHub
3月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
New design for launch/run (#40086)
上级
464f65b1
变更
25
隐藏空白更改
内联
并排
Showing
25 changed file
with
2546 addition
and
0 deletion
+2546
-0
python/paddle/distributed/run/__init__.py
python/paddle/distributed/run/__init__.py
+86
-0
python/paddle/distributed/run/__main__.py
python/paddle/distributed/run/__main__.py
+28
-0
python/paddle/distributed/run/context/__init__.py
python/paddle/distributed/run/context/__init__.py
+219
-0
python/paddle/distributed/run/context/device.py
python/paddle/distributed/run/context/device.py
+88
-0
python/paddle/distributed/run/context/event.py
python/paddle/distributed/run/context/event.py
+20
-0
python/paddle/distributed/run/context/node.py
python/paddle/distributed/run/context/node.py
+64
-0
python/paddle/distributed/run/context/resource.py
python/paddle/distributed/run/context/resource.py
+18
-0
python/paddle/distributed/run/context/status.py
python/paddle/distributed/run/context/status.py
+58
-0
python/paddle/distributed/run/controllers/__init__.py
python/paddle/distributed/run/controllers/__init__.py
+32
-0
python/paddle/distributed/run/controllers/collective.py
python/paddle/distributed/run/controllers/collective.py
+185
-0
python/paddle/distributed/run/controllers/controller.py
python/paddle/distributed/run/controllers/controller.py
+192
-0
python/paddle/distributed/run/controllers/master.py
python/paddle/distributed/run/controllers/master.py
+289
-0
python/paddle/distributed/run/controllers/ps.py
python/paddle/distributed/run/controllers/ps.py
+221
-0
python/paddle/distributed/run/job/__init__.py
python/paddle/distributed/run/job/__init__.py
+25
-0
python/paddle/distributed/run/job/container.py
python/paddle/distributed/run/job/container.py
+179
-0
python/paddle/distributed/run/job/job.py
python/paddle/distributed/run/job/job.py
+80
-0
python/paddle/distributed/run/job/pod.py
python/paddle/distributed/run/job/pod.py
+185
-0
python/paddle/distributed/run/job/status.py
python/paddle/distributed/run/job/status.py
+24
-0
python/paddle/distributed/run/plugins/__init__.py
python/paddle/distributed/run/plugins/__init__.py
+50
-0
python/paddle/distributed/run/plugins/ip.py
python/paddle/distributed/run/plugins/ip.py
+30
-0
python/paddle/distributed/run/utils/kv_client.py
python/paddle/distributed/run/utils/kv_client.py
+94
-0
python/paddle/distributed/run/utils/kv_server.py
python/paddle/distributed/run/utils/kv_server.py
+121
-0
python/paddle/distributed/run/utils/process_context.py
python/paddle/distributed/run/utils/process_context.py
+83
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/test_run.py
python/paddle/fluid/tests/unittests/test_run.py
+174
-0
未找到文件。
python/paddle/distributed/run/__init__.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.job.container
import
Container
from
.job.pod
import
Pod
from
.job.job
import
Job
from
.
import
plugins
#__all__ = [Container, Pod, Job]
'''
Paddle distribution training entry ``python -m paddle.distributed.run``.
Help
# for arg usage and explanation, try the following command
# python -m paddle.distributed.run -h
Collective Mode
Case 1: 1 node
use all visible devices
# python -m paddle.distributed.run train.py
use specified devices
# python -m paddle.distributed.run --devices=0,1,2,3 train.py
Case 2: multi-node, auto detect ip/port
# python -m paddle.distributed.run --np 2 train.py
# auto print following command
# python -m paddle.distributed.run --master 10.0.0.1:13538 --np 2 demo.py
# then copy and paste above command to other nodes
Case 3: multi-node, specified master/rendezvous server
# python -m paddle.distributed.run --np 2 --master 10.0.0.1:2379 train.py
# the master ip must be one of the node and the port must available
Parameter Server Mode
Case 1.1: 1 node, 1 ps, 1 trainer
# python -m paddle.distributed.run --mode ps train.py
# python -m paddle.distributed.run --server_num=1 --trainer_num=1 train.py
Case 1.2: 1 node, 2 ps, 2 trainer
# python -m paddle.distributed.run --server_num=2 --trainer_num=2 train.py
Case 2: 2 node, 2 ps, 2 trainer per node
# python -m paddle.distributed.run --server_num=2 --trainer_num=2 --np 2 train.py
# auto print following command
# python -m paddle.distributed.run --master 10.0.0.1:13538 --server_num=2 --trainer_num=2 --np 2 train.py
# then copy and paste above command to other nodes
Case 3: multi-node, specified master/rendezvous server
# python -m paddle.distributed.run --master 10.0.0.1:13538 --server_num=2 --trainer_num=2 --np 2 train.py
# the master ip must be one of the node and the port must available
Case 4: specified servers and trainers in each node
python -m paddle.distributed.run --servers 127.0.0.1:8900,127.0.0.1:8901 --trainers 127.0.0.1:8902,127.0.0.1:8903 train.py
Elastic Mode
# run following command in 3 node to run immediately, or in 2 node to run after elastic_timeout
# python -m paddle.distributed.run --master etcd://10.0.0.1:2379 --np 2:3 train.py
# once the peer number changes between 2:3, the strategy holds
'''
python/paddle/distributed/run/__main__.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.context
import
Context
from
.
import
controllers
# initialize the context to run
ctx
=
Context
()
# initialize the selected controller
c
=
controllers
.
init
(
ctx
)
# run the pods
c
.
run
()
# manager or just wait pod
c
.
finalize
()
python/paddle/distributed/run/context/__init__.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
argparse
import
ArgumentParser
,
REMAINDER
import
os
,
copy
from
paddle.distributed.run
import
plugins
from
.node
import
Node
from
.status
import
Status
import
logging
class
Context
(
object
):
def
__init__
(
self
,
enable_plugin
=
True
):
os
.
environ
.
pop
(
'http_proxy'
,
None
)
os
.
environ
.
pop
(
'https_proxy'
,
None
)
self
.
args
=
self
.
parse_args
()
self
.
envs
=
self
.
fetch_envs
()
self
.
logger
=
self
.
get_logger
()
self
.
node
=
Node
()
self
.
status
=
Status
()
self
.
set_env_in_args
()
# design for event queue, later
self
.
events
=
[]
if
enable_plugin
:
self
.
_enable_plugin
()
def
get_envs
(
self
):
return
self
.
envs
.
copy
()
def
_enable_plugin
(
self
):
for
pl
in
plugins
.
enabled_plugins
:
pl
(
self
)
def
parse_args
(
self
):
parser
=
ArgumentParser
()
base_group
=
parser
.
add_argument_group
(
"Base Parameters"
)
base_group
.
add_argument
(
"--master"
,
type
=
str
,
default
=
None
,
help
=
"the master/rendezvous server, ip:port"
)
base_group
.
add_argument
(
"--rank"
,
type
=
int
,
default
=-
1
,
help
=
"the peer rank"
)
base_group
.
add_argument
(
"--log"
,
type
=
str
,
default
=
"INFO"
,
help
=
"log level. Default INFO"
)
base_group
.
add_argument
(
"--np"
,
type
=
str
,
default
=
"1"
,
help
=
"the number of peers, i.e. pod/node number"
)
base_group
.
add_argument
(
"--nproc_per_node"
,
type
=
int
,
default
=
None
,
help
=
"the number of processes in a pod"
)
base_group
.
add_argument
(
"--log_dir"
,
type
=
str
,
default
=
"log"
,
help
=
"the path for each process's log. Default ./log"
)
base_group
.
add_argument
(
"--mode"
,
type
=
str
,
default
=
"collective"
,
help
=
"run mode of the job, collective/ps/ps-heter"
)
base_group
.
add_argument
(
"--id"
,
type
=
str
,
default
=
"default"
,
help
=
"unique id of the job. Default default"
)
base_group
.
add_argument
(
"--devices"
,
type
=
str
,
default
=
None
,
help
=
"accelerate devices. as --gpus,npus,xps"
)
base_group
.
add_argument
(
"--host"
,
type
=
str
,
default
=
None
,
help
=
"host ip"
)
base_group
.
add_argument
(
"training_script"
,
type
=
str
,
help
=
"the full path of py script,"
"followed by arguments for the "
"training script"
)
base_group
.
add_argument
(
'training_script_args'
,
nargs
=
REMAINDER
)
ps_group
=
parser
.
add_argument_group
(
"Parameter-Server Parameters"
)
# for parameter server
ps_group
.
add_argument
(
"--servers"
,
type
=
str
,
default
=
''
,
help
=
"servers endpoints full list"
)
ps_group
.
add_argument
(
"--trainers"
,
type
=
str
,
default
=
''
,
help
=
"trainers endpoints full list"
)
ps_group
.
add_argument
(
"--trainer_num"
,
type
=
int
,
default
=
None
,
help
=
"number of trainers"
)
ps_group
.
add_argument
(
"--server_num"
,
type
=
int
,
default
=
None
,
help
=
"number of servers"
)
ps_group
.
add_argument
(
"--gloo_port"
,
type
=
int
,
default
=
6767
,
help
=
"gloo http port"
)
ps_group
.
add_argument
(
"--with_gloo"
,
type
=
str
,
default
=
"0"
,
help
=
"use gloo or not"
)
# parameter elastic mode
elastic_group
=
parser
.
add_argument_group
(
"Elastic Parameters"
)
elastic_group
.
add_argument
(
"--max_restart"
,
type
=
int
,
default
=
3
,
help
=
"the times can restart. Default 3"
)
elastic_group
.
add_argument
(
"--elastic_level"
,
type
=
int
,
default
=-
1
,
help
=
"elastic level: -1 disable, 0 failed exit, peers hold, 1 internal restart"
)
elastic_group
.
add_argument
(
"--elastic_timeout"
,
type
=
int
,
default
=
30
,
help
=
"seconds to wait before elastic perform training"
)
return
parser
.
parse_args
()
def
_valide_env
(
self
,
key
):
if
key
in
[
'POD_IP'
]:
return
True
if
key
.
endswith
(
'_VISIBLE_DEVICES'
):
return
True
if
key
.
startswith
(
'PADDLE_'
):
return
True
return
False
def
fetch_envs
(
self
):
ge
=
os
.
environ
.
copy
()
black_env_list
=
[
'http_proxy'
,
'https_proxy'
]
for
key
in
black_env_list
:
ge
.
pop
(
key
,
None
)
return
ge
'''
# use black list instead white list
return {k: ge[k] for k in ge if self._valide_env(k)}
'''
def
get_logger
(
self
,
level
=
logging
.
INFO
):
logger
=
logging
.
getLogger
(
"PADDLERUN"
)
logger
.
setLevel
(
self
.
args
.
log
.
upper
()
or
level
)
formatter
=
logging
.
Formatter
(
fmt
=
'%(name)s %(levelname)s %(asctime)s %(message)s'
)
ch
=
logging
.
StreamHandler
()
ch
.
setFormatter
(
formatter
)
logger
.
addHandler
(
ch
)
return
logger
def
set_env_in_args
(
self
):
env_args
=
{
'POD_IP'
:
'host'
,
'PADDLE_MASTER'
:
'master'
,
'PADDLE_DEVICES'
:
'devices'
,
'PADDLE_NP'
:
'np'
,
'PADDLE_MODE'
:
'mode'
,
'PADDLE_LOG'
:
'log'
,
'PADDLE_NPROC_PER_NODE'
:
'nproc_per_node'
,
'PADDLE_JOB_ID'
:
'id'
,
'PADDLE_RANK'
:
'rank'
,
'PADDLE_LOG_DIR'
:
'log_dir'
,
'PADDLE_MAX_RESTlRT'
:
'max_restart'
,
'PADDLE_ELASTIC_LEVEL'
:
'elastic_level'
,
'PADDLE_ELASTIC_TIMEOUT'
:
'elastic_timeout'
,
'PADDLE_SERVER_NUM'
:
'server_num'
,
'PADDLE_TRAINER_NUM'
:
'trainer_num'
,
'PADDLE_SERVERS_ENDPOINTS'
:
'servers'
,
'PADDLE_TRAINERS_ENDPOINTS'
:
'trainers'
,
'PADDLE_GLOO_PORT'
:
'gloo_port'
,
'PADDLE_WITH_GLOO'
:
'with_gloo'
,
}
for
k
,
v
in
env_args
.
items
():
if
k
in
self
.
envs
:
setattr
(
self
.
args
,
v
,
self
.
envs
[
k
])
python/paddle/distributed/run/context/device.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
class
DeviceType
:
CPU
=
'cpu'
GPU
=
'gpu'
XPU
=
'xpu'
NPU
=
'npu'
class
Device
(
object
):
def
__init__
(
self
,
dtype
=
None
,
count
=
1
,
memory
=
""
,
labels
=
""
):
self
.
dtype
=
dtype
self
.
count
=
count
self
.
memory
=
memory
self
.
labels
=
labels
def
__str__
(
self
):
return
","
.
join
(
self
.
labels
)
@
classmethod
def
parse_device
(
self
):
dev
=
Device
()
visible_devices
=
None
if
'CUDA_VISIBLE_DEVICES'
in
os
.
environ
or
'NVIDIA_VISIBLE_DEVICES'
in
os
.
environ
:
dev
.
dtype
=
DeviceType
.
GPU
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
)
or
os
.
getenv
(
"NVIDIA_VISIBLE_DEVICES"
)
elif
'XPU_VISIBLE_DEVICES'
in
os
.
environ
:
dev
.
dtype
=
DeviceType
.
XPU
visible_devices
=
os
.
getenv
(
"XPU_VISIBLE_DEVICES"
)
elif
'ASCEND_VISIBLE_DEVICES'
in
os
.
environ
:
dev
.
dtype
=
DeviceType
.
NPU
visible_devices
=
os
.
getenv
(
"ASCEND_VISIBLE_DEVICES"
)
if
visible_devices
and
visible_devices
!=
'all'
:
dev
.
labels
=
visible_devices
.
split
(
','
)
dev
.
count
=
len
(
dev
.
labels
)
else
:
return
self
.
detect_device
()
return
dev
@
classmethod
def
detect_device
(
self
):
import
paddle.fluid
as
fluid
dev
=
Device
()
num
=
0
visible_devices
=
None
if
fluid
.
core
.
is_compiled_with_cuda
():
dev
.
dtype
=
DeviceType
.
GPU
num
=
fluid
.
core
.
get_cuda_device_count
()
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
)
or
os
.
getenv
(
"NVIDIA_VISIBLE_DEVICES"
)
elif
fluid
.
core
.
is_compiled_with_xpu
():
dev
.
dtype
=
DeviceType
.
XPU
num
=
fluid
.
core
.
get_xpu_device_count
()
visible_devices
=
os
.
getenv
(
"XPU_VISIBLE_DEVICES"
)
elif
fluid
.
core
.
is_compiled_with_npu
():
dev
.
dtype
=
DeviceType
.
NPU
num
=
fluid
.
core
.
get_npu_device_count
()
visible_devices
=
os
.
getenv
(
"ASCEND_VISIBLE_DEVICES"
)
if
num
==
0
:
dev
.
dtype
=
DeviceType
.
CPU
elif
visible_devices
is
None
or
visible_devices
==
"all"
or
visible_devices
==
""
:
dev
.
labels
=
[
str
(
x
)
for
x
in
range
(
0
,
num
)]
dev
.
count
=
num
else
:
dev
.
labels
=
visible_devices
.
split
(
','
)
dev
.
count
=
len
(
dev
.
labels
)
return
dev
python/paddle/distributed/run/context/event.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
Event
(
object
):
def
__init__
(
self
,
kind
=
"status"
,
message
=
""
,
fatal
=
False
):
self
.
kind
=
kind
self
.
message
=
message
self
.
fatal
=
fatal
python/paddle/distributed/run/context/node.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.device
import
Device
import
socket
import
struct
from
contextlib
import
closing
class
Node
(
object
):
def
__init__
(
self
):
# self.device = Device.detect_device()
self
.
device
=
Device
.
parse_device
()
self
.
ip
=
self
.
get_host_ip
()
self
.
free_ports
=
[]
def
get_host_ip
(
self
):
try
:
self
.
hostname
=
socket
.
gethostname
()
self
.
ip
=
socket
.
gethostbyname
(
socket
.
getfqdn
(
self
.
hostname
))
return
self
.
ip
except
:
return
'127.0.0.1'
def
get_free_ports
(
self
,
n
=
1
):
free_ports
=
[
self
.
get_free_port
()
for
i
in
range
(
n
)]
self
.
free_ports
+=
free_ports
return
free_ports
def
get_ports_occupied
(
self
):
return
self
.
free_ports
@
classmethod
def
get_free_port
(
self
):
with
closing
(
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
))
as
s
:
s
.
setsockopt
(
socket
.
SOL_SOCKET
,
socket
.
SO_LINGER
,
struct
.
pack
(
'ii'
,
1
,
0
))
s
.
bind
((
''
,
0
))
return
s
.
getsockname
()[
1
]
@
classmethod
def
is_server_ready
(
self
,
ip
,
port
):
with
closing
(
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
))
as
sock
:
#sock.settimeout(0.01)
#sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if
hasattr
(
socket
,
'SO_REUSEPORT'
):
sock
.
setsockopt
(
socket
.
SOL_SOCKET
,
socket
.
SO_REUSEPORT
,
1
)
result
=
sock
.
connect_ex
((
ip
,
int
(
port
)))
if
result
==
0
:
return
True
else
:
return
False
python/paddle/distributed/run/context/resource.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
Resource
(
object
):
def
__init__
(
self
):
self
.
devices
=
[]
python/paddle/distributed/run/context/status.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
Status
(
object
):
UNINIT
=
"uninit"
READY
=
"ready"
RUNNING
=
"running"
FAILED
=
"failed"
TERMINATING
=
"terminating"
RESTARTING
=
"restarting"
UNKNOWN
=
"unknown"
COMPLETED
=
"completed"
DONE
=
"done"
# should exit whatever status
def
__init__
(
self
):
self
.
_current_status
=
None
def
current
(
self
):
return
self
.
_current_status
def
is_running
(
self
):
return
self
.
_current_status
==
self
.
RUNNING
def
is_restarting
(
self
):
return
self
.
_current_status
==
self
.
RESTARTING
def
is_done
(
self
):
if
self
.
_current_status
in
[
self
.
DONE
,
self
.
COMPLETED
,
self
.
FAILED
]:
return
True
else
:
return
False
def
run
(
self
):
self
.
_current_status
=
self
.
RUNNING
def
fail
(
self
):
self
.
_current_status
=
self
.
FAILED
def
complete
(
self
):
self
.
_current_status
=
self
.
COMPLETED
def
restart
(
self
):
self
.
_current_status
=
self
.
RESTARTING
def
done
(
self
):
self
.
_current_status
=
self
.
DONE
python/paddle/distributed/run/controllers/__init__.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__
=
[
"init"
]
from
.collective
import
CollectiveController
from
.collective
import
CollectiveElasticController
from
.ps
import
PSController
# the order is extremely important
_controllers
=
[
CollectiveElasticController
,
PSController
,
CollectiveController
,
]
def
init
(
ctx
):
for
c
in
_controllers
:
if
c
.
enable
(
ctx
):
return
c
(
ctx
)
python/paddle/distributed/run/controllers/collective.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.controller
import
Controller
import
json
import
os
import
six
import
time
class
CollectiveController
(
Controller
):
@
classmethod
def
enable
(
cls
,
ctx
):
if
ctx
:
ctx
.
logger
.
debug
(
"{} enabled"
.
format
(
cls
.
__name__
))
return
True
else
:
return
False
def
build_pod
(
self
):
self
.
pod
.
replicas
=
self
.
pod_replicas
()
# rank will be reset when restart
self
.
pod
.
rank
=
self
.
ctx
.
args
.
rank
port
=
self
.
ctx
.
node
.
get_free_port
()
# compatible
endpoints
=
[
"{}:{}"
.
format
(
self
.
ctx
.
node
.
ip
,
p
)
for
p
in
self
.
ctx
.
node
.
get_free_ports
(
self
.
pod
.
replicas
)
]
data
=
json
.
dumps
({
'name'
:
self
.
pod
.
name
,
'rank'
:
self
.
pod
.
rank
,
'replicas'
:
self
.
pod
.
replicas
,
'dtype'
:
self
.
ctx
.
node
.
device
.
dtype
,
'candidate'
:
'{}:{}'
.
format
(
self
.
ctx
.
node
.
ip
,
port
),
'endpoints'
:
","
.
join
(
endpoints
),
})
peer_list
,
rank
=
self
.
master
.
sync_peers
(
'/{}/info'
.
format
(
self
.
job
.
id
),
self
.
pod
.
name
,
data
,
self
.
job
.
replicas
,
self
.
pod
.
rank
)
self
.
pod
.
rank
=
rank
if
len
(
peer_list
)
<
1
:
return
False
peer_list
=
[
json
.
loads
(
i
)
for
i
in
peer_list
]
self
.
ctx
.
logger
.
debug
(
"sync peers done {}"
.
format
(
peer_list
))
self
.
save_pod_log
(
peer_list
)
global_size
=
sum
([
i
[
'replicas'
]
for
i
in
peer_list
])
rank_offset
=
sum
([
i
[
'replicas'
]
for
i
in
peer_list
[:
rank
]])
'''
The new designed collective need nothing but a master endpoint
'''
collective_master
=
peer_list
[
0
][
'candidate'
]
job_endpoints
=
[
i
[
'endpoints'
]
for
i
in
peer_list
]
self
.
pod
.
reset
()
for
i
in
range
(
self
.
pod
.
replicas
):
e
=
{
"PADDLE_MASTER"
:
collective_master
,
"PADDLE_GLOBAL_SIZE"
:
"{}"
.
format
(
global_size
),
"PADDLE_LOCAL_SIZE"
:
"{}"
.
format
(
self
.
pod
.
replicas
),
"PADDLE_GLOBAL_RANK"
:
"{}"
.
format
(
i
+
rank_offset
),
"PADDLE_LOCAL_RANK"
:
"{}"
.
format
(
i
),
## compatible env
"PADDLE_TRAINER_ENDPOINTS"
:
","
.
join
(
job_endpoints
),
"PADDLE_CURRENT_ENDPOINT"
:
endpoints
[
i
],
"PADDLE_TRAINER_ID"
:
"{}"
.
format
(
i
+
rank_offset
),
"PADDLE_TRAINERS_NUM"
:
"{}"
.
format
(
global_size
),
"PADDLE_RANK_IN_NODE"
:
str
(
i
),
}
self
.
add_container
(
envs
=
e
,
log_tag
=
i
)
return
True
class
CollectiveElasticController
(
CollectiveController
):
@
classmethod
def
enable
(
cls
,
ctx
):
if
ctx
.
args
.
master
and
ctx
.
args
.
master
.
startswith
(
"etcd://"
):
ctx
.
logger
.
debug
(
"{} enabled"
.
format
(
cls
.
__name__
))
return
True
else
:
return
False
def
register
(
self
):
if
self
.
job
.
id
==
'default'
:
self
.
ctx
.
logger
.
warning
(
'Using default job name may cause conflict, add --id in args'
)
self
.
master
.
register_heartbeat
(
self
.
job
.
id
,
self
.
pod
.
name
)
def
watch
(
self
)
->
bool
:
'''
watch self and peer status, return true to exit
'''
while
not
self
.
ctx
.
status
.
is_done
():
# self status
status
=
self
.
pod
.
watch
(
timeout
=
2
)
self
.
ctx
.
logger
.
debug
(
"Pod status {}, Ctx status {}"
.
format
(
status
,
self
.
ctx
.
status
.
current
()))
# completed
if
status
==
self
.
ctx
.
status
.
COMPLETED
:
self
.
master
.
set_status
(
status
)
self
.
ctx
.
status
.
complete
()
self
.
ctx
.
logger
.
info
(
"Pod complete {}"
.
format
(
status
))
return
True
# self failure
elif
status
==
self
.
ctx
.
status
.
FAILED
:
self
.
master
.
set_status
(
status
)
self
.
master
.
restart_peer
()
self
.
ctx
.
logger
.
info
(
"Pod failed {}"
.
format
(
status
))
self
.
pod
.
stop
()
if
self
.
ctx
.
args
.
elastic_level
<=
0
:
return
True
else
:
return
False
# peer failure
if
self
.
ctx
.
status
.
is_restarting
()
and
self
.
master
.
get_status
(
)
!=
self
.
ctx
.
status
.
COMPLETED
:
self
.
pod
.
stop
()
return
False
#peers = self.master.fetch_peer_alive()
#print("peers {}".format(peers))
def
run
(
self
):
timeout
=
self
.
ctx
.
args
.
elastic_timeout
if
self
.
job
.
elastic
else
self
.
ctx
.
args
.
elastic_timeout
*
10
self
.
register
()
while
self
.
pod
.
restart
<=
self
.
ctx
.
args
.
max_restart
:
self
.
build_job
()
ok
,
replicas
=
self
.
master
.
wait_peer_ready
(
self
.
job
.
replicas_min
,
self
.
job
.
replicas_max
,
timeout
)
if
ok
:
self
.
job
.
replicas
=
replicas
else
:
self
.
ctx
.
logger
.
warnning
(
"peer not ready {}"
.
format
(
self
.
job
))
break
self
.
ctx
.
logger
.
debug
(
"Run {}"
.
format
(
self
.
job
))
if
not
self
.
build_pod
():
continue
self
.
master
.
set_status
(
self
.
ctx
.
status
.
RUNNING
)
self
.
ctx
.
status
.
run
()
assert
len
(
self
.
pod
.
containers
)
>
0
,
"No container in the pod"
self
.
ctx
.
logger
.
debug
(
"Run {}"
.
format
(
self
.
pod
))
self
.
ctx
.
logger
.
debug
(
"Run {}"
.
format
(
self
.
pod
.
containers
[
0
]))
self
.
pod
.
deploy
()
if
self
.
watch
():
break
self
.
ctx
.
logger
.
debug
(
"Job done {}"
.
format
(
self
.
job
))
python/paddle/distributed/run/controllers/controller.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
os
import
signal
from
paddle.distributed.run.job
import
Job
from
paddle.distributed.run.job
import
Pod
from
paddle.distributed.run.job
import
Container
from
.master
import
Master
import
time
class
ControleMode
:
COLLECTIVE
=
"collective"
PS
=
"ps"
class
ControllerBase
(
object
):
def
__init__
(
self
,
ctx
):
signal
.
signal
(
signal
.
SIGTERM
,
self
.
signal_handler
)
signal
.
signal
(
signal
.
SIGABRT
,
self
.
signal_handler
)
signal
.
signal
(
signal
.
SIGINT
,
self
.
signal_handler
)
self
.
ctx
=
ctx
self
.
master
=
Master
.
factory
(
self
.
ctx
)
self
.
job
=
Job
(
np
=
self
.
ctx
.
args
.
np
,
mode
=
self
.
ctx
.
args
.
mode
,
id
=
self
.
ctx
.
args
.
id
)
self
.
pod
=
Pod
()
self
.
join_server
=
None
def
run
(
self
):
self
.
build_job
()
self
.
build_pod
()
if
len
(
self
.
pod
.
containers
)
<
1
:
self
.
ctx
.
logger
.
error
(
"No container in the pod {}"
.
format
(
self
.
pod
))
return
self
.
ctx
.
logger
.
info
(
"Run {}"
.
format
(
self
.
pod
))
self
.
ctx
.
logger
.
debug
(
self
.
pod
.
containers
[
0
])
self
.
pod
.
deploy
()
self
.
watch
()
def
watch
(
self
)
->
bool
:
status
=
self
.
pod
.
watch
()
if
status
==
self
.
ctx
.
status
.
COMPLETED
:
self
.
ctx
.
logger
.
info
(
"Pod {}"
.
format
(
status
))
elif
status
==
self
.
ctx
.
status
.
FAILED
:
self
.
ctx
.
logger
.
info
(
"Pod {}"
.
format
(
status
))
self
.
ctx
.
logger
.
error
(
"Container failed !!!
\n
{}"
.
format
(
self
.
pod
.
failed_container
()))
self
.
pod
.
tail
()
self
.
pod
.
stop
()
def
stop
(
self
,
sigint
=
None
):
self
.
ctx
.
logger
.
debug
(
"Controller stop"
)
self
.
master
.
stop
()
self
.
pod
.
stop
(
sigint
)
def
finalize
(
self
):
self
.
pod
.
join
()
self
.
master
.
stop
()
self
.
ctx
.
logger
.
info
(
"Exit code {}"
.
format
(
self
.
pod
.
exit_code
))
sys
.
exit
(
self
.
pod
.
exit_code
)
def
signal_handler
(
self
,
sigint
,
frame
):
self
.
ctx
.
logger
.
info
(
"Terminating with signal {}"
.
format
(
sigint
))
if
hasattr
(
self
,
'sigint'
):
time
.
sleep
(
5
)
sys
.
exit
(
sigint
)
self
.
sigint
=
sigint
self
.
ctx
.
status
.
done
()
self
.
stop
(
sigint
)
time
.
sleep
(
1
)
self
.
ctx
.
logger
.
debug
(
"Exit with signal {}"
.
format
(
sigint
))
sys
.
exit
(
sigint
)
class
Controller
(
ControllerBase
):
'''
Controller API for customization
'''
def
build_job
(
self
):
'''
build job fill the job info.
'''
self
.
ctx
.
logger
.
info
(
self
.
job
)
def
build_pod
(
self
)
->
bool
:
'''
build pod includes creating containers etc.
Return True if succeed
'''
raise
NotImplementedError
def
_get_entrypoint
(
self
):
entrypoint
=
[
sys
.
executable
,
"-u"
,
self
.
ctx
.
args
.
training_script
]
entrypoint
.
extend
(
self
.
ctx
.
args
.
training_script_args
)
return
entrypoint
def
_get_out_err_file
(
self
,
out
=
None
,
err
=
None
):
if
out
and
self
.
ctx
.
args
.
log_dir
!=
""
:
out
=
os
.
path
.
join
(
self
.
ctx
.
args
.
log_dir
,
out
)
if
err
and
self
.
ctx
.
args
.
log_dir
!=
""
:
err
=
os
.
path
.
join
(
self
.
ctx
.
args
.
log_dir
,
err
)
return
out
,
(
err
or
out
)
def
new_container
(
self
,
entrypoint
=
None
,
envs
=
{},
use_ctx_env
=
True
,
out
=
None
,
err
=
None
):
c
=
Container
(
entrypoint
=
(
entrypoint
or
self
.
_get_entrypoint
()),
env
=
(
self
.
ctx
.
get_envs
()
if
use_ctx_env
else
{}),
)
c
.
outfile
,
c
.
errfile
=
self
.
_get_out_err_file
(
out
,
err
)
c
.
update_env
(
envs
)
return
c
def
add_container
(
self
,
container
=
None
,
entrypoint
=
None
,
envs
=
{},
log_tag
=
None
,
is_init
=
False
):
if
not
is_init
and
log_tag
is
not
None
:
log_file
=
"{}.{}.{}.log"
.
format
(
self
.
job
.
id
,
self
.
pod
.
name
,
log_tag
)
else
:
log_file
=
None
if
not
container
:
container
=
self
.
new_container
(
entrypoint
=
entrypoint
,
envs
=
envs
,
out
=
log_file
,
err
=
log_file
)
if
is_init
:
self
.
pod
.
add_init_container
(
container
)
else
:
self
.
pod
.
add_container
(
container
)
def
pod_replicas
(
self
):
'''
how many process/container should be run in pod
'''
if
self
.
ctx
.
args
.
nproc_per_node
:
return
int
(
self
.
ctx
.
args
.
nproc_per_node
)
else
:
return
self
.
ctx
.
node
.
device
.
count
def
save_pod_log
(
self
,
info
):
'''
save_pod_log append *info* to the log file of pod.name
'''
if
not
self
.
ctx
.
args
.
log_dir
:
return
f
=
os
.
path
.
join
(
self
.
ctx
.
args
.
log_dir
,
'{}.{}.log'
.
format
(
self
.
job
.
id
,
self
.
pod
.
name
))
try
:
os
.
makedirs
(
os
.
path
.
dirname
(
f
),
exist_ok
=
True
)
with
open
(
f
,
'a+'
)
as
fd
:
fd
.
write
(
str
(
info
))
except
Exception
as
e
:
self
.
ctx
.
logger
.
error
(
"save log failed because {}"
.
format
(
e
))
python/paddle/distributed/run/controllers/master.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.distributed.run.utils.kv_client
import
KVClient
from
paddle.distributed.run.utils.kv_server
import
KVServer
import
time
import
sys
import
six
import
threading
import
copy
import
random
ETCD_PROTOCAL
=
'etcd://'
class
Master
(
object
):
'''
Master is a distributed store design to exchange info among nodes
'''
MAIN
=
"main"
STANDBY
=
"standby"
PATICIPANT
=
"participant"
def
__init__
(
self
,
ctx
):
self
.
ctx
=
ctx
self
.
server
=
None
self
.
initialized
=
False
self
.
endpoint
=
None
def
stop
(
self
):
raise
NotImplementedError
def
sync_peers
(
self
,
prefix
,
key
,
value
,
size
,
rank
=-
1
)
->
(
list
,
int
):
raise
NotImplementedError
@
classmethod
def
factory
(
cls
,
ctx
):
if
ctx
.
args
.
master
and
ctx
.
args
.
master
.
startswith
(
ETCD_PROTOCAL
):
return
ETCDMaster
(
ctx
)
else
:
return
HTTPMaster
(
ctx
)
class
HTTPMaster
(
Master
):
def
lazy_init
(
self
):
if
self
.
initialized
:
return
self
.
role
=
Master
.
PATICIPANT
if
self
.
ctx
.
args
.
master
:
self
.
endpoint
=
self
.
ctx
.
args
.
master
ip
,
port
=
self
.
endpoint
.
split
(
':'
)
if
ip
in
[
'127.0.0.1'
,
self
.
ctx
.
node
.
ip
]:
time
.
sleep
(
2
*
random
.
random
())
while
not
self
.
ctx
.
node
.
is_server_ready
(
ip
,
int
(
port
)):
try
:
self
.
server
=
KVServer
(
int
(
port
))
self
.
role
=
Master
.
MAIN
break
except
Exception
as
e
:
self
.
ctx
.
logger
.
warning
(
"start master failed {}"
.
format
(
e
))
time
.
sleep
(
0.1
)
continue
else
:
port
=
self
.
ctx
.
node
.
get_free_port
()
self
.
endpoint
=
"{}:{}"
.
format
(
self
.
ctx
.
node
.
ip
,
port
)
self
.
server
=
KVServer
(
port
)
self
.
role
=
Master
.
MAIN
print
(
"Copy the following command to other nodes to run."
)
cmd
=
[
sys
.
executable
.
split
(
'/'
)[
-
1
],
"-m"
,
"paddle.distributed.run"
]
cmd
.
extend
([
"--master"
,
self
.
endpoint
])
cmd
.
extend
(
sys
.
argv
[
1
:])
print
(
"-"
*
80
)
print
(
" "
.
join
(
cmd
))
print
(
"-"
*
80
)
if
self
.
ctx
.
args
.
rank
>=
0
:
self
.
ctx
.
logger
.
warning
(
"--rank set in the command may not compatible in auto mode"
)
if
'127.0.0.1'
in
self
.
endpoint
:
self
.
endpoint
=
self
.
endpoint
.
replace
(
'127.0.0.1'
,
self
.
ctx
.
node
.
ip
)
self
.
client
=
KVClient
(
self
.
endpoint
)
self
.
initialized
=
True
self
.
_start_server
()
def
_start_server
(
self
):
if
self
.
server
and
not
self
.
server
.
started
:
self
.
server
.
start
()
self
.
ctx
.
logger
.
debug
(
"KV server start at {}"
.
format
(
self
.
endpoint
))
def
_stop_server
(
self
):
if
self
.
server
and
not
self
.
server
.
stopped
:
self
.
server
.
stop
()
self
.
ctx
.
logger
.
debug
(
"KV server stopped"
)
def
stop
(
self
):
self
.
_stop_server
()
def
sync_peers
(
self
,
prefix
,
key
,
value
,
size
,
rank
=-
1
)
->
(
list
,
int
):
if
size
<
2
:
return
[
value
],
0
self
.
lazy_init
()
while
not
self
.
ctx
.
status
.
is_done
():
if
self
.
client
.
wait_server_ready
(
timeout
=
5
):
break
else
:
self
.
ctx
.
logger
.
warning
(
"master not ready"
)
time
.
sleep
(
0.1
)
# 'aaaaaa' make suer main pod (master server) as rank 0
ky
=
'aaaaaa'
if
rank
<
0
and
self
.
role
==
Master
.
MAIN
else
key
k
=
"{}/{}/{}"
.
format
(
prefix
,
ky
,
rank
)
while
not
self
.
ctx
.
status
.
is_done
():
if
not
self
.
client
.
put
(
k
,
value
):
self
.
ctx
.
logger
.
warning
(
"put value failed"
)
time
.
sleep
(
0.1
)
continue
rjson
=
self
.
client
.
get_prefix
(
prefix
)
self
.
ctx
.
logger
.
debug
(
"sync peers {}"
.
format
(
rjson
))
if
rjson
and
len
(
rjson
)
==
size
:
if
rank
<
0
:
keys
=
list
(
rjson
.
keys
())
keys
.
sort
()
ret
=
[
rjson
[
k
]
for
k
in
keys
]
idx
=
ret
.
index
(
value
)
return
ret
,
idx
else
:
ret
=
[
None
]
*
size
for
k
,
v
in
rjson
.
items
():
ret
[
int
(
k
.
split
(
'/'
)[
-
1
])]
=
v
return
ret
,
rank
else
:
time
.
sleep
(
0.5
)
return
[],
0
class
ETCDMaster
(
Master
):
def
__init__
(
self
,
ctx
):
super
().
__init__
(
ctx
)
if
self
.
ctx
.
args
.
master
:
# etcd://localhost:2379
self
.
endpoint
=
self
.
ctx
.
args
.
master
.
strip
(
"etcd://"
)
import
etcd3
host
,
port
=
self
.
endpoint
.
split
(
':'
)
self
.
client
=
etcd3
.
client
(
host
=
host
,
port
=
port
)
def
sync_peers
(
self
,
prefix
,
key
,
value
,
size
,
rank
=-
1
)
->
(
list
,
int
):
'''
sync_peers gather all value for key under scope prefix
result always be sorted either by rank or alphabet of pod.name
'''
path
=
"{}/{}/{}"
.
format
(
prefix
,
key
,
rank
)
self
.
client
.
delete_prefix
(
prefix
)
self
.
ctx
.
logger
.
debug
(
"sync path {} value {}"
.
format
(
path
,
value
))
while
not
self
.
ctx
.
status
.
is_done
():
self
.
client
.
put
(
path
,
six
.
b
(
value
))
result
=
[
i
for
i
in
self
.
client
.
get_prefix
(
prefix
)]
result
=
copy
.
deepcopy
(
result
)
self
.
ctx
.
logger
.
debug
(
"sync peers {}"
.
format
(
result
))
if
len
(
result
)
==
size
:
if
rank
<
0
:
keys
=
[
six
.
ensure_str
(
i
[
1
].
key
)
for
i
in
result
]
sorted_keys
=
[
six
.
ensure_str
(
i
[
1
].
key
)
for
i
in
result
]
sorted_keys
.
sort
()
values
=
[
six
.
ensure_str
(
i
[
0
])
for
i
in
result
]
ret
=
[
values
[
keys
.
index
(
k
)]
for
k
in
sorted_keys
]
idx
=
ret
.
index
(
value
)
return
ret
,
idx
else
:
ret
=
[
None
]
*
size
for
v
,
k
in
result
:
ii
=
int
(
six
.
ensure_str
(
k
.
key
).
split
(
'/'
)[
-
1
])
if
ii
<
0
:
self
.
ctx
.
logger
.
error
(
"rank {} error in sync"
.
format
(
ii
))
ret
[
ii
]
=
six
.
ensure_str
(
v
)
return
ret
,
rank
else
:
time
.
sleep
(
0.5
)
def
register_heartbeat
(
self
,
job_id
,
pod_id
,
ttl
=
10
):
if
hasattr
(
self
,
'heartbeat_prefix'
):
self
.
ctx
.
logger
.
warning
(
"Heartbeat already done"
)
return
self
.
job_prefix
=
'/paddle/{}'
.
format
(
job_id
)
self
.
heartbeat_prefix
=
'{}/heartbeat'
.
format
(
self
.
job_prefix
)
lease
=
self
.
client
.
lease
(
ttl
)
#self.client.delete_prefix(self.job_prefix)
beat_path
=
"{}/{}"
.
format
(
self
.
heartbeat_prefix
,
pod_id
)
self
.
client
.
put
(
beat_path
,
six
.
b
(
pod_id
),
lease
=
lease
)
def
_beat_watch
(
event
):
self
.
ctx
.
status
.
restart
()
beat_watch
=
self
.
client
.
add_watch_prefix_callback
(
self
.
heartbeat_prefix
,
_beat_watch
)
def
_heartbeat
():
while
not
self
.
ctx
.
status
.
is_done
():
try
:
lease
.
refresh
()
if
pod_id
not
in
self
.
fetch_peer_alive
():
self
.
client
.
put
(
beat_path
,
six
.
b
(
pod_id
),
lease
=
lease
)
self
.
ctx
.
logger
.
debug
(
"Heartbeat register again"
)
except
Exception
as
e
:
self
.
ctx
.
logger
.
error
(
"Heartbeat error {}"
.
format
(
e
))
time
.
sleep
(
ttl
/
2
)
self
.
ctx
.
logger
.
debug
(
"Heartbeat done"
)
self
.
client
.
cancel_watch
(
beat_watch
)
self
.
beat_thread
=
threading
.
Thread
(
name
=
'heartbeat'
,
target
=
_heartbeat
,
daemon
=
True
)
self
.
beat_thread
.
start
()
def
fetch_peer_alive
(
self
):
peer_alive
=
[
six
.
ensure_str
(
i
[
0
])
for
i
in
self
.
client
.
get_prefix
(
self
.
heartbeat_prefix
)
]
self
.
ctx
.
logger
.
debug
(
"peer alive {}"
.
format
(
peer_alive
))
return
peer_alive
def
wait_peer_ready
(
self
,
replicas_min
,
replicas_max
,
timeout
):
end
=
time
.
time
()
+
timeout
while
not
self
.
ctx
.
status
.
is_done
()
and
time
.
time
()
<
end
:
if
len
(
self
.
fetch_peer_alive
())
==
replicas_max
:
return
(
True
,
replicas_max
)
else
:
time
.
sleep
(
0.5
)
np
=
len
(
self
.
fetch_peer_alive
())
if
np
>=
replicas_min
and
np
<=
replicas_max
:
return
(
True
,
np
)
else
:
return
(
False
,
np
)
def
restart_peer
(
self
):
self
.
client
.
delete_prefix
(
self
.
heartbeat_prefix
)
def
set_status
(
self
,
status
):
assert
self
.
client
.
put
(
self
.
job_prefix
,
six
.
b
(
status
),
lease
=
self
.
client
.
lease
(
600
)),
"set status failed {}"
.
format
(
status
)
def
get_status
(
self
):
return
six
.
ensure_str
(
self
.
client
.
get
(
self
.
job_prefix
)[
0
]
or
''
)
def
stop
(
self
):
if
hasattr
(
self
,
'beat_thread'
):
self
.
ctx
.
status
.
done
()
# TODO(kuizhiqing) thread should exit
#self.beat_thread.join()
python/paddle/distributed/run/controllers/ps.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.controller
import
Controller
,
ControleMode
import
json
import
os
,
shutil
class
PSController
(
Controller
):
@
classmethod
def
enable
(
cls
,
ctx
):
if
ctx
.
args
.
mode
==
ControleMode
.
PS
or
ctx
.
args
.
server_num
or
len
(
ctx
.
args
.
servers
)
>
0
:
ctx
.
logger
.
debug
(
"{} enabled"
.
format
(
cls
.
__name__
))
ctx
.
args
.
mode
=
ControleMode
.
PS
return
True
else
:
return
False
def
build_pod
(
self
):
if
self
.
ctx
.
args
.
servers
and
self
.
ctx
.
args
.
trainers
:
self
.
_build_pod_with_args
()
else
:
self
.
_build_pod_with_master
()
def
_build_pod_with_args
(
self
):
if
'127.0.0.1'
in
self
.
ctx
.
args
.
servers
:
host
=
'127.0.0.1'
else
:
host
=
self
.
ctx
.
node
.
ip
server_endpoints
=
[
s
for
s
in
self
.
ctx
.
args
.
servers
.
split
(
","
)]
trainer_endpoints
=
[
s
for
s
in
self
.
ctx
.
args
.
trainers
.
split
(
","
)]
servers
=
[
s
for
s
in
self
.
ctx
.
args
.
servers
.
split
(
","
)
if
s
.
startswith
(
host
)
]
trainers
=
[
s
for
s
in
self
.
ctx
.
args
.
trainers
.
split
(
","
)
if
s
.
startswith
(
host
)
]
server_num
=
len
(
servers
)
trainer_num
=
len
(
trainers
)
self
.
pod
.
replicas
=
server_num
+
trainer_num
self
.
save_pod_log
([
server_endpoints
,
trainer_endpoints
])
import
tempfile
gloo_rendezvous_dir
=
tempfile
.
mkdtemp
()
if
os
.
path
.
exists
(
gloo_rendezvous_dir
):
shutil
.
rmtree
(
gloo_rendezvous_dir
)
gloo_port
=
self
.
ctx
.
args
.
gloo_port
gloo_http
=
"{}:{}"
.
format
(
server_endpoints
[
0
].
split
(
":"
)[
0
],
gloo_port
)
_gloo_envs
=
{
"PADDLE_GLOO_RENDEZVOUS"
:
"3"
,
"PADDLE_GLOO_FS_PATH"
:
gloo_rendezvous_dir
,
"PADDLE_GLOO_HTTP_ENDPOINT"
:
gloo_http
,
"PADDLE_WITH_GLOO"
:
self
.
ctx
.
args
.
with_gloo
}
for
i
in
range
(
server_num
):
e
=
{
"PADDLE_PSERVERS_IP_PORT_LIST"
:
self
.
ctx
.
args
.
servers
,
"PADDLE_TRAINER_ENDPOINTS"
:
self
.
ctx
.
args
.
trainers
,
"PADDLE_PORT"
:
servers
[
i
].
split
(
":"
)[
1
],
"PADDLE_ROLE"
:
"PSERVER"
,
"TRAINING_ROLE"
:
"PSERVER"
,
"PADDLE_TRAINERS_NUM"
:
"{}"
.
format
(
len
(
trainer_endpoints
)),
"POD_IP"
:
self
.
ctx
.
node
.
ip
,
}
e
.
update
(
_gloo_envs
)
log_tag
=
"ps.{}"
.
format
(
i
)
self
.
add_container
(
envs
=
e
,
log_tag
=
log_tag
)
trainer_rank_offset
=
0
for
s
in
trainer_endpoints
:
if
s
.
startswith
(
host
):
break
else
:
trainer_rank_offset
+=
1
for
i
in
range
(
trainer_num
):
e
=
{
"PADDLE_PSERVERS_IP_PORT_LIST"
:
","
.
join
(
server_endpoints
),
"PADDLE_TRAINER_ENDPOINTS"
:
","
.
join
(
trainer_endpoints
),
"PADDLE_PORT"
:
trainers
[
i
].
split
(
":"
)[
1
],
"PADDLE_ROLE"
:
"TRAINER"
,
"TRAINING_ROLE"
:
"TRAINER"
,
"PADDLE_TRAINER_ID"
:
"{}"
.
format
(
i
+
trainer_rank_offset
),
"PADDLE_TRAINERS_NUM"
:
"{}"
.
format
(
len
(
trainer_endpoints
)),
"POD_IP"
:
self
.
ctx
.
node
.
ip
,
}
e
.
update
(
_gloo_envs
)
log_tag
=
"trainer.{}"
.
format
(
i
)
self
.
add_container
(
envs
=
e
,
log_tag
=
log_tag
)
def
_build_pod_with_master
(
self
):
self
.
pod
.
rank
=
self
.
ctx
.
args
.
rank
server_num
=
self
.
ctx
.
args
.
server_num
or
1
servers
=
[
"{}:{}"
.
format
(
self
.
ctx
.
node
.
ip
,
p
)
for
p
in
self
.
ctx
.
node
.
get_free_ports
(
server_num
)
]
trainer_num
=
self
.
ctx
.
args
.
trainer_num
or
1
trainers
=
[
"{}:{}"
.
format
(
self
.
ctx
.
node
.
ip
,
p
)
for
p
in
self
.
ctx
.
node
.
get_free_ports
(
trainer_num
)
]
data
=
json
.
dumps
({
'name'
:
self
.
pod
.
name
,
'rank'
:
self
.
pod
.
rank
,
'servers'
:
servers
,
'trainers'
:
trainers
,
'dtype'
:
self
.
ctx
.
node
.
device
.
dtype
,
'gloo_port'
:
self
.
ctx
.
node
.
get_free_port
(),
})
peer_list
,
rank
=
self
.
master
.
sync_peers
(
'/{}/info'
.
format
(
self
.
job
.
id
),
self
.
pod
.
name
,
data
,
self
.
job
.
replicas
,
self
.
pod
.
rank
)
self
.
ctx
.
logger
.
debug
(
"sync peers done {}"
.
format
(
peer_list
))
peer_list
=
[
json
.
loads
(
i
)
for
i
in
peer_list
]
self
.
save_pod_log
(
peer_list
)
server_endpoints
=
[
j
for
i
in
peer_list
for
j
in
i
[
'servers'
]]
trainer_endpoints
=
[
j
for
i
in
peer_list
for
j
in
i
[
'trainers'
]]
#rank_offset = sum([i['replicas'] for i in peer_list[:rank]])
server_rank_offset
=
sum
([
len
(
i
[
'servers'
])
for
i
in
peer_list
[:
rank
]])
trainer_rank_offset
=
sum
(
[
len
(
i
[
'trainers'
])
for
i
in
peer_list
[:
rank
]])
self
.
pod
.
rank
=
rank
self
.
pod
.
replicas
=
server_num
+
trainer_num
import
tempfile
gloo_rendezvous_dir
=
tempfile
.
mkdtemp
()
if
os
.
path
.
exists
(
gloo_rendezvous_dir
):
shutil
.
rmtree
(
gloo_rendezvous_dir
)
gloo_port
=
peer_list
[
0
][
'gloo_port'
]
gloo_http
=
"{}:{}"
.
format
(
server_endpoints
[
0
].
split
(
":"
)[
0
],
gloo_port
)
_gloo_envs
=
{
"PADDLE_GLOO_RENDEZVOUS"
:
"3"
,
"PADDLE_GLOO_FS_PATH"
:
gloo_rendezvous_dir
,
"PADDLE_GLOO_HTTP_ENDPOINT"
:
gloo_http
,
"PADDLE_WITH_GLOO"
:
self
.
ctx
.
args
.
with_gloo
}
for
i
in
range
(
server_num
):
e
=
{
"PADDLE_PSERVERS_IP_PORT_LIST"
:
","
.
join
(
server_endpoints
),
"PADDLE_TRAINER_ENDPOINTS"
:
","
.
join
(
trainer_endpoints
),
"PADDLE_PORT"
:
server_endpoints
[
i
+
server_rank_offset
].
split
(
":"
)[
1
],
"PADDLE_ROLE"
:
"PSERVER"
,
"TRAINING_ROLE"
:
"PSERVER"
,
"PADDLE_TRAINERS_NUM"
:
"{}"
.
format
(
len
(
trainer_endpoints
)),
"POD_IP"
:
self
.
ctx
.
node
.
ip
,
}
e
.
update
(
_gloo_envs
)
log_tag
=
"ps.{}"
.
format
(
i
)
self
.
add_container
(
envs
=
e
,
log_tag
=
log_tag
)
for
i
in
range
(
trainer_num
):
e
=
{
"PADDLE_PSERVERS_IP_PORT_LIST"
:
","
.
join
(
server_endpoints
),
"PADDLE_TRAINER_ENDPOINTS"
:
","
.
join
(
trainer_endpoints
),
"PADDLE_PORT"
:
trainer_endpoints
[
i
+
trainer_rank_offset
].
split
(
":"
)[
1
],
"PADDLE_ROLE"
:
"TRAINER"
,
"TRAINING_ROLE"
:
"TRAINER"
,
"PADDLE_TRAINER_ID"
:
"{}"
.
format
(
i
+
trainer_rank_offset
),
"PADDLE_TRAINERS_NUM"
:
"{}"
.
format
(
len
(
trainer_endpoints
)),
"POD_IP"
:
self
.
ctx
.
node
.
ip
,
}
e
.
update
(
_gloo_envs
)
log_tag
=
"trainer.{}"
.
format
(
i
)
self
.
add_container
(
envs
=
e
,
log_tag
=
log_tag
)
''' NEW VERSION
for i in range(server_num):
e = {
"PADDLE_PSERVER_ENDPOINTS": ",".join(server_endpoints),
"PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints),
"PADDLE_ROLE": "PSERVER",
"PADDLE_RANK": "{}".format(i + server_rank_offset),
}
log_tag = "ps.{}".format(i)
self.add_container(envs=e, log_tag=log_tag)
for i in range(trainer_num):
e = {
"PADDLE_PSERVER_ENDPOINTS": ",".join(server_endpoints),
"PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints),
"PADDLE_ROLE": "TRAINER_CPU",
"PADDLE_RANK": "{}".format(i + trainer_rank_offset),
}
log_tag = "trainer.{}".format(i)
self.add_container(envs=e, log_tag=log_tag)
'''
python/paddle/distributed/run/job/__init__.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.pod
import
Pod
from
.job
import
Job
from
.container
import
Container
from
.status
import
Status
__all__
=
[
'Pod'
,
'Job'
,
'Container'
,
'Status'
,
]
python/paddle/distributed/run/job/container.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
OrderedDict
from
paddle.distributed.run.utils.process_context
import
ProcessContext
from
.status
import
Status
import
os
,
copy
,
sys
import
time
class
Container
(
object
):
'''
TODO(kuizhiqing) A container can be run by process/thread or just a callable function
'''
def
__init__
(
self
,
entrypoint
=
[],
rank
=-
1
,
env
=
{}):
self
.
_entrypoint
=
entrypoint
self
.
_rank
=
rank
self
.
_out
=
None
self
.
_err
=
None
self
.
_env
=
env
self
.
_proc
=
None
self
.
_retry
:
int
=
3
self
.
_grace_period
=
10
self
.
_log_handler
=
None
@
property
def
entrypoint
(
self
):
return
self
.
_entrypoint
@
entrypoint
.
setter
def
entrypoint
(
self
,
entry
):
self
.
_entrypoint
=
entry
@
property
def
rank
(
self
):
return
self
.
_rank
@
rank
.
setter
def
rank
(
self
,
r
):
self
.
_rank
=
r
@
property
def
outfile
(
self
):
return
self
.
_out
@
outfile
.
setter
def
outfile
(
self
,
out
):
self
.
_out
=
out
@
property
def
errfile
(
self
):
return
self
.
_err
@
errfile
.
setter
def
errfile
(
self
,
err
):
self
.
_err
=
err
def
update_env
(
self
,
env
=
{},
**
kwargs
):
env
=
{
k
:
v
for
k
,
v
in
env
.
items
()
if
isinstance
(
v
,
str
)}
self
.
_env
.
update
(
env
)
kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
isinstance
(
v
,
str
)}
self
.
_env
.
update
(
kwargs
)
def
_get_fd
(
self
,
pth
):
if
not
pth
:
return
None
try
:
d
=
os
.
path
.
dirname
(
pth
)
if
not
os
.
path
.
isdir
(
d
):
os
.
makedirs
(
d
,
exist_ok
=
True
)
return
open
(
pth
,
'w'
)
except
:
return
None
def
start
(
self
,
timeout
=-
1
):
end
=
time
.
time
()
+
timeout
if
self
.
_proc
and
self
.
_proc
.
alive
():
return
True
self
.
_stdout
=
self
.
_get_fd
(
self
.
_out
)
or
sys
.
stdout
if
self
.
_out
==
self
.
_err
:
self
.
_stderr
=
self
.
_stdout
elif
self
.
_err
:
self
.
_stderr
=
self
.
_get_fd
(
self
.
_err
)
or
sys
.
stderr
self
.
_proc
=
ProcessContext
(
self
.
_entrypoint
,
env
=
self
.
_env
,
out
=
self
.
_stdout
,
err
=
self
.
_stderr
)
self
.
_proc
.
start
()
while
timeout
>
0
and
time
.
time
()
<
end
:
if
self
.
_proc
.
alive
():
time
.
sleep
(
0.1
)
continue
if
self
.
_proc
.
exit_code
()
==
0
:
return
True
return
False
def
terminate
(
self
,
force
=
False
):
if
self
.
_log_handler
:
self
.
_log_handler
.
close
()
self
.
_log_handler
=
None
if
self
.
_proc
and
self
.
_proc
.
alive
():
return
self
.
_proc
.
terminate
(
force
)
def
wait
(
self
,
timeout
=
None
):
self
.
_proc
.
wait
(
timeout
)
def
exit_code
(
self
):
return
self
.
_proc
.
exit_code
()
if
self
.
_proc
else
-
1
def
status
(
self
):
if
not
self
.
_proc
:
return
Status
.
UNINIT
if
self
.
_proc
.
alive
():
return
Status
.
RUNNING
elif
self
.
_proc
.
exit_code
()
==
0
:
return
Status
.
COMPLETED
else
:
return
Status
.
FAILED
def
__str__
(
self
):
return
'Container rank {} status {} cmd {} code {} log {}
\n
env {}'
.
format
(
self
.
_rank
,
self
.
status
(),
self
.
_entrypoint
,
self
.
exit_code
(),
self
.
errfile
,
self
.
_env
,
)
def
logs
(
self
,
fn
=
None
,
offset
=
0
,
whence
=
1
,
lines
=
1000
):
if
not
self
.
_log_handler
:
self
.
_log_handler
=
open
(
self
.
_out
)
if
fn
is
None
:
fn
=
sys
.
stdout
self
.
_log_handler
.
seek
(
offset
,
whence
)
try
:
idx
=
0
for
line
in
self
.
_log_handler
:
fn
.
write
(
line
)
idx
+=
1
if
idx
>
lines
:
break
finally
:
return
self
.
_log_handler
.
tell
()
def
tail
(
self
,
length
=
3000
):
if
not
self
.
_log_handler
:
self
.
_log_handler
=
open
(
self
.
_out
)
self
.
_log_handler
.
seek
(
0
,
2
)
ed
=
self
.
_log_handler
.
tell
()
if
ed
>
length
:
self
.
logs
(
offset
=
ed
-
length
,
whence
=
0
)
else
:
self
.
logs
(
offset
=
0
,
whence
=
0
)
python/paddle/distributed/run/job/job.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
JobMode
:
COLLECTIVE
=
'collective'
PS
=
'ps'
HETER
=
'heter'
class
Job
(
object
):
def
__init__
(
self
,
id
=
'default'
,
mode
=
JobMode
.
COLLECTIVE
,
np
=
"1"
):
self
.
_mode
=
mode
self
.
_id
=
id
self
.
_replicas
=
0
self
.
_replicas_min
=
self
.
_replicas
self
.
_replicas_max
=
self
.
_replicas
self
.
_elastic
=
False
self
.
set_replicas
(
str
(
np
))
def
__str__
(
self
):
return
"Job: {}, mode {}, replicas {}[{}:{}], elastic {}"
.
format
(
self
.
id
,
self
.
mode
,
self
.
_replicas
,
self
.
_replicas_min
,
self
.
_replicas_max
,
self
.
elastic
)
@
property
def
mode
(
self
):
return
self
.
_mode
@
property
def
id
(
self
):
return
self
.
_id
@
property
def
elastic
(
self
):
return
self
.
_elastic
@
property
def
replicas
(
self
):
return
self
.
_replicas
@
property
def
replicas_min
(
self
):
return
self
.
_replicas_min
@
property
def
replicas_max
(
self
):
return
self
.
_replicas_max
@
replicas
.
setter
def
replicas
(
self
,
replicas
):
self
.
_replicas
=
replicas
def
set_replicas
(
self
,
np
:
str
):
np
=
str
(
np
)
if
np
else
'1'
if
':'
in
np
:
nps
=
np
.
split
(
':'
)
self
.
_replicas_min
,
self
.
_replicas_max
=
int
(
nps
[
0
]),
int
(
nps
[
1
])
self
.
_replicas
=
self
.
_replicas_max
# default to max
self
.
_elastic
=
True
else
:
self
.
_replicas
=
int
(
np
)
self
.
_replicas_min
,
self
.
_replicas_max
=
self
.
_replicas
,
self
.
_replicas
self
.
_elastic
=
False
python/paddle/distributed/run/job/pod.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
OrderedDict
from
.container
import
Container
from
.status
import
Status
import
random
import
time
class
PodSepc
(
object
):
def
__init__
(
self
):
self
.
_name
=
''
.
join
(
random
.
choice
(
'abcdefghijklmnopqrstuvwxyz'
)
for
_
in
range
(
6
))
# by controller
self
.
_init_containers
:
List
[
Container
]
=
[]
self
.
_containers
:
List
[
Container
]
=
[]
#self.resource: Resource = None
#self.status: Status = None
self
.
_rank
=
-
1
self
.
_init_timeout
=
120
# 2 min timeout for each init container
self
.
_restart
=
-
1
self
.
_replicas
=
0
# number of containers
self
.
_exit_code
=
0
class
Pod
(
PodSepc
):
def
__init__
(
self
):
super
().
__init__
()
def
__str__
(
self
):
return
"Pod: {}, replicas {}, status {}"
.
format
(
self
.
name
,
self
.
replicas
,
self
.
status
())
def
failed_container
(
self
):
for
c
in
self
.
_containers
:
if
c
.
status
()
==
Status
.
FAILED
:
return
c
return
None
@
property
def
name
(
self
):
return
self
.
_name
@
property
def
replicas
(
self
):
return
self
.
_replicas
@
replicas
.
setter
def
replicas
(
self
,
r
):
self
.
_replicas
=
r
@
property
def
rank
(
self
):
return
self
.
_rank
@
rank
.
setter
def
rank
(
self
,
r
):
self
.
_rank
=
r
@
property
def
restart
(
self
):
return
self
.
_restart
@
property
def
containers
(
self
):
return
self
.
_containers
def
add_container
(
self
,
c
):
c
.
rank
=
len
(
self
.
_containers
)
self
.
_containers
.
append
(
c
)
@
property
def
init_containers
(
self
):
return
self
.
_init_containers
def
add_init_container
(
self
,
c
):
c
.
rank
=
len
(
self
.
_init_containers
)
self
.
_init_containers
.
append
(
c
)
@
property
def
exit_code
(
self
):
for
c
in
self
.
_containers
:
if
c
.
exit_code
()
!=
0
:
return
c
.
exit_code
()
return
0
def
deploy
(
self
):
for
i
in
self
.
_init_containers
:
i
.
start
(
self
.
_init_timeout
)
for
c
in
self
.
_containers
:
c
.
start
()
self
.
_restart
+=
1
def
stop
(
self
,
sigint
=
0
):
for
c
in
self
.
_containers
:
force
=
True
if
sigint
==
9
else
False
c
.
terminate
(
force
)
def
join
(
self
):
for
c
in
self
.
_containers
:
c
.
wait
(
None
)
def
status
(
self
):
if
self
.
is_failed
():
return
Status
.
FAILED
if
self
.
is_completed
():
return
Status
.
COMPLETED
return
Status
.
READY
def
reset
(
self
):
self
.
_init_containers
=
[]
self
.
_containers
=
[]
def
is_failed
(
self
):
for
c
in
self
.
_containers
:
if
c
.
status
()
==
Status
.
FAILED
:
return
True
return
False
def
is_completed
(
self
):
for
c
in
self
.
_containers
:
if
c
.
status
()
!=
Status
.
COMPLETED
:
return
False
return
True
def
logs
(
self
,
idx
=
None
):
if
idx
is
None
:
if
self
.
failed_container
():
self
.
failed_container
().
logs
()
else
:
self
.
_containers
[
0
].
logs
()
else
:
self
.
_containers
[
idx
].
logs
()
def
tail
(
self
,
idx
=
None
):
if
idx
is
None
:
if
self
.
failed_container
():
self
.
failed_container
().
tail
()
else
:
self
.
_containers
[
0
].
tail
()
else
:
self
.
_containers
[
idx
].
tail
()
def
watch
(
self
,
all_list
=
[
Status
.
COMPLETED
],
any_list
=
[
Status
.
FAILED
],
interval
=
1
,
timeout
=-
1
):
'''
watch return if any container status in any_list
or all container status in all_list
'''
end
=
time
.
time
()
+
timeout
while
timeout
<
0
or
time
.
time
()
<
end
:
for
c
in
self
.
_containers
:
if
c
.
status
()
in
any_list
:
return
c
.
status
()
s
=
[
c
.
status
()
for
c
in
self
.
_containers
]
if
len
(
set
(
s
))
==
1
and
s
[
0
]
in
all_list
:
return
s
[
0
]
time
.
sleep
(
interval
)
python/paddle/distributed/run/job/status.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
Status
(
object
):
UNINIT
=
"uninit"
READY
=
"ready"
RUNNING
=
"running"
FAILED
=
"failed"
TERMINATING
=
"terminating"
RESTARTING
=
"restarting"
UNKNOWN
=
"unknown"
COMPLETED
=
"completed"
python/paddle/distributed/run/plugins/__init__.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
six
__all__
=
[]
def
log
(
ctx
):
ctx
.
logger
.
info
(
"----------- Configuration ----------------------"
)
for
arg
,
value
in
sorted
(
six
.
iteritems
(
vars
(
ctx
.
args
))):
ctx
.
logger
.
info
(
"%s: %s"
%
(
arg
,
value
))
ctx
.
logger
.
info
(
"--------------------------------------------------"
)
def
process_args
(
ctx
):
# reset device by args
#argdev = ctx.args.gpus or ctx.args.xpus or ctx.args.npus
argdev
=
ctx
.
args
.
devices
if
argdev
:
ctx
.
node
.
device
.
labels
=
argdev
.
split
(
','
)
ctx
.
node
.
device
.
count
=
len
(
ctx
.
node
.
device
.
labels
)
ctx
.
logger
.
debug
(
'Device reset by args {}'
.
format
(
argdev
))
def
collective_compatible
(
ctx
):
if
'PADDLE_TRAINER_ENDPOINTS'
in
ctx
.
envs
:
ctx
.
master
=
ctx
.
envs
[
'PADDLE_TRAINER_ENDPOINTS'
].
split
(
','
)[
0
]
if
'DISTRIBUTED_TRAINER_ENDPOINTS'
in
ctx
.
envs
:
ctx
.
master
=
ctx
.
envs
[
'DISTRIBUTED_TRAINER_ENDPOINTS'
].
split
(
','
)[
0
]
def
rewrite_host_ip
(
ctx
):
if
ctx
.
args
.
host
is
not
None
and
"."
in
ctx
.
args
.
host
:
ctx
.
logger
.
warning
(
'Host ip reset to {}'
.
format
(
ctx
.
args
.
host
))
ctx
.
node
.
ip
=
ctx
.
args
.
host
enabled_plugins
=
[
collective_compatible
,
rewrite_host_ip
,
process_args
,
log
]
python/paddle/distributed/run/plugins/ip.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
socket
def
get_local_ip
(
ctx
):
_
,
ip
=
_get_host_name_ip
()
ctx
.
args
.
host
=
ip
ctx
.
envs
[
"POD_IP"
]
=
ip
def
_get_host_name_ip
():
try
:
host_name
=
socket
.
gethostname
()
host_ip
=
socket
.
gethostbyname
(
host_name
)
return
host_name
,
host_ip
except
:
return
None
python/paddle/distributed/run/utils/kv_client.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
requests
import
time
class
KVClient
(
object
):
def
__init__
(
self
,
endpoint
=
'localhost:2379'
):
self
.
endpoint
=
endpoint
if
endpoint
.
startswith
(
"http://"
)
else
"http://{}"
.
format
(
endpoint
)
def
put
(
self
,
key
,
value
):
key
=
key
if
key
.
startswith
(
'/'
)
else
"/{}"
.
format
(
key
)
u
=
"{}{}"
.
format
(
self
.
endpoint
,
key
)
try
:
r
=
requests
.
post
(
u
,
data
=
value
,
timeout
=
3
)
if
r
.
status_code
==
200
:
return
True
else
:
return
False
except
:
return
False
def
get
(
self
,
key
):
key
=
key
if
key
.
startswith
(
'/'
)
else
"/{}"
.
format
(
key
)
u
=
"{}{}"
.
format
(
self
.
endpoint
,
key
)
try
:
r
=
requests
.
get
(
u
,
timeout
=
3
)
if
r
.
status_code
==
200
:
ret
=
r
.
json
()
return
ret
.
get
(
key
,
''
)
else
:
return
"error"
except
:
return
""
def
get_prefix
(
self
,
key
):
key
=
key
if
key
.
startswith
(
'/'
)
else
"/{}"
.
format
(
key
)
u
=
"{}{}"
.
format
(
self
.
endpoint
,
key
)
try
:
r
=
requests
.
get
(
u
,
timeout
=
3
)
if
r
.
status_code
==
200
:
return
r
.
json
()
except
:
return
""
def
delete
(
self
,
key
):
key
=
key
if
key
.
startswith
(
'/'
)
else
"/{}"
.
format
(
key
)
u
=
"{}{}"
.
format
(
self
.
endpoint
,
key
)
try
:
r
=
requests
.
delete
(
u
,
timeout
=
3
)
if
r
.
status_code
==
200
:
return
True
else
:
return
False
except
:
return
False
def
wait_server_ready
(
self
,
timeout
=
3
):
end
=
time
.
time
()
+
timeout
while
time
.
time
()
<
end
:
if
self
.
get
(
"/healthy"
)
==
"ok"
:
return
True
if
__name__
==
'__main__'
:
cli
=
PKVClient
(
"http://localhost:8090"
)
data
=
{
"/workers/1"
:
"rank1"
,
"/workers/2"
:
"rank2"
}
for
k
,
v
in
data
.
items
():
cli
.
put
(
k
,
v
)
x
=
cli
.
get_prefix
(
"/workers"
)
print
(
x
)
for
k
,
v
in
data
.
items
():
assert
x
[
k
]
==
v
cli
.
put
(
"key"
,
"value"
)
print
(
cli
.
get
(
"key"
))
assert
cli
.
get
(
"key"
)
==
"value"
cli
.
delete
(
"key"
)
print
(
cli
.
get
(
"/key"
))
print
(
cli
.
get
(
"/healthy"
))
assert
cli
.
get
(
"/healthy"
)
==
"ok"
python/paddle/distributed/run/utils/kv_server.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
http.server
import
HTTPServer
import
http.server
as
SimpleHTTPServer
from
multiprocessing
import
Process
import
threading
import
json
class
KVHandler
(
SimpleHTTPServer
.
SimpleHTTPRequestHandler
):
def
do_GET
(
self
):
with
self
.
server
.
kv_lock
:
ret
=
{}
for
k
,
v
in
self
.
server
.
kv
.
items
():
if
k
.
startswith
(
self
.
path
):
ret
[
k
]
=
v
.
decode
(
encoding
=
"utf-8"
)
if
ret
:
self
.
output
(
200
,
json
.
dumps
(
ret
).
encode
(
"utf-8"
))
else
:
self
.
output
(
404
)
def
do_PUT
(
self
):
self
.
do_POST
()
def
do_POST
(
self
):
content_length
=
int
(
self
.
headers
[
'Content-Length'
]
or
0
)
try
:
value
=
self
.
rfile
.
read
(
content_length
)
with
self
.
server
.
kv_lock
:
self
.
server
.
kv
[
self
.
path
]
=
value
self
.
output
(
200
)
return
except
:
self
.
output
(
500
)
def
do_DELETE
(
self
):
with
self
.
server
.
kv_lock
:
if
self
.
path
in
self
.
server
.
kv
:
del
self
.
server
.
kv
[
self
.
path
]
self
.
output
(
200
)
else
:
self
.
output
(
404
)
def
output
(
self
,
code
,
value
=
''
):
self
.
send_response
(
code
)
self
.
send_header
(
"Content-Length"
,
len
(
value
))
self
.
send_header
(
"Content-Type"
,
"application/json; charset=utf8"
)
self
.
end_headers
()
if
value
:
self
.
wfile
.
write
(
value
)
def
log_message
(
self
,
format
,
*
args
):
return
class
KVServer
(
HTTPServer
,
object
):
def
__init__
(
self
,
port
):
super
(
KVServer
,
self
).
__init__
((
''
,
port
),
KVHandler
)
self
.
kv_lock
=
threading
.
Lock
()
self
.
kv
=
{
'/healthy'
:
b
'ok'
}
self
.
port
=
port
self
.
stopped
=
False
self
.
started
=
False
def
start
(
self
):
self
.
listen_thread
=
threading
.
Thread
(
target
=
self
.
serve_forever
)
self
.
listen_thread
.
start
()
self
.
started
=
True
def
stop
(
self
):
self
.
shutdown
()
self
.
listen_thread
.
join
()
self
.
server_close
()
self
.
stopped
=
True
class
PKVServer
():
def
__init__
(
self
,
port
):
self
.
_server
=
KVServer
(
port
)
def
start
(
self
):
self
.
proc
=
Process
(
target
=
self
.
_server
.
start
)
self
.
proc
.
daemon
=
True
self
.
proc
.
start
()
def
stop
(
self
):
self
.
_server
.
stop
()
self
.
proc
.
join
()
@
property
def
started
(
self
):
return
self
.
_server
.
started
@
property
def
stopped
(
self
):
return
self
.
_server
.
stopped
if
__name__
==
'__main__'
:
#kv = PKVServer(8090)
kv
=
KVServer
(
8090
)
kv
.
start
()
import
time
#print("serve at 8090 for 600 s")
time
.
sleep
(
600
)
python/paddle/distributed/run/utils/process_context.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
subprocess
import
os
,
sys
,
signal
,
time
class
ProcessContext
(
object
):
def
__init__
(
self
,
cmd
,
env
=
os
.
environ
,
out
=
sys
.
stdout
,
err
=
sys
.
stderr
,
group
=
True
,
preexec_fn
=
None
):
self
.
_cmd
=
cmd
self
.
_env
=
env
self
.
_preexec_fn
=
preexec_fn
self
.
_stdout
=
out
self
.
_stderr
=
err
self
.
_group
=
group
if
os
.
name
!=
'nt'
else
False
self
.
_proc
=
None
self
.
_code
=
None
def
_start
(
self
):
pre_fn
=
os
.
setsid
if
self
.
_group
else
None
self
.
_proc
=
subprocess
.
Popen
(
self
.
_cmd
,
env
=
self
.
_env
,
stdout
=
self
.
_stdout
,
stderr
=
self
.
_stderr
,
preexec_fn
=
self
.
_preexec_fn
or
pre_fn
)
def
_close_std
(
self
):
try
:
if
not
self
.
_stdout
.
isatty
():
self
.
_stdout
.
close
()
if
not
self
.
_stderr
.
isatty
():
self
.
_stderr
.
close
()
except
:
pass
def
alive
(
self
):
return
self
.
_proc
and
self
.
_proc
.
poll
()
is
None
def
exit_code
(
self
):
return
self
.
_proc
.
poll
()
if
self
.
_proc
else
None
def
start
(
self
):
self
.
_start
()
def
terminate
(
self
,
force
=
False
,
max_retry
=
3
):
for
i
in
range
(
max_retry
):
if
self
.
alive
():
if
self
.
_group
:
os
.
killpg
(
os
.
getpgid
(
self
.
_proc
.
pid
),
signal
.
SIGTERM
)
else
:
self
.
_proc
.
terminate
()
time
.
sleep
(
0.2
)
else
:
break
if
force
and
self
.
alive
():
self
.
_proc
.
kill
()
self
.
_close_std
()
return
self
.
alive
()
def
wait
(
self
,
timeout
=
None
):
self
.
_proc
.
wait
(
timeout
)
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
67c6ddff
...
@@ -949,6 +949,7 @@ if (WITH_DISTRIBUTE AND NOT APPLE)
...
@@ -949,6 +949,7 @@ if (WITH_DISTRIBUTE AND NOT APPLE)
endif
()
endif
()
# setting timeout value as 15S
# setting timeout value as 15S
set_tests_properties
(
test_run PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_sync_batch_norm_op PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_sync_batch_norm_op PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_cross_op PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_cross_op PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_imperative_lod_tensor_to_selected_rows PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_imperative_lod_tensor_to_selected_rows PROPERTIES TIMEOUT 200
)
...
...
python/paddle/fluid/tests/unittests/test_run.py
0 → 100644
浏览文件 @
67c6ddff
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
subprocess
import
sys
,
os
import
json
import
shutil
import
random
from
os
import
listdir
from
os.path
import
isfile
,
join
pyname
=
'train.py'
colpyfile
=
'''# train.py for unitest
import os
env = os.environ.copy()
assert "PADDLE_MASTER" in env
assert "PADDLE_GLOBAL_SIZE" in env
assert "PADDLE_LOCAL_SIZE" in env
assert "PADDLE_GLOBAL_RANK" in env
assert "PADDLE_LOCAL_RANK" in env
'''
pspyfile
=
'''# train.py for unitest
import os
env = os.environ.copy()
assert "PADDLE_PSERVERS_IP_PORT_LIST" in env
assert "PADDLE_TRAINER_ENDPOINTS" in env
#assert "PADDLE_PSERVER_ENDPOINTS" in env
#assert "PADDLE_TRAINER_ENDPOINTS" in env
#assert "PADDLE_ROLE" in env
#assert "PADDLE_RANK" in env
'''
def
write_file
(
name
,
ct
):
with
open
(
name
,
"w"
)
as
f
:
f
.
write
(
ct
)
def
get_files
(
pth
,
prefix
):
return
[
f
for
f
in
listdir
(
pth
)
if
isfile
(
join
(
pth
,
f
))
and
f
.
startswith
(
prefix
)
]
class
Collective_Test
(
unittest
.
TestCase
):
def
setUp
(
self
):
write_file
(
pyname
,
colpyfile
)
def
pdrun
(
self
,
args
,
env
=
None
):
cmd
=
[
sys
.
executable
.
split
(
'/'
)[
-
1
],
"-m"
,
"paddle.distributed.run"
]
if
args
:
cmd
.
extend
(
args
.
split
(
" "
))
cmd
.
extend
([
pyname
])
proc
=
subprocess
.
Popen
(
cmd
,
env
)
return
proc
'''
def test_collective_1(self):
args = "--id test1"
p = self.pdrun(args)
p.wait()
self.assertTrue(p.poll() == 0)
'''
def
test_collective_2
(
self
):
if
os
.
path
.
exists
(
'./log'
):
shutil
.
rmtree
(
'./log'
)
args
=
"--id test2 --devices 0,1,2"
p
=
self
.
pdrun
(
args
)
p
.
wait
()
self
.
assertTrue
(
p
.
poll
()
==
0
)
c
=
get_files
(
'log'
,
'test2'
)
self
.
assertTrue
(
len
(
c
)
==
4
)
def
test_collective_3
(
self
):
if
os
.
path
.
exists
(
'./log'
):
shutil
.
rmtree
(
'./log'
)
port
=
random
.
randrange
(
6000
,
8000
)
args
=
"--id test3 --devices 0,1 --master 127.0.0.1:{} --np 2"
.
format
(
port
)
p1
=
self
.
pdrun
(
args
)
p2
=
self
.
pdrun
(
args
)
p1
.
wait
()
p2
.
wait
()
self
.
assertTrue
(
p1
.
poll
()
==
0
)
self
.
assertTrue
(
p2
.
poll
()
==
0
)
c
=
get_files
(
'log'
,
'test3'
)
self
.
assertTrue
(
len
(
c
)
==
6
)
class
PS_Test
(
unittest
.
TestCase
):
def
setUp
(
self
):
write_file
(
pyname
,
pspyfile
)
def
pdrun
(
self
,
args
,
env
=
None
):
cmd
=
[
sys
.
executable
.
split
(
'/'
)[
-
1
],
"-m"
,
"paddle.distributed.run"
]
if
args
:
cmd
.
extend
(
args
.
split
(
" "
))
cmd
.
extend
([
pyname
])
proc
=
subprocess
.
Popen
(
cmd
,
env
)
return
proc
'''
def test_ps_1(self):
args = "--mode ps"
p = self.pdrun(args)
p.wait()
self.assertTrue(p.poll() == 0)
def test_ps_2(self):
if os.path.exists('./log'):
shutil.rmtree('./log')
args = "--id ps2 --server_num=2 --trainer_num=2"
p = self.pdrun(args)
p.wait()
self.assertTrue(p.poll() == 0)
c = get_files('log', 'ps2')
self.assertTrue(len(c) == 5)
'''
def
test_ps_3
(
self
):
if
os
.
path
.
exists
(
'./log'
):
shutil
.
rmtree
(
'./log'
)
port
=
random
.
randrange
(
6000
,
8000
)
args
=
"--id ps3 --master 127.0.0.1:{} --np 2 --server_num=1 --trainer_num=1"
.
format
(
port
)
p1
=
self
.
pdrun
(
args
)
p2
=
self
.
pdrun
(
args
)
p1
.
wait
()
p2
.
wait
()
self
.
assertTrue
(
p1
.
poll
()
==
0
)
self
.
assertTrue
(
p2
.
poll
()
==
0
)
c
=
get_files
(
'log'
,
'ps3'
)
self
.
assertTrue
(
len
(
c
)
==
6
)
def
test_ps_4
(
self
):
if
os
.
path
.
exists
(
'./log'
):
shutil
.
rmtree
(
'./log'
)
args
=
"--id ps4 --servers 127.0.0.1:8900,127.0.0.1:8901 --trainers 127.0.0.1:8902,127.0.0.1:8903"
p1
=
self
.
pdrun
(
args
)
p1
.
wait
()
self
.
assertTrue
(
p1
.
poll
()
==
0
)
c
=
get_files
(
'log'
,
'ps4'
)
self
.
assertTrue
(
len
(
c
)
==
5
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录