Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
920806db
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
920806db
编写于
2月 07, 2021
作者:
S
sandyhouse
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
70eb21c5
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
533 addition
and
94 deletion
+533
-94
python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py
...addle/distributed/fleet/meta_optimizers/sharding/prune.py
+3
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
...addle/distributed/fleet/meta_optimizers/sharding/utils.py
+86
-65
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+404
-18
python/paddle/fluid/backward.py
python/paddle/fluid/backward.py
+40
-11
未找到文件。
python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py
浏览文件 @
920806db
...
...
@@ -126,6 +126,9 @@ class ProgramDeps(object):
def
should_remove_op
(
self
,
op_idx
):
op
=
self
.
_block
.
ops
[
op_idx
]
# remove check_finite_and_unscale op if its input 'X' is empty
if
op
.
type
==
'check_finite_and_unscale'
and
len
(
op
.
input
(
'X'
))
==
0
:
return
True
for
output_name
in
op
.
desc
.
output_arg_names
():
if
output_name
not
in
self
.
_should_removed_var
:
return
False
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
浏览文件 @
920806db
...
...
@@ -28,21 +28,24 @@ def check_broadcast(block):
if the broadcasted var has a fill_constant op, the fill_constant
op should stay forward before the broadcast op, and before a
sync_calc op. Otherwise, raise error.
should ignore and skip broadcast_op of inner_parallelism (e.g. Megatron)
"""
broadcast_vars
=
{}
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"c_broadcast"
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
"@BroadCast"
in
var_name
:
if
var_name
in
broadcast_vars
:
raise
ValueError
(
"var_name areadly exist: {}"
"the old pos is {}, the new pos is {}"
.
format
(
var_name
,
broadcast_vars
[
var_name
][
"broadcast_pos"
],
idx
))
broadcast_vars
[
var_name
]
=
{
"fill_constant_pos"
:
-
1
,
"broadcast_pos"
:
idx
,
}
if
op
.
all_attrs
()[
"use_calc_stream"
]
==
False
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
"@BroadCast"
in
var_name
:
if
var_name
in
broadcast_vars
:
raise
ValueError
(
"var_name areadly exist: {}"
"the old pos is {}, the new pos is {}"
.
format
(
var_name
,
broadcast_vars
[
var_name
][
"broadcast_pos"
],
idx
))
broadcast_vars
[
var_name
]
=
{
"fill_constant_pos"
:
-
1
,
"broadcast_pos"
:
idx
,
}
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"fill_constant"
:
...
...
@@ -61,14 +64,15 @@ def check_broadcast(block):
last_sync_calc_op_idx
=
idx
continue
if
op
.
type
==
"c_broadcast"
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
"@BroadCast"
in
var_name
:
if
broadcast_vars
[
var_name
][
"fill_constant_pos"
]
!=
-
1
:
assert
(
last_sync_calc_op_idx
!=
-
1
)
assert
(
broadcast_vars
[
var_name
][
"fill_constant_pos"
]
<
last_sync_calc_op_idx
)
assert
(
last_sync_calc_op_idx
<
idx
)
continue
if
op
.
all_attrs
()[
"use_calc_stream"
]
==
False
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
"@BroadCast"
in
var_name
:
if
broadcast_vars
[
var_name
][
"fill_constant_pos"
]
!=
-
1
:
assert
(
last_sync_calc_op_idx
!=
-
1
)
assert
(
broadcast_vars
[
var_name
][
"fill_constant_pos"
]
<
last_sync_calc_op_idx
)
assert
(
last_sync_calc_op_idx
<
idx
)
continue
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
in
broadcast_vars
:
assert
(
broadcast_vars
[
input_name
][
"broadcast_pos"
]
!=
-
1
)
...
...
@@ -78,7 +82,7 @@ def check_broadcast(block):
return
def
check_allreduce_sum
(
block
,
shard
,
dp_ring_id
=-
1
):
def
check_allreduce_sum
(
block
,
shard
,
sharding_ring_id
,
dp_ring_id
=-
1
):
"""
the op order should be:
grad:
...
...
@@ -89,32 +93,36 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
- 4: allreuce_sum_dp (dp_grads)
- 5: sync_comm (dp_grads)
- 6: op that use Var (dp_grads & sum)
should ignore and skip allreduce_op of inner_parallelism (e.g. Megatron)
"""
vars_status
=
{}
dp_grads_status
=
{}
idx_last_grad_allreduce
=
-
1
idx_amp_allreduce
=
-
1
idx_gradient_clip_allreduce
=
-
1
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"c_allreduce_sum"
:
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
param
=
var_name
.
split
(
"@"
)[
0
]
if
op
.
all_attrs
()[
"use_calc_stream"
]
==
False
:
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
param
=
var_name
.
split
(
"@"
)[
0
]
assert
'sum'
in
var_name
or
(
"@GRAD"
in
var_name
)
if
'sum'
in
var_name
or
(
not
shard
.
has_param
(
param
)):
vars_status
[
var_name
]
=
-
1
else
:
dp_grads_status
[
var_name
]
=
-
1
assert
'sum'
in
var_name
or
(
"@GRAD"
in
var_name
)
if
'sum'
in
var_name
or
(
not
shard
.
has_param
(
param
)):
vars_status
[
var_name
]
=
-
1
else
:
dp_grads_status
[
var_name
]
=
-
1
if
ring_id
!=
0
:
assert
shard
.
has_param
(
param
)
assert
ring_id
==
dp_ring_id
if
ring_id
!=
sharding_ring_id
:
assert
shard
.
has_param
(
param
)
assert
ring_id
==
dp_ring_id
if
"sum"
in
var_name
:
idx_amp_allreduce
=
idx
elif
"@GRAD"
:
idx_last_grad_allreduce
=
idx
if
"sum"
in
var_name
:
idx_amp_allreduce
=
idx
elif
"@GRAD"
:
idx_last_grad_allreduce
=
idx
if
op
.
type
==
"c_allreduce_max"
:
idx_gradient_clip_allreduce
=
idx
...
...
@@ -130,36 +138,38 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
dp_grads_status
[
var_name
]
=
1
elif
op
.
type
==
"c_allreduce_sum"
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
if
ring_id
==
0
:
if
var_name
in
vars_status
:
_status
=
vars_status
[
var_name
]
else
:
_status
=
dp_grads_status
[
var_name
]
if
_status
==
-
1
:
raise
ValueError
(
"{} is not generated, but you are"
"trying to all-reduce it"
.
format
(
var_name
))
if
_status
==
0
:
raise
ValueError
(
"There should be a sync_calc op "
"after generate Var: {} and before the"
"c_allreduce_sum op"
.
format
(
var_name
))
assert
(
_status
==
1
)
if
var_name
in
vars_status
:
vars_status
[
var_name
]
=
2
if
op
.
all_attrs
()[
"use_calc_stream"
]
==
False
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
if
ring_id
==
sharding_ring_id
:
if
var_name
in
vars_status
:
_status
=
vars_status
[
var_name
]
else
:
_status
=
dp_grads_status
[
var_name
]
if
_status
==
-
1
:
raise
ValueError
(
"{} is not generated, but you are"
"trying to all-reduce it"
.
format
(
var_name
))
if
_status
==
0
:
raise
ValueError
(
"There should be a sync_calc op "
"after generate Var: {} and before the"
"c_allreduce_sum op"
.
format
(
var_name
))
assert
(
_status
==
1
)
if
var_name
in
vars_status
:
vars_status
[
var_name
]
=
2
else
:
dp_grads_status
[
var_name
]
=
2
else
:
dp_grads_status
[
var_name
]
=
2
else
:
assert
ring_id
==
dp_ring_id
param
=
var_name
.
split
(
"@"
)[
0
]
assert
shard
.
has_param
(
param
)
assert
dp_grads_status
[
var_name
]
==
3
dp_grads_status
[
var_name
]
=
4
assert
ring_id
==
dp_ring_id
param
=
var_name
.
split
(
"@"
)[
0
]
assert
shard
.
has_param
(
param
)
assert
dp_grads_status
[
var_name
]
==
3
dp_grads_status
[
var_name
]
=
4
elif
op
.
type
==
"c_sync_comm_stream"
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
ring_id
=
op
.
desc
.
attr
(
"ring_id"
)
if
ring_id
==
0
:
if
ring_id
==
sharding_ring_id
:
for
var_name
in
op
.
desc
.
input_arg_names
():
if
var_name
in
vars_status
:
assert
vars_status
[
var_name
]
==
2
...
...
@@ -217,9 +227,14 @@ def get_valid_op_role(block, insert_idx):
return OpRole.Forward or OpRole.Backward
"""
op_role
=
block
.
ops
[
insert_idx
].
attr
(
'op_role'
)
if
(
insert_idx
>=
len
(
block
.
ops
))
or
(
op_role
in
[
int
(
OpRole
.
Backward
),
int
(
OpRole
.
Optimize
)]):
return
OpRole
.
Backward
#if (insert_idx >= len(block.ops)) or (
# op_role in [int(OpRole.Backward), int(OpRole.Optimize)]):
# return OpRole.Backward
#if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
# return OpRole.Forward
if
insert_idx
>=
len
(
block
.
ops
):
return
OpRole
.
Optimize
if
op_role
==
int
(
OpRole
.
Backward
):
return
OpRole
.
Backward
if
op_role
==
int
(
OpRole
.
Optimize
):
return
OpRole
.
Optimize
if
op_role
in
[
int
(
OpRole
.
Forward
),
int
(
OpRole
.
Loss
)]:
return
OpRole
.
Forward
...
...
@@ -428,7 +443,7 @@ def comm_analyse(main_program):
count
))
def
add_sync_comm
(
program
,
dist_strategy
):
def
add_sync_comm
(
program
,
nccl_ids
):
"""
When clone a test prog by clone from the sharding main prog,
part of the sync_comm op maybe be pruned by mistake, this function
...
...
@@ -438,6 +453,9 @@ def add_sync_comm(program, dist_strategy):
#NOTE (liangjianzhong): only support one comm stream by now, use more than one
# comm streams will cause error. should be revise in future.
assert
isinstance
(
nccl_ids
,
list
),
"the second argument of this function should be a list of nccl_ids"
block
=
program
.
global_block
()
not_sync_vars
=
set
([])
for
op
in
block
.
ops
:
...
...
@@ -448,7 +466,7 @@ def add_sync_comm(program, dist_strategy):
for
input_name
in
op
.
desc
.
input_arg_names
():
not_sync_vars
.
remove
(
input_name
)
if
not_sync_vars
:
for
nccl_id
in
range
(
dist_strategy
.
nccl_comm_num
)
:
for
nccl_id
in
nccl_ids
:
block
.
append_op
(
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
list
(
not_sync_vars
)},
...
...
@@ -467,6 +485,9 @@ def save_persistables(exe, dirname, main_program, filename=None):
This function handles the model saving for sharding training.
"""
if
main_program
.
_pipeline_opt
:
main_program
=
main_program
.
_pipeline_opt
[
'section_program'
][
'program'
]
def
is_opt_vars
(
var
):
# NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer
# now only Momentum and adam are compatible with sharding
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
浏览文件 @
920806db
...
...
@@ -16,7 +16,7 @@ from paddle.fluid import unique_name, core
import
paddle.fluid
as
fluid
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
,
OP_ROLE_VAR_KEY
,
CollectiveHelper
from
paddle.distributed.fleet.meta_optimizers.common
import
is_backward_op
from
paddle.distributed.fleet.meta_optimizers.common
import
is_backward_op
,
is_optimizer_op
,
is_update_op
from
paddle.distributed.fleet.meta_optimizers.meta_optimizer_base
import
MetaOptimizerBase
from
paddle.distributed.fleet.meta_optimizers.sharding.shard
import
Shard
,
ProgramSegment
from
paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper
import
FP16Utils
...
...
@@ -39,6 +39,7 @@ class ShardingOptimizer(MetaOptimizerBase):
"AMPOptimizer"
,
"LarsOptimizer"
,
"LambOptimizer"
,
"ModelParallelOptimizer"
,
]
self
.
meta_optimizers_black_list
=
[
"GraphExecutionOptimizer"
,
]
self
.
_main_program
=
None
...
...
@@ -51,6 +52,10 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
_reduced_grads_to_param
=
{}
self
.
_shard
=
Shard
()
# use sharding as outer parallelism (e.g. inner:Megatron & outer sharding)
self
.
_as_outer_parallelism
=
False
self
.
_inner_parallelism_size
=
None
def
_can_apply
(
self
):
if
not
self
.
role_maker
.
_is_collective
:
return
False
...
...
@@ -79,20 +84,61 @@ class ShardingOptimizer(MetaOptimizerBase):
"fuse_broadcast_MB"
]
self
.
hybrid_dp
=
self
.
user_defined_strategy
.
sharding_configs
[
"hybrid_dp"
]
self
.
_as_outer_parallelism
=
self
.
user_defined_strategy
.
sharding_configs
[
"as_outer_parallelism"
]
self
.
_inner_parallelism_size
=
int
(
self
.
user_defined_strategy
.
sharding_configs
[
"inner_parallelism_size"
])
self
.
use_pipeline
=
self
.
user_defined_strategy
.
sharding_configs
[
"use_pipeline"
]
if
self
.
inner_opt
is
None
:
raise
ValueError
(
"self.inner_opt of ShardingOptimizer should not be None."
)
optimize_ops
,
params_grads
=
self
.
inner_opt
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
if
self
.
use_pipeline
:
pp_optimizer
=
fluid
.
optimizer
.
PipelineOptimizer
(
self
.
inner_opt
)
main_program
=
loss
.
block
.
program
main_program
.
_pipeline_opt
=
dict
()
pp_rank
=
self
.
role_maker
.
_worker_index
(
)
//
self
.
user_defined_strategy
.
sharding_configs
[
'sharding_group_size'
]
main_program
.
_pipeline_opt
[
'local_rank'
]
=
pp_rank
main_program
.
_pipeline_opt
[
'global_rank'
]
=
self
.
role_maker
.
_worker_index
()
main_program
.
_pipeline_opt
[
'use_sharding'
]
=
True
main_program
.
_pipeline_opt
[
'ring_id'
]
=
1
optimize_ops
,
params_grads
,
program_list
=
pp_optimizer
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
self
.
pipeline_nodes
=
len
(
program_list
)
else
:
optimize_ops
,
params_grads
=
self
.
inner_opt
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
if
startup_program
is
None
:
startup_program
=
default_startup_program
()
main_block
=
loss
.
block
if
self
.
use_pipeline
:
startup_program
=
startup_program
.
_pipeline_opt
[
'startup_program'
]
#main_program = main_program._pipeline_opt['section_program']['program']
print
(
"pp_rank:"
,
pp_rank
)
main_program
=
program_list
[
pp_rank
][
'program'
]
with
open
(
"main_%d"
%
self
.
role_maker
.
_worker_index
(),
'w'
)
as
f
:
f
.
writelines
(
str
(
main_program
))
main_block
=
main_program
.
global_block
()
new_params_grads
=
[]
for
param
,
grad
in
params_grads
:
if
main_block
.
has_var
(
param
.
name
):
new_params_grads
.
append
((
param
,
grad
))
params_grads
=
new_params_grads
else
:
main_block
=
loss
.
block
startup_block
=
startup_program
.
global_block
()
self
.
_main_program
=
main_block
.
program
self
.
_startup_program
=
startup_program
if
self
.
use_pipeline
:
pp_optimizer
.
_rename_gradient_var_name
(
main_block
)
# step1: set_up
self
.
_set_up
(
params_grads
)
...
...
@@ -105,17 +151,200 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_block
.
_sync_with_cpp
()
# step4: insert reduce_sum for grad
insert_scale_loss_grad_ops
(
main_block
,
scale
=
1.0
/
self
.
role_maker
.
_worker_num
())
# grad_scale_coeff = self.role_maker._worker_num()
# if self._as_outer_parallelism:
# grad_scale_coeff = grad_scale_coeff / self._inner_parallelism_size
# insert_scale_loss_grad_ops(main_block, scale=1.0 / grad_scale_coeff)
sharding_group_size
=
self
.
user_defined_strategy
.
sharding_configs
[
'sharding_group_size'
]
insert_scale_loss_grad_ops
(
main_block
,
scale
=
1.0
/
sharding_group_size
)
main_block
.
_sync_with_cpp
()
# step5: remove unneeded ops and vars from block
self
.
_prune_main_program
(
main_block
)
self
.
_prune_startup_program
(
startup_block
)
if
self
.
hybrid_dp
:
self
.
_initialization_broadcast
(
startup_program
)
if
self
.
use_pipeline
:
# crop ops
for
idx
,
op
in
reversed
(
list
(
enumerate
(
main_block
.
ops
))):
# if op.type == 'fill_constant' and int(op.attr('op_role')) == 16:
# out_name = op.output_arg_names[0]
# if not 'GRAD' in out_name: continue
# param_name = out_name.strip("@GRAD")
# #if main_block.has_var(out_name): continue
# if self._shard.has_param(param_name): continue
# main_block._remove_op(idx)
if
is_update_op
(
op
):
op_role_var
=
op
.
attr
(
'op_role_var'
)
param_name
=
op_role_var
[
0
]
if
not
self
.
_shard
.
has_param
(
param_name
):
main_block
.
_remove_op
(
idx
)
param_list
=
[]
for
param_name
,
grad_name
in
params_grads
:
if
self
.
_shard
.
has_param
(
param_name
):
param_list
.
append
(
param_name
)
#pp_optimizer._clear_gradients(main_block, param_list)
pp_optimizer
.
_accumulate_gradients
(
main_block
)
#if not self._shard.has_param(param_name): continue
##if not main_block.has_var(grad_name): continue
#assert main_block.has_var(grad_name)
#grad_var = main_block.vars[grad_name]
#grad_var.persistable = True
#main_block._insert_op(
# index=0,
# type='fill_constant',
# inputs={},
# outputs={'Out': [grad_var]},
# attrs={
# 'shape': grad_var.shape,
# 'dtype': grad_var.dtype,
# 'value': float(0),
# #self._op_device_key: device,
# # a trick to run this op once per mini-batch
# 'op_role': core.op_proto_and_checker_maker.OpRole.LRSched,
# })
#def _create_var(block, ref_var, name):
# """
# Create a new var for block, which has the same type,
# shape and dtype as ref_var, then rename it with the
# name `name`.
# """
# new_var = block.create_var(
# name=name,
# shape=ref_var.shape,
# dtype=ref_var.dtype,
# type=ref_var.type,
# lod_level=ref_var.lod_level,
# persistable=ref_var.persistable,
# is_data=ref_var.is_data,
# need_check_feed=ref_var.desc.need_check_feed())
# new_var.stop_gradient = ref_var.stop_gradient
# return new_var
#def _rename_arg(op, old_name, new_name):
# op_desc = op.desc
# if isinstance(op_desc, tuple):
# op_desc = op_desc[0]
# op_desc._rename_input(old_name, new_name)
# op_desc._rename_output(old_name, new_name)
#print("params_grads:", params_grads)
#for param_name, grad_name in params_grads:
# if not self._shard.has_param(param_name): continue
# #if not main_block.has_var(grad_name): continue
# assert main_block.has_var(grad_name)
# use_fp16 = False
# fp16_grad_name = param_name + '.cast_fp16@GRAD'
# if main_block.has_var(grad_name):
# fp16_grad_var = main_block.vars[fp16_grad_name]
# use_fp16 = True
# grad_var = main_block.vars[grad_name]
# if use_fp16:
# cast_grad_var_name = paddle.fluid.unique_name.generate(
# grad_name)
# cast_var = _create_var(main_block, fp16_grad_var,
# cast_grad_var_name)
# cast_var.persistable = False
# main_block.append_op(
# #index=offset + 1,
# type='cast',
# inputs={'X': grad_var},
# outputs={'Out': cast_var},
# attrs={
# 'in_dtype': grad_var.dtype,
# 'out_dtype': cast_var.dtype,
# 'op_role':
# core.op_proto_and_checker_maker.OpRole.Backward,
# })
# #offset += 1
# main_block.append_op(
# #index=offset + 1,
# type='sum',
# inputs={'X': [fp16_grad_var, cast_var]},
# outputs={'Out': fp16_grad_var},
# attrs={
# 'op_role':
# core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var
# })
# for index, op in reversed(tuple(enumerate(list(main_block.ops)))):
# offset = index
# if is_backward_op(op) and (
# 'op_role_var' in op.attr_names):
# op_role_var = op.all_attrs()['op_role_var']
# if len(op_role_var) == 0:
# continue
# assert len(op_role_var) % 2 == 0
# offset = index
# for i in range(0, len(op_role_var), 2):
# grad_name = op_role_var[i + 1]
# if not main_block.has_var(grad_name): continue
# grad_var = main_block.vars[grad_name]
# if not 'cast_fp16' in grad_name:
# new_grad_var_name = paddle.fluid.unique_name.generate(grad_name)
# new_var = _create_var(main_block, grad_var,
# new_grad_var_name)
# new_var.persistable = False
# _rename_arg(op, grad_name, new_grad_var_name)
# main_block._insert_op(
# index=offset + 1,
# type='sum',
# inputs={'X': [grad_var, new_var]},
# outputs={'Out': grad_var},
# attrs={
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var
# })
# offset += 1
# if 'cast_fp16' in grad_name:
# param_name = op_role_var[i]
# fp32_grad_var_name = param_name + "@GRAD"
# fp32_grad_var = main_block.vars[grad_name]
# cast_grad_var_name = paddle.fluid.unique_name.generate(
# fp32_grad_var_name)
# cast_var = _create_var(main_block, grad_var,
# cast_grad_var_name)
# cast_var.persistable = False
# main_block._insert_op(
# index=offset + 1,
# type='cast',
# inputs={'X': fp32_grad_var},
# outputs={'Out': cast_var},
# attrs={
# 'in_dtype': fp32_grad_var.dtype,
# 'out_dtype': cast_var.dtype,
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# # self._op_role_var_key: op_role_var
# })
# offset += 1
# main_block._insert_op(
# index=offset + 1,
# type='sum',
# inputs={'X': [grad_var, cast_var]},
# outputs={'Out': grad_var},
# attrs={
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var})
main_block
.
_sync_with_cpp
()
with
open
(
"start_sharding_%d"
%
self
.
role_maker
.
_worker_index
(),
'w'
)
as
f
:
f
.
writelines
(
str
(
startup_block
.
program
))
with
open
(
"main_sharding_%d"
%
self
.
role_maker
.
_worker_index
(),
'w'
)
as
f
:
f
.
writelines
(
str
(
main_block
.
program
))
# check op dependecy
check_broadcast
(
main_block
)
check_allreduce_sum
(
main_block
,
self
.
_shard
,
self
.
dp_ring_id
)
check_allreduce_sum
(
main_block
,
self
.
_shard
,
self
.
sharding_ring_id
,
self
.
dp_ring_id
)
#check_allreduce_sum(main_block, self._shard, self.dp_ring_id)
self
.
_wait
()
return
optimize_ops
,
params_grads
...
...
@@ -134,11 +363,23 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
_startup_program
,
self
.
current_endpoint
,
self
.
sharding_group_endpoints
,
self
.
sharding_rank
,
self
.
sharding_ring_id
,
True
)
# inner & outer model parallelism
if
self
.
_as_outer_parallelism
:
self
.
_collective_helper
.
_init_communicator
(
self
.
_startup_program
,
self
.
current_endpoint
,
self
.
mp_group_endpoints
,
self
.
mp_rank
,
self
.
mp_group_id
,
True
)
# dp
if
self
.
hybrid_dp
:
self
.
_collective_helper
.
_init_communicator
(
self
.
_startup_program
,
self
.
current_endpoint
,
self
.
dp_group_endpoints
,
self
.
dp_rank
,
self
.
dp_ring_id
,
True
)
# pp
if
self
.
use_pipeline
:
self
.
_collective_helper
.
_init_communicator
(
self
.
_startup_program
,
self
.
current_endpoint
,
self
.
pp_group_endpoints
,
self
.
pp_rank
,
self
.
pp_ring_id
,
True
)
startup_block
=
self
.
_startup_program
.
global_block
()
startup_block
.
_sync_with_cpp
()
...
...
@@ -205,8 +446,8 @@ class ShardingOptimizer(MetaOptimizerBase):
for
i
in
range
(
0
,
len
(
op_role_var
),
2
):
param
,
reduced_grad
=
op_role_var
[
i
],
op_role_var
[
i
+
1
]
segment
.
_allreduce_vars
.
append
(
reduced_grad
)
assert
(
reduced_grad
not
in
self
.
_reduced_grads_to_param
)
#
assert (
#
reduced_grad not in self._reduced_grads_to_param)
self
.
_reduced_grads_to_param
[
reduced_grad
]
=
param
# find cast op
...
...
@@ -234,9 +475,14 @@ class ShardingOptimizer(MetaOptimizerBase):
"""
weightdecay_helper
=
WeightDecayHelper
()
weightdecay_helper
.
prune_weight_decay
(
block
,
self
.
_shard
)
# NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism
# group. and each Data Parallelism group should have its own sync of FoundInfinite
Model_Paramllelism_ring_id
=
self
.
sharding_ring_id
if
self
.
_as_outer_parallelism
:
Model_Paramllelism_ring_id
=
self
.
mp_group_id
FP16Utils
.
prune_fp16
(
block
,
self
.
_shard
,
self
.
_reduced_grads_to_param
,
self
.
sharding
_ring_id
)
gradientclip_helper
=
GradientClipHelper
(
self
.
sharding
_ring_id
)
Model_Paramllelism
_ring_id
)
gradientclip_helper
=
GradientClipHelper
(
Model_Paramllelism
_ring_id
)
gradientclip_helper
.
prune_gradient_clip
(
block
,
self
.
_shard
)
# build prog deps
...
...
@@ -264,8 +510,13 @@ class ShardingOptimizer(MetaOptimizerBase):
# Prune
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
op
.
type
in
[
"c_allreduce_sum"
,
"c_sync_comm_stream"
,
"c_calc_comm_stream"
,
"c_gen_nccl_id"
,
"c_comm_init, c_comm_init_hcom"
"c_allreduce_sum"
,
"c_sync_comm_stream"
,
"c_calc_comm_stream"
,
"c_gen_nccl_id"
,
"c_comm_init"
,
'send_v2'
,
'recv_v2'
,
]:
pass
elif
op
.
type
==
"conditional_block"
:
...
...
@@ -303,6 +554,14 @@ class ShardingOptimizer(MetaOptimizerBase):
program_deps
.
remove_op
(
idx
)
block
.
_sync_with_cpp
()
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
op
.
type
==
'concat'
and
is_optimizer_op
(
op
):
# remove inputs that not on this card
reserved_x
=
[]
for
var_name
in
op
.
desc
.
input
(
"X"
):
if
block
.
has_var
(
var_name
):
reserved_x
.
append
(
var_name
)
op
.
desc
.
set_input
(
'X'
,
reserved_x
)
block
.
_sync_with_cpp
()
return
def
_add_broadcast_allreduce
(
self
,
block
):
...
...
@@ -459,6 +718,7 @@ class ShardingOptimizer(MetaOptimizerBase):
def
_init_comm
(
self
):
if
self
.
hybrid_dp
:
assert
self
.
_as_outer_parallelism
==
False
,
"hybrid dp is conflict when using sharding as outer parallelism"
self
.
sharding_group_size
=
self
.
user_defined_strategy
.
sharding_configs
[
"sharding_group_size"
]
self
.
sharding_ring_id
=
0
...
...
@@ -485,13 +745,109 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
global_word_size
,
self
.
sharding_group_size
,
self
.
dp_group_size
)
self
.
pp_ring_id
=
-
1
self
.
pp_rank
=
-
1
self
.
pp_group_size
=
None
self
.
pp_group_endpoints
=
None
# sharding parallelism is the only model parallelism in the current setting
self
.
mp_group_id
=
self
.
sharding_ring_id
self
.
mp_rank
=
self
.
sharding_rank
self
.
mp_group_size
=
self
.
sharding_group_size
self
.
mp_group_endpoints
=
self
.
sharding_group_endpoints
[:]
logging
.
info
(
"Using Sharing&DP mode !"
)
else
:
self
.
sharding_ring_id
=
0
self
.
sharding_rank
=
self
.
global_rank
self
.
sharding_group_size
=
self
.
role_maker
.
_worker_num
()
self
.
sharding_group_endpoints
=
self
.
endpoints
if
self
.
_as_outer_parallelism
:
self
.
sharding_ring_id
=
1
assert
self
.
global_word_size
>
self
.
_inner_parallelism_size
,
\
"global_word_size: {} should be larger than inner_parallelism_size: {}"
.
format
(
self
.
global_word_size
,
self
.
_inner_parallelism_size
)
assert
self
.
global_word_size
%
self
.
_inner_parallelism_size
==
0
,
\
"global_word_size: {} should be divisible to the inner_parallelism_size: {}"
.
format
(
self
.
global_word_size
,
self
.
_inner_parallelism_size
)
self
.
sharding_rank
=
self
.
global_rank
//
self
.
_inner_parallelism_size
self
.
sharding_group_size
=
self
.
role_maker
.
_worker_num
(
)
//
self
.
_inner_parallelism_size
_offset
=
self
.
global_rank
%
self
.
_inner_parallelism_size
self
.
sharding_group_endpoints
=
[
ep
for
idx
,
ep
in
enumerate
(
self
.
endpoints
)
if
idx
%
self
.
_inner_parallelism_size
==
_offset
]
# the current entire model parallelism group is the combination of innert & sharding parallelism
self
.
mp_group_id
=
2
self
.
mp_rank
=
self
.
global_rank
self
.
mp_group_size
=
self
.
role_maker
.
_worker_num
()
self
.
mp_group_endpoints
=
self
.
endpoints
[:]
logging
.
info
(
"Using Sharing as Outer parallelism mode !"
)
# print(
# "init the nccl comm for megatron paramllelism, this should be done in Megatron Metaoptimizer"
# )
# partition_idx = self.global_rank // self._inner_parallelism_size
# magetron_endpoints = self.endpoints[
# partition_idx * self._inner_parallelism_size:partition_idx *
# self._inner_parallelism_size + self._inner_parallelism_size]
# magetron_rank = self.global_rank % self._inner_parallelism_size
# self._collective_helper._init_communicator(
# program=self._startup_program,
# current_endpoint=self.current_endpoint,
# endpoints=magetron_endpoints,
# rank=magetron_rank,
# ring_id=0,
# wait_port=True)
# logging.info("megatron group size: {}".format(
# self._inner_parallelism_size))
# logging.info("megatron rank: {}".format(magetron_rank))
# logging.info("megatron endpoints: {}".format(
# magetron_endpoints))
if
self
.
use_pipeline
:
self
.
sharding_ring_id
=
0
self
.
sharding_group_size
=
self
.
user_defined_strategy
.
sharding_configs
[
'sharding_group_size'
]
self
.
sharding_rank
=
self
.
global_rank
%
self
.
sharding_group_size
assert
self
.
sharding_group_size
*
self
.
pipeline_nodes
==
self
.
role_maker
.
_worker_num
(
)
self
.
pp_ring_id
=
1
self
.
pp_rank
=
self
.
global_rank
//
self
.
sharding_group_size
self
.
sharding_group_endpoints
=
[
ep
for
idx
,
ep
in
enumerate
(
self
.
endpoints
)
if
(
idx
//
self
.
sharding_group_size
)
==
self
.
pp_rank
]
self
.
pp_group_size
=
self
.
pipeline_nodes
self
.
pp_group_endpoints
=
[
ep
for
idx
,
ep
in
enumerate
(
self
.
endpoints
)
if
(
idx
%
self
.
sharding_group_size
)
==
self
.
sharding_rank
]
self
.
dp_ring_id
=
-
1
self
.
dp_rank
=
-
1
self
.
dp_group_size
=
None
self
.
dp_group_endpoints
=
None
logging
.
info
(
"Using Sharing with pipeline !"
)
else
:
self
.
sharding_ring_id
=
0
self
.
sharding_rank
=
self
.
global_rank
self
.
sharding_group_size
=
self
.
role_maker
.
_worker_num
()
self
.
sharding_group_endpoints
=
self
.
endpoints
# sharding parallelism is the only model parallelism in the current setting
self
.
mp_group_id
=
self
.
sharding_ring_id
self
.
mp_rank
=
self
.
sharding_rank
self
.
mp_group_size
=
self
.
sharding_group_size
self
.
mp_group_endpoints
=
self
.
sharding_group_endpoints
[:]
logging
.
info
(
"Using Sharing alone mode !"
)
self
.
dp_ring_id
=
-
1
self
.
dp_rank
=
-
1
self
.
dp_group_size
=
None
self
.
dp_group_endpoints
=
None
self
.
pp_ring_id
=
-
1
self
.
pp_rank
=
-
1
self
.
pp_group_size
=
None
self
.
pp_group_endpoints
=
None
self
.
dp_ring_id
=
-
1
self
.
dp_rank
=
-
1
self
.
dp_group_size
=
None
...
...
@@ -503,12 +859,42 @@ class ShardingOptimizer(MetaOptimizerBase):
logging
.
info
(
"global rank: {}"
.
format
(
self
.
global_rank
))
logging
.
info
(
"sharding group_size: {}"
.
format
(
self
.
sharding_group_size
))
logging
.
info
(
"sharding rank: {}"
.
format
(
self
.
sharding_rank
))
logging
.
info
(
"current model parallelism group_size: {}"
.
format
(
self
.
mp_group_size
))
logging
.
info
(
"current model parallelism rank: {}"
.
format
(
self
.
mp_rank
))
logging
.
info
(
"dp group size: {}"
.
format
(
self
.
dp_group_size
))
logging
.
info
(
"dp rank: {}"
.
format
(
self
.
dp_rank
))
logging
.
info
(
"current endpoint: {}"
.
format
(
self
.
current_endpoint
))
logging
.
info
(
"global word endpoints: {}"
.
format
(
self
.
endpoints
))
logging
.
info
(
"sharding group endpoints: {}"
.
format
(
self
.
sharding_group_endpoints
))
logging
.
info
(
"current model parallelism group endpoints: {}"
.
format
(
self
.
mp_group_endpoints
))
logging
.
info
(
"dp group endpoints: {}"
.
format
(
self
.
dp_group_endpoints
))
logging
.
info
(
"global word endpoints: {}"
.
format
(
self
.
endpoints
))
return
def
_initialization_broadcast
(
self
,
startup_prog
):
"""
this funtion is to ensure the initialization between dp group to be
identical when hybrid-dp is used.
"""
block
=
startup_prog
.
global_block
()
params
=
[]
for
param
in
block
.
iter_parameters
():
params
.
append
(
param
)
block
.
append_op
(
type
=
'c_broadcast'
,
inputs
=
{
'X'
:
param
},
outputs
=
{
'Out'
:
param
},
attrs
=
{
'ring_id'
:
self
.
dp_ring_id
,
'root'
:
0
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
block
.
append_op
(
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
params
},
outputs
=
{
'Out'
:
params
},
attrs
=
{
'ring_id'
:
self
.
dp_ring_id
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
python/paddle/fluid/backward.py
浏览文件 @
920806db
...
...
@@ -115,7 +115,7 @@ class ProgramStats(object):
updated_min_idx
=
min_idx
while
idx_
>
pre_segment_end_idx
:
if
is_amp_cast
(
self
.
ops
[
idx_
]):
_logger
.
debug
(
"found amp-cast op: {}, : {}"
.
format
(
self
.
ops
[
_logger
.
info
(
"found amp-cast op: {}, : {}"
.
format
(
self
.
ops
[
idx_
].
desc
.
type
(),
self
.
ops
[
idx_
].
desc
.
input_arg_names
()[
0
]))
updated_min_idx
=
idx_
...
...
@@ -155,7 +155,7 @@ class ProgramStats(object):
sorted_checkpoints
=
[]
for
name
in
checkpoints_name
:
if
name
not
in
self
.
var_op_deps
:
_logger
.
debug
(
_logger
.
info
(
"Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program."
%
name
)
elif
self
.
var_op_deps
[
name
][
"var_as_output_ops"
]
==
[]:
...
...
@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
desc
)
new_op_desc
.
_set_attr
(
op_role_attr_name
,
backward
)
if
desc
.
has_attr
(
'op_device'
):
new_op_desc
.
_set_attr
(
'op_device'
,
desc
.
attr
(
'op_device'
))
result_descs
.
append
(
new_op_desc
)
return
result_descs
...
...
@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block):
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
desc
)
new_op_desc
.
_set_attr
(
op_role_attr_name
,
backward
)
if
desc
.
has_attr
(
'op_device'
):
new_op_desc
.
_set_attr
(
'op_device'
,
desc
.
attr
(
'op_device'
))
result_descs
.
append
(
new_op_desc
)
return
result_descs
...
...
@@ -784,7 +788,6 @@ def _append_backward_ops_with_checkpoints_(
start_idx
=
0
pre_segment_end_idx
=
-
1
while
True
:
_logger
.
debug
(
"FW op range[0] - [{}]"
.
format
(
len
(
ops
)))
if
start_idx
>=
len
(
checkpoints_name
)
-
1
:
break
# min_idx: checkpoint_1' s input op
...
...
@@ -797,6 +800,9 @@ def _append_backward_ops_with_checkpoints_(
min_idx
=
program_stat
.
_update_segment_start
(
min_idx
,
pre_segment_end_idx
)
segments
.
append
([
min_idx
,
max_idx
+
1
])
else
:
_logger
.
info
(
"Could not recompute op range [{}] - [{}] "
.
format
(
min_idx
,
max_idx
+
1
))
start_idx
+=
1
...
...
@@ -806,15 +812,15 @@ def _append_backward_ops_with_checkpoints_(
recompute_segments
=
segments
for
i
,
(
idx1
,
idx2
)
in
enumerate
(
recompute_segments
):
_logger
.
debug
(
"recompute segment[{}]"
.
format
(
i
))
_logger
.
debug
(
"segment start op: [{}]: [{}]"
.
format
(
ops
[
idx1
].
desc
.
type
(
_logger
.
info
(
"recompute segment[{}]"
.
format
(
i
))
_logger
.
info
(
"segment start op: [{}]: [{}]"
.
format
(
ops
[
idx1
].
desc
.
type
(
),
ops
[
idx1
].
desc
.
input_arg_names
()))
_logger
.
debug
(
"segment end op: [{}]: [{}]"
.
format
(
ops
[
_logger
.
info
(
"segment end op: [{}]: [{}]"
.
format
(
ops
[
idx2
-
1
].
desc
.
type
(),
ops
[
idx2
-
1
].
desc
.
input_arg_names
()))
_logger
.
debug
(
"recompute segment[{}]"
.
format
(
i
))
_logger
.
debug
(
"segment start op: [{}]: [{}]"
.
format
(
ops
[
idx1
].
desc
.
type
(
_logger
.
info
(
"recompute segment[{}]"
.
format
(
i
))
_logger
.
info
(
"segment start op: [{}]: [{}]"
.
format
(
ops
[
idx1
].
desc
.
type
(
),
ops
[
idx1
].
desc
.
input_arg_names
()))
_logger
.
debug
(
"segment end op: [{}]: [{}]"
.
format
(
ops
[
_logger
.
info
(
"segment end op: [{}]: [{}]"
.
format
(
ops
[
idx2
-
1
].
desc
.
type
(),
ops
[
idx2
-
1
].
desc
.
input_arg_names
()))
# 2) go through all forward ops and induct all variables that will be hold in memory
...
...
@@ -825,9 +831,9 @@ def _append_backward_ops_with_checkpoints_(
program_stat
.
get_out_of_subgraph_vars
(
segment
[
0
],
segment
[
1
]))
cross_vars
=
set
(
vars_should_be_hold
)
-
set
(
checkpoints_name
)
_logger
.
debug
(
"found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars"
.
format
(
\
_logger
.
info
(
"found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars"
.
format
(
\
len
(
cross_vars
),
cross_vars
))
_logger
.
debug
(
"found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars"
.
format
(
\
_logger
.
info
(
"found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars"
.
format
(
\
len
(
cross_vars
),
cross_vars
))
# b. output of seed op should be kept in memory
...
...
@@ -843,6 +849,7 @@ def _append_backward_ops_with_checkpoints_(
vars_in_memory
=
vars_should_be_hold
+
checkpoints_name
max_calculated_op_position
=
len
(
ops
)
device_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpDeviceAttrName
()
if
recompute_segments
==
[]:
gap_ops
=
ops
[
0
:
max_calculated_op_position
]
for
op
in
reversed
(
gap_ops
):
...
...
@@ -852,6 +859,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_
(
op
.
desc
,
"with_sub_block"
))
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
# Set device for grad_op according to forward Op
if
op
.
desc
.
has_attr
(
device_attr_name
):
op_device
=
op
.
desc
.
attr
(
device_attr_name
)
for
op_desc
in
grad_op_desc
:
op_desc
.
_set_attr
(
device_attr_name
,
op_device
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
)
grad_op_descs
.
extend
(
added_descs
)
grad_to_var
.
update
(
op_grad_to_var
)
...
...
@@ -866,6 +878,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_
(
op
.
desc
,
"with_sub_block"
))
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
# Set device for grad_op according to forward Op
if
op
.
desc
.
has_attr
(
device_attr_name
):
op_device
=
op
.
desc
.
attr
(
device_attr_name
)
for
op_desc
in
grad_op_desc
:
op_desc
.
_set_attr
(
device_attr_name
,
op_device
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
)
grad_op_descs
.
extend
(
added_descs
)
grad_to_var
.
update
(
op_grad_to_var
)
...
...
@@ -888,6 +905,18 @@ def _append_backward_ops_with_checkpoints_(
continue
if
name
not
in
var_name_dict
:
var_name_dict
[
name
]
=
name
+
var_suffix
# we should create the rename var in subprog, otherwise its VarType will be BOOL
block
.
create_var
(
name
=
var_name_dict
[
name
],
shape
=
block
.
program
.
global_block
().
var
(
name
).
shape
,
dtype
=
block
.
program
.
global_block
().
var
(
name
).
dtype
,
type
=
block
.
program
.
global_block
().
var
(
name
).
type
,
persistable
=
block
.
program
.
global_block
().
var
(
name
).
persistable
,
stop_gradient
=
block
.
program
.
global_block
().
var
(
name
)
.
stop_gradient
)
# 3.a. add ops in current recompute_segment as forward recomputation ops
buffer_descs
=
_add_needed_descs_to_block
(
ff_ops
,
buffer_block
,
block
,
vars_in_memory
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录