Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
5dab0b0d
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5dab0b0d
编写于
9月 27, 2022
作者:
Z
zhaoyingli
提交者:
GitHub
9月 27, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AutoParallel] fix amp o1 (#46391) (#46481)
上级
5711bbee
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
27 addition
and
16 deletion
+27
-16
python/paddle/distributed/passes/auto_parallel_amp.py
python/paddle/distributed/passes/auto_parallel_amp.py
+24
-16
python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
.../fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
+3
-0
未找到文件。
python/paddle/distributed/passes/auto_parallel_amp.py
浏览文件 @
5dab0b0d
...
...
@@ -38,14 +38,18 @@ class AMPState(object):
self
.
_op_fp16_dict
=
{
}
# op_id --> True/False. 'True' means that the current op is in fp16 mode.
self
.
_var_name_dict
=
{}
# fwd_op_id --> {old_name: cast_name}
self
.
is_train
=
False
def
_is_fp16_op
(
self
,
op_id
):
return
self
.
_op_fp16_dict
.
get
(
op_id
,
None
)
def
_build_stat
s
(
self
,
amp_lists
,
dist_context
):
def
_build_stat
e
(
self
,
amp_lists
,
dist_context
):
ops
=
self
.
_block
.
ops
dist_op_context
=
dist_context
.
dist_op_context
for
op
in
ops
:
if
int
(
op
.
attr
(
'op_role'
))
==
257
:
self
.
is_train
=
True
if
int
(
op
.
attr
(
'op_role'
))
==
int
(
OpRole
.
Forward
):
self
.
_mark_black_white_ops
(
amp_lists
)
elif
int
(
op
.
attr
(
'op_role'
))
==
int
(
OpRole
.
Backward
):
...
...
@@ -59,6 +63,8 @@ class AMPState(object):
elif
int
(
op
.
attr
(
'op_role'
))
==
int
(
OpRole
.
Optimize
):
break
return
self
.
is_train
def
_mark_black_white_ops
(
self
,
amp_lists
):
"""
this function is modified from paddle.fluid.contrib.mixed_precision
...
...
@@ -546,23 +552,25 @@ class AMPPass(PassBase):
set
(
self
.
get_attr
(
"custom_black_list"
)),
set
(
self
.
get_attr
(
"custom_black_varnames"
)))
amp_state
=
AMPState
(
main_program
.
global_block
())
amp_state
.
_build_stats
(
amp_lists
,
self
.
dist_context
)
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
amp_state
=
AMPState
(
main_program
.
global_block
())
is_train
=
amp_state
.
_build_state
(
amp_lists
,
self
.
dist_context
)
amp_state
.
cast_forward_program
(
self
.
dist_context
)
amp_state
.
cast_backward_program
(
params_grads
,
self
.
dist_context
)
# TODO (JZ-LIANG)support cast forward program only when inference
self
.
_init_amp_var
()
self
.
_scale_loss
()
if
self
.
get_attr
(
"use_dynamic_loss_scaling"
)
or
self
.
get_attr
(
"init_loss_scaling"
)
!=
1.0
:
grads
,
found_inf
=
_check_and_update_gradient
(
params_grads
,
self
.
_loss_scaling
,
self
.
dist_context
)
if
self
.
get_attr
(
"use_dynamic_loss_scaling"
):
self
.
_update_loss_scaling
(
grads
,
found_inf
)
if
is_train
:
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
amp_state
.
cast_backward_program
(
params_grads
,
self
.
dist_context
)
self
.
_init_amp_var
()
self
.
_scale_loss
()
if
self
.
get_attr
(
"use_dynamic_loss_scaling"
)
or
self
.
get_attr
(
"init_loss_scaling"
)
!=
1.0
:
grads
,
found_inf
=
_check_and_update_gradient
(
params_grads
,
self
.
_loss_scaling
,
self
.
dist_context
)
if
self
.
get_attr
(
"use_dynamic_loss_scaling"
):
self
.
_update_loss_scaling
(
grads
,
found_inf
)
def
_init_amp_var
(
self
):
self
.
_loss_scaling
=
paddle
.
static
.
create_global_var
(
...
...
python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
浏览文件 @
5dab0b0d
...
...
@@ -97,6 +97,7 @@ class TestAMPPass(unittest.TestCase):
3
,
batch_size
=
self
.
batch_size
)
amp_o1_losses
=
np
.
array
(
amp_o1_losses
[
"loss"
])
amp_o1_engine
.
evaluate
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
)
# self.check_results(mp_losses, amp_o1_losses)
# mp2 amp-o2 training
...
...
@@ -105,6 +106,7 @@ class TestAMPPass(unittest.TestCase):
3
,
batch_size
=
self
.
batch_size
)
amp_o2_losses
=
np
.
array
(
amp_o2_losses
[
"loss"
])
amp_o2_engine
.
evaluate
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
)
# self.check_results(mp_losses, amp_o2_losses)
# mp2 amp-o3 training
...
...
@@ -113,6 +115,7 @@ class TestAMPPass(unittest.TestCase):
3
,
batch_size
=
self
.
batch_size
)
amp_o3_losses
=
np
.
array
(
amp_o3_losses
[
"loss"
])
amp_o3_engine
.
evaluate
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
)
# self.check_results(mp_losses, amp_o3_losses)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录