Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f3f3d57a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f3f3d57a
编写于
5月 08, 2023
作者:
Y
yangguohao
提交者:
GitHub
5月 08, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dy2St]Following update of register_hook for static mode (#53572)
上级
2f503382
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
108 addition
and
108 deletion
+108
-108
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+0
-2
python/paddle/jit/dy2static/ast_transformer.py
python/paddle/jit/dy2static/ast_transformer.py
+2
-0
python/paddle/jit/dy2static/ast_utils.py
python/paddle/jit/dy2static/ast_utils.py
+0
-84
python/paddle/jit/dy2static/tensorhook_transformer.py
python/paddle/jit/dy2static/tensorhook_transformer.py
+89
-0
python/paddle/jit/dy2static/utils.py
python/paddle/jit/dy2static/utils.py
+11
-14
test/dygraph_to_static/test_tensor_hook.py
test/dygraph_to_static/test_tensor_hook.py
+6
-8
未找到文件。
python/paddle/fluid/framework.py
浏览文件 @
f3f3d57a
...
@@ -1645,8 +1645,6 @@ class Variable(metaclass=VariableMetaClass):
...
@@ -1645,8 +1645,6 @@ class Variable(metaclass=VariableMetaClass):
def
backward_hook_wrapper
(
dy
):
def
backward_hook_wrapper
(
dy
):
"""call the backward hook in ."""
"""call the backward hook in ."""
import
numpy
as
np
return
hook
(
np
.
array
(
dy
))
return
hook
(
np
.
array
(
dy
))
def
forward_hook_wrapper
(
x
):
def
forward_hook_wrapper
(
x
):
...
...
python/paddle/jit/dy2static/ast_transformer.py
浏览文件 @
f3f3d57a
...
@@ -38,6 +38,7 @@ from .loop_transformer import LoopTransformer
...
@@ -38,6 +38,7 @@ from .loop_transformer import LoopTransformer
from
.return_transformer
import
ReturnTransformer
from
.return_transformer
import
ReturnTransformer
from
.static_analysis
import
StaticAnalysisVisitor
from
.static_analysis
import
StaticAnalysisVisitor
from
.tensor_shape_transformer
import
TensorShapeTransformer
from
.tensor_shape_transformer
import
TensorShapeTransformer
from
.tensorhook_transformer
import
RegisterHookTransformer
from
.typehint_transformer
import
TypeHintTransformer
from
.typehint_transformer
import
TypeHintTransformer
from
.utils
import
ast_to_source_code
from
.utils
import
ast_to_source_code
...
@@ -92,6 +93,7 @@ class DygraphToStaticAst(BaseTransformer):
...
@@ -92,6 +93,7 @@ class DygraphToStaticAst(BaseTransformer):
self
.
visit
(
node_wrapper
.
node
)
self
.
visit
(
node_wrapper
.
node
)
transformers
=
[
transformers
=
[
RegisterHookTransformer
,
EarlyReturnTransformer
,
EarlyReturnTransformer
,
BasicApiTransformer
,
# Basic Api
BasicApiTransformer
,
# Basic Api
TensorShapeTransformer
,
# Tensor.shape -> paddle.shape(Tensor)
TensorShapeTransformer
,
# Tensor.shape -> paddle.shape(Tensor)
...
...
python/paddle/jit/dy2static/ast_utils.py
浏览文件 @
f3f3d57a
...
@@ -14,9 +14,6 @@
...
@@ -14,9 +14,6 @@
import
ast
import
ast
import
collections
import
inspect
import
textwrap
import
astor
import
astor
...
@@ -41,84 +38,3 @@ def ast_to_source_code(ast_node):
...
@@ -41,84 +38,3 @@ def ast_to_source_code(ast_node):
source_code
=
astor
.
to_source
(
ast_node
,
pretty_source
=
pretty_source
)
source_code
=
astor
.
to_source
(
ast_node
,
pretty_source
=
pretty_source
)
return
source_code
return
source_code
class
RegisterHookVisitor
(
gast
.
NodeVisitor
):
def
__init__
(
self
,
func_name
):
self
.
register_hook_pos_map
=
collections
.
defaultdict
(
list
)
self
.
assignment_pos_map
=
collections
.
defaultdict
(
list
)
self
.
func_name
=
func_name
def
visit_FunctionDef
(
self
,
func_def
):
# The inner function that has register_hook will not be processed
if
func_def
.
name
!=
self
.
func_name
:
return
register_hook_pos_map
=
self
.
register_hook_pos_map
assignment_pos_map
=
self
.
assignment_pos_map
for
i
in
range
(
len
(
func_def
.
body
)
-
1
,
-
1
,
-
1
):
body
=
func_def
.
body
[
i
]
# Check if the code body contains the register_hook
if
isinstance
(
body
,
ast
.
Expr
):
for
node
in
ast
.
walk
(
body
):
if
(
isinstance
(
node
,
ast
.
Attribute
)
and
node
.
attr
==
'register_hook'
):
# parameter name for register_hook
param_name
=
node
.
value
.
id
register_hook_pos_map
[
param_name
].
append
(
i
)
elif
isinstance
(
body
,
ast
.
Assign
):
for
target
in
body
.
targets
:
assignment_pos_map
[
target
.
id
].
append
(
i
)
# Confirm the order
order_map
=
{}
for
k
,
idx_list
in
register_hook_pos_map
.
items
():
for
idx
in
idx_list
:
if
k
not
in
assignment_pos_map
:
order_map
[
idx
]
=
1
else
:
for
assignment_idx
in
assignment_pos_map
[
k
]:
if
idx
>
assignment_idx
:
order_map
[
idx
]
=
assignment_idx
+
1
break
code_order
=
[
*
range
(
len
(
func_def
.
body
))]
for
k
,
v
in
sorted
(
order_map
.
items
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
):
if
k
==
v
:
continue
code_order
.
remove
(
k
)
code_order
.
insert
(
v
,
k
)
# rearrange the code according to the specified order
new_body
=
[
func_def
.
body
[
i
]
for
i
in
code_order
]
func_def
.
body
=
new_body
def
modify_function_code
(
func
):
"""
Modify the function code for the register hook
"""
func_ast
=
ast
.
parse
(
textwrap
.
dedent
(
inspect
.
getsource
(
func
)))
# check if there is register_hook on code after visit the tree.
check_register_hook
=
next
(
(
node
for
node
in
ast
.
walk
(
func_ast
)
if
isinstance
(
node
,
ast
.
Attribute
)
and
node
.
attr
==
'register_hook'
),
None
,
)
if
check_register_hook
is
None
:
return
visitor
=
RegisterHookVisitor
(
func
.
__name__
)
visitor
.
visit
(
func_ast
)
def
pretty_source
(
source
):
return
''
.
join
(
source
)
new_code
=
astor
.
to_source
(
func_ast
,
pretty_source
=
pretty_source
)
return
new_code
python/paddle/jit/dy2static/tensorhook_transformer.py
0 → 100644
浏览文件 @
f3f3d57a
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
collections
from
paddle.utils
import
gast
from
.base_transformer
import
BaseTransformer
class
RegisterHookTransformer
(
BaseTransformer
):
def
__init__
(
self
,
wrapper_root
):
self
.
register_hook_pos_map
=
collections
.
defaultdict
(
list
)
self
.
assignment_pos_map
=
collections
.
defaultdict
(
list
)
self
.
root
=
wrapper_root
.
node
def
transform
(
self
):
"""
Main function to transform AST.
"""
self
.
visit
(
self
.
root
)
def
visit_FunctionDef
(
self
,
func_def
):
# The inner function that has register_hook will not be processed
check_register_hook
=
next
(
(
node
for
node
in
gast
.
walk
(
func_def
)
if
isinstance
(
node
,
gast
.
Attribute
)
and
node
.
attr
==
'register_hook'
),
None
,
)
if
check_register_hook
is
None
:
return
func_def
register_hook_pos_map
=
self
.
register_hook_pos_map
assignment_pos_map
=
self
.
assignment_pos_map
for
i
in
range
(
len
(
func_def
.
body
)
-
1
,
-
1
,
-
1
):
body
=
func_def
.
body
[
i
]
# Check if the code body contains the register_hook
if
isinstance
(
body
,
gast
.
Expr
):
for
node
in
gast
.
walk
(
body
):
if
(
isinstance
(
node
,
gast
.
Attribute
)
and
node
.
attr
==
'register_hook'
):
# parameter name for register_hook
param_name
=
node
.
value
.
id
register_hook_pos_map
[
param_name
].
append
(
i
)
elif
isinstance
(
body
,
gast
.
Assign
):
for
target
in
body
.
targets
:
assignment_pos_map
[
target
.
id
].
append
(
i
)
# Confirm the order
order_map
=
{}
for
k
,
idx_list
in
register_hook_pos_map
.
items
():
for
idx
in
idx_list
:
if
k
not
in
assignment_pos_map
:
order_map
[
idx
]
=
1
else
:
for
assignment_idx
in
assignment_pos_map
[
k
]:
if
idx
>
assignment_idx
:
order_map
[
idx
]
=
assignment_idx
+
1
break
code_order
=
[
*
range
(
len
(
func_def
.
body
))]
for
k
,
v
in
sorted
(
order_map
.
items
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
):
if
k
==
v
:
continue
code_order
.
remove
(
k
)
code_order
.
insert
(
v
,
k
)
# rearrange the code according to the specified order
new_body
=
[
func_def
.
body
[
i
]
for
i
in
code_order
]
func_def
.
body
=
new_body
return
func_def
python/paddle/jit/dy2static/utils.py
浏览文件 @
f3f3d57a
...
@@ -38,7 +38,7 @@ from paddle.fluid.layer_helper import LayerHelper
...
@@ -38,7 +38,7 @@ from paddle.fluid.layer_helper import LayerHelper
from
paddle.fluid.wrapped_decorator
import
signature_safe_contextmanager
from
paddle.fluid.wrapped_decorator
import
signature_safe_contextmanager
from
paddle.utils
import
gast
from
paddle.utils
import
gast
from
.ast_utils
import
ast_to_source_code
,
modify_function_code
from
.ast_utils
import
ast_to_source_code
from
.static_analysis
import
StaticAnalysisVisitor
from
.static_analysis
import
StaticAnalysisVisitor
from
.utils_helper
import
DYGRAPH_MODULE_PREFIX
# noqa: F401
from
.utils_helper
import
DYGRAPH_MODULE_PREFIX
# noqa: F401
from
.utils_helper
import
DYGRAPH_TO_STATIC_MODULE_PREFIX
# noqa: F401
from
.utils_helper
import
DYGRAPH_TO_STATIC_MODULE_PREFIX
# noqa: F401
...
@@ -643,20 +643,17 @@ def func_to_source_code(function, dedent=True):
...
@@ -643,20 +643,17 @@ def func_to_source_code(function, dedent=True):
type
(
function
).
__name__
type
(
function
).
__name__
)
)
)
)
# return modified function source code if there is 'register_hook', otherwise return None
source_code
=
modify_function_code
(
function
)
if
source_code
is
None
:
source_code_list
,
_
=
inspect
.
getsourcelines
(
function
)
# Replace comments with blank lines so that error messages are not misplaced
source_code_list
=
[
line
if
not
line
.
lstrip
().
startswith
(
'#'
)
else
'
\n
'
for
line
in
source_code_list
]
source_code
=
''
.
join
(
source_code_list
)
if
dedent
:
source_code_list
,
_
=
inspect
.
getsourcelines
(
function
)
source_code
=
textwrap
.
dedent
(
source_code
)
# Replace comments with blank lines so that error messages are not misplaced
source_code_list
=
[
line
if
not
line
.
lstrip
().
startswith
(
'#'
)
else
'
\n
'
for
line
in
source_code_list
]
source_code
=
''
.
join
(
source_code_list
)
if
dedent
:
source_code
=
textwrap
.
dedent
(
source_code
)
return
source_code
return
source_code
...
...
test/dygraph_to_static/test_tensor_hook.py
浏览文件 @
f3f3d57a
...
@@ -45,7 +45,7 @@ class TestStaticAnalysis(unittest.TestCase):
...
@@ -45,7 +45,7 @@ class TestStaticAnalysis(unittest.TestCase):
jit_f
=
to_static
(
f
)
jit_f
=
to_static
(
f
)
loss
=
jit_f
(
x_jit
)
loss
=
jit_f
(
x_jit
)
loss
.
backward
()
loss
.
backward
()
self
.
assertTrue
(
np
.
allclose
(
x
.
grad
.
numpy
(),
x_jit
.
grad
.
numpy
()
))
np
.
testing
.
assert_allclose
(
x
.
grad
.
numpy
(),
x_jit
.
grad
.
numpy
(
))
def
test_hook_for_reassignment_parameter
(
self
):
def
test_hook_for_reassignment_parameter
(
self
):
def
f
(
x
):
def
f
(
x
):
...
@@ -68,7 +68,7 @@ class TestStaticAnalysis(unittest.TestCase):
...
@@ -68,7 +68,7 @@ class TestStaticAnalysis(unittest.TestCase):
jit_f
=
to_static
(
f
)
jit_f
=
to_static
(
f
)
loss
=
jit_f
(
x_jit
)
loss
=
jit_f
(
x_jit
)
loss
.
backward
()
loss
.
backward
()
self
.
assertTrue
(
np
.
allclose
(
x
.
grad
.
numpy
(),
x_jit
.
grad
.
numpy
()
))
np
.
testing
.
assert_allclose
(
x
.
grad
.
numpy
(),
x_jit
.
grad
.
numpy
(
))
def
test_hook_for_repeat_register
(
self
):
def
test_hook_for_repeat_register
(
self
):
def
f
(
x
):
def
f
(
x
):
...
@@ -91,7 +91,7 @@ class TestStaticAnalysis(unittest.TestCase):
...
@@ -91,7 +91,7 @@ class TestStaticAnalysis(unittest.TestCase):
jit_f
=
to_static
(
f
)
jit_f
=
to_static
(
f
)
loss
=
jit_f
(
x_jit
)
loss
=
jit_f
(
x_jit
)
loss
.
backward
()
loss
.
backward
()
self
.
assertTrue
(
np
.
allclose
(
x
.
grad
.
numpy
(),
x_jit
.
grad
.
numpy
()
))
np
.
testing
.
assert_allclose
(
x
.
grad
.
numpy
(),
x_jit
.
grad
.
numpy
(
))
def
test_hook_in_init_for_layer
(
self
):
def
test_hook_in_init_for_layer
(
self
):
def
hook
(
grad
):
def
hook
(
grad
):
...
@@ -120,11 +120,9 @@ class TestStaticAnalysis(unittest.TestCase):
...
@@ -120,11 +120,9 @@ class TestStaticAnalysis(unittest.TestCase):
loss_jit
=
jit_layer
(
image_jit
)
loss_jit
=
jit_layer
(
image_jit
)
loss_jit
.
backward
()
loss_jit
.
backward
()
loss
.
backward
()
loss
.
backward
()
self
.
assertTrue
(
np
.
testing
.
assert_allclose
(
np
.
allclose
(
layer
.
parameters
()[
0
].
grad
.
numpy
(),
layer
.
parameters
()[
0
].
grad
.
numpy
(),
jit_layer
.
parameters
()[
0
].
grad
.
numpy
(),
jit_layer
.
parameters
()[
0
].
grad
.
numpy
(),
)
)
)
# def test_hook_in_forward_for_layer(self):
# def test_hook_in_forward_for_layer(self):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录