Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a4eadd15
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a4eadd15
编写于
9月 16, 2021
作者:
Y
Yuang Liu
提交者:
GitHub
9月 16, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[hybrid] Fix mp multi gradient clip prob (#35713)
上级
4b683887
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
127 addition
and
23 deletion
+127
-23
python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
...ed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
+90
-19
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+1
-2
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+9
-2
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
...uid/tests/unittests/test_fleet_sharding_meta_optimizer.py
+27
-0
未找到文件。
python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
浏览文件 @
a4eadd15
...
@@ -142,32 +142,103 @@ class GradientClipHelper(object):
...
@@ -142,32 +142,103 @@ class GradientClipHelper(object):
return
return
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
def
sync_global_norm
(
self
,
block
,
ring_ids
):
def
sync_global_norm
(
self
,
block
,
ring_ids
,
mp_rank
):
"""
"""
prune gradient_clip related ops for params that not belong to cur shard
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
keep: sum, sqrt, elementwise_max, elementwise_div
"""
"""
# FIXME(wangxi): mp should prune duplicated param_grads
is_clip_grad_by_global_norm
=
False
for
idx
,
op
in
list
(
enumerate
(
block
.
ops
)):
if
not
self
.
_is_gradient_clip_op
(
op
):
continue
if
op
.
type
==
'sum'
:
is_clip_grad_by_global_norm
=
True
break
if
not
is_clip_grad_by_global_norm
:
# TODO(Yuang Liu): need some extra handles when clip_grad_norm for mp
return
removed_op_idx
=
set
()
removed_tmp_var
=
set
()
for
idx
,
op
in
list
(
enumerate
(
block
.
ops
)):
if
not
self
.
_is_gradient_clip_op
(
op
):
continue
if
op
.
type
==
'sum'
:
break
for
input_name
in
op
.
input_arg_names
:
input_var
=
block
.
var
(
input_name
)
# NOTE: when mp_degree > 1, some vars will be split into each mp rank.
# However, there still some vars such as Scale, Bias are not split.
# Those not be split vars should only be counted once during grad clip
# by global norm. Those vars either doesn't have is_distributed attr
# or the is_distributed attr has been set as False.
# Therefore, we prune those duplicated vars for grad clip.
if
mp_rank
>=
1
and
(
not
(
hasattr
(
input_var
,
'is_distributed'
)
and
input_var
.
is_distributed
)):
removed_op_idx
.
add
(
idx
)
for
output_name
in
op
.
output_arg_names
:
removed_tmp_var
.
add
(
output_name
)
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
not
self
.
_is_gradient_clip_op
(
op
):
if
not
self
.
_is_gradient_clip_op
(
op
):
continue
continue
if
idx
in
removed_op_idx
:
block
.
_remove_op
(
idx
,
sync
=
False
)
if
op
.
type
==
"sum"
:
for
var_name
in
removed_tmp_var
:
sum_res
=
op
.
desc
.
output_arg_names
()[
0
]
block
.
_remove_var
(
var_name
,
sync
=
False
)
for
idx
,
op
in
list
(
enumerate
(
block
.
ops
)):
if
not
self
.
_is_gradient_clip_op
(
op
):
continue
if
op
.
type
==
'sum'
:
# If mp_rank == 0, no extra handles, just allreduce
# If mp_rank >= 1, some extra handles is needed
sum_rst_var
=
block
.
var
(
op
.
output_arg_names
[
0
])
if
mp_rank
>=
1
:
reserved_vars
=
[]
for
input_name
in
op
.
input_arg_names
:
if
input_name
not
in
removed_tmp_var
:
reserved_vars
.
append
(
input_name
)
if
len
(
reserved_vars
)
>
0
:
op
.
desc
.
set_input
(
"X"
,
reserved_vars
)
else
:
# If all input of sum op should be removed, then remove the sum op.
# And set the output's value of sum to 0.
namescope
=
op
.
attr
(
"op_namescope"
)
block
.
_remove_op
(
idx
,
sync
=
False
)
fill_constant_op
=
block
.
_insert_op_without_sync
(
idx
,
type
=
'fill_constant'
,
inputs
=
{},
outputs
=
{
'Out'
:
sum_rst_var
},
attrs
=
{
'shape'
:
sum_rst_var
.
shape
,
'dtype'
:
sum_rst_var
.
dtype
,
'value'
:
0.0
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
fill_constant_op
.
_set_attr
(
'op_namescope'
,
namescope
)
self
.
_insert_allreduce
(
block
,
ring_ids
,
idx
,
sum_rst_var
)
break
@
staticmethod
def
_insert_allreduce
(
block
,
ring_ids
,
idx
,
var
):
for
ring_id
in
ring_ids
:
for
ring_id
in
ring_ids
:
if
ring_id
==
-
1
:
continue
if
ring_id
==
-
1
:
continue
idx
=
idx
+
1
idx
=
idx
+
1
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
idx
,
idx
,
type
=
'c_allreduce_sum'
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
sum_res
},
inputs
=
{
'X'
:
var
},
outputs
=
{
'Out'
:
sum_res
},
outputs
=
{
'Out'
:
var
},
attrs
=
{
attrs
=
{
'ring_id'
:
ring_id
,
'ring_id'
:
ring_id
,
'op_namescope'
:
"/gradient_clip_model_parallelism"
,
'op_namescope'
:
"/gradient_clip_model_parallelism"
,
'use_calc_stream'
:
True
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Optimize
,
OP_ROLE_KEY
:
OpRole
.
Optimize
,
})
})
return
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
浏览文件 @
a4eadd15
...
@@ -435,7 +435,6 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -435,7 +435,6 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block
=
self
.
_main_program
.
global_block
()
main_block
=
self
.
_main_program
.
global_block
()
startup_block
=
self
.
_startup_program
.
global_block
()
startup_block
=
self
.
_startup_program
.
global_block
()
# FIXME(wangxi): mp should prune duplicated param_grads when calc
# amp inf_var & clip global_norm_var
# amp inf_var & clip global_norm_var
rings
=
[
self
.
mp_ring_id
,
self
.
pp_ring_id
]
rings
=
[
self
.
mp_ring_id
,
self
.
pp_ring_id
]
...
@@ -446,7 +445,7 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -446,7 +445,7 @@ class ShardingOptimizer(MetaOptimizerBase):
gradientclip_helper
=
GradientClipHelper
(
None
)
gradientclip_helper
=
GradientClipHelper
(
None
)
gradientclip_helper
.
sync_global_norm
(
gradientclip_helper
.
sync_global_norm
(
main_block
,
[
self
.
mp_ring_id
,
self
.
pp_ring_id
])
main_block
,
[
self
.
mp_ring_id
,
self
.
pp_ring_id
]
,
self
.
mp_rank
)
def
_insert_loss_grad_scale_op
(
self
):
def
_insert_loss_grad_scale_op
(
self
):
main_block
=
self
.
_main_program
.
global_block
()
main_block
=
self
.
_main_program
.
global_block
()
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
a4eadd15
...
@@ -4381,7 +4381,7 @@ class PipelineOptimizer(object):
...
@@ -4381,7 +4381,7 @@ class PipelineOptimizer(object):
persistable
=
source_var
.
persistable
)
persistable
=
source_var
.
persistable
)
else
:
else
:
dest_var
=
block
.
_clone_variable
(
source_var
,
False
)
dest_var
=
block
.
_clone_variable
(
source_var
,
False
)
dest_var
.
stop_gradient
=
source_var
.
stop_gradient
self
.
_clone_var_attr
(
dest_var
,
source_var
)
# When use with sharding, allreduce_sum and allreduce_max
# When use with sharding, allreduce_sum and allreduce_max
# used for global gradient clip and amp will be added by sharding.
# used for global gradient clip and amp will be added by sharding.
op_idx
+=
1
op_idx
+=
1
...
@@ -4547,9 +4547,14 @@ class PipelineOptimizer(object):
...
@@ -4547,9 +4547,14 @@ class PipelineOptimizer(object):
persistable
=
ref_var
.
persistable
,
persistable
=
ref_var
.
persistable
,
is_data
=
ref_var
.
is_data
,
is_data
=
ref_var
.
is_data
,
need_check_feed
=
ref_var
.
desc
.
need_check_feed
())
need_check_feed
=
ref_var
.
desc
.
need_check_feed
())
new_var
.
stop_gradient
=
ref_var
.
stop_gradient
self
.
_clone_var_attr
(
new_var
,
ref_var
)
return
new_var
return
new_var
def
_clone_var_attr
(
self
,
dest
,
src
):
dest
.
stop_gradient
=
src
.
stop_gradient
if
hasattr
(
src
,
'is_distributed'
):
dest
.
is_distributed
=
src
.
is_distributed
def
_strip_grad_suffix
(
self
,
name
):
def
_strip_grad_suffix
(
self
,
name
):
"""
"""
Strip the grad suffix from the given variable name
Strip the grad suffix from the given variable name
...
@@ -5209,6 +5214,8 @@ class PipelineOptimizer(object):
...
@@ -5209,6 +5214,8 @@ class PipelineOptimizer(object):
persistable
=
True
,
persistable
=
True
,
stop_gradient
=
False
)
stop_gradient
=
False
)
real_param
=
main_block
.
var
(
param
)
real_param
=
main_block
.
var
(
param
)
if
hasattr
(
real_param
,
'is_distributed'
):
merged_grad_var
.
is_distributed
=
real_param
.
is_distributed
tmp_size
=
self
.
_get_var_size
(
real_grad
)
tmp_size
=
self
.
_get_var_size
(
real_grad
)
# two strategies for splitting the grad
# two strategies for splitting the grad
# 1. the current segment's size reach the user defined grad_size_in_MB
# 1. the current segment's size reach the user defined grad_size_in_MB
...
...
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
浏览文件 @
a4eadd15
...
@@ -658,6 +658,33 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
...
@@ -658,6 +658,33 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'c_gen_nccl_id'
,
'c_comm_init'
'c_gen_nccl_id'
,
'c_comm_init'
])
])
self
.
assertEqual
(
main_prog_op_types
,
[
'partial_recv'
,
'partial_allgather'
,
'cast'
,
'cast'
,
'mul'
,
'cast'
,
'elementwise_add'
,
'cast'
,
'tanh'
,
'cast'
,
'cast'
,
'mul'
,
'cast'
,
'elementwise_add'
,
'cast'
,
'tanh'
,
'cast'
,
'cast'
,
'mul'
,
'cast'
,
'elementwise_add'
,
'cast'
,
'tanh'
,
'cast'
,
'cast'
,
'mul'
,
'cast'
,
'elementwise_add'
,
'softmax'
,
'cast'
,
'cross_entropy2'
,
'mean'
,
'elementwise_mul'
,
'fill_constant'
,
'elementwise_mul_grad'
,
'mean_grad'
,
'cross_entropy_grad2'
,
'cast'
,
'softmax_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'tanh_grad'
,
'cast'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'tanh_grad'
,
'cast'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'tanh_grad'
,
'cast'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'c_sync_calc_stream'
,
'partial_send'
,
'fill_constant'
,
'cast'
,
'sum'
,
'fill_constant'
,
'cast'
,
'sum'
,
'fill_constant'
,
'cast'
,
'sum'
,
'fill_constant'
,
'cast'
,
'sum'
,
'fill_constant'
,
'cast'
,
'sum'
,
'fill_constant'
,
'cast'
,
'sum'
,
'fill_constant'
,
'cast'
,
'sum'
,
'fill_constant'
,
'cast'
,
'sum'
,
'c_sync_comm_stream'
,
'check_finite_and_unscale'
,
'cast'
,
'c_allreduce_max'
,
'c_allreduce_max'
,
'cast'
,
'update_loss_scaling'
,
'fill_constant'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'sqrt'
,
'fill_constant'
,
'elementwise_max'
,
'elementwise_div'
,
'elementwise_mul'
,
'elementwise_mul'
,
'elementwise_mul'
,
'elementwise_mul'
,
'elementwise_mul'
,
'elementwise_mul'
,
'elementwise_mul'
,
'elementwise_mul'
,
'momentum'
,
'momentum'
,
'momentum'
,
'momentum'
,
'momentum'
,
'momentum'
,
'momentum'
,
'momentum'
])
# pp + mp, partial send recv
# pp + mp, partial send recv
self
.
assertIn
(
'partial_recv'
,
main_prog_op_types
)
self
.
assertIn
(
'partial_recv'
,
main_prog_op_types
)
self
.
assertIn
(
'partial_allgather'
,
main_prog_op_types
)
self
.
assertIn
(
'partial_allgather'
,
main_prog_op_types
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录