未验证 提交 f3f3d57a 编写于 作者: Y yangguohao 提交者: GitHub

[Dy2St]Following update of register_hook for static mode (#53572)

上级 2f503382
...@@ -1645,8 +1645,6 @@ class Variable(metaclass=VariableMetaClass): ...@@ -1645,8 +1645,6 @@ class Variable(metaclass=VariableMetaClass):
def backward_hook_wrapper(dy): def backward_hook_wrapper(dy):
"""call the backward hook in .""" """call the backward hook in ."""
import numpy as np
return hook(np.array(dy)) return hook(np.array(dy))
def forward_hook_wrapper(x): def forward_hook_wrapper(x):
......
...@@ -38,6 +38,7 @@ from .loop_transformer import LoopTransformer ...@@ -38,6 +38,7 @@ from .loop_transformer import LoopTransformer
from .return_transformer import ReturnTransformer from .return_transformer import ReturnTransformer
from .static_analysis import StaticAnalysisVisitor from .static_analysis import StaticAnalysisVisitor
from .tensor_shape_transformer import TensorShapeTransformer from .tensor_shape_transformer import TensorShapeTransformer
from .tensorhook_transformer import RegisterHookTransformer
from .typehint_transformer import TypeHintTransformer from .typehint_transformer import TypeHintTransformer
from .utils import ast_to_source_code from .utils import ast_to_source_code
...@@ -92,6 +93,7 @@ class DygraphToStaticAst(BaseTransformer): ...@@ -92,6 +93,7 @@ class DygraphToStaticAst(BaseTransformer):
self.visit(node_wrapper.node) self.visit(node_wrapper.node)
transformers = [ transformers = [
RegisterHookTransformer,
EarlyReturnTransformer, EarlyReturnTransformer,
BasicApiTransformer, # Basic Api BasicApiTransformer, # Basic Api
TensorShapeTransformer, # Tensor.shape -> paddle.shape(Tensor) TensorShapeTransformer, # Tensor.shape -> paddle.shape(Tensor)
......
...@@ -14,9 +14,6 @@ ...@@ -14,9 +14,6 @@
import ast import ast
import collections
import inspect
import textwrap
import astor import astor
...@@ -41,84 +38,3 @@ def ast_to_source_code(ast_node): ...@@ -41,84 +38,3 @@ def ast_to_source_code(ast_node):
source_code = astor.to_source(ast_node, pretty_source=pretty_source) source_code = astor.to_source(ast_node, pretty_source=pretty_source)
return source_code return source_code
class RegisterHookVisitor(gast.NodeVisitor):
def __init__(self, func_name):
self.register_hook_pos_map = collections.defaultdict(list)
self.assignment_pos_map = collections.defaultdict(list)
self.func_name = func_name
def visit_FunctionDef(self, func_def):
# The inner function that has register_hook will not be processed
if func_def.name != self.func_name:
return
register_hook_pos_map = self.register_hook_pos_map
assignment_pos_map = self.assignment_pos_map
for i in range(len(func_def.body) - 1, -1, -1):
body = func_def.body[i]
# Check if the code body contains the register_hook
if isinstance(body, ast.Expr):
for node in ast.walk(body):
if (
isinstance(node, ast.Attribute)
and node.attr == 'register_hook'
):
# parameter name for register_hook
param_name = node.value.id
register_hook_pos_map[param_name].append(i)
elif isinstance(body, ast.Assign):
for target in body.targets:
assignment_pos_map[target.id].append(i)
# Confirm the order
order_map = {}
for k, idx_list in register_hook_pos_map.items():
for idx in idx_list:
if k not in assignment_pos_map:
order_map[idx] = 1
else:
for assignment_idx in assignment_pos_map[k]:
if idx > assignment_idx:
order_map[idx] = assignment_idx + 1
break
code_order = [*range(len(func_def.body))]
for k, v in sorted(order_map.items(), key=lambda x: x[1], reverse=True):
if k == v:
continue
code_order.remove(k)
code_order.insert(v, k)
# rearrange the code according to the specified order
new_body = [func_def.body[i] for i in code_order]
func_def.body = new_body
def modify_function_code(func):
"""
Modify the function code for the register hook
"""
func_ast = ast.parse(textwrap.dedent(inspect.getsource(func)))
# check if there is register_hook on code after visit the tree.
check_register_hook = next(
(
node
for node in ast.walk(func_ast)
if isinstance(node, ast.Attribute) and node.attr == 'register_hook'
),
None,
)
if check_register_hook is None:
return
visitor = RegisterHookVisitor(func.__name__)
visitor.visit(func_ast)
def pretty_source(source):
return ''.join(source)
new_code = astor.to_source(func_ast, pretty_source=pretty_source)
return new_code
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from paddle.utils import gast
from .base_transformer import BaseTransformer
class RegisterHookTransformer(BaseTransformer):
def __init__(self, wrapper_root):
self.register_hook_pos_map = collections.defaultdict(list)
self.assignment_pos_map = collections.defaultdict(list)
self.root = wrapper_root.node
def transform(self):
"""
Main function to transform AST.
"""
self.visit(self.root)
def visit_FunctionDef(self, func_def):
# The inner function that has register_hook will not be processed
check_register_hook = next(
(
node
for node in gast.walk(func_def)
if isinstance(node, gast.Attribute)
and node.attr == 'register_hook'
),
None,
)
if check_register_hook is None:
return func_def
register_hook_pos_map = self.register_hook_pos_map
assignment_pos_map = self.assignment_pos_map
for i in range(len(func_def.body) - 1, -1, -1):
body = func_def.body[i]
# Check if the code body contains the register_hook
if isinstance(body, gast.Expr):
for node in gast.walk(body):
if (
isinstance(node, gast.Attribute)
and node.attr == 'register_hook'
):
# parameter name for register_hook
param_name = node.value.id
register_hook_pos_map[param_name].append(i)
elif isinstance(body, gast.Assign):
for target in body.targets:
assignment_pos_map[target.id].append(i)
# Confirm the order
order_map = {}
for k, idx_list in register_hook_pos_map.items():
for idx in idx_list:
if k not in assignment_pos_map:
order_map[idx] = 1
else:
for assignment_idx in assignment_pos_map[k]:
if idx > assignment_idx:
order_map[idx] = assignment_idx + 1
break
code_order = [*range(len(func_def.body))]
for k, v in sorted(order_map.items(), key=lambda x: x[1], reverse=True):
if k == v:
continue
code_order.remove(k)
code_order.insert(v, k)
# rearrange the code according to the specified order
new_body = [func_def.body[i] for i in code_order]
func_def.body = new_body
return func_def
...@@ -38,7 +38,7 @@ from paddle.fluid.layer_helper import LayerHelper ...@@ -38,7 +38,7 @@ from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from paddle.utils import gast from paddle.utils import gast
from .ast_utils import ast_to_source_code, modify_function_code from .ast_utils import ast_to_source_code
from .static_analysis import StaticAnalysisVisitor from .static_analysis import StaticAnalysisVisitor
from .utils_helper import DYGRAPH_MODULE_PREFIX # noqa: F401 from .utils_helper import DYGRAPH_MODULE_PREFIX # noqa: F401
from .utils_helper import DYGRAPH_TO_STATIC_MODULE_PREFIX # noqa: F401 from .utils_helper import DYGRAPH_TO_STATIC_MODULE_PREFIX # noqa: F401
...@@ -643,20 +643,17 @@ def func_to_source_code(function, dedent=True): ...@@ -643,20 +643,17 @@ def func_to_source_code(function, dedent=True):
type(function).__name__ type(function).__name__
) )
) )
# return modified function source code if there is 'register_hook', otherwise return None
source_code = modify_function_code(function)
if source_code is None:
source_code_list, _ = inspect.getsourcelines(function)
# Replace comments with blank lines so that error messages are not misplaced
source_code_list = [
line if not line.lstrip().startswith('#') else '\n'
for line in source_code_list
]
source_code = ''.join(source_code_list)
if dedent: source_code_list, _ = inspect.getsourcelines(function)
source_code = textwrap.dedent(source_code) # Replace comments with blank lines so that error messages are not misplaced
source_code_list = [
line if not line.lstrip().startswith('#') else '\n'
for line in source_code_list
]
source_code = ''.join(source_code_list)
if dedent:
source_code = textwrap.dedent(source_code)
return source_code return source_code
......
...@@ -45,7 +45,7 @@ class TestStaticAnalysis(unittest.TestCase): ...@@ -45,7 +45,7 @@ class TestStaticAnalysis(unittest.TestCase):
jit_f = to_static(f) jit_f = to_static(f)
loss = jit_f(x_jit) loss = jit_f(x_jit)
loss.backward() loss.backward()
self.assertTrue(np.allclose(x.grad.numpy(), x_jit.grad.numpy())) np.testing.assert_allclose(x.grad.numpy(), x_jit.grad.numpy())
def test_hook_for_reassignment_parameter(self): def test_hook_for_reassignment_parameter(self):
def f(x): def f(x):
...@@ -68,7 +68,7 @@ class TestStaticAnalysis(unittest.TestCase): ...@@ -68,7 +68,7 @@ class TestStaticAnalysis(unittest.TestCase):
jit_f = to_static(f) jit_f = to_static(f)
loss = jit_f(x_jit) loss = jit_f(x_jit)
loss.backward() loss.backward()
self.assertTrue(np.allclose(x.grad.numpy(), x_jit.grad.numpy())) np.testing.assert_allclose(x.grad.numpy(), x_jit.grad.numpy())
def test_hook_for_repeat_register(self): def test_hook_for_repeat_register(self):
def f(x): def f(x):
...@@ -91,7 +91,7 @@ class TestStaticAnalysis(unittest.TestCase): ...@@ -91,7 +91,7 @@ class TestStaticAnalysis(unittest.TestCase):
jit_f = to_static(f) jit_f = to_static(f)
loss = jit_f(x_jit) loss = jit_f(x_jit)
loss.backward() loss.backward()
self.assertTrue(np.allclose(x.grad.numpy(), x_jit.grad.numpy())) np.testing.assert_allclose(x.grad.numpy(), x_jit.grad.numpy())
def test_hook_in_init_for_layer(self): def test_hook_in_init_for_layer(self):
def hook(grad): def hook(grad):
...@@ -120,11 +120,9 @@ class TestStaticAnalysis(unittest.TestCase): ...@@ -120,11 +120,9 @@ class TestStaticAnalysis(unittest.TestCase):
loss_jit = jit_layer(image_jit) loss_jit = jit_layer(image_jit)
loss_jit.backward() loss_jit.backward()
loss.backward() loss.backward()
self.assertTrue( np.testing.assert_allclose(
np.allclose( layer.parameters()[0].grad.numpy(),
layer.parameters()[0].grad.numpy(), jit_layer.parameters()[0].grad.numpy(),
jit_layer.parameters()[0].grad.numpy(),
)
) )
# def test_hook_in_forward_for_layer(self): # def test_hook_in_forward_for_layer(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册