Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
282e09dc
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
282e09dc
编写于
2月 27, 2022
作者:
L
Leo Chen
提交者:
GitHub
2月 27, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix pylayer problem with amp (#39950)
* fix pylayer problem with amp * add ut * refine code
上级
b33a3c23
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
50 addition
and
0 deletion
+50
-0
python/paddle/autograd/py_layer.py
python/paddle/autograd/py_layer.py
+10
-0
python/paddle/fluid/dygraph/amp/auto_cast.py
python/paddle/fluid/dygraph/amp/auto_cast.py
+13
-0
python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py
...d/tests/unittests/test_imperative_auto_mixed_precision.py
+27
-0
未找到文件。
python/paddle/autograd/py_layer.py
浏览文件 @
282e09dc
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
import
paddle
import
paddle
from
paddle.fluid.framework
import
dygraph_only
from
paddle.fluid.framework
import
dygraph_only
from
paddle.fluid.dygraph.amp.auto_cast
import
amp_state
from
paddle.amp.auto_cast
import
auto_cast
from
paddle.fluid
import
core
from
paddle.fluid
import
core
__all__
=
[]
__all__
=
[]
...
@@ -46,6 +48,7 @@ class PyLayerContext(object):
...
@@ -46,6 +48,7 @@ class PyLayerContext(object):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
container
=
None
self
.
container
=
None
self
.
_amp_state
=
amp_state
()
def
save_for_backward
(
self
,
*
tensors
):
def
save_for_backward
(
self
,
*
tensors
):
"""
"""
...
@@ -178,6 +181,13 @@ class PyLayerBackward(PyLayerContext):
...
@@ -178,6 +181,13 @@ class PyLayerBackward(PyLayerContext):
def
backward
(
self
,
*
args
,
**
kwargs
):
def
backward
(
self
,
*
args
,
**
kwargs
):
with
paddle
.
fluid
.
dygraph
.
guard
():
with
paddle
.
fluid
.
dygraph
.
guard
():
with
paddle
.
fluid
.
dygraph
.
no_grad
():
with
paddle
.
fluid
.
dygraph
.
no_grad
():
if
self
.
_amp_state
and
'enable'
in
self
.
_amp_state
and
self
.
_amp_state
[
'enable'
]:
with
auto_cast
(
**
args
[
0
].
_amp_state
):
return
self
.
_forward_cls
.
backward
(
*
args
,
**
kwargs
)
else
:
return
self
.
_forward_cls
.
backward
(
*
args
,
**
kwargs
)
return
self
.
_forward_cls
.
backward
(
*
args
,
**
kwargs
)
return
self
.
_forward_cls
.
backward
(
*
args
,
**
kwargs
)
...
...
python/paddle/fluid/dygraph/amp/auto_cast.py
浏览文件 @
282e09dc
...
@@ -78,6 +78,13 @@ PURE_FP16_BLACK_LIST = {
...
@@ -78,6 +78,13 @@ PURE_FP16_BLACK_LIST = {
BF16_WHITE_LIST
=
{
'conv2d'
}
BF16_WHITE_LIST
=
{
'conv2d'
}
BF16_BLACK_LIST
=
{
' '
}
BF16_BLACK_LIST
=
{
' '
}
_g_amp_state_
=
None
def
amp_state
():
global
_g_amp_state_
return
_g_amp_state_
#NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list
#NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list
# The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode.
# The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode.
...
@@ -240,6 +247,11 @@ def amp_guard(enable=True,
...
@@ -240,6 +247,11 @@ def amp_guard(enable=True,
print(conv.dtype) # FP32
print(conv.dtype) # FP32
"""
"""
amp_state
=
locals
()
global
_g_amp_state_
original_state
=
_g_amp_state_
_g_amp_state_
=
amp_state
# check amp_level: O0-O2
# check amp_level: O0-O2
level
=
level
.
upper
()
level
=
level
.
upper
()
if
not
(
level
in
[
'O0'
,
'O1'
,
'O2'
]):
if
not
(
level
in
[
'O0'
,
'O1'
,
'O2'
]):
...
@@ -349,6 +361,7 @@ def amp_guard(enable=True,
...
@@ -349,6 +361,7 @@ def amp_guard(enable=True,
yield
yield
finally
:
finally
:
if
tracer
:
if
tracer
:
_g_amp_state_
=
original_state
tracer
.
_amp_level
=
original_amp_level
tracer
.
_amp_level
=
original_amp_level
tracer
.
_set_amp_op_list
(
original_white_list
,
original_black_list
)
tracer
.
_set_amp_op_list
(
original_white_list
,
original_black_list
)
# set_flags(original_flags)
# set_flags(original_flags)
...
...
python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py
浏览文件 @
282e09dc
...
@@ -20,6 +20,7 @@ import six
...
@@ -20,6 +20,7 @@ import six
from
test_imperative_resnet
import
ResNet
,
BottleneckBlock
,
ConvBNLayer
,
train_parameters
,
optimizer_setting
from
test_imperative_resnet
import
ResNet
,
BottleneckBlock
,
ConvBNLayer
,
train_parameters
,
optimizer_setting
import
paddle.nn
as
nn
import
paddle.nn
as
nn
from
paddle.static
import
InputSpec
from
paddle.static
import
InputSpec
from
paddle.autograd
import
PyLayer
if
fluid
.
core
.
is_compiled_with_cuda
():
if
fluid
.
core
.
is_compiled_with_cuda
():
fluid
.
set_flags
({
"FLAGS_cudnn_deterministic"
:
True
})
fluid
.
set_flags
({
"FLAGS_cudnn_deterministic"
:
True
})
...
@@ -1146,5 +1147,31 @@ class TestBf16(unittest.TestCase):
...
@@ -1146,5 +1147,31 @@ class TestBf16(unittest.TestCase):
self
.
assertTrue
(
np
.
allclose
(
out_fp32
,
out_bf16
,
rtol
=
1.e-3
,
atol
=
1.e-1
))
self
.
assertTrue
(
np
.
allclose
(
out_fp32
,
out_bf16
,
rtol
=
1.e-3
,
atol
=
1.e-1
))
class
TestPyLayerWithAmp
(
unittest
.
TestCase
):
def
test_pylayer
(
self
):
class
MyMM
(
PyLayer
):
@
staticmethod
def
forward
(
ctx
,
a
,
b
):
ctx
.
save_for_backward
(
a
,
b
)
return
a
.
mm
(
b
)
@
staticmethod
def
backward
(
ctx
,
grad
):
a
,
b
=
ctx
.
saved_tensor
()
# NOTE(zhiqiu): a and b is float32 now, while grad is fp16 when forward runs with auto_cast()
# thus, the mm operation raise errors because of the dtype of inputs are inconsistent
return
grad
.
mm
(
b
.
t
()),
a
.
t
().
mm
(
grad
)
x
=
paddle
.
rand
([
10
,
10
])
y
=
paddle
.
rand
([
10
,
10
])
x
.
stop_gradient
=
False
y
.
stop_gradient
=
False
with
paddle
.
amp
.
auto_cast
():
res
=
MyMM
.
apply
(
x
,
y
)
loss
=
paddle
.
mean
(
res
)
loss
.
backward
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录