Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
51c414b6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
51c414b6
编写于
7月 04, 2023
作者:
T
Tian
提交者:
GitHub
7月 04, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add perf test api to fleet (#54856)
上级
21fa0346
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
172 addition
and
1 deletion
+172
-1
python/paddle/distributed/fleet/__init__.py
python/paddle/distributed/fleet/__init__.py
+1
-0
python/paddle/distributed/fleet/fleet.py
python/paddle/distributed/fleet/fleet.py
+75
-1
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+3
-0
python/paddle/fluid/tests/unittests/hybrid_parallel_perf_test.py
...paddle/fluid/tests/unittests/hybrid_parallel_perf_test.py
+66
-0
python/paddle/fluid/tests/unittests/test_fleet_perf_test.py
python/paddle/fluid/tests/unittests/test_fleet_perf_test.py
+26
-0
tools/parallel_UT_rule.py
tools/parallel_UT_rule.py
+1
-0
未找到文件。
python/paddle/distributed/fleet/__init__.py
浏览文件 @
51c414b6
...
...
@@ -103,4 +103,5 @@ set_log_level = log_util.set_log_level
get_log_level_code
=
log_util
.
get_log_level_code
get_log_level_name
=
log_util
.
get_log_level_name
save_cache_table
=
fleet
.
save_cache_table
perf_test
=
fleet
.
perf_test
from
..
import
auto_parallel
as
auto
python/paddle/distributed/fleet/fleet.py
浏览文件 @
51c414b6
...
...
@@ -14,6 +14,7 @@
import
copy
import
os
import
time
import
paddle
from
paddle.fluid
import
compiler
...
...
@@ -192,7 +193,6 @@ class Fleet:
log_level (Integer, String, optional): A ``Integer`` or ``String`` Variable determining how hight
the logging level is. Default is "INFO".
Returns:
None
...
...
@@ -382,6 +382,80 @@ class Fleet:
)
return
self
def
perf_test
(
self
,
round
=
50
):
# test allreduce perf
def
allreduce_test
(
iteration
,
x
,
group
):
paddle
.
distributed
.
barrier
()
paddle
.
device
.
cuda
.
synchronize
()
start_t
=
time
.
time
()
for
_
in
range
(
iteration
):
paddle
.
distributed
.
all_reduce
(
x
,
group
=
group
)
paddle
.
device
.
cuda
.
synchronize
()
end_t
=
time
.
time
()
return
(
end_t
-
start_t
)
/
iteration
# test reduce perf
def
reduce_test
(
iteration
,
x
,
group
):
paddle
.
distributed
.
barrier
()
paddle
.
device
.
cuda
.
synchronize
()
start_t
=
time
.
time
()
for
_
in
range
(
iteration
):
# TODO: shuffle dst
paddle
.
distributed
.
reduce
(
x
,
dst
=
min
(
group
.
ranks
),
group
=
group
)
paddle
.
device
.
cuda
.
synchronize
()
end_t
=
time
.
time
()
return
(
end_t
-
start_t
)
/
iteration
# test broadcast perf
def
broadcast_test
(
iteration
,
x
,
group
):
paddle
.
distributed
.
barrier
()
paddle
.
device
.
cuda
.
synchronize
()
start_t
=
time
.
time
()
for
_
in
range
(
iteration
):
# TODO: shuffle src
paddle
.
distributed
.
broadcast
(
x
,
src
=
min
(
group
.
ranks
),
group
=
group
)
paddle
.
device
.
cuda
.
synchronize
()
end_t
=
time
.
time
()
return
(
end_t
-
start_t
)
/
iteration
hcg
=
self
.
get_hybrid_communicate_group
()
dp_group
=
hcg
.
get_data_parallel_group
()
sharding_group
=
hcg
.
get_sharding_parallel_group
()
test_group
=
None
if
dp_group
.
nranks
>
1
:
test_group
=
dp_group
elif
sharding_group
.
nranks
>
1
:
test_group
=
sharding_group
else
:
logger
.
warning
(
f
"hcg created with dp_degree:
{
dp_group
.
nranks
}
and sharding_degree:
{
sharding_group
.
nranks
}
, skipping perf test..."
)
return
# test 1M ~ 1G
nbytes
=
1
<<
20
# 1048576(1MB)
final_nbytes
=
1
<<
30
# 1073741824(1GB)
dtype
=
paddle
.
float32
while
nbytes
<=
final_nbytes
:
x
=
paddle
.
zeros
([
nbytes
//
4
],
dtype
=
dtype
)
# warmup
allreduce_test
(
iteration
=
10
,
x
=
x
,
group
=
test_group
)
# test-allreduce
ret
=
allreduce_test
(
iteration
=
round
,
x
=
x
,
group
=
test_group
)
logger
.
info
(
f
"[AllReduceTest] nbytes
{
nbytes
}
B test result:
{
ret
}
s/iter"
)
ret
=
reduce_test
(
iteration
=
round
,
x
=
x
,
group
=
test_group
)
logger
.
info
(
f
"[ReduceTest] nbytes
{
nbytes
}
B test result:
{
ret
}
s/iter"
)
ret
=
broadcast_test
(
iteration
=
round
,
x
=
x
,
group
=
test_group
)
logger
.
info
(
f
"[BroadcastTest] nbytes
{
nbytes
}
B test result:
{
ret
}
s/iter"
)
nbytes
=
nbytes
<<
1
def
_init_hybrid_parallel_env
(
self
):
"""initialize the hybrid environment"""
self
.
hybrid_configs
=
self
.
_user_defined_strategy
.
hybrid_configs
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
51c414b6
...
...
@@ -87,6 +87,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
list
(
REMOVE_ITEM TEST_OPS test_c_comm_init_all_op
)
list
(
REMOVE_ITEM TEST_OPS test_c_embedding_op
)
list
(
REMOVE_ITEM TEST_OPS test_pipeline_parallel
)
list
(
REMOVE_ITEM TEST_OPS test_fleet_perf_test
)
list
(
REMOVE_ITEM TEST_OPS test_memcpy_op
)
list
(
REMOVE_ITEM TEST_OPS test_raw_program_optimizer
)
list
(
REMOVE_ITEM TEST_OPS test_fleet_gradient_scale
)
...
...
@@ -1062,11 +1063,13 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_pipeline_parallel PROPERTIES LABELS
"RUN_TYPE=DIST"
)
set_tests_properties
(
test_fleet_perf_test PROPERTIES LABELS
"RUN_TYPE=DIST"
)
set_tests_properties
(
test_reducescatter PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_allgather PROPERTIES TIMEOUT 120
)
endif
()
set_tests_properties
(
test_paddle_multiprocessing PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_pipeline_parallel PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_fleet_perf_test PROPERTIES TIMEOUT 120
)
endif
()
if
(
WITH_GPU OR WITH_ROCM
)
set_tests_properties
(
test_rank_attention_op PROPERTIES TIMEOUT 120
)
...
...
python/paddle/fluid/tests/unittests/hybrid_parallel_perf_test.py
0 → 100644
浏览文件 @
51c414b6
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
random
import
unittest
import
numpy
as
np
import
paddle
from
paddle.distributed
import
fleet
def
set_random_seed
(
seed
,
dp_id
,
rank_id
):
"""Set random seed for reproducability."""
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
+
dp_id
)
paddle
.
seed
(
seed
+
dp_id
)
batch_size
=
4
micro_batch_size
=
2
class
TestDistDPTraning
(
unittest
.
TestCase
):
def
setUp
(
self
):
strategy
=
fleet
.
DistributedStrategy
()
self
.
model_parallel_size
=
1
self
.
data_parallel_size
=
2
self
.
pipeline_parallel_size
=
1
strategy
.
hybrid_configs
=
{
"dp_degree"
:
self
.
data_parallel_size
,
"mp_degree"
:
self
.
model_parallel_size
,
"pp_degree"
:
self
.
pipeline_parallel_size
,
}
strategy
.
pipeline_configs
=
{
"accumulate_steps"
:
batch_size
//
micro_batch_size
,
"micro_batch_size"
:
micro_batch_size
,
}
fleet
.
init
(
is_collective
=
True
,
strategy
=
strategy
)
def
build_optimizer
(
self
,
model
):
scheduler
=
paddle
.
optimizer
.
lr
.
PiecewiseDecay
(
boundaries
=
[
2
],
values
=
[
0.001
,
0.002
],
verbose
=
True
)
optimizer
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
scheduler
,
parameters
=
model
.
parameters
()
)
return
scheduler
,
optimizer
def
test_communication_perf
(
self
):
fleet
.
perf_test
(
round
=
1
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_fleet_perf_test.py
0 → 100644
浏览文件 @
51c414b6
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
test_parallel_dygraph_dataparallel
import
TestMultipleGpus
class
TestFleetPerfTest
(
TestMultipleGpus
):
def
test_fleet_perf_test
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_perf_test.py'
)
if
__name__
==
"__main__"
:
unittest
.
main
()
tools/parallel_UT_rule.py
浏览文件 @
51c414b6
...
...
@@ -538,6 +538,7 @@ HIGH_PARALLEL_JOB_NEW = [
'test_dist_fleet_ps3'
,
'test_dist_mnist_pg'
,
'test_pipeline_parallel'
,
'test_fleet_perf_test'
,
'test_dist_fleet_ps5'
,
'test_dist_fleet_sparse_embedding_ctr'
,
'test_collective_broadcast_api'
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录