Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
920806db
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
920806db
编写于
2月 07, 2021
作者:
S
sandyhouse
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
70eb21c5
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
533 addition
and
94 deletion
+533
-94
python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py
...addle/distributed/fleet/meta_optimizers/sharding/prune.py
+3
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
...addle/distributed/fleet/meta_optimizers/sharding/utils.py
+86
-65
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+404
-18
python/paddle/fluid/backward.py
python/paddle/fluid/backward.py
+40
-11
未找到文件。
python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py
浏览文件 @
920806db
...
@@ -126,6 +126,9 @@ class ProgramDeps(object):
...
@@ -126,6 +126,9 @@ class ProgramDeps(object):
def
should_remove_op
(
self
,
op_idx
):
def
should_remove_op
(
self
,
op_idx
):
op
=
self
.
_block
.
ops
[
op_idx
]
op
=
self
.
_block
.
ops
[
op_idx
]
# remove check_finite_and_unscale op if its input 'X' is empty
if
op
.
type
==
'check_finite_and_unscale'
and
len
(
op
.
input
(
'X'
))
==
0
:
return
True
for
output_name
in
op
.
desc
.
output_arg_names
():
for
output_name
in
op
.
desc
.
output_arg_names
():
if
output_name
not
in
self
.
_should_removed_var
:
if
output_name
not
in
self
.
_should_removed_var
:
return
False
return
False
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
浏览文件 @
920806db
...
@@ -28,17 +28,20 @@ def check_broadcast(block):
...
@@ -28,17 +28,20 @@ def check_broadcast(block):
if the broadcasted var has a fill_constant op, the fill_constant
if the broadcasted var has a fill_constant op, the fill_constant
op should stay forward before the broadcast op, and before a
op should stay forward before the broadcast op, and before a
sync_calc op. Otherwise, raise error.
sync_calc op. Otherwise, raise error.
should ignore and skip broadcast_op of inner_parallelism (e.g. Megatron)
"""
"""
broadcast_vars
=
{}
broadcast_vars
=
{}
for
idx
,
op
in
enumerate
(
block
.
ops
):
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"c_broadcast"
:
if
op
.
type
==
"c_broadcast"
:
if
op
.
all_attrs
()[
"use_calc_stream"
]
==
False
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
"@BroadCast"
in
var_name
:
if
"@BroadCast"
in
var_name
:
if
var_name
in
broadcast_vars
:
if
var_name
in
broadcast_vars
:
raise
ValueError
(
"var_name areadly exist: {}"
raise
ValueError
(
"var_name areadly exist: {}"
"the old pos is {}, the new pos is {}"
.
"the old pos is {}, the new pos is {}"
.
format
(
var_name
,
broadcast_vars
[
var_name
]
[
format
(
var_name
,
broadcast_vars
[
"broadcast_pos"
],
idx
))
var_name
][
"broadcast_pos"
],
idx
))
broadcast_vars
[
var_name
]
=
{
broadcast_vars
[
var_name
]
=
{
"fill_constant_pos"
:
-
1
,
"fill_constant_pos"
:
-
1
,
"broadcast_pos"
:
idx
,
"broadcast_pos"
:
idx
,
...
@@ -61,6 +64,7 @@ def check_broadcast(block):
...
@@ -61,6 +64,7 @@ def check_broadcast(block):
last_sync_calc_op_idx
=
idx
last_sync_calc_op_idx
=
idx
continue
continue
if
op
.
type
==
"c_broadcast"
:
if
op
.
type
==
"c_broadcast"
:
if
op
.
all_attrs
()[
"use_calc_stream"
]
==
False
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
"@BroadCast"
in
var_name
:
if
"@BroadCast"
in
var_name
:
if
broadcast_vars
[
var_name
][
"fill_constant_pos"
]
!=
-
1
:
if
broadcast_vars
[
var_name
][
"fill_constant_pos"
]
!=
-
1
:
...
@@ -78,7 +82,7 @@ def check_broadcast(block):
...
@@ -78,7 +82,7 @@ def check_broadcast(block):
return
return
def
check_allreduce_sum
(
block
,
shard
,
dp_ring_id
=-
1
):
def
check_allreduce_sum
(
block
,
shard
,
sharding_ring_id
,
dp_ring_id
=-
1
):
"""
"""
the op order should be:
the op order should be:
grad:
grad:
...
@@ -89,14 +93,18 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
...
@@ -89,14 +93,18 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
- 4: allreuce_sum_dp (dp_grads)
- 4: allreuce_sum_dp (dp_grads)
- 5: sync_comm (dp_grads)
- 5: sync_comm (dp_grads)
- 6: op that use Var (dp_grads & sum)
- 6: op that use Var (dp_grads & sum)
should ignore and skip allreduce_op of inner_parallelism (e.g. Megatron)
"""
"""
vars_status
=
{}
vars_status
=
{}
dp_grads_status
=
{}
dp_grads_status
=
{}
idx_last_grad_allreduce
=
-
1
idx_last_grad_allreduce
=
-
1
idx_amp_allreduce
=
-
1
idx_amp_allreduce
=
-
1
idx_gradient_clip_allreduce
=
-
1
idx_gradient_clip_allreduce
=
-
1
for
idx
,
op
in
enumerate
(
block
.
ops
):
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"c_allreduce_sum"
:
if
op
.
type
==
"c_allreduce_sum"
:
if
op
.
all_attrs
()[
"use_calc_stream"
]
==
False
:
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
param
=
var_name
.
split
(
"@"
)[
0
]
param
=
var_name
.
split
(
"@"
)[
0
]
...
@@ -107,7 +115,7 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
...
@@ -107,7 +115,7 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
else
:
else
:
dp_grads_status
[
var_name
]
=
-
1
dp_grads_status
[
var_name
]
=
-
1
if
ring_id
!=
0
:
if
ring_id
!=
sharding_ring_id
:
assert
shard
.
has_param
(
param
)
assert
shard
.
has_param
(
param
)
assert
ring_id
==
dp_ring_id
assert
ring_id
==
dp_ring_id
...
@@ -130,16 +138,18 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
...
@@ -130,16 +138,18 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
dp_grads_status
[
var_name
]
=
1
dp_grads_status
[
var_name
]
=
1
elif
op
.
type
==
"c_allreduce_sum"
:
elif
op
.
type
==
"c_allreduce_sum"
:
if
op
.
all_attrs
()[
"use_calc_stream"
]
==
False
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
if
ring_id
==
0
:
if
ring_id
==
sharding_ring_id
:
if
var_name
in
vars_status
:
if
var_name
in
vars_status
:
_status
=
vars_status
[
var_name
]
_status
=
vars_status
[
var_name
]
else
:
else
:
_status
=
dp_grads_status
[
var_name
]
_status
=
dp_grads_status
[
var_name
]
if
_status
==
-
1
:
if
_status
==
-
1
:
raise
ValueError
(
"{} is not generated, but you are"
raise
ValueError
(
"{} is not generated, but you are"
"trying to all-reduce it"
.
format
(
var_name
))
"trying to all-reduce it"
.
format
(
var_name
))
if
_status
==
0
:
if
_status
==
0
:
raise
ValueError
(
"There should be a sync_calc op "
raise
ValueError
(
"There should be a sync_calc op "
"after generate Var: {} and before the"
"after generate Var: {} and before the"
...
@@ -159,7 +169,7 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
...
@@ -159,7 +169,7 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
elif
op
.
type
==
"c_sync_comm_stream"
:
elif
op
.
type
==
"c_sync_comm_stream"
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
if
ring_id
==
0
:
if
ring_id
==
sharding_ring_id
:
for
var_name
in
op
.
desc
.
input_arg_names
():
for
var_name
in
op
.
desc
.
input_arg_names
():
if
var_name
in
vars_status
:
if
var_name
in
vars_status
:
assert
vars_status
[
var_name
]
==
2
assert
vars_status
[
var_name
]
==
2
...
@@ -217,9 +227,14 @@ def get_valid_op_role(block, insert_idx):
...
@@ -217,9 +227,14 @@ def get_valid_op_role(block, insert_idx):
return OpRole.Forward or OpRole.Backward
return OpRole.Forward or OpRole.Backward
"""
"""
op_role
=
block
.
ops
[
insert_idx
].
attr
(
'op_role'
)
op_role
=
block
.
ops
[
insert_idx
].
attr
(
'op_role'
)
if
(
insert_idx
>=
len
(
block
.
ops
))
or
(
#if (insert_idx >= len(block.ops)) or (
op_role
in
[
int
(
OpRole
.
Backward
),
int
(
OpRole
.
Optimize
)]):
# op_role in [int(OpRole.Backward), int(OpRole.Optimize)]):
return
OpRole
.
Backward
# return OpRole.Backward
#if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
# return OpRole.Forward
if
insert_idx
>=
len
(
block
.
ops
):
return
OpRole
.
Optimize
if
op_role
==
int
(
OpRole
.
Backward
):
return
OpRole
.
Backward
if
op_role
==
int
(
OpRole
.
Optimize
):
return
OpRole
.
Optimize
if
op_role
in
[
int
(
OpRole
.
Forward
),
int
(
OpRole
.
Loss
)]:
if
op_role
in
[
int
(
OpRole
.
Forward
),
int
(
OpRole
.
Loss
)]:
return
OpRole
.
Forward
return
OpRole
.
Forward
...
@@ -428,7 +443,7 @@ def comm_analyse(main_program):
...
@@ -428,7 +443,7 @@ def comm_analyse(main_program):
count
))
count
))
def
add_sync_comm
(
program
,
dist_strategy
):
def
add_sync_comm
(
program
,
nccl_ids
):
"""
"""
When clone a test prog by clone from the sharding main prog,
When clone a test prog by clone from the sharding main prog,
part of the sync_comm op maybe be pruned by mistake, this function
part of the sync_comm op maybe be pruned by mistake, this function
...
@@ -438,6 +453,9 @@ def add_sync_comm(program, dist_strategy):
...
@@ -438,6 +453,9 @@ def add_sync_comm(program, dist_strategy):
#NOTE (liangjianzhong): only support one comm stream by now, use more than one
#NOTE (liangjianzhong): only support one comm stream by now, use more than one
# comm streams will cause error. should be revise in future.
# comm streams will cause error. should be revise in future.
assert
isinstance
(
nccl_ids
,
list
),
"the second argument of this function should be a list of nccl_ids"
block
=
program
.
global_block
()
block
=
program
.
global_block
()
not_sync_vars
=
set
([])
not_sync_vars
=
set
([])
for
op
in
block
.
ops
:
for
op
in
block
.
ops
:
...
@@ -448,7 +466,7 @@ def add_sync_comm(program, dist_strategy):
...
@@ -448,7 +466,7 @@ def add_sync_comm(program, dist_strategy):
for
input_name
in
op
.
desc
.
input_arg_names
():
for
input_name
in
op
.
desc
.
input_arg_names
():
not_sync_vars
.
remove
(
input_name
)
not_sync_vars
.
remove
(
input_name
)
if
not_sync_vars
:
if
not_sync_vars
:
for
nccl_id
in
range
(
dist_strategy
.
nccl_comm_num
)
:
for
nccl_id
in
nccl_ids
:
block
.
append_op
(
block
.
append_op
(
type
=
'c_sync_comm_stream'
,
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
list
(
not_sync_vars
)},
inputs
=
{
'X'
:
list
(
not_sync_vars
)},
...
@@ -467,6 +485,9 @@ def save_persistables(exe, dirname, main_program, filename=None):
...
@@ -467,6 +485,9 @@ def save_persistables(exe, dirname, main_program, filename=None):
This function handles the model saving for sharding training.
This function handles the model saving for sharding training.
"""
"""
if
main_program
.
_pipeline_opt
:
main_program
=
main_program
.
_pipeline_opt
[
'section_program'
][
'program'
]
def
is_opt_vars
(
var
):
def
is_opt_vars
(
var
):
# NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer
# NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer
# now only Momentum and adam are compatible with sharding
# now only Momentum and adam are compatible with sharding
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
浏览文件 @
920806db
...
@@ -16,7 +16,7 @@ from paddle.fluid import unique_name, core
...
@@ -16,7 +16,7 @@ from paddle.fluid import unique_name, core
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
,
OP_ROLE_VAR_KEY
,
CollectiveHelper
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
,
OP_ROLE_VAR_KEY
,
CollectiveHelper
from
paddle.distributed.fleet.meta_optimizers.common
import
is_backward_op
from
paddle.distributed.fleet.meta_optimizers.common
import
is_backward_op
,
is_optimizer_op
,
is_update_op
from
paddle.distributed.fleet.meta_optimizers.meta_optimizer_base
import
MetaOptimizerBase
from
paddle.distributed.fleet.meta_optimizers.meta_optimizer_base
import
MetaOptimizerBase
from
paddle.distributed.fleet.meta_optimizers.sharding.shard
import
Shard
,
ProgramSegment
from
paddle.distributed.fleet.meta_optimizers.sharding.shard
import
Shard
,
ProgramSegment
from
paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper
import
FP16Utils
from
paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper
import
FP16Utils
...
@@ -39,6 +39,7 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -39,6 +39,7 @@ class ShardingOptimizer(MetaOptimizerBase):
"AMPOptimizer"
,
"AMPOptimizer"
,
"LarsOptimizer"
,
"LarsOptimizer"
,
"LambOptimizer"
,
"LambOptimizer"
,
"ModelParallelOptimizer"
,
]
]
self
.
meta_optimizers_black_list
=
[
"GraphExecutionOptimizer"
,
]
self
.
meta_optimizers_black_list
=
[
"GraphExecutionOptimizer"
,
]
self
.
_main_program
=
None
self
.
_main_program
=
None
...
@@ -51,6 +52,10 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -51,6 +52,10 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
_reduced_grads_to_param
=
{}
self
.
_reduced_grads_to_param
=
{}
self
.
_shard
=
Shard
()
self
.
_shard
=
Shard
()
# use sharding as outer parallelism (e.g. inner:Megatron & outer sharding)
self
.
_as_outer_parallelism
=
False
self
.
_inner_parallelism_size
=
None
def
_can_apply
(
self
):
def
_can_apply
(
self
):
if
not
self
.
role_maker
.
_is_collective
:
if
not
self
.
role_maker
.
_is_collective
:
return
False
return
False
...
@@ -79,20 +84,61 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -79,20 +84,61 @@ class ShardingOptimizer(MetaOptimizerBase):
"fuse_broadcast_MB"
]
"fuse_broadcast_MB"
]
self
.
hybrid_dp
=
self
.
user_defined_strategy
.
sharding_configs
[
self
.
hybrid_dp
=
self
.
user_defined_strategy
.
sharding_configs
[
"hybrid_dp"
]
"hybrid_dp"
]
self
.
_as_outer_parallelism
=
self
.
user_defined_strategy
.
sharding_configs
[
"as_outer_parallelism"
]
self
.
_inner_parallelism_size
=
int
(
self
.
user_defined_strategy
.
sharding_configs
[
"inner_parallelism_size"
])
self
.
use_pipeline
=
self
.
user_defined_strategy
.
sharding_configs
[
"use_pipeline"
]
if
self
.
inner_opt
is
None
:
if
self
.
inner_opt
is
None
:
raise
ValueError
(
raise
ValueError
(
"self.inner_opt of ShardingOptimizer should not be None."
)
"self.inner_opt of ShardingOptimizer should not be None."
)
if
self
.
use_pipeline
:
pp_optimizer
=
fluid
.
optimizer
.
PipelineOptimizer
(
self
.
inner_opt
)
main_program
=
loss
.
block
.
program
main_program
.
_pipeline_opt
=
dict
()
pp_rank
=
self
.
role_maker
.
_worker_index
(
)
//
self
.
user_defined_strategy
.
sharding_configs
[
'sharding_group_size'
]
main_program
.
_pipeline_opt
[
'local_rank'
]
=
pp_rank
main_program
.
_pipeline_opt
[
'global_rank'
]
=
self
.
role_maker
.
_worker_index
()
main_program
.
_pipeline_opt
[
'use_sharding'
]
=
True
main_program
.
_pipeline_opt
[
'ring_id'
]
=
1
optimize_ops
,
params_grads
,
program_list
=
pp_optimizer
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
self
.
pipeline_nodes
=
len
(
program_list
)
else
:
optimize_ops
,
params_grads
=
self
.
inner_opt
.
minimize
(
optimize_ops
,
params_grads
=
self
.
inner_opt
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
if
startup_program
is
None
:
if
startup_program
is
None
:
startup_program
=
default_startup_program
()
startup_program
=
default_startup_program
()
if
self
.
use_pipeline
:
startup_program
=
startup_program
.
_pipeline_opt
[
'startup_program'
]
#main_program = main_program._pipeline_opt['section_program']['program']
print
(
"pp_rank:"
,
pp_rank
)
main_program
=
program_list
[
pp_rank
][
'program'
]
with
open
(
"main_%d"
%
self
.
role_maker
.
_worker_index
(),
'w'
)
as
f
:
f
.
writelines
(
str
(
main_program
))
main_block
=
main_program
.
global_block
()
new_params_grads
=
[]
for
param
,
grad
in
params_grads
:
if
main_block
.
has_var
(
param
.
name
):
new_params_grads
.
append
((
param
,
grad
))
params_grads
=
new_params_grads
else
:
main_block
=
loss
.
block
main_block
=
loss
.
block
startup_block
=
startup_program
.
global_block
()
startup_block
=
startup_program
.
global_block
()
self
.
_main_program
=
main_block
.
program
self
.
_main_program
=
main_block
.
program
self
.
_startup_program
=
startup_program
self
.
_startup_program
=
startup_program
if
self
.
use_pipeline
:
pp_optimizer
.
_rename_gradient_var_name
(
main_block
)
# step1: set_up
# step1: set_up
self
.
_set_up
(
params_grads
)
self
.
_set_up
(
params_grads
)
...
@@ -105,17 +151,200 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -105,17 +151,200 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_block
.
_sync_with_cpp
()
startup_block
.
_sync_with_cpp
()
# step4: insert reduce_sum for grad
# step4: insert reduce_sum for grad
insert_scale_loss_grad_ops
(
# grad_scale_coeff = self.role_maker._worker_num()
main_block
,
scale
=
1.0
/
self
.
role_maker
.
_worker_num
())
# if self._as_outer_parallelism:
# grad_scale_coeff = grad_scale_coeff / self._inner_parallelism_size
# insert_scale_loss_grad_ops(main_block, scale=1.0 / grad_scale_coeff)
sharding_group_size
=
self
.
user_defined_strategy
.
sharding_configs
[
'sharding_group_size'
]
insert_scale_loss_grad_ops
(
main_block
,
scale
=
1.0
/
sharding_group_size
)
main_block
.
_sync_with_cpp
()
main_block
.
_sync_with_cpp
()
# step5: remove unneeded ops and vars from block
# step5: remove unneeded ops and vars from block
self
.
_prune_main_program
(
main_block
)
self
.
_prune_main_program
(
main_block
)
self
.
_prune_startup_program
(
startup_block
)
self
.
_prune_startup_program
(
startup_block
)
if
self
.
hybrid_dp
:
self
.
_initialization_broadcast
(
startup_program
)
if
self
.
use_pipeline
:
# crop ops
for
idx
,
op
in
reversed
(
list
(
enumerate
(
main_block
.
ops
))):
# if op.type == 'fill_constant' and int(op.attr('op_role')) == 16:
# out_name = op.output_arg_names[0]
# if not 'GRAD' in out_name: continue
# param_name = out_name.strip("@GRAD")
# #if main_block.has_var(out_name): continue
# if self._shard.has_param(param_name): continue
# main_block._remove_op(idx)
if
is_update_op
(
op
):
op_role_var
=
op
.
attr
(
'op_role_var'
)
param_name
=
op_role_var
[
0
]
if
not
self
.
_shard
.
has_param
(
param_name
):
main_block
.
_remove_op
(
idx
)
param_list
=
[]
for
param_name
,
grad_name
in
params_grads
:
if
self
.
_shard
.
has_param
(
param_name
):
param_list
.
append
(
param_name
)
#pp_optimizer._clear_gradients(main_block, param_list)
pp_optimizer
.
_accumulate_gradients
(
main_block
)
#if not self._shard.has_param(param_name): continue
##if not main_block.has_var(grad_name): continue
#assert main_block.has_var(grad_name)
#grad_var = main_block.vars[grad_name]
#grad_var.persistable = True
#main_block._insert_op(
# index=0,
# type='fill_constant',
# inputs={},
# outputs={'Out': [grad_var]},
# attrs={
# 'shape': grad_var.shape,
# 'dtype': grad_var.dtype,
# 'value': float(0),
# #self._op_device_key: device,
# # a trick to run this op once per mini-batch
# 'op_role': core.op_proto_and_checker_maker.OpRole.LRSched,
# })
#def _create_var(block, ref_var, name):
# """
# Create a new var for block, which has the same type,
# shape and dtype as ref_var, then rename it with the
# name `name`.
# """
# new_var = block.create_var(
# name=name,
# shape=ref_var.shape,
# dtype=ref_var.dtype,
# type=ref_var.type,
# lod_level=ref_var.lod_level,
# persistable=ref_var.persistable,
# is_data=ref_var.is_data,
# need_check_feed=ref_var.desc.need_check_feed())
# new_var.stop_gradient = ref_var.stop_gradient
# return new_var
#def _rename_arg(op, old_name, new_name):
# op_desc = op.desc
# if isinstance(op_desc, tuple):
# op_desc = op_desc[0]
# op_desc._rename_input(old_name, new_name)
# op_desc._rename_output(old_name, new_name)
#print("params_grads:", params_grads)
#for param_name, grad_name in params_grads:
# if not self._shard.has_param(param_name): continue
# #if not main_block.has_var(grad_name): continue
# assert main_block.has_var(grad_name)
# use_fp16 = False
# fp16_grad_name = param_name + '.cast_fp16@GRAD'
# if main_block.has_var(grad_name):
# fp16_grad_var = main_block.vars[fp16_grad_name]
# use_fp16 = True
# grad_var = main_block.vars[grad_name]
# if use_fp16:
# cast_grad_var_name = paddle.fluid.unique_name.generate(
# grad_name)
# cast_var = _create_var(main_block, fp16_grad_var,
# cast_grad_var_name)
# cast_var.persistable = False
# main_block.append_op(
# #index=offset + 1,
# type='cast',
# inputs={'X': grad_var},
# outputs={'Out': cast_var},
# attrs={
# 'in_dtype': grad_var.dtype,
# 'out_dtype': cast_var.dtype,
# 'op_role':
# core.op_proto_and_checker_maker.OpRole.Backward,
# })
# #offset += 1
# main_block.append_op(
# #index=offset + 1,
# type='sum',
# inputs={'X': [fp16_grad_var, cast_var]},
# outputs={'Out': fp16_grad_var},
# attrs={
# 'op_role':
# core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var
# })
# for index, op in reversed(tuple(enumerate(list(main_block.ops)))):
# offset = index
# if is_backward_op(op) and (
# 'op_role_var' in op.attr_names):
# op_role_var = op.all_attrs()['op_role_var']
# if len(op_role_var) == 0:
# continue
# assert len(op_role_var) % 2 == 0
# offset = index
# for i in range(0, len(op_role_var), 2):
# grad_name = op_role_var[i + 1]
# if not main_block.has_var(grad_name): continue
# grad_var = main_block.vars[grad_name]
# if not 'cast_fp16' in grad_name:
# new_grad_var_name = paddle.fluid.unique_name.generate(grad_name)
# new_var = _create_var(main_block, grad_var,
# new_grad_var_name)
# new_var.persistable = False
# _rename_arg(op, grad_name, new_grad_var_name)
# main_block._insert_op(
# index=offset + 1,
# type='sum',
# inputs={'X': [grad_var, new_var]},
# outputs={'Out': grad_var},
# attrs={
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var
# })
# offset += 1
# if 'cast_fp16' in grad_name:
# param_name = op_role_var[i]
# fp32_grad_var_name = param_name + "@GRAD"
# fp32_grad_var = main_block.vars[grad_name]
# cast_grad_var_name = paddle.fluid.unique_name.generate(
# fp32_grad_var_name)
# cast_var = _create_var(main_block, grad_var,
# cast_grad_var_name)
# cast_var.persistable = False
# main_block._insert_op(
# index=offset + 1,
# type='cast',
# inputs={'X': fp32_grad_var},
# outputs={'Out': cast_var},
# attrs={
# 'in_dtype': fp32_grad_var.dtype,
# 'out_dtype': cast_var.dtype,
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# # self._op_role_var_key: op_role_var
# })
# offset += 1
# main_block._insert_op(
# index=offset + 1,
# type='sum',
# inputs={'X': [grad_var, cast_var]},
# outputs={'Out': grad_var},
# attrs={
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var})
main_block
.
_sync_with_cpp
()
with
open
(
"start_sharding_%d"
%
self
.
role_maker
.
_worker_index
(),
'w'
)
as
f
:
f
.
writelines
(
str
(
startup_block
.
program
))
with
open
(
"main_sharding_%d"
%
self
.
role_maker
.
_worker_index
(),
'w'
)
as
f
:
f
.
writelines
(
str
(
main_block
.
program
))
# check op dependecy
# check op dependecy
check_broadcast
(
main_block
)
check_broadcast
(
main_block
)
check_allreduce_sum
(
main_block
,
self
.
_shard
,
self
.
dp_ring_id
)
check_allreduce_sum
(
main_block
,
self
.
_shard
,
self
.
sharding_ring_id
,
self
.
dp_ring_id
)
#check_allreduce_sum(main_block, self._shard, self.dp_ring_id)
self
.
_wait
()
self
.
_wait
()
return
optimize_ops
,
params_grads
return
optimize_ops
,
params_grads
...
@@ -134,11 +363,23 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -134,11 +363,23 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
_startup_program
,
self
.
current_endpoint
,
self
.
_startup_program
,
self
.
current_endpoint
,
self
.
sharding_group_endpoints
,
self
.
sharding_rank
,
self
.
sharding_group_endpoints
,
self
.
sharding_rank
,
self
.
sharding_ring_id
,
True
)
self
.
sharding_ring_id
,
True
)
# inner & outer model parallelism
if
self
.
_as_outer_parallelism
:
self
.
_collective_helper
.
_init_communicator
(
self
.
_startup_program
,
self
.
current_endpoint
,
self
.
mp_group_endpoints
,
self
.
mp_rank
,
self
.
mp_group_id
,
True
)
# dp
# dp
if
self
.
hybrid_dp
:
if
self
.
hybrid_dp
:
self
.
_collective_helper
.
_init_communicator
(
self
.
_collective_helper
.
_init_communicator
(
self
.
_startup_program
,
self
.
current_endpoint
,
self
.
_startup_program
,
self
.
current_endpoint
,
self
.
dp_group_endpoints
,
self
.
dp_rank
,
self
.
dp_ring_id
,
True
)
self
.
dp_group_endpoints
,
self
.
dp_rank
,
self
.
dp_ring_id
,
True
)
# pp
if
self
.
use_pipeline
:
self
.
_collective_helper
.
_init_communicator
(
self
.
_startup_program
,
self
.
current_endpoint
,
self
.
pp_group_endpoints
,
self
.
pp_rank
,
self
.
pp_ring_id
,
True
)
startup_block
=
self
.
_startup_program
.
global_block
()
startup_block
=
self
.
_startup_program
.
global_block
()
startup_block
.
_sync_with_cpp
()
startup_block
.
_sync_with_cpp
()
...
@@ -205,8 +446,8 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -205,8 +446,8 @@ class ShardingOptimizer(MetaOptimizerBase):
for
i
in
range
(
0
,
len
(
op_role_var
),
2
):
for
i
in
range
(
0
,
len
(
op_role_var
),
2
):
param
,
reduced_grad
=
op_role_var
[
i
],
op_role_var
[
i
+
1
]
param
,
reduced_grad
=
op_role_var
[
i
],
op_role_var
[
i
+
1
]
segment
.
_allreduce_vars
.
append
(
reduced_grad
)
segment
.
_allreduce_vars
.
append
(
reduced_grad
)
assert
(
#
assert (
reduced_grad
not
in
self
.
_reduced_grads_to_param
)
#
reduced_grad not in self._reduced_grads_to_param)
self
.
_reduced_grads_to_param
[
reduced_grad
]
=
param
self
.
_reduced_grads_to_param
[
reduced_grad
]
=
param
# find cast op
# find cast op
...
@@ -234,9 +475,14 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -234,9 +475,14 @@ class ShardingOptimizer(MetaOptimizerBase):
"""
"""
weightdecay_helper
=
WeightDecayHelper
()
weightdecay_helper
=
WeightDecayHelper
()
weightdecay_helper
.
prune_weight_decay
(
block
,
self
.
_shard
)
weightdecay_helper
.
prune_weight_decay
(
block
,
self
.
_shard
)
# NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism
# group. and each Data Parallelism group should have its own sync of FoundInfinite
Model_Paramllelism_ring_id
=
self
.
sharding_ring_id
if
self
.
_as_outer_parallelism
:
Model_Paramllelism_ring_id
=
self
.
mp_group_id
FP16Utils
.
prune_fp16
(
block
,
self
.
_shard
,
self
.
_reduced_grads_to_param
,
FP16Utils
.
prune_fp16
(
block
,
self
.
_shard
,
self
.
_reduced_grads_to_param
,
self
.
sharding
_ring_id
)
Model_Paramllelism
_ring_id
)
gradientclip_helper
=
GradientClipHelper
(
self
.
sharding
_ring_id
)
gradientclip_helper
=
GradientClipHelper
(
Model_Paramllelism
_ring_id
)
gradientclip_helper
.
prune_gradient_clip
(
block
,
self
.
_shard
)
gradientclip_helper
.
prune_gradient_clip
(
block
,
self
.
_shard
)
# build prog deps
# build prog deps
...
@@ -264,8 +510,13 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -264,8 +510,13 @@ class ShardingOptimizer(MetaOptimizerBase):
# Prune
# Prune
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
op
.
type
in
[
if
op
.
type
in
[
"c_allreduce_sum"
,
"c_sync_comm_stream"
,
"c_allreduce_sum"
,
"c_calc_comm_stream"
,
"c_gen_nccl_id"
,
"c_comm_init, c_comm_init_hcom"
"c_sync_comm_stream"
,
"c_calc_comm_stream"
,
"c_gen_nccl_id"
,
"c_comm_init"
,
'send_v2'
,
'recv_v2'
,
]:
]:
pass
pass
elif
op
.
type
==
"conditional_block"
:
elif
op
.
type
==
"conditional_block"
:
...
@@ -303,6 +554,14 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -303,6 +554,14 @@ class ShardingOptimizer(MetaOptimizerBase):
program_deps
.
remove_op
(
idx
)
program_deps
.
remove_op
(
idx
)
block
.
_sync_with_cpp
()
block
.
_sync_with_cpp
()
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
op
.
type
==
'concat'
and
is_optimizer_op
(
op
):
# remove inputs that not on this card
reserved_x
=
[]
for
var_name
in
op
.
desc
.
input
(
"X"
):
if
block
.
has_var
(
var_name
):
reserved_x
.
append
(
var_name
)
op
.
desc
.
set_input
(
'X'
,
reserved_x
)
block
.
_sync_with_cpp
()
return
return
def
_add_broadcast_allreduce
(
self
,
block
):
def
_add_broadcast_allreduce
(
self
,
block
):
...
@@ -459,6 +718,7 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -459,6 +718,7 @@ class ShardingOptimizer(MetaOptimizerBase):
def
_init_comm
(
self
):
def
_init_comm
(
self
):
if
self
.
hybrid_dp
:
if
self
.
hybrid_dp
:
assert
self
.
_as_outer_parallelism
==
False
,
"hybrid dp is conflict when using sharding as outer parallelism"
self
.
sharding_group_size
=
self
.
user_defined_strategy
.
sharding_configs
[
self
.
sharding_group_size
=
self
.
user_defined_strategy
.
sharding_configs
[
"sharding_group_size"
]
"sharding_group_size"
]
self
.
sharding_ring_id
=
0
self
.
sharding_ring_id
=
0
...
@@ -485,13 +745,109 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -485,13 +745,109 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
global_word_size
,
self
.
global_word_size
,
self
.
sharding_group_size
,
self
.
sharding_group_size
,
self
.
dp_group_size
)
self
.
dp_group_size
)
self
.
pp_ring_id
=
-
1
self
.
pp_rank
=
-
1
self
.
pp_group_size
=
None
self
.
pp_group_endpoints
=
None
# sharding parallelism is the only model parallelism in the current setting
self
.
mp_group_id
=
self
.
sharding_ring_id
self
.
mp_rank
=
self
.
sharding_rank
self
.
mp_group_size
=
self
.
sharding_group_size
self
.
mp_group_endpoints
=
self
.
sharding_group_endpoints
[:]
logging
.
info
(
"Using Sharing&DP mode !"
)
logging
.
info
(
"Using Sharing&DP mode !"
)
else
:
if
self
.
_as_outer_parallelism
:
self
.
sharding_ring_id
=
1
assert
self
.
global_word_size
>
self
.
_inner_parallelism_size
,
\
"global_word_size: {} should be larger than inner_parallelism_size: {}"
.
format
(
self
.
global_word_size
,
self
.
_inner_parallelism_size
)
assert
self
.
global_word_size
%
self
.
_inner_parallelism_size
==
0
,
\
"global_word_size: {} should be divisible to the inner_parallelism_size: {}"
.
format
(
self
.
global_word_size
,
self
.
_inner_parallelism_size
)
self
.
sharding_rank
=
self
.
global_rank
//
self
.
_inner_parallelism_size
self
.
sharding_group_size
=
self
.
role_maker
.
_worker_num
(
)
//
self
.
_inner_parallelism_size
_offset
=
self
.
global_rank
%
self
.
_inner_parallelism_size
self
.
sharding_group_endpoints
=
[
ep
for
idx
,
ep
in
enumerate
(
self
.
endpoints
)
if
idx
%
self
.
_inner_parallelism_size
==
_offset
]
# the current entire model parallelism group is the combination of innert & sharding parallelism
self
.
mp_group_id
=
2
self
.
mp_rank
=
self
.
global_rank
self
.
mp_group_size
=
self
.
role_maker
.
_worker_num
()
self
.
mp_group_endpoints
=
self
.
endpoints
[:]
logging
.
info
(
"Using Sharing as Outer parallelism mode !"
)
# print(
# "init the nccl comm for megatron paramllelism, this should be done in Megatron Metaoptimizer"
# )
# partition_idx = self.global_rank // self._inner_parallelism_size
# magetron_endpoints = self.endpoints[
# partition_idx * self._inner_parallelism_size:partition_idx *
# self._inner_parallelism_size + self._inner_parallelism_size]
# magetron_rank = self.global_rank % self._inner_parallelism_size
# self._collective_helper._init_communicator(
# program=self._startup_program,
# current_endpoint=self.current_endpoint,
# endpoints=magetron_endpoints,
# rank=magetron_rank,
# ring_id=0,
# wait_port=True)
# logging.info("megatron group size: {}".format(
# self._inner_parallelism_size))
# logging.info("megatron rank: {}".format(magetron_rank))
# logging.info("megatron endpoints: {}".format(
# magetron_endpoints))
if
self
.
use_pipeline
:
self
.
sharding_ring_id
=
0
self
.
sharding_group_size
=
self
.
user_defined_strategy
.
sharding_configs
[
'sharding_group_size'
]
self
.
sharding_rank
=
self
.
global_rank
%
self
.
sharding_group_size
assert
self
.
sharding_group_size
*
self
.
pipeline_nodes
==
self
.
role_maker
.
_worker_num
(
)
self
.
pp_ring_id
=
1
self
.
pp_rank
=
self
.
global_rank
//
self
.
sharding_group_size
self
.
sharding_group_endpoints
=
[
ep
for
idx
,
ep
in
enumerate
(
self
.
endpoints
)
if
(
idx
//
self
.
sharding_group_size
)
==
self
.
pp_rank
]
self
.
pp_group_size
=
self
.
pipeline_nodes
self
.
pp_group_endpoints
=
[
ep
for
idx
,
ep
in
enumerate
(
self
.
endpoints
)
if
(
idx
%
self
.
sharding_group_size
)
==
self
.
sharding_rank
]
self
.
dp_ring_id
=
-
1
self
.
dp_rank
=
-
1
self
.
dp_group_size
=
None
self
.
dp_group_endpoints
=
None
logging
.
info
(
"Using Sharing with pipeline !"
)
else
:
else
:
self
.
sharding_ring_id
=
0
self
.
sharding_ring_id
=
0
self
.
sharding_rank
=
self
.
global_rank
self
.
sharding_rank
=
self
.
global_rank
self
.
sharding_group_size
=
self
.
role_maker
.
_worker_num
()
self
.
sharding_group_size
=
self
.
role_maker
.
_worker_num
()
self
.
sharding_group_endpoints
=
self
.
endpoints
self
.
sharding_group_endpoints
=
self
.
endpoints
# sharding parallelism is the only model parallelism in the current setting
self
.
mp_group_id
=
self
.
sharding_ring_id
self
.
mp_rank
=
self
.
sharding_rank
self
.
mp_group_size
=
self
.
sharding_group_size
self
.
mp_group_endpoints
=
self
.
sharding_group_endpoints
[:]
logging
.
info
(
"Using Sharing alone mode !"
)
self
.
dp_ring_id
=
-
1
self
.
dp_rank
=
-
1
self
.
dp_group_size
=
None
self
.
dp_group_endpoints
=
None
self
.
pp_ring_id
=
-
1
self
.
pp_rank
=
-
1
self
.
pp_group_size
=
None
self
.
pp_group_endpoints
=
None
self
.
dp_ring_id
=
-
1
self
.
dp_ring_id
=
-
1
self
.
dp_rank
=
-
1
self
.
dp_rank
=
-
1
self
.
dp_group_size
=
None
self
.
dp_group_size
=
None
...
@@ -503,12 +859,42 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -503,12 +859,42 @@ class ShardingOptimizer(MetaOptimizerBase):
logging
.
info
(
"global rank: {}"
.
format
(
self
.
global_rank
))
logging
.
info
(
"global rank: {}"
.
format
(
self
.
global_rank
))
logging
.
info
(
"sharding group_size: {}"
.
format
(
self
.
sharding_group_size
))
logging
.
info
(
"sharding group_size: {}"
.
format
(
self
.
sharding_group_size
))
logging
.
info
(
"sharding rank: {}"
.
format
(
self
.
sharding_rank
))
logging
.
info
(
"sharding rank: {}"
.
format
(
self
.
sharding_rank
))
logging
.
info
(
"current model parallelism group_size: {}"
.
format
(
self
.
mp_group_size
))
logging
.
info
(
"current model parallelism rank: {}"
.
format
(
self
.
mp_rank
))
logging
.
info
(
"dp group size: {}"
.
format
(
self
.
dp_group_size
))
logging
.
info
(
"dp group size: {}"
.
format
(
self
.
dp_group_size
))
logging
.
info
(
"dp rank: {}"
.
format
(
self
.
dp_rank
))
logging
.
info
(
"dp rank: {}"
.
format
(
self
.
dp_rank
))
logging
.
info
(
"current endpoint: {}"
.
format
(
self
.
current_endpoint
))
logging
.
info
(
"current endpoint: {}"
.
format
(
self
.
current_endpoint
))
logging
.
info
(
"global word endpoints: {}"
.
format
(
self
.
endpoints
))
logging
.
info
(
"sharding group endpoints: {}"
.
format
(
logging
.
info
(
"sharding group endpoints: {}"
.
format
(
self
.
sharding_group_endpoints
))
self
.
sharding_group_endpoints
))
logging
.
info
(
"current model parallelism group endpoints: {}"
.
format
(
self
.
mp_group_endpoints
))
logging
.
info
(
"dp group endpoints: {}"
.
format
(
self
.
dp_group_endpoints
))
logging
.
info
(
"dp group endpoints: {}"
.
format
(
self
.
dp_group_endpoints
))
logging
.
info
(
"global word endpoints: {}"
.
format
(
self
.
endpoints
))
return
return
def
_initialization_broadcast
(
self
,
startup_prog
):
"""
this funtion is to ensure the initialization between dp group to be
identical when hybrid-dp is used.
"""
block
=
startup_prog
.
global_block
()
params
=
[]
for
param
in
block
.
iter_parameters
():
params
.
append
(
param
)
block
.
append_op
(
type
=
'c_broadcast'
,
inputs
=
{
'X'
:
param
},
outputs
=
{
'Out'
:
param
},
attrs
=
{
'ring_id'
:
self
.
dp_ring_id
,
'root'
:
0
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
block
.
append_op
(
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
params
},
outputs
=
{
'Out'
:
params
},
attrs
=
{
'ring_id'
:
self
.
dp_ring_id
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
python/paddle/fluid/backward.py
浏览文件 @
920806db
...
@@ -115,7 +115,7 @@ class ProgramStats(object):
...
@@ -115,7 +115,7 @@ class ProgramStats(object):
updated_min_idx
=
min_idx
updated_min_idx
=
min_idx
while
idx_
>
pre_segment_end_idx
:
while
idx_
>
pre_segment_end_idx
:
if
is_amp_cast
(
self
.
ops
[
idx_
]):
if
is_amp_cast
(
self
.
ops
[
idx_
]):
_logger
.
debug
(
"found amp-cast op: {}, : {}"
.
format
(
self
.
ops
[
_logger
.
info
(
"found amp-cast op: {}, : {}"
.
format
(
self
.
ops
[
idx_
].
desc
.
type
(),
self
.
ops
[
idx_
].
desc
.
input_arg_names
()[
idx_
].
desc
.
type
(),
self
.
ops
[
idx_
].
desc
.
input_arg_names
()[
0
]))
0
]))
updated_min_idx
=
idx_
updated_min_idx
=
idx_
...
@@ -155,7 +155,7 @@ class ProgramStats(object):
...
@@ -155,7 +155,7 @@ class ProgramStats(object):
sorted_checkpoints
=
[]
sorted_checkpoints
=
[]
for
name
in
checkpoints_name
:
for
name
in
checkpoints_name
:
if
name
not
in
self
.
var_op_deps
:
if
name
not
in
self
.
var_op_deps
:
_logger
.
debug
(
_logger
.
info
(
"Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program."
"Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program."
%
name
)
%
name
)
elif
self
.
var_op_deps
[
name
][
"var_as_output_ops"
]
==
[]:
elif
self
.
var_op_deps
[
name
][
"var_as_output_ops"
]
==
[]:
...
@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
...
@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
desc
)
new_op_desc
.
copy_from
(
desc
)
new_op_desc
.
_set_attr
(
op_role_attr_name
,
backward
)
new_op_desc
.
_set_attr
(
op_role_attr_name
,
backward
)
if
desc
.
has_attr
(
'op_device'
):
new_op_desc
.
_set_attr
(
'op_device'
,
desc
.
attr
(
'op_device'
))
result_descs
.
append
(
new_op_desc
)
result_descs
.
append
(
new_op_desc
)
return
result_descs
return
result_descs
...
@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block):
...
@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block):
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
desc
)
new_op_desc
.
copy_from
(
desc
)
new_op_desc
.
_set_attr
(
op_role_attr_name
,
backward
)
new_op_desc
.
_set_attr
(
op_role_attr_name
,
backward
)
if
desc
.
has_attr
(
'op_device'
):
new_op_desc
.
_set_attr
(
'op_device'
,
desc
.
attr
(
'op_device'
))
result_descs
.
append
(
new_op_desc
)
result_descs
.
append
(
new_op_desc
)
return
result_descs
return
result_descs
...
@@ -784,7 +788,6 @@ def _append_backward_ops_with_checkpoints_(
...
@@ -784,7 +788,6 @@ def _append_backward_ops_with_checkpoints_(
start_idx
=
0
start_idx
=
0
pre_segment_end_idx
=
-
1
pre_segment_end_idx
=
-
1
while
True
:
while
True
:
_logger
.
debug
(
"FW op range[0] - [{}]"
.
format
(
len
(
ops
)))
if
start_idx
>=
len
(
checkpoints_name
)
-
1
:
if
start_idx
>=
len
(
checkpoints_name
)
-
1
:
break
break
# min_idx: checkpoint_1' s input op
# min_idx: checkpoint_1' s input op
...
@@ -797,6 +800,9 @@ def _append_backward_ops_with_checkpoints_(
...
@@ -797,6 +800,9 @@ def _append_backward_ops_with_checkpoints_(
min_idx
=
program_stat
.
_update_segment_start
(
min_idx
=
program_stat
.
_update_segment_start
(
min_idx
,
pre_segment_end_idx
)
min_idx
,
pre_segment_end_idx
)
segments
.
append
([
min_idx
,
max_idx
+
1
])
segments
.
append
([
min_idx
,
max_idx
+
1
])
else
:
_logger
.
info
(
"Could not recompute op range [{}] - [{}] "
.
format
(
min_idx
,
max_idx
+
1
))
start_idx
+=
1
start_idx
+=
1
...
@@ -806,15 +812,15 @@ def _append_backward_ops_with_checkpoints_(
...
@@ -806,15 +812,15 @@ def _append_backward_ops_with_checkpoints_(
recompute_segments
=
segments
recompute_segments
=
segments
for
i
,
(
idx1
,
idx2
)
in
enumerate
(
recompute_segments
):
for
i
,
(
idx1
,
idx2
)
in
enumerate
(
recompute_segments
):
_logger
.
debug
(
"recompute segment[{}]"
.
format
(
i
))
_logger
.
info
(
"recompute segment[{}]"
.
format
(
i
))
_logger
.
debug
(
"segment start op: [{}]: [{}]"
.
format
(
ops
[
idx1
].
desc
.
type
(
_logger
.
info
(
"segment start op: [{}]: [{}]"
.
format
(
ops
[
idx1
].
desc
.
type
(
),
ops
[
idx1
].
desc
.
input_arg_names
()))
),
ops
[
idx1
].
desc
.
input_arg_names
()))
_logger
.
debug
(
"segment end op: [{}]: [{}]"
.
format
(
ops
[
_logger
.
info
(
"segment end op: [{}]: [{}]"
.
format
(
ops
[
idx2
-
1
].
desc
.
type
(),
ops
[
idx2
-
1
].
desc
.
input_arg_names
()))
idx2
-
1
].
desc
.
type
(),
ops
[
idx2
-
1
].
desc
.
input_arg_names
()))
_logger
.
debug
(
"recompute segment[{}]"
.
format
(
i
))
_logger
.
info
(
"recompute segment[{}]"
.
format
(
i
))
_logger
.
debug
(
"segment start op: [{}]: [{}]"
.
format
(
ops
[
idx1
].
desc
.
type
(
_logger
.
info
(
"segment start op: [{}]: [{}]"
.
format
(
ops
[
idx1
].
desc
.
type
(
),
ops
[
idx1
].
desc
.
input_arg_names
()))
),
ops
[
idx1
].
desc
.
input_arg_names
()))
_logger
.
debug
(
"segment end op: [{}]: [{}]"
.
format
(
ops
[
_logger
.
info
(
"segment end op: [{}]: [{}]"
.
format
(
ops
[
idx2
-
1
].
desc
.
type
(),
ops
[
idx2
-
1
].
desc
.
input_arg_names
()))
idx2
-
1
].
desc
.
type
(),
ops
[
idx2
-
1
].
desc
.
input_arg_names
()))
# 2) go through all forward ops and induct all variables that will be hold in memory
# 2) go through all forward ops and induct all variables that will be hold in memory
...
@@ -825,9 +831,9 @@ def _append_backward_ops_with_checkpoints_(
...
@@ -825,9 +831,9 @@ def _append_backward_ops_with_checkpoints_(
program_stat
.
get_out_of_subgraph_vars
(
segment
[
0
],
segment
[
1
]))
program_stat
.
get_out_of_subgraph_vars
(
segment
[
0
],
segment
[
1
]))
cross_vars
=
set
(
vars_should_be_hold
)
-
set
(
checkpoints_name
)
cross_vars
=
set
(
vars_should_be_hold
)
-
set
(
checkpoints_name
)
_logger
.
debug
(
"found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars"
.
format
(
\
_logger
.
info
(
"found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars"
.
format
(
\
len
(
cross_vars
),
cross_vars
))
len
(
cross_vars
),
cross_vars
))
_logger
.
debug
(
"found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars"
.
format
(
\
_logger
.
info
(
"found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars"
.
format
(
\
len
(
cross_vars
),
cross_vars
))
len
(
cross_vars
),
cross_vars
))
# b. output of seed op should be kept in memory
# b. output of seed op should be kept in memory
...
@@ -843,6 +849,7 @@ def _append_backward_ops_with_checkpoints_(
...
@@ -843,6 +849,7 @@ def _append_backward_ops_with_checkpoints_(
vars_in_memory
=
vars_should_be_hold
+
checkpoints_name
vars_in_memory
=
vars_should_be_hold
+
checkpoints_name
max_calculated_op_position
=
len
(
ops
)
max_calculated_op_position
=
len
(
ops
)
device_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpDeviceAttrName
()
if
recompute_segments
==
[]:
if
recompute_segments
==
[]:
gap_ops
=
ops
[
0
:
max_calculated_op_position
]
gap_ops
=
ops
[
0
:
max_calculated_op_position
]
for
op
in
reversed
(
gap_ops
):
for
op
in
reversed
(
gap_ops
):
...
@@ -852,6 +859,11 @@ def _append_backward_ops_with_checkpoints_(
...
@@ -852,6 +859,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_
(
op
.
desc
,
"with_sub_block"
))
_pretty_op_desc_
(
op
.
desc
,
"with_sub_block"
))
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
# Set device for grad_op according to forward Op
if
op
.
desc
.
has_attr
(
device_attr_name
):
op_device
=
op
.
desc
.
attr
(
device_attr_name
)
for
op_desc
in
grad_op_desc
:
op_desc
.
_set_attr
(
device_attr_name
,
op_device
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
)
grad_op_descs
.
extend
(
added_descs
)
grad_op_descs
.
extend
(
added_descs
)
grad_to_var
.
update
(
op_grad_to_var
)
grad_to_var
.
update
(
op_grad_to_var
)
...
@@ -866,6 +878,11 @@ def _append_backward_ops_with_checkpoints_(
...
@@ -866,6 +878,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_
(
op
.
desc
,
"with_sub_block"
))
_pretty_op_desc_
(
op
.
desc
,
"with_sub_block"
))
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
# Set device for grad_op according to forward Op
if
op
.
desc
.
has_attr
(
device_attr_name
):
op_device
=
op
.
desc
.
attr
(
device_attr_name
)
for
op_desc
in
grad_op_desc
:
op_desc
.
_set_attr
(
device_attr_name
,
op_device
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
)
grad_op_descs
.
extend
(
added_descs
)
grad_op_descs
.
extend
(
added_descs
)
grad_to_var
.
update
(
op_grad_to_var
)
grad_to_var
.
update
(
op_grad_to_var
)
...
@@ -888,6 +905,18 @@ def _append_backward_ops_with_checkpoints_(
...
@@ -888,6 +905,18 @@ def _append_backward_ops_with_checkpoints_(
continue
continue
if
name
not
in
var_name_dict
:
if
name
not
in
var_name_dict
:
var_name_dict
[
name
]
=
name
+
var_suffix
var_name_dict
[
name
]
=
name
+
var_suffix
# we should create the rename var in subprog, otherwise its VarType will be BOOL
block
.
create_var
(
name
=
var_name_dict
[
name
],
shape
=
block
.
program
.
global_block
().
var
(
name
).
shape
,
dtype
=
block
.
program
.
global_block
().
var
(
name
).
dtype
,
type
=
block
.
program
.
global_block
().
var
(
name
).
type
,
persistable
=
block
.
program
.
global_block
().
var
(
name
).
persistable
,
stop_gradient
=
block
.
program
.
global_block
().
var
(
name
)
.
stop_gradient
)
# 3.a. add ops in current recompute_segment as forward recomputation ops
# 3.a. add ops in current recompute_segment as forward recomputation ops
buffer_descs
=
_add_needed_descs_to_block
(
ff_ops
,
buffer_block
,
block
,
buffer_descs
=
_add_needed_descs_to_block
(
ff_ops
,
buffer_block
,
block
,
vars_in_memory
)
vars_in_memory
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录