Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
db30aa1d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2322
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
db30aa1d
编写于
4月 27, 2023
作者:
Y
yangguohao
提交者:
GitHub
4月 27, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【Hackathon No.91】register_hook for static mode (#52948)
上级
cf6cbc34
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
302 addition
and
19 deletion
+302
-19
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+19
-2
python/paddle/fluid/tests/unittests/test_tensor_register_hook.py
...paddle/fluid/tests/unittests/test_tensor_register_hook.py
+15
-7
python/paddle/jit/dy2static/ast_utils.py
python/paddle/jit/dy2static/ast_utils.py
+84
-0
python/paddle/jit/dy2static/utils.py
python/paddle/jit/dy2static/utils.py
+15
-10
test/dygraph_to_static/test_tensor_hook.py
test/dygraph_to_static/test_tensor_hook.py
+169
-0
未找到文件。
python/paddle/fluid/framework.py
浏览文件 @
db30aa1d
...
@@ -1640,9 +1640,26 @@ class Variable(metaclass=VariableMetaClass):
...
@@ -1640,9 +1640,26 @@ class Variable(metaclass=VariableMetaClass):
"""
"""
pass
pass
@
fake_interface_only
def
register_hook
(
self
,
hook
):
def
register_hook
(
self
,
hook
):
pass
import
paddle
def
backward_hook_wrapper
(
dy
):
"""call the backward hook in ."""
import
numpy
as
np
return
hook
(
np
.
array
(
dy
))
def
forward_hook_wrapper
(
x
):
"""do nothing but return a new variable."""
return
x
paddle
.
static
.
py_func
(
func
=
forward_hook_wrapper
,
x
=
self
,
out
=
self
,
backward_func
=
backward_hook_wrapper
,
skip_vars_in_backward_input
=
[
self
],
)
def
__str__
(
self
):
def
__str__
(
self
):
return
self
.
_to_readable_code
()
return
self
.
_to_readable_code
()
...
...
python/paddle/fluid/tests/unittests/test_tensor_register_hook.py
浏览文件 @
db30aa1d
...
@@ -45,9 +45,10 @@ class SimpleNetForStatic(nn.Layer):
...
@@ -45,9 +45,10 @@ class SimpleNetForStatic(nn.Layer):
self
.
linear1
=
nn
.
Linear
(
in_size
,
in_size
)
self
.
linear1
=
nn
.
Linear
(
in_size
,
in_size
)
self
.
linear2
=
nn
.
Linear
(
in_size
,
out_size
)
self
.
linear2
=
nn
.
Linear
(
in_size
,
out_size
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
hook
=
False
):
ret1
=
self
.
linear1
(
x
)
ret1
=
self
.
linear1
(
x
)
ret1
.
register_hook
(
lambda
grad
:
grad
*
2
)
if
hook
:
ret1
.
register_hook
(
lambda
grad
:
grad
*
2
)
ret2
=
self
.
linear2
(
ret1
)
ret2
=
self
.
linear2
(
ret1
)
out
=
paddle
.
mean
(
ret2
,
axis
=-
1
)
out
=
paddle
.
mean
(
ret2
,
axis
=-
1
)
...
@@ -512,8 +513,7 @@ class TestTensorRegisterHook(unittest.TestCase):
...
@@ -512,8 +513,7 @@ class TestTensorRegisterHook(unittest.TestCase):
)
)
net
=
SimpleNetForStatic
(
self
.
in_size
,
self
.
out_size
)
net
=
SimpleNetForStatic
(
self
.
in_size
,
self
.
out_size
)
with
self
.
assertRaises
(
AssertionError
):
out
=
net
(
x
)
out
=
net
(
x
)
paddle
.
disable_static
()
paddle
.
disable_static
()
...
@@ -527,9 +527,17 @@ class TestTensorRegisterHook(unittest.TestCase):
...
@@ -527,9 +527,17 @@ class TestTensorRegisterHook(unittest.TestCase):
'float32'
'float32'
)
)
data_t
=
paddle
.
to_tensor
(
data
)
data_t
=
paddle
.
to_tensor
(
data
)
data_t2
=
paddle
.
to_tensor
(
data
)
with
self
.
assertRaises
(
AssertionError
):
data_t
.
stop_gradient
=
False
out
=
jit_net
(
data_t
)
data_t2
.
stop_gradient
=
False
out1
=
jit_net
(
data_t
)
out2
=
jit_net
(
data_t2
,
True
)
out1
.
backward
()
out2
.
backward
()
np
.
testing
.
assert_array_equal
(
2
*
data_t
.
grad
.
numpy
(),
data_t2
.
grad
.
numpy
()
)
HOOK_INIT_VALUE
=
10
HOOK_INIT_VALUE
=
10
...
...
python/paddle/jit/dy2static/ast_utils.py
浏览文件 @
db30aa1d
...
@@ -14,6 +14,9 @@
...
@@ -14,6 +14,9 @@
import
ast
import
ast
import
collections
import
inspect
import
textwrap
import
astor
import
astor
...
@@ -38,3 +41,84 @@ def ast_to_source_code(ast_node):
...
@@ -38,3 +41,84 @@ def ast_to_source_code(ast_node):
source_code
=
astor
.
to_source
(
ast_node
,
pretty_source
=
pretty_source
)
source_code
=
astor
.
to_source
(
ast_node
,
pretty_source
=
pretty_source
)
return
source_code
return
source_code
class
RegisterHookVisitor
(
gast
.
NodeVisitor
):
def
__init__
(
self
,
func_name
):
self
.
register_hook_pos_map
=
collections
.
defaultdict
(
list
)
self
.
assignment_pos_map
=
collections
.
defaultdict
(
list
)
self
.
func_name
=
func_name
def
visit_FunctionDef
(
self
,
func_def
):
# The inner function that has register_hook will not be processed
if
func_def
.
name
!=
self
.
func_name
:
return
register_hook_pos_map
=
self
.
register_hook_pos_map
assignment_pos_map
=
self
.
assignment_pos_map
for
i
in
range
(
len
(
func_def
.
body
)
-
1
,
-
1
,
-
1
):
body
=
func_def
.
body
[
i
]
# Check if the code body contains the register_hook
if
isinstance
(
body
,
ast
.
Expr
):
for
node
in
ast
.
walk
(
body
):
if
(
isinstance
(
node
,
ast
.
Attribute
)
and
node
.
attr
==
'register_hook'
):
# parameter name for register_hook
param_name
=
node
.
value
.
id
register_hook_pos_map
[
param_name
].
append
(
i
)
elif
isinstance
(
body
,
ast
.
Assign
):
for
target
in
body
.
targets
:
assignment_pos_map
[
target
.
id
].
append
(
i
)
# Confirm the order
order_map
=
{}
for
k
,
idx_list
in
register_hook_pos_map
.
items
():
for
idx
in
idx_list
:
if
k
not
in
assignment_pos_map
:
order_map
[
idx
]
=
1
else
:
for
assignment_idx
in
assignment_pos_map
[
k
]:
if
idx
>
assignment_idx
:
order_map
[
idx
]
=
assignment_idx
+
1
break
code_order
=
[
*
range
(
len
(
func_def
.
body
))]
for
k
,
v
in
sorted
(
order_map
.
items
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
):
if
k
==
v
:
continue
code_order
.
remove
(
k
)
code_order
.
insert
(
v
,
k
)
# rearrange the code according to the specified order
new_body
=
[
func_def
.
body
[
i
]
for
i
in
code_order
]
func_def
.
body
=
new_body
def
modify_function_code
(
func
):
"""
Modify the function code for the register hook
"""
func_ast
=
ast
.
parse
(
textwrap
.
dedent
(
inspect
.
getsource
(
func
)))
# check if there is register_hook on code after visit the tree.
check_register_hook
=
next
(
(
node
for
node
in
ast
.
walk
(
func_ast
)
if
isinstance
(
node
,
ast
.
Attribute
)
and
node
.
attr
==
'register_hook'
),
None
,
)
if
check_register_hook
is
None
:
return
visitor
=
RegisterHookVisitor
(
func
.
__name__
)
visitor
.
visit
(
func_ast
)
def
pretty_source
(
source
):
return
''
.
join
(
source
)
new_code
=
astor
.
to_source
(
func_ast
,
pretty_source
=
pretty_source
)
return
new_code
python/paddle/jit/dy2static/utils.py
浏览文件 @
db30aa1d
...
@@ -38,7 +38,7 @@ from paddle.fluid.layer_helper import LayerHelper
...
@@ -38,7 +38,7 @@ from paddle.fluid.layer_helper import LayerHelper
from
paddle.fluid.wrapped_decorator
import
signature_safe_contextmanager
from
paddle.fluid.wrapped_decorator
import
signature_safe_contextmanager
from
paddle.utils
import
gast
from
paddle.utils
import
gast
from
.ast_utils
import
ast_to_source_code
from
.ast_utils
import
ast_to_source_code
,
modify_function_code
from
.static_analysis
import
StaticAnalysisVisitor
from
.static_analysis
import
StaticAnalysisVisitor
from
.utils_helper
import
DYGRAPH_MODULE_PREFIX
# noqa: F401
from
.utils_helper
import
DYGRAPH_MODULE_PREFIX
# noqa: F401
from
.utils_helper
import
DYGRAPH_TO_STATIC_MODULE_PREFIX
# noqa: F401
from
.utils_helper
import
DYGRAPH_TO_STATIC_MODULE_PREFIX
# noqa: F401
...
@@ -643,15 +643,20 @@ def func_to_source_code(function, dedent=True):
...
@@ -643,15 +643,20 @@ def func_to_source_code(function, dedent=True):
type
(
function
).
__name__
type
(
function
).
__name__
)
)
)
)
source_code_list
,
_
=
inspect
.
getsourcelines
(
function
)
# return modified function source code if there is 'register_hook', otherwise return None
# Replace comments with blank lines so that error messages are not misplaced
source_code
=
modify_function_code
(
function
)
source_code_list
=
[
line
if
not
line
.
lstrip
().
startswith
(
'#'
)
else
'
\n
'
if
source_code
is
None
:
for
line
in
source_code_list
source_code_list
,
_
=
inspect
.
getsourcelines
(
function
)
]
# Replace comments with blank lines so that error messages are not misplaced
source_code
=
''
.
join
(
source_code_list
)
source_code_list
=
[
if
dedent
:
line
if
not
line
.
lstrip
().
startswith
(
'#'
)
else
'
\n
'
source_code
=
textwrap
.
dedent
(
source_code
)
for
line
in
source_code_list
]
source_code
=
''
.
join
(
source_code_list
)
if
dedent
:
source_code
=
textwrap
.
dedent
(
source_code
)
return
source_code
return
source_code
...
...
test/dygraph_to_static/test_tensor_hook.py
0 → 100644
浏览文件 @
db30aa1d
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
paddle
from
paddle
import
nn
from
paddle.jit
import
to_static
class
TestStaticAnalysis
(
unittest
.
TestCase
):
def
test_hook_for_different_parameter
(
self
):
def
f
(
x
):
def
h
(
g
):
return
2
*
g
y
=
x
+
4
f
=
y
+
x
z
=
f
**
2
y
.
register_hook
(
h
)
f
.
register_hook
(
h
)
x
.
register_hook
(
h
)
return
z
x
=
paddle
.
to_tensor
([
2.0
])
x
.
stop_gradient
=
False
loss
=
f
(
x
)
loss
.
backward
()
x_jit
=
paddle
.
to_tensor
([
2.0
])
x_jit
.
stop_gradient
=
False
jit_f
=
to_static
(
f
)
loss
=
jit_f
(
x_jit
)
loss
.
backward
()
self
.
assertTrue
(
np
.
allclose
(
x
.
grad
.
numpy
(),
x_jit
.
grad
.
numpy
()))
def
test_hook_for_reassignment_parameter
(
self
):
def
f
(
x
):
def
h
(
g
):
return
2
*
g
y
=
x
+
4
x
=
y
*
5
z
=
x
**
2
x
.
register_hook
(
h
)
return
z
x
=
paddle
.
to_tensor
([
2.0
])
x
.
stop_gradient
=
False
loss
=
f
(
x
)
loss
.
backward
()
x_jit
=
paddle
.
to_tensor
([
2.0
])
x_jit
.
stop_gradient
=
False
jit_f
=
to_static
(
f
)
loss
=
jit_f
(
x_jit
)
loss
.
backward
()
self
.
assertTrue
(
np
.
allclose
(
x
.
grad
.
numpy
(),
x_jit
.
grad
.
numpy
()))
def
test_hook_for_repeat_register
(
self
):
def
f
(
x
):
def
h
(
g
):
return
2
*
g
y
=
x
+
4
z
=
y
**
2
x
.
register_hook
(
h
)
x
.
register_hook
(
h
)
return
z
x
=
paddle
.
to_tensor
([
2.0
])
x
.
stop_gradient
=
False
loss
=
f
(
x
)
loss
.
backward
()
x_jit
=
paddle
.
to_tensor
([
2.0
])
x_jit
.
stop_gradient
=
False
jit_f
=
to_static
(
f
)
loss
=
jit_f
(
x_jit
)
loss
.
backward
()
self
.
assertTrue
(
np
.
allclose
(
x
.
grad
.
numpy
(),
x_jit
.
grad
.
numpy
()))
def
test_hook_in_init_for_layer
(
self
):
def
hook
(
grad
):
return
grad
*
2
IMAGE_SIZE
=
784
CLASS_NUM
=
10
class
LinearNet
(
nn
.
Layer
):
def
__init__
(
self
):
super
().
__init__
()
self
.
_linear
=
nn
.
Linear
(
IMAGE_SIZE
,
CLASS_NUM
)
# register_hook in init
self
.
_linear
.
parameters
()[
0
].
register_hook
(
hook
)
def
forward
(
self
,
x
):
return
self
.
_linear
(
x
)
# create network
layer
=
LinearNet
()
jit_layer
=
to_static
(
LinearNet
())
data
=
np
.
random
.
random
([
IMAGE_SIZE
]).
astype
(
'float32'
)
image
=
paddle
.
to_tensor
(
data
)
image_jit
=
paddle
.
to_tensor
(
data
)
loss
=
layer
(
image
)
loss_jit
=
jit_layer
(
image_jit
)
loss_jit
.
backward
()
loss
.
backward
()
self
.
assertTrue
(
np
.
allclose
(
layer
.
parameters
()[
0
].
grad
.
numpy
(),
jit_layer
.
parameters
()[
0
].
grad
.
numpy
(),
)
)
# def test_hook_in_forward_for_layer(self):
#
# IMAGE_SIZE = 784
# CLASS_NUM = 10
#
# class LinearNet(nn.Layer):
# def __init__(self):
# super().__init__()
# self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
#
# def forward(self, x):
# def hook(grad):
# return grad * 2
#
# res = self._linear(x)
#
# # register_hook in forward
# self._linear.parameters()[0].register_hook(hook)
# return res
#
# # create network
# layer = LinearNet()
# jit_layer = to_static(LinearNet())
# data = np.random.random([IMAGE_SIZE]).astype('float32')
# image = paddle.to_tensor(data)
# image_jit = paddle.to_tensor(data)
# loss = layer(image)
# loss_jit = jit_layer(image_jit)
# loss_jit.backward()
# loss.backward()
# self.assertTrue(
# np.allclose(
# layer.parameters()[0].grad.numpy(),
# jit_layer.parameters()[0].grad.numpy(),
# )
# )
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录