Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
178d7e5e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
178d7e5e
编写于
10月 18, 2022
作者:
L
LiYuRio
提交者:
GitHub
10月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add strategy group (#47021)
上级
d68c38ef
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
606 addition
and
0 deletion
+606
-0
python/paddle/distributed/fleet/base/orthogonal_strategy.py
python/paddle/distributed/fleet/base/orthogonal_strategy.py
+181
-0
python/paddle/distributed/fleet/base/strategy_group.py
python/paddle/distributed/fleet/base/strategy_group.py
+227
-0
python/paddle/fluid/tests/unittests/collective/CMakeLists.txt
...on/paddle/fluid/tests/unittests/collective/CMakeLists.txt
+22
-0
python/paddle/fluid/tests/unittests/collective/orthogonal_strategy.py
...e/fluid/tests/unittests/collective/orthogonal_strategy.py
+45
-0
python/paddle/fluid/tests/unittests/collective/strategy_group.py
...paddle/fluid/tests/unittests/collective/strategy_group.py
+95
-0
python/paddle/fluid/tests/unittests/collective/test_orthogonal_strategy.sh
...id/tests/unittests/collective/test_orthogonal_strategy.sh
+17
-0
python/paddle/fluid/tests/unittests/collective/test_strategy_group.sh
...e/fluid/tests/unittests/collective/test_strategy_group.sh
+17
-0
python/paddle/fluid/tests/unittests/collective/testslist.csv
python/paddle/fluid/tests/unittests/collective/testslist.csv
+2
-0
未找到文件。
python/paddle/distributed/fleet/base/orthogonal_strategy.py
0 → 100644
浏览文件 @
178d7e5e
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
itertools
import
collections
import
functools
import
paddle.distributed
as
dist
from
paddle.distributed.fleet.base.strategy_group
import
StrategyGroupBase
class
OrthogonalStrategy
():
"""
A hybrid of multiple distributed strategies. Strategies need to be orthogonal, means the ranks are organized like
a square if there are two strategies, a cube if there aree three strategies, etc.
Args:
list_of_strategy(list): Stategy in the list should be represented as tuple, format as (strategy_name, degree, strategy_class).
fused_strategy_dict(dict, optional): Exist strategies can be fused to new strategy. Use the name of new strategy as key, a list of
strategy names you want to fuse as value.
Returns:
The instance of strategy.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
from paddle.distributed.fleet.base.strategy_group import DPGroup, MPGroup, PPGroup
from paddle.distributed.fleet.base.orthogonal_strategy import OrthogonalStrategy
dist.init_parallel_env()
strategy = OrthogonalStrategy([("dp", 2, DPGroup), ("mp", 2, MPGroup), ("pp", 2, PPGroup)], fused_strategy_dict={"check": ["mp", "pp"]})
"""
def
__init__
(
self
,
list_of_strategy
,
fused_strategy_dict
=
{}):
self
.
_list_of_strategy
=
list_of_strategy
self
.
_fused_strategy_dict
=
fused_strategy_dict
self
.
_rank
=
dist
.
get_rank
()
self
.
_rank_list_dict
=
{}
self
.
_name_to_group_dict
=
{}
self
.
_name_to_degree_dict
=
{}
self
.
_list_of_strategy_name
=
[
strategy
[
0
]
for
strategy
in
list_of_strategy
]
self
.
_list_of_degree
=
[
strategy
[
1
]
for
strategy
in
list_of_strategy
]
self
.
_coordinate
=
collections
.
namedtuple
(
'Coordinate'
,
self
.
_list_of_strategy_name
)
self
.
_check_valid_strategy
()
ranges
=
[
range
(
degree
)
for
degree
in
self
.
_list_of_degree
]
list_of_coord
=
[
self
.
_coordinate
(
*
coord
)
for
coord
in
itertools
.
product
(
*
ranges
)
]
self
.
_coord_to_rank_dict
=
dict
(
zip
(
list_of_coord
,
range
(
len
(
list_of_coord
))))
for
idx
,
strategy
in
enumerate
(
list_of_strategy
):
strategy_name
=
strategy
[
0
]
self
.
_name_to_degree_dict
[
strategy_name
]
=
strategy
[
1
]
self
.
_rank_list_dict
[
strategy_name
]
=
self
.
_calc_rank_list
(
idx
)
self
.
_name_to_group_dict
[
strategy_name
]
=
strategy
[
2
](
self
.
_rank_list_dict
[
strategy_name
])
self
.
_name_to_fused_group_dict
=
{}
self
.
_create_fused_group
()
def
strategy_group
(
self
,
name
):
"""
Get strategy group with specific name.
Args:
name: The name of strategy group
Returns:
An instance of specific strategy group.
"""
assert
name
in
self
.
_list_of_strategy_name
,
"Strategy group {} is not created."
.
format
(
name
)
return
self
.
_name_to_group_dict
[
name
]
def
fused_strategy_group
(
self
,
name
):
"""
Get fused strategy group with specific name.
Args:
name: The name of fused strategy group
Returns:
(StrategyGroupBase): An instance of strategy group.
"""
assert
name
in
self
.
_name_to_fused_group_dict
,
"Fused strategy group {} is not created."
.
format
(
name
)
return
self
.
_name_to_fused_group_dict
[
name
]
def
rank_in_strategy
(
self
,
name
):
"""
Get local rank in strategy group with specific name.
Args:
name: The name of strategy group
Returns:
(Integer): Local rank in specific strategy.
"""
assert
name
in
self
.
_list_of_strategy_name
,
"Strategy group {} is not created."
.
format
(
name
)
return
self
.
_name_to_group_dict
[
name
].
group
.
rank
def
_check_valid_strategy
(
self
):
assert
len
(
self
.
_list_of_strategy_name
)
==
len
(
set
(
self
.
_list_of_strategy_name
)
),
"Defined duplicated strategies: {}"
.
format
(
list_of_strategy
)
num_of_ranks
=
functools
.
reduce
(
lambda
x
,
y
:
x
*
y
,
self
.
_list_of_degree
)
assert
num_of_ranks
==
dist
.
get_world_size
(
),
"There are total {} ranks, but need {} ranks in this strategy."
.
format
(
dist
.
get_world_size
(),
num_of_ranks
)
for
fused_strategy
in
self
.
_fused_strategy_dict
.
values
():
for
strategy
in
fused_strategy
:
assert
strategy
in
self
.
_list_of_strategy_name
,
"Can not fuse strategy {} without defined previous."
.
format
(
strategy
)
def
_create_fused_group
(
self
):
for
name
in
self
.
_fused_strategy_dict
:
fused_strategy
=
self
.
_fused_strategy_dict
[
name
]
non_fused_strategy
=
list
(
set
(
self
.
_list_of_strategy_name
).
difference
(
fused_strategy
))
non_fused_ranges
=
[]
for
strategy
in
non_fused_strategy
:
non_fused_ranges
.
append
(
range
(
self
.
_name_to_degree_dict
[
strategy
]))
fused_ranges
=
[]
for
strategy
in
fused_strategy
:
fused_ranges
.
append
(
range
(
self
.
_name_to_degree_dict
[
strategy
]))
rank_list
=
[]
for
non_fused_ranks
in
itertools
.
product
(
*
non_fused_ranges
):
coord_dict
=
{}
ranks
=
[]
for
i
,
non_fused_rank
in
enumerate
(
non_fused_ranks
):
coord_dict
[
non_fused_strategy
[
i
]]
=
non_fused_rank
for
fused_ranks
in
itertools
.
product
(
*
fused_ranges
):
for
i
,
fused_rank
in
enumerate
(
fused_ranks
):
coord_dict
[
fused_strategy
[
i
]]
=
fused_rank
ranks
.
append
(
self
.
_coord_to_rank_dict
[
self
.
_coordinate
(
**
coord_dict
)])
rank_list
.
append
(
ranks
)
self
.
_name_to_fused_group_dict
[
name
]
=
StrategyGroupBase
(
rank_list
)
def
_calc_rank_list
(
self
,
strategy_axis
):
ranges
=
[]
for
idx
,
degree
in
enumerate
(
self
.
_list_of_degree
):
if
idx
==
strategy_axis
:
continue
ranges
.
append
(
range
(
degree
))
rank_list
=
[]
for
coord
in
itertools
.
product
(
*
ranges
):
ranks
=
[]
for
val
in
range
(
self
.
_list_of_degree
[
strategy_axis
]):
coord_list
=
list
(
coord
)
coord_list
.
insert
(
strategy_axis
,
val
)
ranks
.
append
(
self
.
_coord_to_rank_dict
[
self
.
_coordinate
(
*
coord_list
)])
rank_list
.
append
(
ranks
)
return
rank_list
python/paddle/distributed/fleet/base/strategy_group.py
0 → 100644
浏览文件 @
178d7e5e
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle.distributed
as
dist
class
StrategyGroupBase
():
"""
The base class of communication group with distributed strategy.
Args:
list_of_ranks: A 2D-array, such as `[[0, 1, 2, 3], [4, 5, 6, 7]]`. Ranks in sublist represents
they are in the same communication group.
Returns:
The instance of strategy group.
Examples:
.. code-block:: python
import paddle.distributed as dist
from paddle.distributed.fleet.base.strategy_group import StrategyGroupBase
dist.init_parallel_env()
strategy_group = dist.fleet.base.strategy_group.StrategyGroupBase([[0, 1], [2, 3]])
print(strategy_group.world_size) # 2
"""
def
__init__
(
self
,
list_of_ranks
):
assert
dist
.
is_initialized
(
),
"The global communication group need to be initialized."
assert
len
(
list_of_ranks
),
"The list_of_ranks can not be empty."
self
.
_rank
=
dist
.
get_rank
()
self
.
_list_of_ranks
=
list_of_ranks
self
.
_group
=
self
.
_create_group
()
@
property
def
world_size
(
self
):
"""
The world size of communication group.
Returns:
Integer if the world_size of each group are equal, or a list of world_size if they are not equal.
"""
world_size_list
=
[]
for
ranks
in
self
.
_list_of_ranks
:
world_size_list
.
append
(
len
(
ranks
))
is_value
=
all
(
world_size
==
world_size_list
[
0
]
for
world_size
in
world_size_list
)
return
world_size_list
[
0
]
if
is_value
else
world_size_list
@
property
def
group
(
self
):
"""
The communication group which current rank belongs to.
Returns:
Group if current rank only belong to single communication group, or a list of Group if it belongs many.
"""
return
self
.
_group
def
_create_group
(
self
):
list_of_group
=
[]
for
ranks
in
self
.
_list_of_ranks
:
group
=
dist
.
new_group
(
ranks
=
ranks
)
if
self
.
_rank
in
ranks
:
list_of_group
.
append
(
group
)
assert
len
(
list_of_group
)
>
0
,
"Rank {} does not belong to the list_of_ranks {}."
.
format
(
self
.
_rank
,
self
.
_list_of_ranks
)
return
list_of_group
if
len
(
list_of_group
)
>
1
else
list_of_group
[
0
]
class
DPGroup
(
StrategyGroupBase
):
"""
The communication group strategy for data parallel.
Args:
list_of_ranks: A 2D-array, such as `[[0, 1, 2, 3], [4, 5, 6, 7]]`. Ranks in sublist represents
they are in the same communication group.
Returns:
The instance of data parallel strategy group.
"""
def
__init__
(
self
,
list_of_ranks
):
super
(
DPGroup
,
self
).
__init__
(
list_of_ranks
)
assert
not
isinstance
(
self
.
group
,
list
),
"Rank {} belongs to multi dp groups"
.
format
(
self
.
_rank
)
class
MPGroup
(
StrategyGroupBase
):
"""
The communication group strategy for model parallel.
Args:
list_of_ranks: A 2D-array, such as `[[0, 1, 2, 3], [4, 5, 6, 7]]`. Ranks in sublist represents
they are in the same communication group.
Returns:
The instance of model parallel strategy group.
"""
def
__init__
(
self
,
list_of_ranks
):
super
(
MPGroup
,
self
).
__init__
(
list_of_ranks
)
assert
not
isinstance
(
self
.
group
,
list
),
"Rank {} belongs to multi mp groups"
.
format
(
self
.
_rank
)
class
ShardingGroup
(
StrategyGroupBase
):
"""
The communication group strategy for sharding parallel.
Args:
list_of_ranks: A 2D-array, such as `[[0, 1, 2, 3], [4, 5, 6, 7]]`. Ranks in sublist represents
they are in the same communication group.
Returns:
The instance of sharding parallel strategy group.
"""
def
__init__
(
self
,
list_of_ranks
):
super
(
ShardingGroup
,
self
).
__init__
(
list_of_ranks
)
assert
not
isinstance
(
self
.
group
,
list
),
"Rank {} belongs to multi sharding groups"
.
format
(
self
.
_rank
)
class
PPGroup
(
StrategyGroupBase
):
"""
The communication group strategy for pipeline parallel.
Args:
list_of_ranks: A 2D-array, such as `[[0, 1, 2, 3], [4, 5, 6, 7]]`. Ranks in sublist represents
they are in the same communication group.
Returns:
The instance of pipeline parallel strategy group.
"""
def
__init__
(
self
,
list_of_ranks
):
super
(
PPGroup
,
self
).
__init__
(
list_of_ranks
)
assert
not
isinstance
(
self
.
group
,
list
),
"Rank {} belongs to multi pp groups"
.
format
(
self
.
_rank
)
self
.
_send_next_group
=
None
self
.
_send_prev_group
=
None
self
.
_recv_next_group
=
None
self
.
_recv_prev_group
=
None
self
.
_rank_of_next_stage
=
None
self
.
_rank_of_prev_stage
=
None
if
self
.
world_size
>
1
:
self
.
_create_p2p_group
()
@
property
def
rank_of_prev_stage
(
self
):
"""
Rank of the previous pp stage.
Returns:
The global rank of previous pp stage. `None` if without previous.
"""
return
self
.
_rank_of_prev_stage
@
property
def
rank_of_next_stage
(
self
):
"""
Rank of the next pp stage.
Returns:
The global rank of next pp stage. `None` if without next.
"""
return
self
.
_rank_of_next_stage
@
property
def
p2p_groups
(
self
):
"""
Communication subgroup in order to switch data with previous and next stage.
Returns:
Four subgroups including send/recv to/from prev/next.
"""
return
self
.
_send_next_group
,
self
.
_send_prev_group
,
self
.
_recv_next_group
,
self
.
_recv_prev_group
def
_create_p2p_group
(
self
):
degree
=
self
.
world_size
for
ranks
in
self
.
_list_of_ranks
:
for
idx
,
rank
in
enumerate
(
ranks
):
next_rank
=
ranks
[(
idx
+
1
)
%
degree
]
prev_rank
=
ranks
[(
idx
-
1
)
%
degree
]
if
self
.
_rank
==
rank
:
self
.
_rank_of_next_stage
=
next_rank
self
.
_rank_of_prev_stage
=
prev_rank
next_group
=
dist
.
new_group
(
ranks
=
[
rank
,
next_rank
])
if
self
.
_rank
==
rank
:
self
.
_send_next_group
=
next_group
elif
self
.
_rank
==
next_rank
:
self
.
_recv_prev_group
=
next_group
prev_group
=
dist
.
new_group
(
ranks
=
[
prev_rank
,
rank
])
if
self
.
_rank
==
rank
:
self
.
_send_prev_group
=
prev_group
elif
self
.
_rank
==
prev_rank
:
self
.
_recv_next_group
=
prev_group
assert
self
.
_send_next_group
and
self
.
_send_prev_group
and
self
.
_recv_next_group
and
self
.
_recv_prev_group
,
\
"Error occurs while creating p2p group for rank {}."
.
format
(
self
.
_rank
)
python/paddle/fluid/tests/unittests/collective/CMakeLists.txt
浏览文件 @
178d7e5e
...
...
@@ -391,5 +391,27 @@ if(WITH_MPI)
"PADDLE_DIST_UT_PORT=21672;http_proxy=;https_proxy="
)
endif
()
endif
()
if
((
WITH_ROCM OR WITH_GPU
)
AND
(
LINUX
))
bash_test_modules
(
test_strategy_group
START_BASH
test_strategy_group.sh
LABELS
"RUN_TYPE=DIST"
ENVS
"PADDLE_DIST_UT_PORT=21814;http_proxy=;https_proxy="
)
set_tests_properties
(
test_strategy_group PROPERTIES TIMEOUT
"120"
)
endif
()
if
((
WITH_ROCM OR WITH_GPU
)
AND
(
LINUX
))
bash_test_modules
(
test_orthogonal_strategy
START_BASH
test_orthogonal_strategy.sh
LABELS
"RUN_TYPE=DIST"
ENVS
"PADDLE_DIST_UT_PORT=21958;http_proxy=;https_proxy="
)
set_tests_properties
(
test_orthogonal_strategy PROPERTIES TIMEOUT
"120"
)
endif
()
add_subdirectory
(
fleet
)
add_subdirectory
(
multinode
)
python/paddle/fluid/tests/unittests/collective/orthogonal_strategy.py
0 → 100644
浏览文件 @
178d7e5e
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle.distributed
as
dist
from
paddle.distributed.fleet.base.strategy_group
import
DPGroup
,
ShardingGroup
,
MPGroup
,
PPGroup
from
paddle.distributed.fleet.base.orthogonal_strategy
import
OrthogonalStrategy
class
TestOrthogonalStrategyAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
_num_of_ranks
=
2
dist
.
init_parallel_env
()
self
.
_global_rank
=
dist
.
get_rank
()
self
.
_strategy
=
OrthogonalStrategy
(
[(
"dp"
,
2
,
DPGroup
),
(
"mp"
,
1
,
MPGroup
),
(
"sharding"
,
1
,
ShardingGroup
),
(
"pp"
,
1
,
PPGroup
)],
fused_strategy_dict
=
{
"checkness"
:
[
"mp"
,
"sharding"
,
"pp"
]})
def
test_orthogonal_strategy
(
self
):
dp_group
=
self
.
_strategy
.
strategy_group
(
"dp"
)
self
.
assertEqual
(
dp_group
.
world_size
,
self
.
_num_of_ranks
)
self
.
assertEqual
(
dp_group
.
group
.
nranks
,
self
.
_num_of_ranks
)
self
.
assertEqual
(
self
.
_strategy
.
rank_in_strategy
(
"dp"
),
self
.
_global_rank
)
fused_group
=
self
.
_strategy
.
fused_strategy_group
(
"checkness"
)
self
.
assertEqual
(
fused_group
.
world_size
,
1
)
self
.
assertEqual
(
fused_group
.
group
.
nranks
,
1
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/strategy_group.py
0 → 100644
浏览文件 @
178d7e5e
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
paddle
import
paddle.distributed
as
dist
from
paddle.distributed.fleet.base.strategy_group
import
StrategyGroupBase
,
DPGroup
,
MPGroup
,
PPGroup
,
ShardingGroup
def
_check_using_all_reduce
(
group
):
data
=
paddle
.
to_tensor
([
1
,
2
,
3
])
result
=
paddle
.
to_tensor
([
2
,
4
,
6
])
dist
.
all_reduce
(
data
,
group
=
group
)
assert
np
.
array_equal
(
data
,
result
)
def
_check_using_send
(
group
,
dst
):
data
=
paddle
.
to_tensor
([
1
,
2
,
3
])
dist
.
send
(
data
,
dst
=
dst
,
group
=
group
)
def
_check_using_recv
(
group
,
src
):
result
=
paddle
.
to_tensor
([
1
,
2
,
3
])
data
=
paddle
.
to_tensor
([
0
,
0
,
0
])
dist
.
recv
(
data
,
src
=
src
,
group
=
group
)
assert
np
.
array_equal
(
data
,
result
)
class
TestStrategyGroupAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
_num_of_ranks
=
2
self
.
_list_of_rank
=
[[
0
,
1
]]
self
.
_list_of_ranks
=
[[
0
,
1
],
[
0
,
1
]]
dist
.
init_parallel_env
()
self
.
_global_rank
=
dist
.
get_rank
()
self
.
_peer_rank
=
0
if
self
.
_global_rank
==
1
else
1
def
test_strategy_group_base
(
self
):
strategy_group
=
StrategyGroupBase
(
self
.
_list_of_rank
)
self
.
assertEqual
(
strategy_group
.
world_size
,
self
.
_num_of_ranks
)
self
.
assertEqual
(
strategy_group
.
group
.
nranks
,
self
.
_num_of_ranks
)
_check_using_all_reduce
(
strategy_group
.
group
)
def
test_data_parallel_group
(
self
):
dp_group
=
DPGroup
(
self
.
_list_of_rank
)
self
.
assertEqual
(
dp_group
.
world_size
,
self
.
_num_of_ranks
)
self
.
assertEqual
(
dp_group
.
group
.
nranks
,
self
.
_num_of_ranks
)
_check_using_all_reduce
(
dp_group
.
group
)
def
test_model_parallel_group
(
self
):
mp_group
=
MPGroup
(
self
.
_list_of_rank
)
self
.
assertEqual
(
mp_group
.
world_size
,
self
.
_num_of_ranks
)
self
.
assertEqual
(
mp_group
.
group
.
nranks
,
self
.
_num_of_ranks
)
_check_using_all_reduce
(
mp_group
.
group
)
def
test_sharding_parallel_group
(
self
):
sharding_group
=
ShardingGroup
(
self
.
_list_of_rank
)
self
.
assertEqual
(
sharding_group
.
world_size
,
self
.
_num_of_ranks
)
self
.
assertEqual
(
sharding_group
.
group
.
nranks
,
self
.
_num_of_ranks
)
_check_using_all_reduce
(
sharding_group
.
group
)
def
test_pipeline_parallel_group
(
self
):
pp_group
=
PPGroup
(
self
.
_list_of_rank
)
send_next_group
,
send_prev_group
,
recv_next_group
,
recv_prev_group
=
pp_group
.
p2p_groups
if
self
.
_global_rank
==
0
:
self
.
assertEqual
(
pp_group
.
rank_of_next_stage
,
1
)
self
.
assertEqual
(
pp_group
.
rank_of_prev_stage
,
1
)
_check_using_send
(
send_next_group
,
self
.
_peer_rank
)
_check_using_send
(
send_prev_group
,
self
.
_peer_rank
)
_check_using_recv
(
recv_prev_group
,
self
.
_peer_rank
)
_check_using_recv
(
recv_next_group
,
self
.
_peer_rank
)
else
:
self
.
assertEqual
(
pp_group
.
rank_of_next_stage
,
0
)
self
.
assertEqual
(
pp_group
.
rank_of_prev_stage
,
0
)
_check_using_recv
(
recv_prev_group
,
self
.
_peer_rank
)
_check_using_recv
(
recv_next_group
,
self
.
_peer_rank
)
_check_using_send
(
send_next_group
,
self
.
_peer_rank
)
_check_using_send
(
send_prev_group
,
self
.
_peer_rank
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_orthogonal_strategy.sh
0 → 100644
浏览文件 @
178d7e5e
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set
-e
CUDA_VISIBLE_DEVICES
=
0,1 python
-m
paddle.distributed.launch
--gpus
=
0,1 orthogonal_strategy.py
python/paddle/fluid/tests/unittests/collective/test_strategy_group.sh
0 → 100644
浏览文件 @
178d7e5e
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set
-e
CUDA_VISIBLE_DEVICES
=
0,1 python
-m
paddle.distributed.launch
--gpus
=
0,1 strategy_group.py
python/paddle/fluid/tests/unittests/collective/testslist.csv
浏览文件 @
178d7e5e
...
...
@@ -46,3 +46,5 @@ test_gen_nccl_id_op,,gpu;rocm;ASCEND;ASCEND_CL,,DIST,../dist_test.sh,2,,http_pro
test_new_group_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_world_size_and_rank,linux,rocm;gpu,120,DIST,test_world_size_and_rank.sh,2,,http_proxy=;https_proxy=,
test_mpi_comm,linux,,,DIST,test_mpi_comm.sh,2,,http_proxy=;https_proxy=,WITH_MPI
test_strategy_group,linux,rocm;gpu,120,DIST,test_strategy_group.sh,2,,http_proxy=;https_proxy=,
test_orthogonal_strategy,linux,rocm;gpu,120,DIST,test_orthogonal_strategy.sh,2,,http_proxy=;https_proxy=,
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录