Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
41e2d413
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
41e2d413
编写于
8月 02, 2021
作者:
W
WangXi
提交者:
GitHub
8月 02, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[NPU] fix npu pipeline comm init (#34466)
上级
8b72a1a7
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
114 addition
and
25 deletion
+114
-25
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+114
-25
未找到文件。
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
浏览文件 @
41e2d413
...
...
@@ -379,6 +379,119 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
_wait
()
return
optimize_ops
,
params_grads
def
_init_pair_comm
(
self
,
pair
,
ring_id
):
pp_group_endpoints
=
[
self
.
pp_group_endpoints
[
pair
[
0
]],
self
.
pp_group_endpoints
[
pair
[
1
]],
]
pp_rank
=
0
if
self
.
pp_rank
==
pair
[
0
]
else
1
self
.
_collective_helper
.
_init_communicator
(
self
.
_startup_program
,
self
.
current_endpoint
,
pp_group_endpoints
,
pp_rank
,
ring_id
,
False
,
global_ring_id
=
self
.
global_ring_id
,
sync
=
False
)
def
_init_npu_pipeline_comm
(
self
,
startup_block
):
# NOTE(wangxi): some bug with hccl, must set pp_degree be even number
assert
(
self
.
pp_degree
%
2
)
==
0
max_ring_id
=
-
1
my_pair
=
[]
for
pair
in
self
.
pipeline_pair
:
pair_key
=
pair
[
0
]
*
1000
+
pair
[
1
]
ring_id
=
self
.
pp_ring_map
[
pair_key
]
max_ring_id
=
max
(
max_ring_id
,
ring_id
)
logger
.
info
(
"pp pair:{}, ring_id: {}"
.
format
(
pair
,
ring_id
))
if
self
.
pp_rank
in
pair
:
my_pair
.
append
(
pair
)
# for example: self.pp_rank=2, self.pp_degree=4
send_to_next_pair
=
(
self
.
pp_rank
,
(
self
.
pp_rank
+
1
)
%
self
.
pp_degree
)
# 2->3
recv_from_next_pair
=
((
self
.
pp_rank
+
1
)
%
self
.
pp_degree
,
self
.
pp_rank
)
# 3->2
recv_from_prev_pair
=
((
self
.
pp_rank
-
1
+
self
.
pp_degree
)
%
self
.
pp_degree
,
self
.
pp_rank
)
# 1->2
send_to_prev_pair
=
(
self
.
pp_rank
,
(
self
.
pp_rank
-
1
+
self
.
pp_degree
)
%
self
.
pp_degree
)
# 2->1
even
=
(
self
.
pp_rank
%
2
)
==
0
# 1. even send to next, odd recv from prev, 0->1, 2->3
pair
=
send_to_next_pair
if
even
else
recv_from_prev_pair
ring_id
=
self
.
pp_ring_map
[
pair
[
0
]
*
1000
+
pair
[
1
]]
self
.
_init_pair_comm
(
pair
,
ring_id
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
my_pair
.
remove
(
pair
)
logger
.
info
(
"pair0(even->odd): pp pair:{}, ring_id: {}"
.
format
(
pair
,
ring_id
))
# 2. even recv from next, odd send to prev, 1->0, 3->2
pair
=
recv_from_next_pair
if
even
else
send_to_prev_pair
ring_id
=
self
.
pp_ring_map
[
pair
[
0
]
*
1000
+
pair
[
1
]]
self
.
_init_pair_comm
(
pair
,
ring_id
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
my_pair
.
remove
(
pair
)
logger
.
info
(
"pair1(even<-odd): pp pair:{}, ring_id: {}"
.
format
(
pair
,
ring_id
))
# if pp_degree is 2, only need pair(0->1, 1->0)
if
self
.
pp_degree
>
2
:
# 3. odd send to next, even recv from prev, 1->2, 3->0
pair
=
send_to_next_pair
if
not
even
else
recv_from_prev_pair
ring_id
=
self
.
pp_ring_map
.
get
(
pair
[
0
]
*
1000
+
pair
[
1
],
max_ring_id
+
1
)
# 3->0 not in pp_ring_map
self
.
_init_pair_comm
(
pair
,
ring_id
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
if
self
.
pp_rank
!=
0
and
self
.
pp_rank
!=
self
.
pp_degree
-
1
:
my_pair
.
remove
(
pair
)
logger
.
info
(
"pair2(odd->even): pp pair:{}, ring_id: {}"
.
format
(
pair
,
ring_id
))
# 4. odd recv from next, even send to prev, 2->1, 0->3
pair
=
recv_from_next_pair
if
not
even
else
send_to_prev_pair
ring_id
=
self
.
pp_ring_map
.
get
(
pair
[
0
]
*
1000
+
pair
[
1
],
max_ring_id
+
2
)
# 0->3 not in pp_ring_map
self
.
_init_pair_comm
(
pair
,
ring_id
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
if
self
.
pp_rank
!=
0
and
self
.
pp_rank
!=
self
.
pp_degree
-
1
:
my_pair
.
remove
(
pair
)
logger
.
info
(
"pair3(odd<-even): pp pair:{}, ring_id: {}"
.
format
(
pair
,
ring_id
))
assert
len
(
my_pair
)
==
0
,
"Current pipeline does not support cross stage communication, "
\
"please check unexpected pair {}"
.
format
(
my_pair
)
def
_init_pipeline_comm
(
self
,
startup_block
):
# TODO (JZ-LIANG) to unify pp_rank_ and pp_rank
assert
self
.
pp_rank_
==
self
.
pp_rank
,
"pp rank for pp opt [{}], pp rank for sharding opt [{}]"
.
format
(
self
.
pp_rank_
,
self
.
pp_rank
)
if
core
.
is_compiled_with_npu
():
self
.
_init_npu_pipeline_comm
(
startup_block
)
return
# GPU
for
pair
in
self
.
pipeline_pair
:
pair_key
=
pair
[
0
]
*
1000
+
pair
[
1
]
ring_id
=
self
.
pp_ring_map
[
pair_key
]
logger
.
info
(
"pp pair:{}, ring_id: {}"
.
format
(
pair
,
ring_id
))
if
self
.
pp_rank
in
pair
:
self
.
_init_pair_comm
(
pair
,
ring_id
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
def
_init_comm
(
self
):
# config sharding & dp groups
...
...
@@ -435,31 +548,7 @@ class ShardingOptimizer(MetaOptimizerBase):
# pp ring
if
self
.
pp_degree
>
1
:
# TODO (JZ-LIANG) to unify this shit
assert
self
.
pp_rank_
==
self
.
pp_rank
,
"pp rank for pp opt [{}], pp rank for sharding opt [{}]"
.
format
(
self
.
pp_rank_
,
self
.
pp_rank
)
for
pair
in
self
.
pipeline_pair
:
pair_key
=
pair
[
0
]
*
1000
+
pair
[
1
]
ring_id
=
self
.
pp_ring_map
[
pair_key
]
print
(
"pp pair:{}, ring_id: {}"
.
format
(
pair
,
ring_id
))
if
self
.
pp_rank
in
pair
:
pp_group_endpoints
=
[
self
.
pp_group_endpoints
[
pair
[
0
]],
self
.
pp_group_endpoints
[
pair
[
1
]],
]
pp_rank
=
0
if
self
.
pp_rank
==
pair
[
0
]
else
1
self
.
_collective_helper
.
_init_communicator
(
self
.
_startup_program
,
self
.
current_endpoint
,
pp_group_endpoints
,
pp_rank
,
ring_id
,
False
,
global_ring_id
=
self
.
global_ring_id
,
sync
=
False
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
self
.
_init_pipeline_comm
(
startup_block
)
# pure dp ring
if
self
.
dp_degree
>
1
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录