Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
601d7a35
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
601d7a35
编写于
6月 07, 2022
作者:
S
sneaxiy
提交者:
GitHub
6月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add use_master_acc_grad for DistributedFusedLamb (#43266)
* add use_master_acc_grad * add ut
上级
5dcebb9b
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
34 addition
and
6 deletion
+34
-6
paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc
...e/fluid/operators/optimizers/distributed_fused_lamb_op.cc
+3
-0
paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu
...e/fluid/operators/optimizers/distributed_fused_lamb_op.cu
+11
-4
python/paddle/fluid/tests/unittests/distributed_fused_lamb_test_base.py
...fluid/tests/unittests/distributed_fused_lamb_test_base.py
+8
-1
python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_with_clip.py
...sts/unittests/test_distributed_fused_lamb_op_with_clip.py
+3
-1
python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_with_gradient_merge.py
...sts/test_distributed_fused_lamb_op_with_gradient_merge.py
+6
-0
python/paddle/incubate/optimizer/distributed_fused_lamb.py
python/paddle/incubate/optimizer/distributed_fused_lamb.py
+3
-0
未找到文件。
paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc
浏览文件 @
601d7a35
...
...
@@ -141,6 +141,9 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker {
"NCCL communication data. If it is false, it would be less accurate "
"and be less NCCL communication data."
)
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"use_master_acc_grad"
,
"Whether to use master gradient when acc_steps > 1."
)
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"is_grad_scaled_by_nranks"
,
"Whether the input gradient has been scaled by nranks."
)
.
SetDefault
(
true
);
...
...
paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu
浏览文件 @
601d7a35
...
...
@@ -1193,7 +1193,9 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
platform
::
float16
*
fp16_acc_grad
=
nullptr
;
float
*
master_acc_grad
=
nullptr
;
bool
use_master_acc_grad
=
false
;
if
(
has_fp16_param
)
{
use_master_acc_grad
=
ctx
.
Attr
<
bool
>
(
"use_master_acc_grad"
);
auto
*
fp16_acc_grad_t
=
ctx
.
Output
<
framework
::
Tensor
>
(
"FP16AccFusedGrad"
);
PADDLE_ENFORCE_NOT_NULL
(
...
...
@@ -1201,13 +1203,18 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
"Output(FP16AccFusedGrad) cannot be nullptr "
"when Attr(acc_steps) > 1."
));
if
(
!
fp16_acc_grad_t
->
IsInitialized
())
{
fp16_acc_grad_t
->
Resize
({
static_cast
<
int64_t
>
(
3
*
fp16_numel
)});
auto
acc_grad_size
=
use_master_acc_grad
?
(
3
*
fp16_numel
)
:
fp16_numel
;
fp16_acc_grad_t
->
Resize
({
static_cast
<
int64_t
>
(
acc_grad_size
)});
fp16_acc_grad
=
fp16_acc_grad_t
->
mutable_data
<
platform
::
float16
>
(
place
);
}
else
{
fp16_acc_grad
=
fp16_acc_grad_t
->
data
<
platform
::
float16
>
();
}
master_acc_grad
=
reinterpret_cast
<
float
*>
(
fp16_acc_grad
+
fp16_numel
);
if
(
use_master_acc_grad
)
{
master_acc_grad
=
reinterpret_cast
<
float
*>
(
fp16_acc_grad
+
fp16_numel
);
}
}
// Inplace addto
...
...
@@ -1222,8 +1229,8 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
}
if
(
has_fp16_param
)
{
if
(
acc_steps
==
2
)
{
if
(
rounded_step
==
0
)
{
if
(
acc_steps
==
2
||
!
use_master_acc_grad
)
{
if
(
rounded_step
!=
1
)
{
LaunchElementwiseAddWithCastKernel
(
dev_ctx
,
fp16_acc_grad
,
fp16_grad
,
fp16_acc_grad
,
fp16_numel
,
stream
);
...
...
python/paddle/fluid/tests/unittests/distributed_fused_lamb_test_base.py
浏览文件 @
601d7a35
...
...
@@ -162,6 +162,7 @@ def run_model(use_distributed_lamb, use_fp16, use_master_param_norm, **kwargs):
kwargs
=
dict
(
kwargs
)
kwargs
.
pop
(
'clip_after_allreduce'
,
None
)
kwargs
.
pop
(
'alignment'
,
None
)
kwargs
.
pop
(
'use_master_acc_grad'
,
None
)
base_clip
=
grad_clip
if
grad_clip
is
not
None
else
IdentityGradClip
(
)
kwargs
[
'grad_clip'
]
=
GradClipDecorator
(
base_clip
,
...
...
@@ -271,6 +272,7 @@ class TestDistributedFusedLamb(unittest.TestCase):
distutils
.
util
.
strtobool
(
os
.
getenv
(
'CLIP_AFTER_ALLREDUCE'
,
'True'
)))
max_global_norm
=
float
(
os
.
getenv
(
'MAX_GLOBAL_NORM'
,
-
1.0
))
gm_steps
=
int
(
os
.
getenv
(
'GRADIENT_MERGE_STEPS'
,
1
))
use_master_acc_grad
=
bool
(
int
(
os
.
getenv
(
'USE_MASTER_ACC_GRAD'
,
'1'
)))
print
(
'clip_after_allreduce = {}, max_global_norm = {}'
.
format
(
clip_after_allreduce
,
max_global_norm
))
return
{
...
...
@@ -281,9 +283,14 @@ class TestDistributedFusedLamb(unittest.TestCase):
'grad_clip'
:
paddle
.
nn
.
ClipGradByGlobalNorm
(
max_global_norm
)
if
max_global_norm
>
0
else
None
,
'use_master_acc_grad'
:
use_master_acc_grad
,
}
def
run_main
(
self
,
use_fp16
,
use_master_param_norm
=
True
):
def
run_main
(
self
,
use_fp16
,
use_master_param_norm
=
True
,
use_master_acc_grad
=
True
):
if
not
paddle
.
is_compiled_with_cuda
():
return
...
...
python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_with_clip.py
浏览文件 @
601d7a35
...
...
@@ -36,7 +36,8 @@ def remove_file_if_exists(file_name):
def
run_test
(
clip_after_allreduce
=
True
,
max_global_norm
=-
1.0
,
gradient_merge_steps
=
1
):
gradient_merge_steps
=
1
,
use_master_acc_grad
=
True
):
if
not
paddle
.
is_compiled_with_cuda
():
return
if
os
.
name
==
'nt'
:
...
...
@@ -58,6 +59,7 @@ def run_test(clip_after_allreduce=True,
os
.
environ
[
'CLIP_AFTER_ALLREDUCE'
]
=
str
(
clip_after_allreduce
)
os
.
environ
[
'MAX_GLOBAL_NORM'
]
=
str
(
max_global_norm
)
os
.
environ
[
'GRADIENT_MERGE_STEPS'
]
=
str
(
gradient_merge_steps
)
os
.
environ
[
'USE_MASTER_ACC_GRAD'
]
=
str
(
1
if
use_master_acc_grad
else
0
)
touch_file_env
=
'SUCCESS_TOUCH_FILE'
touch_file_name
=
'distributed_fused_lamb_touch_file_{}'
.
format
(
os
.
getpid
())
...
...
python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_with_gradient_merge.py
浏览文件 @
601d7a35
...
...
@@ -23,6 +23,12 @@ class TestDistributedFusedLambGradientMerge(unittest.TestCase):
max_global_norm
=-
1.0
,
gradient_merge_steps
=
2
)
def
test_gm_with_fp16_acc_grad
(
self
):
run_test
(
clip_after_allreduce
=
True
,
max_global_norm
=-
1.0
,
gradient_merge_steps
=
2
,
use_master_acc_grad
=
False
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/incubate/optimizer/distributed_fused_lamb.py
浏览文件 @
601d7a35
...
...
@@ -40,6 +40,7 @@ class DistributedFusedLamb(Optimizer):
alignment
=
128
,
use_master_param_norm
=
True
,
gradient_accumulation_steps
=
1
,
use_master_acc_grad
=
True
,
name
=
None
):
assert
not
framework
.
_non_static_mode
(
),
"DistributedFusedLamb does not support dygraph mode"
...
...
@@ -67,6 +68,7 @@ class DistributedFusedLamb(Optimizer):
self
.
_ring_id
=
0
self
.
_use_master_param_norm
=
use_master_param_norm
self
.
_gradient_accumulation_steps
=
gradient_accumulation_steps
self
.
_use_master_acc_grad
=
use_master_acc_grad
assert
self
.
_gradient_accumulation_steps
>=
1
self
.
helper
=
LayerHelper
(
'distributed_fused_lamb'
)
...
...
@@ -353,5 +355,6 @@ class DistributedFusedLamb(Optimizer):
'use_master_param_norm'
:
self
.
_use_master_param_norm
,
'is_grad_scaled_by_nranks'
:
self
.
_is_grad_scaled_by_nranks
,
'acc_steps'
:
self
.
_gradient_accumulation_steps
,
'use_master_acc_grad'
:
self
.
_use_master_acc_grad
,
})
return
[
lamb_op
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录