Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
920806db
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
920806db
编写于
2月 07, 2021
作者:
S
sandyhouse
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
70eb21c5
变更
4
展开全部
隐藏空白更改
内联
并排
Showing
4 changed file
with
533 addition
and
94 deletion
+533
-94
python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py
...addle/distributed/fleet/meta_optimizers/sharding/prune.py
+3
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
...addle/distributed/fleet/meta_optimizers/sharding/utils.py
+86
-65
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+404
-18
python/paddle/fluid/backward.py
python/paddle/fluid/backward.py
+40
-11
未找到文件。
python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py
浏览文件 @
920806db
...
...
@@ -126,6 +126,9 @@ class ProgramDeps(object):
def
should_remove_op
(
self
,
op_idx
):
op
=
self
.
_block
.
ops
[
op_idx
]
# remove check_finite_and_unscale op if its input 'X' is empty
if
op
.
type
==
'check_finite_and_unscale'
and
len
(
op
.
input
(
'X'
))
==
0
:
return
True
for
output_name
in
op
.
desc
.
output_arg_names
():
if
output_name
not
in
self
.
_should_removed_var
:
return
False
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
浏览文件 @
920806db
...
...
@@ -28,21 +28,24 @@ def check_broadcast(block):
if the broadcasted var has a fill_constant op, the fill_constant
op should stay forward before the broadcast op, and before a
sync_calc op. Otherwise, raise error.
should ignore and skip broadcast_op of inner_parallelism (e.g. Megatron)
"""
broadcast_vars
=
{}
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"c_broadcast"
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
"@BroadCast"
in
var_name
:
if
var_name
in
broadcast_vars
:
raise
ValueError
(
"var_name areadly exist: {}"
"the old pos is {}, the new pos is {}"
.
format
(
var_name
,
broadcast_vars
[
var_name
][
"broadcast_pos"
],
idx
))
broadcast_vars
[
var_name
]
=
{
"fill_constant_pos"
:
-
1
,
"broadcast_pos"
:
idx
,
}
if
op
.
all_attrs
()[
"use_calc_stream"
]
==
False
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
"@BroadCast"
in
var_name
:
if
var_name
in
broadcast_vars
:
raise
ValueError
(
"var_name areadly exist: {}"
"the old pos is {}, the new pos is {}"
.
format
(
var_name
,
broadcast_vars
[
var_name
][
"broadcast_pos"
],
idx
))
broadcast_vars
[
var_name
]
=
{
"fill_constant_pos"
:
-
1
,
"broadcast_pos"
:
idx
,
}
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"fill_constant"
:
...
...
@@ -61,14 +64,15 @@ def check_broadcast(block):
last_sync_calc_op_idx
=
idx
continue
if
op
.
type
==
"c_broadcast"
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
"@BroadCast"
in
var_name
:
if
broadcast_vars
[
var_name
][
"fill_constant_pos"
]
!=
-
1
:
assert
(
last_sync_calc_op_idx
!=
-
1
)
assert
(
broadcast_vars
[
var_name
][
"fill_constant_pos"
]
<
last_sync_calc_op_idx
)
assert
(
last_sync_calc_op_idx
<
idx
)
continue
if
op
.
all_attrs
()[
"use_calc_stream"
]
==
False
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
"@BroadCast"
in
var_name
:
if
broadcast_vars
[
var_name
][
"fill_constant_pos"
]
!=
-
1
:
assert
(
last_sync_calc_op_idx
!=
-
1
)
assert
(
broadcast_vars
[
var_name
][
"fill_constant_pos"
]
<
last_sync_calc_op_idx
)
assert
(
last_sync_calc_op_idx
<
idx
)
continue
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
in
broadcast_vars
:
assert
(
broadcast_vars
[
input_name
][
"broadcast_pos"
]
!=
-
1
)
...
...
@@ -78,7 +82,7 @@ def check_broadcast(block):
return
def
check_allreduce_sum
(
block
,
shard
,
dp_ring_id
=-
1
):
def
check_allreduce_sum
(
block
,
shard
,
sharding_ring_id
,
dp_ring_id
=-
1
):
"""
the op order should be:
grad:
...
...
@@ -89,32 +93,36 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
- 4: allreuce_sum_dp (dp_grads)
- 5: sync_comm (dp_grads)
- 6: op that use Var (dp_grads & sum)
should ignore and skip allreduce_op of inner_parallelism (e.g. Megatron)
"""
vars_status
=
{}
dp_grads_status
=
{}
idx_last_grad_allreduce
=
-
1
idx_amp_allreduce
=
-
1
idx_gradient_clip_allreduce
=
-
1
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"c_allreduce_sum"
:
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
param
=
var_name
.
split
(
"@"
)[
0
]
if
op
.
all_attrs
()[
"use_calc_stream"
]
==
False
:
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
param
=
var_name
.
split
(
"@"
)[
0
]
assert
'sum'
in
var_name
or
(
"@GRAD"
in
var_name
)
if
'sum'
in
var_name
or
(
not
shard
.
has_param
(
param
)):
vars_status
[
var_name
]
=
-
1
else
:
dp_grads_status
[
var_name
]
=
-
1
assert
'sum'
in
var_name
or
(
"@GRAD"
in
var_name
)
if
'sum'
in
var_name
or
(
not
shard
.
has_param
(
param
)):
vars_status
[
var_name
]
=
-
1
else
:
dp_grads_status
[
var_name
]
=
-
1
if
ring_id
!=
0
:
assert
shard
.
has_param
(
param
)
assert
ring_id
==
dp_ring_id
if
ring_id
!=
sharding_ring_id
:
assert
shard
.
has_param
(
param
)
assert
ring_id
==
dp_ring_id
if
"sum"
in
var_name
:
idx_amp_allreduce
=
idx
elif
"@GRAD"
:
idx_last_grad_allreduce
=
idx
if
"sum"
in
var_name
:
idx_amp_allreduce
=
idx
elif
"@GRAD"
:
idx_last_grad_allreduce
=
idx
if
op
.
type
==
"c_allreduce_max"
:
idx_gradient_clip_allreduce
=
idx
...
...
@@ -130,36 +138,38 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
dp_grads_status
[
var_name
]
=
1
elif
op
.
type
==
"c_allreduce_sum"
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
if
ring_id
==
0
:
if
var_name
in
vars_status
:
_status
=
vars_status
[
var_name
]
else
:
_status
=
dp_grads_status
[
var_name
]
if
_status
==
-
1
:
raise
ValueError
(
"{} is not generated, but you are"
"trying to all-reduce it"
.
format
(
var_name
))
if
_status
==
0
:
raise
ValueError
(
"There should be a sync_calc op "
"after generate Var: {} and before the"
"c_allreduce_sum op"
.
format
(
var_name
))
assert
(
_status
==
1
)
if
var_name
in
vars_status
:
vars_status
[
var_name
]
=
2
if
op
.
all_attrs
()[
"use_calc_stream"
]
==
False
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
if
ring_id
==
sharding_ring_id
:
if
var_name
in
vars_status
:
_status
=
vars_status
[
var_name
]
else
:
_status
=
dp_grads_status
[
var_name
]
if
_status
==
-
1
:
raise
ValueError
(
"{} is not generated, but you are"
"trying to all-reduce it"
.
format
(
var_name
))
if
_status
==
0
:
raise
ValueError
(
"There should be a sync_calc op "
"after generate Var: {} and before the"
"c_allreduce_sum op"
.
format
(
var_name
))
assert
(
_status
==
1
)
if
var_name
in
vars_status
:
vars_status
[
var_name
]
=
2
else
:
dp_grads_status
[
var_name
]
=
2
else
:
dp_grads_status
[
var_name
]
=
2
else
:
assert
ring_id
==
dp_ring_id
param
=
var_name
.
split
(
"@"
)[
0
]
assert
shard
.
has_param
(
param
)
assert
dp_grads_status
[
var_name
]
==
3
dp_grads_status
[
var_name
]
=
4
assert
ring_id
==
dp_ring_id
param
=
var_name
.
split
(
"@"
)[
0
]
assert
shard
.
has_param
(
param
)
assert
dp_grads_status
[
var_name
]
==
3
dp_grads_status
[
var_name
]
=
4
elif
op
.
type
==
"c_sync_comm_stream"
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
if
ring_id
==
0
:
if
ring_id
==
sharding_ring_id
:
for
var_name
in
op
.
desc
.
input_arg_names
():
if
var_name
in
vars_status
:
assert
vars_status
[
var_name
]
==
2
...
...
@@ -217,9 +227,14 @@ def get_valid_op_role(block, insert_idx):
return OpRole.Forward or OpRole.Backward
"""
op_role
=
block
.
ops
[
insert_idx
].
attr
(
'op_role'
)
if
(
insert_idx
>=
len
(
block
.
ops
))
or
(
op_role
in
[
int
(
OpRole
.
Backward
),
int
(
OpRole
.
Optimize
)]):
return
OpRole
.
Backward
#if (insert_idx >= len(block.ops)) or (
# op_role in [int(OpRole.Backward), int(OpRole.Optimize)]):
# return OpRole.Backward
#if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
# return OpRole.Forward
if
insert_idx
>=
len
(
block
.
ops
):
return
OpRole
.
Optimize
if
op_role
==
int
(
OpRole
.
Backward
):
return
OpRole
.
Backward
if
op_role
==
int
(
OpRole
.
Optimize
):
return
OpRole
.
Optimize
if
op_role
in
[
int
(
OpRole
.
Forward
),
int
(
OpRole
.
Loss
)]:
return
OpRole
.
Forward
...
...
@@ -428,7 +443,7 @@ def comm_analyse(main_program):
count
))
def
add_sync_comm
(
program
,
dist_strategy
):
def
add_sync_comm
(
program
,
nccl_ids
):
"""
When clone a test prog by clone from the sharding main prog,
part of the sync_comm op maybe be pruned by mistake, this function
...
...
@@ -438,6 +453,9 @@ def add_sync_comm(program, dist_strategy):
#NOTE (liangjianzhong): only support one comm stream by now, use more than one
# comm streams will cause error. should be revise in future.
assert
isinstance
(
nccl_ids
,
list
),
"the second argument of this function should be a list of nccl_ids"
block
=
program
.
global_block
()
not_sync_vars
=
set
([])
for
op
in
block
.
ops
:
...
...
@@ -448,7 +466,7 @@ def add_sync_comm(program, dist_strategy):
for
input_name
in
op
.
desc
.
input_arg_names
():
not_sync_vars
.
remove
(
input_name
)
if
not_sync_vars
:
for
nccl_id
in
range
(
dist_strategy
.
nccl_comm_num
)
:
for
nccl_id
in
nccl_ids
:
block
.
append_op
(
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
list
(
not_sync_vars
)},
...
...
@@ -467,6 +485,9 @@ def save_persistables(exe, dirname, main_program, filename=None):
This function handles the model saving for sharding training.
"""
if
main_program
.
_pipeline_opt
:
main_program
=
main_program
.
_pipeline_opt
[
'section_program'
][
'program'
]
def
is_opt_vars
(
var
):
# NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer
# now only Momentum and adam are compatible with sharding
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
浏览文件 @
920806db
此差异已折叠。
点击以展开。
python/paddle/fluid/backward.py
浏览文件 @
920806db
...
...
@@ -115,7 +115,7 @@ class ProgramStats(object):
updated_min_idx
=
min_idx
while
idx_
>
pre_segment_end_idx
:
if
is_amp_cast
(
self
.
ops
[
idx_
]):
_logger
.
debug
(
"found amp-cast op: {}, : {}"
.
format
(
self
.
ops
[
_logger
.
info
(
"found amp-cast op: {}, : {}"
.
format
(
self
.
ops
[
idx_
].
desc
.
type
(),
self
.
ops
[
idx_
].
desc
.
input_arg_names
()[
0
]))
updated_min_idx
=
idx_
...
...
@@ -155,7 +155,7 @@ class ProgramStats(object):
sorted_checkpoints
=
[]
for
name
in
checkpoints_name
:
if
name
not
in
self
.
var_op_deps
:
_logger
.
debug
(
_logger
.
info
(
"Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program."
%
name
)
elif
self
.
var_op_deps
[
name
][
"var_as_output_ops"
]
==
[]:
...
...
@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
desc
)
new_op_desc
.
_set_attr
(
op_role_attr_name
,
backward
)
if
desc
.
has_attr
(
'op_device'
):
new_op_desc
.
_set_attr
(
'op_device'
,
desc
.
attr
(
'op_device'
))
result_descs
.
append
(
new_op_desc
)
return
result_descs
...
...
@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block):
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
desc
)
new_op_desc
.
_set_attr
(
op_role_attr_name
,
backward
)
if
desc
.
has_attr
(
'op_device'
):
new_op_desc
.
_set_attr
(
'op_device'
,
desc
.
attr
(
'op_device'
))
result_descs
.
append
(
new_op_desc
)
return
result_descs
...
...
@@ -784,7 +788,6 @@ def _append_backward_ops_with_checkpoints_(
start_idx
=
0
pre_segment_end_idx
=
-
1
while
True
:
_logger
.
debug
(
"FW op range[0] - [{}]"
.
format
(
len
(
ops
)))
if
start_idx
>=
len
(
checkpoints_name
)
-
1
:
break
# min_idx: checkpoint_1' s input op
...
...
@@ -797,6 +800,9 @@ def _append_backward_ops_with_checkpoints_(
min_idx
=
program_stat
.
_update_segment_start
(
min_idx
,
pre_segment_end_idx
)
segments
.
append
([
min_idx
,
max_idx
+
1
])
else
:
_logger
.
info
(
"Could not recompute op range [{}] - [{}] "
.
format
(
min_idx
,
max_idx
+
1
))
start_idx
+=
1
...
...
@@ -806,15 +812,15 @@ def _append_backward_ops_with_checkpoints_(
recompute_segments
=
segments
for
i
,
(
idx1
,
idx2
)
in
enumerate
(
recompute_segments
):
_logger
.
debug
(
"recompute segment[{}]"
.
format
(
i
))
_logger
.
debug
(
"segment start op: [{}]: [{}]"
.
format
(
ops
[
idx1
].
desc
.
type
(
_logger
.
info
(
"recompute segment[{}]"
.
format
(
i
))
_logger
.
info
(
"segment start op: [{}]: [{}]"
.
format
(
ops
[
idx1
].
desc
.
type
(
),
ops
[
idx1
].
desc
.
input_arg_names
()))
_logger
.
debug
(
"segment end op: [{}]: [{}]"
.
format
(
ops
[
_logger
.
info
(
"segment end op: [{}]: [{}]"
.
format
(
ops
[
idx2
-
1
].
desc
.
type
(),
ops
[
idx2
-
1
].
desc
.
input_arg_names
()))
_logger
.
debug
(
"recompute segment[{}]"
.
format
(
i
))
_logger
.
debug
(
"segment start op: [{}]: [{}]"
.
format
(
ops
[
idx1
].
desc
.
type
(
_logger
.
info
(
"recompute segment[{}]"
.
format
(
i
))
_logger
.
info
(
"segment start op: [{}]: [{}]"
.
format
(
ops
[
idx1
].
desc
.
type
(
),
ops
[
idx1
].
desc
.
input_arg_names
()))
_logger
.
debug
(
"segment end op: [{}]: [{}]"
.
format
(
ops
[
_logger
.
info
(
"segment end op: [{}]: [{}]"
.
format
(
ops
[
idx2
-
1
].
desc
.
type
(),
ops
[
idx2
-
1
].
desc
.
input_arg_names
()))
# 2) go through all forward ops and induct all variables that will be hold in memory
...
...
@@ -825,9 +831,9 @@ def _append_backward_ops_with_checkpoints_(
program_stat
.
get_out_of_subgraph_vars
(
segment
[
0
],
segment
[
1
]))
cross_vars
=
set
(
vars_should_be_hold
)
-
set
(
checkpoints_name
)
_logger
.
debug
(
"found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars"
.
format
(
\
_logger
.
info
(
"found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars"
.
format
(
\
len
(
cross_vars
),
cross_vars
))
_logger
.
debug
(
"found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars"
.
format
(
\
_logger
.
info
(
"found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars"
.
format
(
\
len
(
cross_vars
),
cross_vars
))
# b. output of seed op should be kept in memory
...
...
@@ -843,6 +849,7 @@ def _append_backward_ops_with_checkpoints_(
vars_in_memory
=
vars_should_be_hold
+
checkpoints_name
max_calculated_op_position
=
len
(
ops
)
device_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpDeviceAttrName
()
if
recompute_segments
==
[]:
gap_ops
=
ops
[
0
:
max_calculated_op_position
]
for
op
in
reversed
(
gap_ops
):
...
...
@@ -852,6 +859,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_
(
op
.
desc
,
"with_sub_block"
))
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
# Set device for grad_op according to forward Op
if
op
.
desc
.
has_attr
(
device_attr_name
):
op_device
=
op
.
desc
.
attr
(
device_attr_name
)
for
op_desc
in
grad_op_desc
:
op_desc
.
_set_attr
(
device_attr_name
,
op_device
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
)
grad_op_descs
.
extend
(
added_descs
)
grad_to_var
.
update
(
op_grad_to_var
)
...
...
@@ -866,6 +878,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_
(
op
.
desc
,
"with_sub_block"
))
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
# Set device for grad_op according to forward Op
if
op
.
desc
.
has_attr
(
device_attr_name
):
op_device
=
op
.
desc
.
attr
(
device_attr_name
)
for
op_desc
in
grad_op_desc
:
op_desc
.
_set_attr
(
device_attr_name
,
op_device
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
)
grad_op_descs
.
extend
(
added_descs
)
grad_to_var
.
update
(
op_grad_to_var
)
...
...
@@ -888,6 +905,18 @@ def _append_backward_ops_with_checkpoints_(
continue
if
name
not
in
var_name_dict
:
var_name_dict
[
name
]
=
name
+
var_suffix
# we should create the rename var in subprog, otherwise its VarType will be BOOL
block
.
create_var
(
name
=
var_name_dict
[
name
],
shape
=
block
.
program
.
global_block
().
var
(
name
).
shape
,
dtype
=
block
.
program
.
global_block
().
var
(
name
).
dtype
,
type
=
block
.
program
.
global_block
().
var
(
name
).
type
,
persistable
=
block
.
program
.
global_block
().
var
(
name
).
persistable
,
stop_gradient
=
block
.
program
.
global_block
().
var
(
name
)
.
stop_gradient
)
# 3.a. add ops in current recompute_segment as forward recomputation ops
buffer_descs
=
_add_needed_descs_to_block
(
ff_ops
,
buffer_block
,
block
,
vars_in_memory
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录