Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9a3e1bce
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9a3e1bce
编写于
7月 28, 2022
作者:
K
kuizhiqing
提交者:
GitHub
7月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[LAUNCH] add distributed launch check tools (#44495)
* add launch test * launch test for cpu * bs 1
上级
067107ad
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
122 addition
and
5 deletion
+122
-5
python/paddle/distributed/launch/context/__init__.py
python/paddle/distributed/launch/context/__init__.py
+2
-1
python/paddle/distributed/launch/controllers/collective.py
python/paddle/distributed/launch/controllers/collective.py
+7
-3
python/paddle/distributed/launch/plugins/__init__.py
python/paddle/distributed/launch/plugins/__init__.py
+13
-1
python/paddle/distributed/launch/plugins/test.py
python/paddle/distributed/launch/plugins/test.py
+100
-0
未找到文件。
python/paddle/distributed/launch/context/__init__.py
浏览文件 @
9a3e1bce
...
...
@@ -101,6 +101,7 @@ class Context(object):
return
False
def
set_env_in_args
(
self
):
# this logic may not propre to replace args with env, but ...
for
k
,
v
in
env_args_mapping
.
items
():
if
k
in
self
.
envs
:
setattr
(
self
.
args
,
v
,
self
.
envs
[
k
]
)
setattr
(
self
.
args
,
v
,
type
(
getattr
(
self
.
args
,
v
))(
self
.
envs
[
k
])
)
python/paddle/distributed/launch/controllers/collective.py
浏览文件 @
9a3e1bce
...
...
@@ -97,10 +97,14 @@ class CollectiveController(Controller):
"PADDLE_TRAINERS_NUM"
:
"{}"
.
format
(
global_size
),
"PADDLE_RANK_IN_NODE"
:
str
(
i
),
}
if
self
.
pod
.
replicas
==
1
:
e
.
update
({
selected_dev_key
:
","
.
join
(
selected_dev_list
)})
if
len
(
selected_dev_list
)
>
0
:
if
self
.
pod
.
replicas
==
1
:
e
.
update
({
selected_dev_key
:
","
.
join
(
selected_dev_list
)})
else
:
e
.
update
({
selected_dev_key
:
selected_dev_list
[
i
]})
else
:
e
.
update
({
selected_dev_key
:
selected_dev_list
[
i
]})
e
.
update
({
'PADDLE_DISTRI_BACKEND'
:
'gloo'
})
self
.
add_container
(
envs
=
e
,
log_tag
=
i
)
return
True
...
...
python/paddle/distributed/launch/plugins/__init__.py
浏览文件 @
9a3e1bce
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
six
import
os
__all__
=
[]
...
...
@@ -60,4 +61,15 @@ def rewrite_host_ip(ctx):
ctx
.
node
.
ip
=
ctx
.
args
.
host
enabled_plugins
=
[
collective_compatible
,
rewrite_host_ip
,
process_args
]
def
test_mode
(
ctx
):
if
ctx
.
args
.
training_script
==
'test'
:
ctx
.
logger
.
info
(
'Paddle Distributed Test begin...'
)
if
int
(
ctx
.
args
.
nnodes
)
<
2
:
ctx
.
args
.
nnodes
=
2
ctx
.
args
.
training_script
=
'{}/test.py'
.
format
(
os
.
path
.
dirname
(
__file__
))
enabled_plugins
=
[
test_mode
,
collective_compatible
,
rewrite_host_ip
,
process_args
]
python/paddle/distributed/launch/plugins/test.py
0 → 100644
浏览文件 @
9a3e1bce
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
paddle
from
paddle.distributed
import
fleet
from
paddle.vision.models
import
ResNet
from
paddle.vision.models.resnet
import
BottleneckBlock
from
paddle.io
import
Dataset
,
BatchSampler
,
DataLoader
base_lr
=
0.1
momentum_rate
=
0.9
l2_decay
=
1e-4
epoch
=
3
batch_num
=
1
batch_size
=
1
class_dim
=
102
# define a random dataset
class
RandomDataset
(
Dataset
):
def
__init__
(
self
,
num_samples
):
self
.
num_samples
=
num_samples
def
__getitem__
(
self
,
idx
):
image
=
np
.
random
.
random
([
3
,
224
,
224
]).
astype
(
'float32'
)
label
=
np
.
random
.
randint
(
0
,
class_dim
-
1
,
(
1
,
)).
astype
(
'int64'
)
return
image
,
label
def
__len__
(
self
):
return
self
.
num_samples
def
optimizer_setting
(
parameter_list
=
None
):
optimizer
=
paddle
.
optimizer
.
Momentum
(
learning_rate
=
base_lr
,
momentum
=
momentum_rate
,
weight_decay
=
paddle
.
regularizer
.
L2Decay
(
l2_decay
),
parameters
=
parameter_list
)
return
optimizer
def
train_resnet
():
fleet
.
init
(
is_collective
=
True
)
resnet
=
ResNet
(
BottleneckBlock
,
18
,
num_classes
=
class_dim
)
optimizer
=
optimizer_setting
(
parameter_list
=
resnet
.
parameters
())
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
)
resnet
=
fleet
.
distributed_model
(
resnet
)
dataset
=
RandomDataset
(
batch_num
*
batch_size
)
train_loader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
shuffle
=
True
,
drop_last
=
True
,
num_workers
=
2
)
print
(
"Distributed training start..."
)
for
eop
in
range
(
epoch
):
resnet
.
train
()
for
batch_id
,
data
in
enumerate
(
train_loader
()):
img
,
label
=
data
label
.
stop_gradient
=
True
out
=
resnet
(
img
)
loss
=
paddle
.
nn
.
functional
.
cross_entropy
(
input
=
out
,
label
=
label
)
avg_loss
=
paddle
.
mean
(
x
=
loss
)
acc_top1
=
paddle
.
metric
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
1
)
acc_top5
=
paddle
.
metric
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
5
)
avg_loss
.
backward
()
optimizer
.
step
()
resnet
.
clear_gradients
()
print
(
"[Epoch %d, batch %d] loss: %.5f, acc1: %.5f, acc5: %.5f"
%
(
eop
,
batch_id
,
avg_loss
,
acc_top1
,
acc_top5
))
print
(
"Distributed training completed"
)
if
__name__
==
'__main__'
:
import
os
nnodes
=
os
.
getenv
(
'PADDLE_NNODES'
)
cn
=
os
.
getenv
(
'PADDLE_LOCAL_SIZE'
)
print
(
f
"Prepare distributed training with
{
nnodes
}
nodes
{
cn
}
cards"
)
train_resnet
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录