Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f48611f3
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f48611f3
编写于
5月 14, 2023
作者:
S
ShenLiang
提交者:
GitHub
5月 14, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Cherry-Pick]Add identity hcg for hybridparallel (#53787)
* add utest * rm hack code
上级
cc6dcc7d
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
90 addition
and
17 deletion
+90
-17
python/paddle/distributed/fleet/base/orthogonal_strategy.py
python/paddle/distributed/fleet/base/orthogonal_strategy.py
+17
-9
python/paddle/distributed/fleet/base/strategy_group.py
python/paddle/distributed/fleet/base/strategy_group.py
+37
-8
python/paddle/fluid/tests/unittests/collective/orthogonal_strategy.py
...e/fluid/tests/unittests/collective/orthogonal_strategy.py
+36
-0
未找到文件。
python/paddle/distributed/fleet/base/orthogonal_strategy.py
浏览文件 @
f48611f3
...
@@ -47,11 +47,16 @@ class OrthogonalStrategy:
...
@@ -47,11 +47,16 @@ class OrthogonalStrategy:
"""
"""
def
__init__
(
self
,
list_of_strategy
,
fused_strategy_dict
=
{}):
def
__init__
(
self
,
list_of_strategy
,
fused_strategy_dict
=
{},
strategy_rank_list
=
None
):
self
.
_list_of_strategy
=
list_of_strategy
self
.
_list_of_strategy
=
list_of_strategy
self
.
_fused_strategy_dict
=
fused_strategy_dict
self
.
_fused_strategy_dict
=
fused_strategy_dict
self
.
_rank
=
dist
.
get_rank
()
self
.
_strategy_rank_list
=
(
self
.
_rank_list_dict
=
{}
strategy_rank_list
if
strategy_rank_list
is
not
None
else
list
(
range
(
dist
.
get_world_size
()))
)
self
.
_name_to_group_dict
=
{}
self
.
_name_to_group_dict
=
{}
self
.
_name_to_degree_dict
=
{}
self
.
_name_to_degree_dict
=
{}
self
.
_list_of_strategy_name
=
[
self
.
_list_of_strategy_name
=
[
...
@@ -67,16 +72,17 @@ class OrthogonalStrategy:
...
@@ -67,16 +72,17 @@ class OrthogonalStrategy:
list_of_coord
=
[
list_of_coord
=
[
self
.
_coordinate
(
*
coord
)
for
coord
in
itertools
.
product
(
*
ranges
)
self
.
_coordinate
(
*
coord
)
for
coord
in
itertools
.
product
(
*
ranges
)
]
]
self
.
_coord_to_rank_dict
=
dict
(
self
.
_coord_to_rank_dict
=
dict
(
zip
(
list_of_coord
,
range
(
len
(
list_of_coord
))
)
zip
(
list_of_coord
,
self
.
_strategy_rank_list
)
)
)
for
idx
,
strategy
in
enumerate
(
list_of_strategy
):
for
idx
,
strategy
in
enumerate
(
list_of_strategy
):
strategy_name
=
strategy
[
0
]
strategy_name
=
strategy
[
0
]
self
.
_name_to_degree_dict
[
strategy_name
]
=
strategy
[
1
]
self
.
_name_to_degree_dict
[
strategy_name
]
=
strategy
[
1
]
self
.
_rank_list_dict
[
strategy_name
]
=
self
.
_calc_rank_list
(
idx
)
rank_list
=
self
.
_calc_rank_list
(
idx
)
self
.
_name_to_group_dict
[
strategy_name
]
=
strategy
[
2
](
self
.
_name_to_group_dict
[
strategy_name
]
=
strategy
[
2
](
self
.
_rank_list_dict
[
strategy_name
]
rank_list
,
)
)
self
.
_name_to_fused_group_dict
=
{}
self
.
_name_to_fused_group_dict
=
{}
...
@@ -136,11 +142,13 @@ class OrthogonalStrategy:
...
@@ -136,11 +142,13 @@ class OrthogonalStrategy:
num_of_ranks
=
functools
.
reduce
(
num_of_ranks
=
functools
.
reduce
(
lambda
x
,
y
:
x
*
y
,
self
.
_list_of_degree
lambda
x
,
y
:
x
*
y
,
self
.
_list_of_degree
)
)
assert
(
num_of_ranks
==
dist
.
get_world_size
()
assert
num_of_ranks
==
len
(
self
.
_strategy_rank_list
),
"There are total {} ranks, but need {} ranks in this strategy."
.
format
(
),
"There are total {} ranks, but need {} ranks in this strategy."
.
format
(
dist
.
get_world_size
(
),
num_of_ranks
len
(
self
.
_strategy_rank_list
),
num_of_ranks
)
)
for
fused_strategy
in
self
.
_fused_strategy_dict
.
values
():
for
fused_strategy
in
self
.
_fused_strategy_dict
.
values
():
for
strategy
in
fused_strategy
:
for
strategy
in
fused_strategy
:
assert
(
assert
(
...
...
python/paddle/distributed/fleet/base/strategy_group.py
浏览文件 @
f48611f3
...
@@ -12,7 +12,9 @@
...
@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
paddle.distributed
as
dist
import
paddle.distributed
as
dist
from
paddle.distributed.fleet.layers.mpu
import
RNGStatesTracker
class
StrategyGroupBase
:
class
StrategyGroupBase
:
...
@@ -39,6 +41,9 @@ class StrategyGroupBase:
...
@@ -39,6 +41,9 @@ class StrategyGroupBase:
"""
"""
def
__init__
(
self
,
list_of_ranks
):
def
__init__
(
self
,
list_of_ranks
):
"""
Initialize the communication group.
"""
assert
(
assert
(
dist
.
is_initialized
()
dist
.
is_initialized
()
),
"The global communication group need to be initialized."
),
"The global communication group need to be initialized."
...
@@ -46,6 +51,19 @@ class StrategyGroupBase:
...
@@ -46,6 +51,19 @@ class StrategyGroupBase:
self
.
_rank
=
dist
.
get_rank
()
self
.
_rank
=
dist
.
get_rank
()
self
.
_list_of_ranks
=
list_of_ranks
self
.
_list_of_ranks
=
list_of_ranks
self
.
_group
=
self
.
_create_group
()
self
.
_group
=
self
.
_create_group
()
self
.
random_states_tracker
=
RNGStatesTracker
()
def
add_random_seed
(
self
,
name
,
seed
):
"""
Add random seed for current rank.
"""
self
.
random_states_tracker
.
add
(
name
,
seed
)
def
get_random_states_tracker
(
self
):
"""
Get the random states tracker.
"""
return
self
.
random_states_tracker
@
property
@
property
def
world_size
(
self
):
def
world_size
(
self
):
...
@@ -74,17 +92,28 @@ class StrategyGroupBase:
...
@@ -74,17 +92,28 @@ class StrategyGroupBase:
return
self
.
_group
return
self
.
_group
def
_create_group
(
self
):
def
_create_group
(
self
):
list_of_group
=
[]
self
.
list_of_group
=
[]
for
ranks
in
self
.
_list_of_ranks
:
for
ranks
in
self
.
_list_of_ranks
:
group
=
dist
.
new_group
(
ranks
=
ranks
)
group
=
dist
.
new_group
(
ranks
=
ranks
)
if
self
.
_rank
in
ranks
:
if
self
.
_rank
in
ranks
:
list_of_group
.
append
(
group
)
self
.
list_of_group
.
append
(
group
)
assert
(
len
(
list_of_group
)
>
0
if
not
self
.
list_of_group
:
),
"Rank {} does not belong to the list_of_ranks {}."
.
format
(
return
None
self
.
_rank
,
self
.
_list_of_ranks
else
:
)
return
(
return
list_of_group
if
len
(
list_of_group
)
>
1
else
list_of_group
[
0
]
self
.
list_of_group
[
0
]
if
len
(
self
.
list_of_group
)
==
1
else
self
.
list_of_group
)
def
__repr__
(
self
):
debug_str
=
f
"seed:
{
self
.
_seed
}
; "
if
not
self
.
list_of_group
:
return
debug_str
+
"No group."
for
i
in
range
(
len
(
self
.
list_of_group
)):
debug_str
+=
f
"Group[
{
i
}
]:
{
str
(
self
.
list_of_group
[
i
])
}
; "
return
debug_str
class
DPGroup
(
StrategyGroupBase
):
class
DPGroup
(
StrategyGroupBase
):
...
...
python/paddle/fluid/tests/unittests/collective/orthogonal_strategy.py
浏览文件 @
f48611f3
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
import
unittest
import
unittest
import
paddle
import
paddle.distributed
as
dist
import
paddle.distributed
as
dist
from
paddle.distributed.fleet.base.orthogonal_strategy
import
OrthogonalStrategy
from
paddle.distributed.fleet.base.orthogonal_strategy
import
OrthogonalStrategy
from
paddle.distributed.fleet.base.strategy_group
import
(
from
paddle.distributed.fleet.base.strategy_group
import
(
...
@@ -52,5 +53,40 @@ class TestOrthogonalStrategyAPI(unittest.TestCase):
...
@@ -52,5 +53,40 @@ class TestOrthogonalStrategyAPI(unittest.TestCase):
self
.
assertEqual
(
fused_group
.
group
.
nranks
,
1
)
self
.
assertEqual
(
fused_group
.
group
.
nranks
,
1
)
class
TestOrthogonalStrategyCustomAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
_num_of_ranks
=
2
dist
.
init_parallel_env
()
self
.
_global_rank
=
dist
.
get_rank
()
self
.
_strategy
=
OrthogonalStrategy
(
[
(
"dp"
,
1
,
DPGroup
),
(
"mp"
,
2
,
MPGroup
),
(
"sharding"
,
1
,
ShardingGroup
),
(
"pp"
,
1
,
PPGroup
),
],
fused_strategy_dict
=
{
"checkness"
:
[
"mp"
,
"sharding"
,
"pp"
]},
strategy_rank_list
=
[
1
,
0
],
)
self
.
_strategy
.
strategy_group
(
"mp"
).
add_random_seed
(
"local_seed"
,
123
)
self
.
_strategy
.
strategy_group
(
"mp"
).
add_random_seed
(
"global_seed"
,
321
)
def
test_orthogonal_strategy
(
self
):
mp_group
=
self
.
_strategy
.
strategy_group
(
"mp"
)
self
.
assertEqual
(
mp_group
.
world_size
,
self
.
_num_of_ranks
)
self
.
assertEqual
(
mp_group
.
group
.
nranks
,
self
.
_num_of_ranks
)
self
.
assertEqual
(
self
.
_strategy
.
rank_in_strategy
(
"mp"
),
self
.
_global_rank
)
fused_group
=
self
.
_strategy
.
fused_strategy_group
(
"checkness"
)
self
.
assertEqual
(
fused_group
.
world_size
,
2
)
self
.
assertEqual
(
fused_group
.
group
.
nranks
,
2
)
with
mp_group
.
random_states_tracker
.
rng_state
(
"local_seed"
):
a
=
paddle
.
randint
(
0
,
100
,
[
10
]).
numpy
()[
0
]
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录