Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1980e33a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1980e33a
编写于
3月 02, 2022
作者:
L
Leo Chen
提交者:
GitHub
3月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add check for backward hook (#40041)
* add check for backward hook * refine ut
上级
09258040
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
62 addition
and
2 deletion
+62
-2
paddle/fluid/imperative/basic_engine.cc
paddle/fluid/imperative/basic_engine.cc
+1
-0
paddle/fluid/imperative/gradient_accumulator.cc
paddle/fluid/imperative/gradient_accumulator.cc
+1
-0
paddle/fluid/imperative/gradient_accumulator.h
paddle/fluid/imperative/gradient_accumulator.h
+24
-0
python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py
...d/tests/unittests/test_imperative_auto_mixed_precision.py
+36
-2
未找到文件。
paddle/fluid/imperative/basic_engine.cc
浏览文件 @
1980e33a
...
@@ -317,6 +317,7 @@ static std::shared_ptr<NameVarMap<VariableWrapper>> CallGradientHooks(
...
@@ -317,6 +317,7 @@ static std::shared_ptr<NameVarMap<VariableWrapper>> CallGradientHooks(
auto
tmp_var
=
var
;
auto
tmp_var
=
var
;
for
(
const
auto
&
hook_pair
:
var
->
GetVariableWrapperHooks
())
{
for
(
const
auto
&
hook_pair
:
var
->
GetVariableWrapperHooks
())
{
tmp_var
=
(
*
hook_pair
.
second
)(
tmp_var
);
tmp_var
=
(
*
hook_pair
.
second
)(
tmp_var
);
CheckVar
(
var
,
tmp_var
);
}
}
(
*
tmp_ins_ptr
)[
pair
.
first
][
i
]
=
tmp_var
;
(
*
tmp_ins_ptr
)[
pair
.
first
][
i
]
=
tmp_var
;
}
}
...
...
paddle/fluid/imperative/gradient_accumulator.cc
浏览文件 @
1980e33a
...
@@ -732,6 +732,7 @@ void GradientAccumulator::CallGradientHooks() {
...
@@ -732,6 +732,7 @@ void GradientAccumulator::CallGradientHooks() {
<<
var_
->
GetVariableWrapperHooks
().
size
();
<<
var_
->
GetVariableWrapperHooks
().
size
();
for
(
const
auto
&
hook_pair
:
var_
->
GetVariableWrapperHooks
())
{
for
(
const
auto
&
hook_pair
:
var_
->
GetVariableWrapperHooks
())
{
tmp_var
=
(
*
hook_pair
.
second
)(
tmp_var
);
tmp_var
=
(
*
hook_pair
.
second
)(
tmp_var
);
CheckVar
(
inner_var_
,
tmp_var
);
}
}
inner_var_
=
tmp_var
;
inner_var_
=
tmp_var
;
}
}
...
...
paddle/fluid/imperative/gradient_accumulator.h
浏览文件 @
1980e33a
...
@@ -179,5 +179,29 @@ void SelectedRowsAddTensor(const VarType& src_selected_rows_var,
...
@@ -179,5 +179,29 @@ void SelectedRowsAddTensor(const VarType& src_selected_rows_var,
template
<
typename
VarType
>
template
<
typename
VarType
>
void
TensorAdd
(
const
VarType
&
src
,
VarType
*
dst
);
void
TensorAdd
(
const
VarType
&
src
,
VarType
*
dst
);
inline
void
CheckVar
(
const
std
::
shared_ptr
<
VariableWrapper
>&
pre
,
const
std
::
shared_ptr
<
VariableWrapper
>&
post
)
{
if
(
pre
->
IsEmpty
()
&&
!
post
->
IsEmpty
())
{
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"The tensor(%s) in before and after hook are not consistent"
,
pre
->
Name
()));
}
if
(
!
pre
->
IsEmpty
()
&&
!
post
->
IsEmpty
())
{
VLOG
(
4
)
<<
pre
->
DataType
()
<<
" "
<<
post
->
DataType
();
PADDLE_ENFORCE_EQ
(
pre
->
DataType
(),
post
->
DataType
(),
platform
::
errors
::
PermissionDenied
(
"The dtype of tensor(%s) before(%s) and after(%s) hook are not "
"consistent"
,
pre
->
Name
(),
framework
::
DataTypeToString
(
pre
->
DataType
()),
framework
::
DataTypeToString
(
post
->
DataType
())));
PADDLE_ENFORCE_EQ
(
pre
->
Place
(),
post
->
Place
(),
platform
::
errors
::
PermissionDenied
(
"The place of tensor(%s) before(%s) and after(%s) "
"hook are not consistent"
,
pre
->
Name
(),
pre
->
Place
(),
post
->
Place
()));
}
}
}
// namespace imperative
}
// namespace imperative
}
// namespace paddle
}
// namespace paddle
python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py
浏览文件 @
1980e33a
...
@@ -1156,7 +1156,7 @@ class TestBf16(unittest.TestCase):
...
@@ -1156,7 +1156,7 @@ class TestBf16(unittest.TestCase):
out_fp32
,
out_bf16_O2
,
rtol
=
1.e-3
,
atol
=
1.e-1
))
out_fp32
,
out_bf16_O2
,
rtol
=
1.e-3
,
atol
=
1.e-1
))
class
Test
PyLayerWithAmp
(
unittest
.
TestCase
):
class
Test
AmpWithPyLyer
(
unittest
.
TestCase
):
def
test_pylayer
(
self
):
def
test_pylayer
(
self
):
class
MyMM
(
PyLayer
):
class
MyMM
(
PyLayer
):
@
staticmethod
@
staticmethod
...
@@ -1168,7 +1168,7 @@ class TestPyLayerWithAmp(unittest.TestCase):
...
@@ -1168,7 +1168,7 @@ class TestPyLayerWithAmp(unittest.TestCase):
def
backward
(
ctx
,
grad
):
def
backward
(
ctx
,
grad
):
a
,
b
=
ctx
.
saved_tensor
()
a
,
b
=
ctx
.
saved_tensor
()
# NOTE(zhiqiu): a and b is float32 now, while grad is fp16 when forward runs with auto_cast()
# NOTE(zhiqiu): a and b is float32 now, while grad is fp16 when forward runs with auto_cast()
# thus, the mm operation raise errors because of the dtype of inputs are inconsistent
# thus, the mm operation raise errors because of the dtype of inputs are inconsistent
before.
return
grad
.
mm
(
b
.
t
()),
a
.
t
().
mm
(
grad
)
return
grad
.
mm
(
b
.
t
()),
a
.
t
().
mm
(
grad
)
x
=
paddle
.
rand
([
10
,
10
])
x
=
paddle
.
rand
([
10
,
10
])
...
@@ -1182,5 +1182,39 @@ class TestPyLayerWithAmp(unittest.TestCase):
...
@@ -1182,5 +1182,39 @@ class TestPyLayerWithAmp(unittest.TestCase):
loss
.
backward
()
loss
.
backward
()
class
TestAmpWithHook
(
unittest
.
TestCase
):
def
test_hook_change_dtype
(
self
):
with
paddle
.
fluid
.
dygraph
.
guard
():
v
=
paddle
.
rand
([
3
,
3
])
v
.
stop_gradient
=
False
def
foo
(
grad
):
print
(
'grad'
,
grad
,
grad
.
dtype
)
# grad's dtype is float32
res
=
paddle
.
mm
(
grad
,
grad
)
# mm runs in fp16
print
(
'res'
,
res
,
res
.
dtype
)
# res's dtype is float16
return
res
v
.
register_hook
(
foo
)
with
paddle
.
amp
.
auto_cast
():
a
=
paddle
.
mm
(
v
,
v
)
loss
=
a
.
sum
()
self
.
assertRaises
(
RuntimeError
,
loss
.
backward
)
def
test_hook_change_place
(
self
):
with
paddle
.
fluid
.
dygraph
.
guard
():
v
=
paddle
.
rand
([
3
,
3
])
v
.
stop_gradient
=
False
def
foo
(
grad
):
res
=
grad
.
cpu
()
# change place
return
res
v
.
register_hook
(
foo
)
with
paddle
.
amp
.
auto_cast
():
a
=
paddle
.
mm
(
v
,
v
)
loss
=
a
.
sum
()
self
.
assertRaises
(
RuntimeError
,
loss
.
backward
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录