Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
920806db
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
920806db
编写于
2月 07, 2021
作者:
S
sandyhouse
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
70eb21c5
变更
4
展开全部
显示空白变更内容
内联
并排
Showing
4 changed file
with
533 addition
and
94 deletion
+533
-94
python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py
...addle/distributed/fleet/meta_optimizers/sharding/prune.py
+3
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
...addle/distributed/fleet/meta_optimizers/sharding/utils.py
+86
-65
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+404
-18
python/paddle/fluid/backward.py
python/paddle/fluid/backward.py
+40
-11
未找到文件。
python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py
浏览文件 @
920806db
...
@@ -126,6 +126,9 @@ class ProgramDeps(object):
...
@@ -126,6 +126,9 @@ class ProgramDeps(object):
def
should_remove_op
(
self
,
op_idx
):
def
should_remove_op
(
self
,
op_idx
):
op
=
self
.
_block
.
ops
[
op_idx
]
op
=
self
.
_block
.
ops
[
op_idx
]
# remove check_finite_and_unscale op if its input 'X' is empty
if
op
.
type
==
'check_finite_and_unscale'
and
len
(
op
.
input
(
'X'
))
==
0
:
return
True
for
output_name
in
op
.
desc
.
output_arg_names
():
for
output_name
in
op
.
desc
.
output_arg_names
():
if
output_name
not
in
self
.
_should_removed_var
:
if
output_name
not
in
self
.
_should_removed_var
:
return
False
return
False
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
浏览文件 @
920806db
...
@@ -28,17 +28,20 @@ def check_broadcast(block):
...
@@ -28,17 +28,20 @@ def check_broadcast(block):
if the broadcasted var has a fill_constant op, the fill_constant
if the broadcasted var has a fill_constant op, the fill_constant
op should stay forward before the broadcast op, and before a
op should stay forward before the broadcast op, and before a
sync_calc op. Otherwise, raise error.
sync_calc op. Otherwise, raise error.
should ignore and skip broadcast_op of inner_parallelism (e.g. Megatron)
"""
"""
broadcast_vars
=
{}
broadcast_vars
=
{}
for
idx
,
op
in
enumerate
(
block
.
ops
):
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"c_broadcast"
:
if
op
.
type
==
"c_broadcast"
:
if
op
.
all_attrs
()[
"use_calc_stream"
]
==
False
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
"@BroadCast"
in
var_name
:
if
"@BroadCast"
in
var_name
:
if
var_name
in
broadcast_vars
:
if
var_name
in
broadcast_vars
:
raise
ValueError
(
"var_name areadly exist: {}"
raise
ValueError
(
"var_name areadly exist: {}"
"the old pos is {}, the new pos is {}"
.
"the old pos is {}, the new pos is {}"
.
format
(
var_name
,
broadcast_vars
[
var_name
]
[
format
(
var_name
,
broadcast_vars
[
"broadcast_pos"
],
idx
))
var_name
][
"broadcast_pos"
],
idx
))
broadcast_vars
[
var_name
]
=
{
broadcast_vars
[
var_name
]
=
{
"fill_constant_pos"
:
-
1
,
"fill_constant_pos"
:
-
1
,
"broadcast_pos"
:
idx
,
"broadcast_pos"
:
idx
,
...
@@ -61,6 +64,7 @@ def check_broadcast(block):
...
@@ -61,6 +64,7 @@ def check_broadcast(block):
last_sync_calc_op_idx
=
idx
last_sync_calc_op_idx
=
idx
continue
continue
if
op
.
type
==
"c_broadcast"
:
if
op
.
type
==
"c_broadcast"
:
if
op
.
all_attrs
()[
"use_calc_stream"
]
==
False
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
"@BroadCast"
in
var_name
:
if
"@BroadCast"
in
var_name
:
if
broadcast_vars
[
var_name
][
"fill_constant_pos"
]
!=
-
1
:
if
broadcast_vars
[
var_name
][
"fill_constant_pos"
]
!=
-
1
:
...
@@ -78,7 +82,7 @@ def check_broadcast(block):
...
@@ -78,7 +82,7 @@ def check_broadcast(block):
return
return
def
check_allreduce_sum
(
block
,
shard
,
dp_ring_id
=-
1
):
def
check_allreduce_sum
(
block
,
shard
,
sharding_ring_id
,
dp_ring_id
=-
1
):
"""
"""
the op order should be:
the op order should be:
grad:
grad:
...
@@ -89,14 +93,18 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
...
@@ -89,14 +93,18 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
- 4: allreuce_sum_dp (dp_grads)
- 4: allreuce_sum_dp (dp_grads)
- 5: sync_comm (dp_grads)
- 5: sync_comm (dp_grads)
- 6: op that use Var (dp_grads & sum)
- 6: op that use Var (dp_grads & sum)
should ignore and skip allreduce_op of inner_parallelism (e.g. Megatron)
"""
"""
vars_status
=
{}
vars_status
=
{}
dp_grads_status
=
{}
dp_grads_status
=
{}
idx_last_grad_allreduce
=
-
1
idx_last_grad_allreduce
=
-
1
idx_amp_allreduce
=
-
1
idx_amp_allreduce
=
-
1
idx_gradient_clip_allreduce
=
-
1
idx_gradient_clip_allreduce
=
-
1
for
idx
,
op
in
enumerate
(
block
.
ops
):
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"c_allreduce_sum"
:
if
op
.
type
==
"c_allreduce_sum"
:
if
op
.
all_attrs
()[
"use_calc_stream"
]
==
False
:
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
param
=
var_name
.
split
(
"@"
)[
0
]
param
=
var_name
.
split
(
"@"
)[
0
]
...
@@ -107,7 +115,7 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
...
@@ -107,7 +115,7 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
else
:
else
:
dp_grads_status
[
var_name
]
=
-
1
dp_grads_status
[
var_name
]
=
-
1
if
ring_id
!=
0
:
if
ring_id
!=
sharding_ring_id
:
assert
shard
.
has_param
(
param
)
assert
shard
.
has_param
(
param
)
assert
ring_id
==
dp_ring_id
assert
ring_id
==
dp_ring_id
...
@@ -130,16 +138,18 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
...
@@ -130,16 +138,18 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
dp_grads_status
[
var_name
]
=
1
dp_grads_status
[
var_name
]
=
1
elif
op
.
type
==
"c_allreduce_sum"
:
elif
op
.
type
==
"c_allreduce_sum"
:
if
op
.
all_attrs
()[
"use_calc_stream"
]
==
False
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
if
ring_id
==
0
:
if
ring_id
==
sharding_ring_id
:
if
var_name
in
vars_status
:
if
var_name
in
vars_status
:
_status
=
vars_status
[
var_name
]
_status
=
vars_status
[
var_name
]
else
:
else
:
_status
=
dp_grads_status
[
var_name
]
_status
=
dp_grads_status
[
var_name
]
if
_status
==
-
1
:
if
_status
==
-
1
:
raise
ValueError
(
"{} is not generated, but you are"
raise
ValueError
(
"{} is not generated, but you are"
"trying to all-reduce it"
.
format
(
var_name
))
"trying to all-reduce it"
.
format
(
var_name
))
if
_status
==
0
:
if
_status
==
0
:
raise
ValueError
(
"There should be a sync_calc op "
raise
ValueError
(
"There should be a sync_calc op "
"after generate Var: {} and before the"
"after generate Var: {} and before the"
...
@@ -159,7 +169,7 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
...
@@ -159,7 +169,7 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
elif
op
.
type
==
"c_sync_comm_stream"
:
elif
op
.
type
==
"c_sync_comm_stream"
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
if
ring_id
==
0
:
if
ring_id
==
sharding_ring_id
:
for
var_name
in
op
.
desc
.
input_arg_names
():
for
var_name
in
op
.
desc
.
input_arg_names
():
if
var_name
in
vars_status
:
if
var_name
in
vars_status
:
assert
vars_status
[
var_name
]
==
2
assert
vars_status
[
var_name
]
==
2
...
@@ -217,9 +227,14 @@ def get_valid_op_role(block, insert_idx):
...
@@ -217,9 +227,14 @@ def get_valid_op_role(block, insert_idx):
return OpRole.Forward or OpRole.Backward
return OpRole.Forward or OpRole.Backward
"""
"""
op_role
=
block
.
ops
[
insert_idx
].
attr
(
'op_role'
)
op_role
=
block
.
ops
[
insert_idx
].
attr
(
'op_role'
)
if
(
insert_idx
>=
len
(
block
.
ops
))
or
(
#if (insert_idx >= len(block.ops)) or (
op_role
in
[
int
(
OpRole
.
Backward
),
int
(
OpRole
.
Optimize
)]):
# op_role in [int(OpRole.Backward), int(OpRole.Optimize)]):
return
OpRole
.
Backward
# return OpRole.Backward
#if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
# return OpRole.Forward
if
insert_idx
>=
len
(
block
.
ops
):
return
OpRole
.
Optimize
if
op_role
==
int
(
OpRole
.
Backward
):
return
OpRole
.
Backward
if
op_role
==
int
(
OpRole
.
Optimize
):
return
OpRole
.
Optimize
if
op_role
in
[
int
(
OpRole
.
Forward
),
int
(
OpRole
.
Loss
)]:
if
op_role
in
[
int
(
OpRole
.
Forward
),
int
(
OpRole
.
Loss
)]:
return
OpRole
.
Forward
return
OpRole
.
Forward
...
@@ -428,7 +443,7 @@ def comm_analyse(main_program):
...
@@ -428,7 +443,7 @@ def comm_analyse(main_program):
count
))
count
))
def
add_sync_comm
(
program
,
dist_strategy
):
def
add_sync_comm
(
program
,
nccl_ids
):
"""
"""
When clone a test prog by clone from the sharding main prog,
When clone a test prog by clone from the sharding main prog,
part of the sync_comm op maybe be pruned by mistake, this function
part of the sync_comm op maybe be pruned by mistake, this function
...
@@ -438,6 +453,9 @@ def add_sync_comm(program, dist_strategy):
...
@@ -438,6 +453,9 @@ def add_sync_comm(program, dist_strategy):
#NOTE (liangjianzhong): only support one comm stream by now, use more than one
#NOTE (liangjianzhong): only support one comm stream by now, use more than one
# comm streams will cause error. should be revise in future.
# comm streams will cause error. should be revise in future.
assert
isinstance
(
nccl_ids
,
list
),
"the second argument of this function should be a list of nccl_ids"
block
=
program
.
global_block
()
block
=
program
.
global_block
()
not_sync_vars
=
set
([])
not_sync_vars
=
set
([])
for
op
in
block
.
ops
:
for
op
in
block
.
ops
:
...
@@ -448,7 +466,7 @@ def add_sync_comm(program, dist_strategy):
...
@@ -448,7 +466,7 @@ def add_sync_comm(program, dist_strategy):
for
input_name
in
op
.
desc
.
input_arg_names
():
for
input_name
in
op
.
desc
.
input_arg_names
():
not_sync_vars
.
remove
(
input_name
)
not_sync_vars
.
remove
(
input_name
)
if
not_sync_vars
:
if
not_sync_vars
:
for
nccl_id
in
range
(
dist_strategy
.
nccl_comm_num
)
:
for
nccl_id
in
nccl_ids
:
block
.
append_op
(
block
.
append_op
(
type
=
'c_sync_comm_stream'
,
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
list
(
not_sync_vars
)},
inputs
=
{
'X'
:
list
(
not_sync_vars
)},
...
@@ -467,6 +485,9 @@ def save_persistables(exe, dirname, main_program, filename=None):
...
@@ -467,6 +485,9 @@ def save_persistables(exe, dirname, main_program, filename=None):
This function handles the model saving for sharding training.
This function handles the model saving for sharding training.
"""
"""
if
main_program
.
_pipeline_opt
:
main_program
=
main_program
.
_pipeline_opt
[
'section_program'
][
'program'
]
def
is_opt_vars
(
var
):
def
is_opt_vars
(
var
):
# NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer
# NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer
# now only Momentum and adam are compatible with sharding
# now only Momentum and adam are compatible with sharding
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
浏览文件 @
920806db
此差异已折叠。
点击以展开。
python/paddle/fluid/backward.py
浏览文件 @
920806db
...
@@ -115,7 +115,7 @@ class ProgramStats(object):
...
@@ -115,7 +115,7 @@ class ProgramStats(object):
updated_min_idx
=
min_idx
updated_min_idx
=
min_idx
while
idx_
>
pre_segment_end_idx
:
while
idx_
>
pre_segment_end_idx
:
if
is_amp_cast
(
self
.
ops
[
idx_
]):
if
is_amp_cast
(
self
.
ops
[
idx_
]):
_logger
.
debug
(
"found amp-cast op: {}, : {}"
.
format
(
self
.
ops
[
_logger
.
info
(
"found amp-cast op: {}, : {}"
.
format
(
self
.
ops
[
idx_
].
desc
.
type
(),
self
.
ops
[
idx_
].
desc
.
input_arg_names
()[
idx_
].
desc
.
type
(),
self
.
ops
[
idx_
].
desc
.
input_arg_names
()[
0
]))
0
]))
updated_min_idx
=
idx_
updated_min_idx
=
idx_
...
@@ -155,7 +155,7 @@ class ProgramStats(object):
...
@@ -155,7 +155,7 @@ class ProgramStats(object):
sorted_checkpoints
=
[]
sorted_checkpoints
=
[]
for
name
in
checkpoints_name
:
for
name
in
checkpoints_name
:
if
name
not
in
self
.
var_op_deps
:
if
name
not
in
self
.
var_op_deps
:
_logger
.
debug
(
_logger
.
info
(
"Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program."
"Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program."
%
name
)
%
name
)
elif
self
.
var_op_deps
[
name
][
"var_as_output_ops"
]
==
[]:
elif
self
.
var_op_deps
[
name
][
"var_as_output_ops"
]
==
[]:
...
@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
...
@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
desc
)
new_op_desc
.
copy_from
(
desc
)
new_op_desc
.
_set_attr
(
op_role_attr_name
,
backward
)
new_op_desc
.
_set_attr
(
op_role_attr_name
,
backward
)
if
desc
.
has_attr
(
'op_device'
):
new_op_desc
.
_set_attr
(
'op_device'
,
desc
.
attr
(
'op_device'
))
result_descs
.
append
(
new_op_desc
)
result_descs
.
append
(
new_op_desc
)
return
result_descs
return
result_descs
...
@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block):
...
@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block):
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
desc
)
new_op_desc
.
copy_from
(
desc
)
new_op_desc
.
_set_attr
(
op_role_attr_name
,
backward
)
new_op_desc
.
_set_attr
(
op_role_attr_name
,
backward
)
if
desc
.
has_attr
(
'op_device'
):
new_op_desc
.
_set_attr
(
'op_device'
,
desc
.
attr
(
'op_device'
))
result_descs
.
append
(
new_op_desc
)
result_descs
.
append
(
new_op_desc
)
return
result_descs
return
result_descs
...
@@ -784,7 +788,6 @@ def _append_backward_ops_with_checkpoints_(
...
@@ -784,7 +788,6 @@ def _append_backward_ops_with_checkpoints_(
start_idx
=
0
start_idx
=
0
pre_segment_end_idx
=
-
1
pre_segment_end_idx
=
-
1
while
True
:
while
True
:
_logger
.
debug
(
"FW op range[0] - [{}]"
.
format
(
len
(
ops
)))
if
start_idx
>=
len
(
checkpoints_name
)
-
1
:
if
start_idx
>=
len
(
checkpoints_name
)
-
1
:
break
break
# min_idx: checkpoint_1' s input op
# min_idx: checkpoint_1' s input op
...
@@ -797,6 +800,9 @@ def _append_backward_ops_with_checkpoints_(
...
@@ -797,6 +800,9 @@ def _append_backward_ops_with_checkpoints_(
min_idx
=
program_stat
.
_update_segment_start
(
min_idx
=
program_stat
.
_update_segment_start
(
min_idx
,
pre_segment_end_idx
)
min_idx
,
pre_segment_end_idx
)
segments
.
append
([
min_idx
,
max_idx
+
1
])
segments
.
append
([
min_idx
,
max_idx
+
1
])
else
:
_logger
.
info
(
"Could not recompute op range [{}] - [{}] "
.
format
(
min_idx
,
max_idx
+
1
))
start_idx
+=
1
start_idx
+=
1
...
@@ -806,15 +812,15 @@ def _append_backward_ops_with_checkpoints_(
...
@@ -806,15 +812,15 @@ def _append_backward_ops_with_checkpoints_(
recompute_segments
=
segments
recompute_segments
=
segments
for
i
,
(
idx1
,
idx2
)
in
enumerate
(
recompute_segments
):
for
i
,
(
idx1
,
idx2
)
in
enumerate
(
recompute_segments
):
_logger
.
debug
(
"recompute segment[{}]"
.
format
(
i
))
_logger
.
info
(
"recompute segment[{}]"
.
format
(
i
))
_logger
.
debug
(
"segment start op: [{}]: [{}]"
.
format
(
ops
[
idx1
].
desc
.
type
(
_logger
.
info
(
"segment start op: [{}]: [{}]"
.
format
(
ops
[
idx1
].
desc
.
type
(
),
ops
[
idx1
].
desc
.
input_arg_names
()))
),
ops
[
idx1
].
desc
.
input_arg_names
()))
_logger
.
debug
(
"segment end op: [{}]: [{}]"
.
format
(
ops
[
_logger
.
info
(
"segment end op: [{}]: [{}]"
.
format
(
ops
[
idx2
-
1
].
desc
.
type
(),
ops
[
idx2
-
1
].
desc
.
input_arg_names
()))
idx2
-
1
].
desc
.
type
(),
ops
[
idx2
-
1
].
desc
.
input_arg_names
()))
_logger
.
debug
(
"recompute segment[{}]"
.
format
(
i
))
_logger
.
info
(
"recompute segment[{}]"
.
format
(
i
))
_logger
.
debug
(
"segment start op: [{}]: [{}]"
.
format
(
ops
[
idx1
].
desc
.
type
(
_logger
.
info
(
"segment start op: [{}]: [{}]"
.
format
(
ops
[
idx1
].
desc
.
type
(
),
ops
[
idx1
].
desc
.
input_arg_names
()))
),
ops
[
idx1
].
desc
.
input_arg_names
()))
_logger
.
debug
(
"segment end op: [{}]: [{}]"
.
format
(
ops
[
_logger
.
info
(
"segment end op: [{}]: [{}]"
.
format
(
ops
[
idx2
-
1
].
desc
.
type
(),
ops
[
idx2
-
1
].
desc
.
input_arg_names
()))
idx2
-
1
].
desc
.
type
(),
ops
[
idx2
-
1
].
desc
.
input_arg_names
()))
# 2) go through all forward ops and induct all variables that will be hold in memory
# 2) go through all forward ops and induct all variables that will be hold in memory
...
@@ -825,9 +831,9 @@ def _append_backward_ops_with_checkpoints_(
...
@@ -825,9 +831,9 @@ def _append_backward_ops_with_checkpoints_(
program_stat
.
get_out_of_subgraph_vars
(
segment
[
0
],
segment
[
1
]))
program_stat
.
get_out_of_subgraph_vars
(
segment
[
0
],
segment
[
1
]))
cross_vars
=
set
(
vars_should_be_hold
)
-
set
(
checkpoints_name
)
cross_vars
=
set
(
vars_should_be_hold
)
-
set
(
checkpoints_name
)
_logger
.
debug
(
"found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars"
.
format
(
\
_logger
.
info
(
"found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars"
.
format
(
\
len
(
cross_vars
),
cross_vars
))
len
(
cross_vars
),
cross_vars
))
_logger
.
debug
(
"found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars"
.
format
(
\
_logger
.
info
(
"found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars"
.
format
(
\
len
(
cross_vars
),
cross_vars
))
len
(
cross_vars
),
cross_vars
))
# b. output of seed op should be kept in memory
# b. output of seed op should be kept in memory
...
@@ -843,6 +849,7 @@ def _append_backward_ops_with_checkpoints_(
...
@@ -843,6 +849,7 @@ def _append_backward_ops_with_checkpoints_(
vars_in_memory
=
vars_should_be_hold
+
checkpoints_name
vars_in_memory
=
vars_should_be_hold
+
checkpoints_name
max_calculated_op_position
=
len
(
ops
)
max_calculated_op_position
=
len
(
ops
)
device_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpDeviceAttrName
()
if
recompute_segments
==
[]:
if
recompute_segments
==
[]:
gap_ops
=
ops
[
0
:
max_calculated_op_position
]
gap_ops
=
ops
[
0
:
max_calculated_op_position
]
for
op
in
reversed
(
gap_ops
):
for
op
in
reversed
(
gap_ops
):
...
@@ -852,6 +859,11 @@ def _append_backward_ops_with_checkpoints_(
...
@@ -852,6 +859,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_
(
op
.
desc
,
"with_sub_block"
))
_pretty_op_desc_
(
op
.
desc
,
"with_sub_block"
))
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
# Set device for grad_op according to forward Op
if
op
.
desc
.
has_attr
(
device_attr_name
):
op_device
=
op
.
desc
.
attr
(
device_attr_name
)
for
op_desc
in
grad_op_desc
:
op_desc
.
_set_attr
(
device_attr_name
,
op_device
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
)
grad_op_descs
.
extend
(
added_descs
)
grad_op_descs
.
extend
(
added_descs
)
grad_to_var
.
update
(
op_grad_to_var
)
grad_to_var
.
update
(
op_grad_to_var
)
...
@@ -866,6 +878,11 @@ def _append_backward_ops_with_checkpoints_(
...
@@ -866,6 +878,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_
(
op
.
desc
,
"with_sub_block"
))
_pretty_op_desc_
(
op
.
desc
,
"with_sub_block"
))
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
# Set device for grad_op according to forward Op
if
op
.
desc
.
has_attr
(
device_attr_name
):
op_device
=
op
.
desc
.
attr
(
device_attr_name
)
for
op_desc
in
grad_op_desc
:
op_desc
.
_set_attr
(
device_attr_name
,
op_device
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
)
grad_op_descs
.
extend
(
added_descs
)
grad_op_descs
.
extend
(
added_descs
)
grad_to_var
.
update
(
op_grad_to_var
)
grad_to_var
.
update
(
op_grad_to_var
)
...
@@ -888,6 +905,18 @@ def _append_backward_ops_with_checkpoints_(
...
@@ -888,6 +905,18 @@ def _append_backward_ops_with_checkpoints_(
continue
continue
if
name
not
in
var_name_dict
:
if
name
not
in
var_name_dict
:
var_name_dict
[
name
]
=
name
+
var_suffix
var_name_dict
[
name
]
=
name
+
var_suffix
# we should create the rename var in subprog, otherwise its VarType will be BOOL
block
.
create_var
(
name
=
var_name_dict
[
name
],
shape
=
block
.
program
.
global_block
().
var
(
name
).
shape
,
dtype
=
block
.
program
.
global_block
().
var
(
name
).
dtype
,
type
=
block
.
program
.
global_block
().
var
(
name
).
type
,
persistable
=
block
.
program
.
global_block
().
var
(
name
).
persistable
,
stop_gradient
=
block
.
program
.
global_block
().
var
(
name
)
.
stop_gradient
)
# 3.a. add ops in current recompute_segment as forward recomputation ops
# 3.a. add ops in current recompute_segment as forward recomputation ops
buffer_descs
=
_add_needed_descs_to_block
(
ff_ops
,
buffer_block
,
block
,
buffer_descs
=
_add_needed_descs_to_block
(
ff_ops
,
buffer_block
,
block
,
vars_in_memory
)
vars_in_memory
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录