Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
31344ab7
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
31344ab7
编写于
11月 22, 2021
作者:
Z
Zhanlue Yang
提交者:
GitHub
11月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add backward function hook to dygraph (#37141)
上级
21957476
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
104 addition
and
0 deletion
+104
-0
paddle/fluid/imperative/basic_engine.cc
paddle/fluid/imperative/basic_engine.cc
+7
-0
paddle/fluid/imperative/op_base.h
paddle/fluid/imperative/op_base.h
+14
-0
paddle/fluid/pybind/imperative.cc
paddle/fluid/pybind/imperative.cc
+20
-0
python/paddle/fluid/tests/unittests/test_function_hook.py
python/paddle/fluid/tests/unittests/test_function_hook.py
+63
-0
未找到文件。
paddle/fluid/imperative/basic_engine.cc
浏览文件 @
31344ab7
...
...
@@ -569,6 +569,13 @@ void BasicEngine::Execute() {
}
}
// Function Post Hook
if
(
cur_op
.
HasVoidFunctionPostHook
())
{
for
(
const
auto
&
hook
:
cur_op
.
GetVoidFunctionPostHooks
())
{
(
*
hook
)();
}
}
for
(
auto
&
pair
:
inplace_output_grad_var_list_
)
{
*
pair
.
first
=
std
::
move
(
*
pair
.
second
);
}
...
...
paddle/fluid/imperative/op_base.h
浏览文件 @
31344ab7
...
...
@@ -186,6 +186,19 @@ class OpBase {
static
pten
::
KernelContext
*
GetKernelContext
()
{
return
&
pt_kernel_context_
;
}
bool
HasVoidFunctionPostHook
()
const
{
return
!
void_function_post_hooks_
.
empty
();
}
void
AddVoidFunctionPostHook
(
std
::
shared_ptr
<
std
::
function
<
void
()
>>&&
hook
)
{
void_function_post_hooks_
.
emplace_back
(
std
::
move
(
hook
));
}
const
std
::
vector
<
std
::
shared_ptr
<
std
::
function
<
void
()
>>>&
GetVoidFunctionPostHooks
()
const
{
return
void_function_post_hooks_
;
}
private:
static
const
std
::
string
&
UnknownOpType
()
{
static
std
::
string
kUnknownOpType
{
"unknown"
};
...
...
@@ -203,6 +216,7 @@ class OpBase {
// In order to reduce the compatibility phase
// performance overhead, temporarily cache KernelContext
static
pten
::
KernelContext
pt_kernel_context_
;
std
::
vector
<
std
::
shared_ptr
<
std
::
function
<
void
()
>>>
void_function_post_hooks_
;
};
class
GradOpNode
{
...
...
paddle/fluid/pybind/imperative.cc
浏览文件 @
31344ab7
...
...
@@ -1640,6 +1640,26 @@ void BindImperative(py::module *m_ptr) {
"gradient or without gradient."
));
return
self
.
GradVarBase
()
->
RemoveVariableWrapperHook
(
hook_id
);
})
.
def
(
"_register_void_function_post_hook"
,
[](
imperative
::
VarBase
&
self
,
const
py
::
handle
&
hook
)
{
PADDLE_ENFORCE_EQ
(
!
self
.
OverridedStopGradient
()
&&
self
.
HasGradVar
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Cannot register void function post hook on a Tensor that "
"stop "
"gradient or without gradient."
));
auto
py_func
=
PyObjectCast
<
std
::
function
<
void
()
>>
(
hook
.
ptr
());
VLOG
(
1
)
<<
111
;
auto
grad_node
=
self
.
MutableGradVarBase
()
->
GradNode
();
VLOG
(
1
)
<<
222
;
VLOG
(
1
)
<<
(
grad_node
==
nullptr
);
for
(
auto
&
cur_op
:
*
grad_node
)
{
VLOG
(
1
)
<<
333
;
cur_op
.
AddVoidFunctionPostHook
(
std
::
make_shared
<
std
::
function
<
void
()
>>
(
py_func
));
VLOG
(
1
)
<<
444
;
}
})
.
def
(
"_register_backward_hook"
,
[](
imperative
::
VarBase
&
self
,
const
py
::
handle
&
hook
)
{
PADDLE_ENFORCE_EQ
(
...
...
python/paddle/fluid/tests/unittests/test_function_hook.py
0 → 100644
浏览文件 @
31344ab7
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle
import
numpy
as
np
import
paddle.fluid.core
as
core
from
paddle
import
_C_ops
class
TestCapture
:
def
__init__
(
self
):
self
.
list
=
[]
test_cap
=
TestCapture
()
def
test_hook
():
test_cap
.
list
.
append
(
1
)
def
grad_hook
(
grad
):
test_cap
.
list
.
append
(
2
)
return
grad
class
TestBakcwardFunctionHookError
(
unittest
.
TestCase
):
def
test_hook
(
self
):
input_data
=
np
.
ones
([
4
,
4
]).
astype
(
'float32'
)
x
=
paddle
.
to_tensor
(
input_data
.
astype
(
np
.
float32
),
stop_gradient
=
False
)
z
=
paddle
.
to_tensor
(
input_data
.
astype
(
np
.
float32
),
stop_gradient
=
False
)
y
=
_C_ops
.
sigmoid
(
x
)
out
=
_C_ops
.
matmul_v2
(
y
,
z
,
'trans_x'
,
False
,
'trans_y'
,
False
)
out
.
_register_void_function_post_hook
(
test_hook
)
y
.
_register_void_function_post_hook
(
test_hook
)
y
.
register_hook
(
grad_hook
)
out
.
backward
()
assert
test_cap
.
list
==
[
1
,
2
,
1
]
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录