Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
174e25cf
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
174e25cf
编写于
2月 09, 2021
作者:
S
sandyhouse
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
997651ab
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
47 addition
and
19 deletion
+47
-19
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+47
-19
未找到文件。
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
浏览文件 @
174e25cf
...
...
@@ -39,7 +39,7 @@ class ShardingOptimizer(MetaOptimizerBase):
"AMPOptimizer"
,
"LarsOptimizer"
,
"LambOptimizer"
,
"ModelParallelOptimizer"
,
#
"ModelParallelOptimizer",
"PipelineOptimizer"
,
]
self
.
meta_optimizers_black_list
=
[
"GraphExecutionOptimizer"
,
]
...
...
@@ -358,6 +358,19 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
_nrings_sharding
)
# config sharding & dp groups
self
.
_init_comm
()
# inner & outer model parallelism
if
self
.
_as_outer_parallelism
:
self
.
_collective_helper
.
_init_communicator
(
self
.
_startup_program
,
self
.
current_endpoint
,
self
.
global_group_endpoints
,
self
.
global_rank
,
self
.
global_group_id
,
True
)
if
self
.
_as_outer_parallelism
:
self
.
_collective_helper
.
_init_communicator
(
self
.
_startup_program
,
self
.
current_endpoint
,
self
.
mp_group_endpoints
,
self
.
mp_rank
,
self
.
mp_group_id
,
False
)
# sharding
print
(
"sharding_group_endpoints:"
,
self
.
sharding_group_endpoints
)
print
(
"sharding_rank:"
,
self
.
sharding_rank
)
...
...
@@ -365,13 +378,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
_collective_helper
.
_init_communicator
(
self
.
_startup_program
,
self
.
current_endpoint
,
self
.
sharding_group_endpoints
,
self
.
sharding_rank
,
self
.
sharding_ring_id
,
True
)
# inner & outer model parallelism
# if self._as_outer_parallelism:
# self._collective_helper._init_communicator(
# self._startup_program, self.current_endpoint,
# self.mp_group_endpoints, self.mp_rank, self.mp_group_id, True)
self
.
sharding_ring_id
,
False
)
# dp
if
self
.
hybrid_dp
:
...
...
@@ -382,7 +389,7 @@ class ShardingOptimizer(MetaOptimizerBase):
if
self
.
use_pipeline
:
self
.
_collective_helper
.
_init_communicator
(
self
.
_startup_program
,
self
.
current_endpoint
,
self
.
pp_group_endpoints
,
self
.
pp_rank
,
self
.
pp_ring_id
,
Tru
e
)
self
.
pp_group_endpoints
,
self
.
pp_rank
,
self
.
pp_ring_id
,
Fals
e
)
startup_block
=
self
.
_startup_program
.
global_block
()
startup_block
.
_sync_with_cpp
()
...
...
@@ -482,7 +489,7 @@ class ShardingOptimizer(MetaOptimizerBase):
# group. and each Data Parallelism group should have its own sync of FoundInfinite
Model_Paramllelism_ring_id
=
self
.
sharding_ring_id
if
self
.
_as_outer_parallelism
:
Model_Paramllelism_ring_id
=
self
.
mp
_group_id
Model_Paramllelism_ring_id
=
self
.
global
_group_id
FP16Utils
.
prune_fp16
(
block
,
self
.
_shard
,
self
.
_reduced_grads_to_param
,
Model_Paramllelism_ring_id
)
gradientclip_helper
=
GradientClipHelper
(
Model_Paramllelism_ring_id
)
...
...
@@ -826,23 +833,42 @@ class ShardingOptimizer(MetaOptimizerBase):
)
==
self
.
sharding_rank
]
else
:
self
.
mp_group_id
=
0
self
.
sharding_ring_id
=
1
self
.
pp_ring_id
=
2
self
.
mp_rank
=
self
.
global_rank
%
self
.
_inner_parallelism_size
self
.
mp_group
=
self
.
global_rank
//
self
.
_inner_parallelism_size
self
.
mp_group_endpoints
=
[
ep
for
idx
,
ep
in
enumerate
(
self
.
endpoints
)
if
idx
//
self
.
_inner_parallelism_size
==
self
.
mp_group
]
print
(
"megatron_group_endpoints:"
,
self
.
mp_group_endpoints
)
print
(
"megatron_rank:"
,
self
.
mp_rank
)
# self.cards_per_node = 8
self
.
sharding_group_size
=
self
.
user_defined_strategy
.
sharding_configs
[
'sharding_group_size'
]
self
.
sharding_rank
=
self
.
global_rank
//
self
.
_inner_parallelism_size
%
self
.
sharding_group_size
# self.sharding_group_id = self.global_rank // (self._inner_parallelism_size % self.sharding_group_size)
self
.
sharding_rank
=
(
self
.
global_rank
//
self
.
_inner_parallelism_size
)
%
self
.
sharding_group_size
self
.
sharding_group_id
=
self
.
global_rank
//
(
self
.
_inner_parallelism_size
*
self
.
sharding_group_size
)
self
.
megatron_rank
=
self
.
global_rank
%
self
.
_inner_parallelism_size
self
.
sharding_group_endpoints
=
[
ep
for
idx
,
ep
in
enumerate
(
self
.
endpoints
)
if
(
idx
//
self
.
_inner_parallelism_size
%
self
.
sharding_group_size
)
==
self
.
sharding_rank
if
(
idx
//
(
self
.
_inner_parallelism_size
*
self
.
sharding_group_size
)
)
==
self
.
sharding_group_id
and
idx
%
self
.
_inner_parallelism_size
==
self
.
megatron_rank
]
print
(
"sharding_endpoint:"
,
self
.
sharding_group_endpoints
)
print
(
"sharding_rank:"
,
self
.
sharding_rank
)
assert
self
.
sharding_group_size
*
self
.
pipeline_nodes
*
self
.
_inner_parallelism_size
==
self
.
role_maker
.
_worker_num
(
)
self
.
pp_rank
=
self
.
global_rank
//
(
self
.
sharding_group_size
*
self
.
_inner_parallelism_size
)
self
.
sharding_group_size
*
self
.
_inner_parallelism_size
)
%
self
.
pipeline_nodes
offset
=
self
.
sharding_group_size
*
self
.
_inner_parallelism_size
# TODO: Adjust for dp
idx_with_pp_0
=
self
.
global_rank
%
(
self
.
sharding_group_size
*
self
.
_inner_parallelism_size
)
self
.
pp_group_endpoints
=
[]
...
...
@@ -850,15 +876,17 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
pp_group_endpoints
.
append
(
self
.
endpoints
[
idx_with_pp_0
])
idx_with_pp_0
+=
offset
print
(
"pp_group_endpoints:"
,
self
.
pp_group_endpoints
)
print
(
"pp_rank:"
,
self
.
pp_rank
)
#self.pp_group_endpoints = [
# ep for idx, ep in enumerate(self.endpoints)
# if (idx % self.sharding_group_size) == self.sharding_rank
#]
self
.
mp_group_id
=
1
self
.
mp
_rank
=
self
.
global_rank
self
.
mp
_group_size
=
self
.
role_maker
.
_worker_num
()
self
.
mp
_group_endpoints
=
self
.
endpoints
[:]
self
.
global_group_id
=
3
self
.
global
_rank
=
self
.
global_rank
self
.
global
_group_size
=
self
.
role_maker
.
_worker_num
()
self
.
global
_group_endpoints
=
self
.
endpoints
[:]
logging
.
info
(
"Using Sharing as Outer parallelism mode !"
)
self
.
dp_ring_id
=
-
1
self
.
dp_rank
=
-
1
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录