未验证 提交 f3f3d57a 编写于 作者: Y yangguohao 提交者: GitHub

[Dy2St]Following update of register_hook for static mode (#53572)

上级 2f503382
......@@ -1645,8 +1645,6 @@ class Variable(metaclass=VariableMetaClass):
def backward_hook_wrapper(dy):
"""call the backward hook in ."""
import numpy as np
return hook(np.array(dy))
def forward_hook_wrapper(x):
......
......@@ -38,6 +38,7 @@ from .loop_transformer import LoopTransformer
from .return_transformer import ReturnTransformer
from .static_analysis import StaticAnalysisVisitor
from .tensor_shape_transformer import TensorShapeTransformer
from .tensorhook_transformer import RegisterHookTransformer
from .typehint_transformer import TypeHintTransformer
from .utils import ast_to_source_code
......@@ -92,6 +93,7 @@ class DygraphToStaticAst(BaseTransformer):
self.visit(node_wrapper.node)
transformers = [
RegisterHookTransformer,
EarlyReturnTransformer,
BasicApiTransformer, # Basic Api
TensorShapeTransformer, # Tensor.shape -> paddle.shape(Tensor)
......
......@@ -14,9 +14,6 @@
import ast
import collections
import inspect
import textwrap
import astor
......@@ -41,84 +38,3 @@ def ast_to_source_code(ast_node):
source_code = astor.to_source(ast_node, pretty_source=pretty_source)
return source_code
class RegisterHookVisitor(gast.NodeVisitor):
def __init__(self, func_name):
self.register_hook_pos_map = collections.defaultdict(list)
self.assignment_pos_map = collections.defaultdict(list)
self.func_name = func_name
def visit_FunctionDef(self, func_def):
# The inner function that has register_hook will not be processed
if func_def.name != self.func_name:
return
register_hook_pos_map = self.register_hook_pos_map
assignment_pos_map = self.assignment_pos_map
for i in range(len(func_def.body) - 1, -1, -1):
body = func_def.body[i]
# Check if the code body contains the register_hook
if isinstance(body, ast.Expr):
for node in ast.walk(body):
if (
isinstance(node, ast.Attribute)
and node.attr == 'register_hook'
):
# parameter name for register_hook
param_name = node.value.id
register_hook_pos_map[param_name].append(i)
elif isinstance(body, ast.Assign):
for target in body.targets:
assignment_pos_map[target.id].append(i)
# Confirm the order
order_map = {}
for k, idx_list in register_hook_pos_map.items():
for idx in idx_list:
if k not in assignment_pos_map:
order_map[idx] = 1
else:
for assignment_idx in assignment_pos_map[k]:
if idx > assignment_idx:
order_map[idx] = assignment_idx + 1
break
code_order = [*range(len(func_def.body))]
for k, v in sorted(order_map.items(), key=lambda x: x[1], reverse=True):
if k == v:
continue
code_order.remove(k)
code_order.insert(v, k)
# rearrange the code according to the specified order
new_body = [func_def.body[i] for i in code_order]
func_def.body = new_body
def modify_function_code(func):
"""
Modify the function code for the register hook
"""
func_ast = ast.parse(textwrap.dedent(inspect.getsource(func)))
# check if there is register_hook on code after visit the tree.
check_register_hook = next(
(
node
for node in ast.walk(func_ast)
if isinstance(node, ast.Attribute) and node.attr == 'register_hook'
),
None,
)
if check_register_hook is None:
return
visitor = RegisterHookVisitor(func.__name__)
visitor.visit(func_ast)
def pretty_source(source):
return ''.join(source)
new_code = astor.to_source(func_ast, pretty_source=pretty_source)
return new_code
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from paddle.utils import gast
from .base_transformer import BaseTransformer
class RegisterHookTransformer(BaseTransformer):
def __init__(self, wrapper_root):
self.register_hook_pos_map = collections.defaultdict(list)
self.assignment_pos_map = collections.defaultdict(list)
self.root = wrapper_root.node
def transform(self):
"""
Main function to transform AST.
"""
self.visit(self.root)
def visit_FunctionDef(self, func_def):
# The inner function that has register_hook will not be processed
check_register_hook = next(
(
node
for node in gast.walk(func_def)
if isinstance(node, gast.Attribute)
and node.attr == 'register_hook'
),
None,
)
if check_register_hook is None:
return func_def
register_hook_pos_map = self.register_hook_pos_map
assignment_pos_map = self.assignment_pos_map
for i in range(len(func_def.body) - 1, -1, -1):
body = func_def.body[i]
# Check if the code body contains the register_hook
if isinstance(body, gast.Expr):
for node in gast.walk(body):
if (
isinstance(node, gast.Attribute)
and node.attr == 'register_hook'
):
# parameter name for register_hook
param_name = node.value.id
register_hook_pos_map[param_name].append(i)
elif isinstance(body, gast.Assign):
for target in body.targets:
assignment_pos_map[target.id].append(i)
# Confirm the order
order_map = {}
for k, idx_list in register_hook_pos_map.items():
for idx in idx_list:
if k not in assignment_pos_map:
order_map[idx] = 1
else:
for assignment_idx in assignment_pos_map[k]:
if idx > assignment_idx:
order_map[idx] = assignment_idx + 1
break
code_order = [*range(len(func_def.body))]
for k, v in sorted(order_map.items(), key=lambda x: x[1], reverse=True):
if k == v:
continue
code_order.remove(k)
code_order.insert(v, k)
# rearrange the code according to the specified order
new_body = [func_def.body[i] for i in code_order]
func_def.body = new_body
return func_def
......@@ -38,7 +38,7 @@ from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from paddle.utils import gast
from .ast_utils import ast_to_source_code, modify_function_code
from .ast_utils import ast_to_source_code
from .static_analysis import StaticAnalysisVisitor
from .utils_helper import DYGRAPH_MODULE_PREFIX # noqa: F401
from .utils_helper import DYGRAPH_TO_STATIC_MODULE_PREFIX # noqa: F401
......@@ -643,10 +643,7 @@ def func_to_source_code(function, dedent=True):
type(function).__name__
)
)
# return modified function source code if there is 'register_hook', otherwise return None
source_code = modify_function_code(function)
if source_code is None:
source_code_list, _ = inspect.getsourcelines(function)
# Replace comments with blank lines so that error messages are not misplaced
source_code_list = [
......
......@@ -45,7 +45,7 @@ class TestStaticAnalysis(unittest.TestCase):
jit_f = to_static(f)
loss = jit_f(x_jit)
loss.backward()
self.assertTrue(np.allclose(x.grad.numpy(), x_jit.grad.numpy()))
np.testing.assert_allclose(x.grad.numpy(), x_jit.grad.numpy())
def test_hook_for_reassignment_parameter(self):
def f(x):
......@@ -68,7 +68,7 @@ class TestStaticAnalysis(unittest.TestCase):
jit_f = to_static(f)
loss = jit_f(x_jit)
loss.backward()
self.assertTrue(np.allclose(x.grad.numpy(), x_jit.grad.numpy()))
np.testing.assert_allclose(x.grad.numpy(), x_jit.grad.numpy())
def test_hook_for_repeat_register(self):
def f(x):
......@@ -91,7 +91,7 @@ class TestStaticAnalysis(unittest.TestCase):
jit_f = to_static(f)
loss = jit_f(x_jit)
loss.backward()
self.assertTrue(np.allclose(x.grad.numpy(), x_jit.grad.numpy()))
np.testing.assert_allclose(x.grad.numpy(), x_jit.grad.numpy())
def test_hook_in_init_for_layer(self):
def hook(grad):
......@@ -120,12 +120,10 @@ class TestStaticAnalysis(unittest.TestCase):
loss_jit = jit_layer(image_jit)
loss_jit.backward()
loss.backward()
self.assertTrue(
np.allclose(
np.testing.assert_allclose(
layer.parameters()[0].grad.numpy(),
jit_layer.parameters()[0].grad.numpy(),
)
)
# def test_hook_in_forward_for_layer(self):
#
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册