Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
7aa0cc3c
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7aa0cc3c
编写于
2月 07, 2021
作者:
S
sandyhouse
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
fa71ee87
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
78 addition
and
40 deletion
+78
-40
python/paddle/distributed/fleet/meta_optimizers/model_parallel_optimizer.py
...ributed/fleet/meta_optimizers/model_parallel_optimizer.py
+43
-24
python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py
...distributed/fleet/meta_optimizers/sharding/fp16_helper.py
+35
-16
未找到文件。
python/paddle/distributed/fleet/meta_optimizers/model_parallel_optimizer.py
浏览文件 @
7aa0cc3c
...
...
@@ -22,9 +22,10 @@ from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper, is_u
class
ModelParallelHelper
(
object
):
def
__init__
(
self
,
role_maker
,
wait_port
=
True
):
def
__init__
(
self
,
role_maker
,
wait_port
=
True
,
megatron_dp
=
False
):
self
.
wait_port
=
wait_port
self
.
role_maker
=
role_maker
self
.
megatron_dp
=
megatron_dp
def
update_startup_program
(
self
,
startup_program
=
None
,
...
...
@@ -48,24 +49,29 @@ class ModelParallelHelper(object):
mp_endpoints
,
mp_rank
,
0
,
self
.
wait_port
)
self
.
_broadcast_params
(
0
,
broadcast_distributed_weight
=
False
)
mp_num
=
len
(
endpoints
)
//
inner_parallelism
if
mp_num
==
1
:
return
# Create rings for gpus as the same model parallel part
eps
=
[]
dp_rank
=
rank
//
inner_parallelism
dp_id
=
rank
%
inner_parallelism
#if dp_rank == 1: dp_rank =0
#if dp_rank == 0: dp_rank =1
ring_id
=
1
for
idx
,
ep
in
enumerate
(
endpoints
):
if
idx
%
inner_parallelism
==
dp_id
:
eps
.
append
(
ep
)
#ep = eps.pop(0)
#eps.insert(1, ep)
print
(
"data parallel eps:{}, rank{}"
.
format
(
eps
,
dp_rank
))
self
.
_init_communicator
(
self
.
startup_program
,
current_endpoint
,
eps
,
dp_rank
,
ring_id
,
self
.
wait_port
)
self
.
_broadcast_params
(
ring_id
,
broadcast_distributed_weight
=
True
)
print
(
"megatron group size: {}"
.
format
(
inner_parallelism
))
print
(
"megatron rank: {}"
.
format
(
mp_rank
))
print
(
"megatron endpoints: {}"
.
format
(
mp_endpoints
))
if
self
.
megatron_dp
:
mp_num
=
len
(
endpoints
)
//
inner_parallelism
if
mp_num
==
1
:
return
# Create rings for gpus as the same model parallel part
eps
=
[]
dp_rank
=
rank
//
inner_parallelism
dp_id
=
rank
%
inner_parallelism
#if dp_rank == 1: dp_rank =0
#if dp_rank == 0: dp_rank =1
ring_id
=
1
for
idx
,
ep
in
enumerate
(
endpoints
):
if
idx
%
inner_parallelism
==
dp_id
:
eps
.
append
(
ep
)
#ep = eps.pop(0)
#eps.insert(1, ep)
print
(
"data parallel eps:{}, rank{}"
.
format
(
eps
,
dp_rank
))
self
.
_init_communicator
(
self
.
startup_program
,
current_endpoint
,
eps
,
dp_rank
,
ring_id
,
self
.
wait_port
)
self
.
_broadcast_params
(
ring_id
,
broadcast_distributed_weight
=
True
)
def
_init_communicator
(
self
,
program
,
current_endpoint
,
endpoints
,
rank
,
ring_id
,
wait_port
):
...
...
@@ -129,9 +135,14 @@ class ModelParallelOptimizer(MetaOptimizerBase):
def
__init__
(
self
,
optimizer
):
super
(
ModelParallelOptimizer
,
self
).
__init__
(
optimizer
)
self
.
inner_opt
=
optimizer
# we do not allow meta optimizer to be inner optimizer currently
self
.
meta_optimizers_white_list
=
[]
self
.
meta_optimizers_white_list
=
[
"RecomputeOptimizer"
,
"AMPOptimizer"
,
"LarsOptimizer"
,
"LambOptimizer"
,
]
self
.
meta_optimizers_black_list
=
[
"GraphExecutionOptimizer"
,
]
self
.
megatron_dp
=
False
def
_set_basic_info
(
self
,
loss
,
role_maker
,
user_defined_optimizer
,
user_defined_strategy
):
...
...
@@ -156,6 +167,10 @@ class ModelParallelOptimizer(MetaOptimizerBase):
dist_strategy
.
model_parallel
=
True
dist_strategy
.
model_parallel_configs
=
{
"parallelism"
:
1
,
}
# the following function will be used by AMP if both Megatron and AMP are turn on together.
def
apply_gradients
(
self
,
params_grads
):
return
self
.
minimize_impl
(
params_grads
=
params_grads
)
def
minimize_impl
(
self
,
loss
,
startup_program
=
None
,
...
...
@@ -167,6 +182,8 @@ class ModelParallelOptimizer(MetaOptimizerBase):
if
startup_program
is
None
:
self
.
startup_program
=
fluid
.
default_startup_program
()
# (TODO) check the order of metaoptimizer
# (TODO) check the params_grads
optimize_ops
,
params_grads
=
self
.
inner_opt
.
minimize
(
loss
,
self
.
startup_program
,
parameter_list
,
no_grad_set
)
...
...
@@ -179,10 +196,12 @@ class ModelParallelOptimizer(MetaOptimizerBase):
self
.
inner_parallelism
)
assert
self
.
nranks
%
self
.
inner_parallelism
==
0
# data parallelism
dp_parallelism
=
self
.
nranks
//
self
.
inner_parallelism
self
.
_transpile_main_program
(
loss
,
dp_parallelism
)
if
self
.
megatron_dp
:
# data parallelism
dp_parallelism
=
self
.
nranks
//
self
.
inner_parallelism
self
.
_transpile_main_program
(
loss
,
dp_parallelism
)
return
optimize_ops
,
params_grads
def
_transpile_main_program
(
self
,
loss
,
dp_parallelism
):
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py
浏览文件 @
7aa0cc3c
...
...
@@ -73,7 +73,7 @@ class FP16Utils(object):
@
staticmethod
def
prune_fp16
(
block
,
shard
,
reduced_grads_to_param
,
ring_id
):
"""
1. prune all cast_fp
32_to_fp16
ops if the param not belongs to this shard
1. prune all cast_fp
16_to_fp32
ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding
"""
# remove cast
...
...
@@ -103,6 +103,7 @@ class FP16Utils(object):
op
.
_rename_input
(
inf_var_name
,
inf_var_name
+
"@sharding"
)
if
op
.
type
in
[
"check_finite_and_unscale"
,
"update_loss_scaling"
]:
reversed_x
=
[]
reversed_x_paramname
=
[]
for
input_name
in
op
.
desc
.
input
(
'X'
):
param_name
=
input_name
.
strip
(
"@GRAD"
)
if
param_name
not
in
shard
.
global_params
:
...
...
@@ -111,12 +112,26 @@ class FP16Utils(object):
"be grads, but {} is not a grad"
.
format
(
input_name
))
if
shard
.
has_param
(
param_name
):
reversed_x
.
append
(
input_name
)
reversed_x_paramname
.
append
(
param_name
)
op
.
desc
.
set_input
(
'X'
,
reversed_x
)
op
.
desc
.
set_output
(
'Out'
,
reversed_x
)
# the grad checking should take the all and only param in the current shard
to_check_param
=
set
(
reversed_x_paramname
)
should_check_param
=
set
(
shard
.
global_params
).
intersection
(
set
([
param
for
param
,
worker_idx
in
shard
.
global_param2device
.
items
()
if
worker_idx
==
shard
.
worker_idx
]))
assert
to_check_param
==
should_check_param
,
"amp check_finite_and_unscale checking miss [{}] and got unexpected [{}]"
.
format
(
should_check_param
-
to_check_param
,
to_check_param
-
should_check_param
)
if
update_loss_scaling_op_idx
==
-
1
:
return
inf_var
=
block
.
var
(
inf_var_name
)
inf_var_
fp
32
=
block
.
create_var
(
inf_var_
int
32
=
block
.
create_var
(
name
=
inf_var_name
+
"@cast_int32"
,
shape
=
inf_var
.
shape
,
dtype
=
core
.
VarDesc
.
VarType
.
INT32
)
...
...
@@ -128,32 +143,36 @@ class FP16Utils(object):
update_loss_scaling_op_idx
,
type
=
'cast'
,
inputs
=
{
'X'
:
inf_var
},
outputs
=
{
'Out'
:
inf_var_
fp
32
},
outputs
=
{
'Out'
:
inf_var_
int
32
},
attrs
=
{
"in_dtype"
:
inf_var
.
dtype
,
"out_dtype"
:
inf_var_
fp
32
.
dtype
,
"out_dtype"
:
inf_var_
int
32
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
insert_sync_calc_op
(
block
,
update_loss_scaling_op_idx
+
1
,
[
inf_var_fp32
])
# this allreduce communication should not overlap with calc
# insert_sync_calc_op(block, update_loss_scaling_op_idx + 1,
# [inf_var_int32])
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
+
2
,
update_loss_scaling_op_idx
+
1
,
type
=
'c_allreduce_max'
,
inputs
=
{
'X'
:
inf_var_fp32
},
outputs
=
{
'Out'
:
inf_var_fp32
},
attrs
=
{
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
inputs
=
{
'X'
:
inf_var_int32
},
outputs
=
{
'Out'
:
inf_var_int32
},
attrs
=
{
'ring_id'
:
ring_id
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
comm_op_num
=
insert_sync_comm_op
(
block
,
update_loss_scaling_op_idx
+
3
,
ring_id
,
[
inf_var_fp
32
])
#
comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3,
# ring_id, [inf_var_int
32])
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
+
3
+
comm_op_num
,
update_loss_scaling_op_idx
+
2
,
type
=
'cast'
,
inputs
=
{
'X'
:
inf_var_
fp
32
},
inputs
=
{
'X'
:
inf_var_
int
32
},
outputs
=
{
'Out'
:
inf_var_sharding
},
attrs
=
{
"in_dtype"
:
inf_var_
fp
32
.
dtype
,
"in_dtype"
:
inf_var_
int
32
.
dtype
,
"out_dtype"
:
inf_var_sharding
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录