未验证 提交 db30aa1d 编写于 作者: Y yangguohao 提交者: GitHub

【Hackathon No.91】register_hook for static mode (#52948)

上级 cf6cbc34
...@@ -1640,9 +1640,26 @@ class Variable(metaclass=VariableMetaClass): ...@@ -1640,9 +1640,26 @@ class Variable(metaclass=VariableMetaClass):
""" """
pass pass
@fake_interface_only
def register_hook(self, hook): def register_hook(self, hook):
pass import paddle
def backward_hook_wrapper(dy):
"""call the backward hook in ."""
import numpy as np
return hook(np.array(dy))
def forward_hook_wrapper(x):
"""do nothing but return a new variable."""
return x
paddle.static.py_func(
func=forward_hook_wrapper,
x=self,
out=self,
backward_func=backward_hook_wrapper,
skip_vars_in_backward_input=[self],
)
def __str__(self): def __str__(self):
return self._to_readable_code() return self._to_readable_code()
......
...@@ -45,9 +45,10 @@ class SimpleNetForStatic(nn.Layer): ...@@ -45,9 +45,10 @@ class SimpleNetForStatic(nn.Layer):
self.linear1 = nn.Linear(in_size, in_size) self.linear1 = nn.Linear(in_size, in_size)
self.linear2 = nn.Linear(in_size, out_size) self.linear2 = nn.Linear(in_size, out_size)
def forward(self, x): def forward(self, x, hook=False):
ret1 = self.linear1(x) ret1 = self.linear1(x)
ret1.register_hook(lambda grad: grad * 2) if hook:
ret1.register_hook(lambda grad: grad * 2)
ret2 = self.linear2(ret1) ret2 = self.linear2(ret1)
out = paddle.mean(ret2, axis=-1) out = paddle.mean(ret2, axis=-1)
...@@ -512,8 +513,7 @@ class TestTensorRegisterHook(unittest.TestCase): ...@@ -512,8 +513,7 @@ class TestTensorRegisterHook(unittest.TestCase):
) )
net = SimpleNetForStatic(self.in_size, self.out_size) net = SimpleNetForStatic(self.in_size, self.out_size)
with self.assertRaises(AssertionError): out = net(x)
out = net(x)
paddle.disable_static() paddle.disable_static()
...@@ -527,9 +527,17 @@ class TestTensorRegisterHook(unittest.TestCase): ...@@ -527,9 +527,17 @@ class TestTensorRegisterHook(unittest.TestCase):
'float32' 'float32'
) )
data_t = paddle.to_tensor(data) data_t = paddle.to_tensor(data)
data_t2 = paddle.to_tensor(data)
with self.assertRaises(AssertionError): data_t.stop_gradient = False
out = jit_net(data_t) data_t2.stop_gradient = False
out1 = jit_net(data_t)
out2 = jit_net(data_t2, True)
out1.backward()
out2.backward()
np.testing.assert_array_equal(
2 * data_t.grad.numpy(), data_t2.grad.numpy()
)
HOOK_INIT_VALUE = 10 HOOK_INIT_VALUE = 10
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
import ast import ast
import collections
import inspect
import textwrap
import astor import astor
...@@ -38,3 +41,84 @@ def ast_to_source_code(ast_node): ...@@ -38,3 +41,84 @@ def ast_to_source_code(ast_node):
source_code = astor.to_source(ast_node, pretty_source=pretty_source) source_code = astor.to_source(ast_node, pretty_source=pretty_source)
return source_code return source_code
class RegisterHookVisitor(gast.NodeVisitor):
def __init__(self, func_name):
self.register_hook_pos_map = collections.defaultdict(list)
self.assignment_pos_map = collections.defaultdict(list)
self.func_name = func_name
def visit_FunctionDef(self, func_def):
# The inner function that has register_hook will not be processed
if func_def.name != self.func_name:
return
register_hook_pos_map = self.register_hook_pos_map
assignment_pos_map = self.assignment_pos_map
for i in range(len(func_def.body) - 1, -1, -1):
body = func_def.body[i]
# Check if the code body contains the register_hook
if isinstance(body, ast.Expr):
for node in ast.walk(body):
if (
isinstance(node, ast.Attribute)
and node.attr == 'register_hook'
):
# parameter name for register_hook
param_name = node.value.id
register_hook_pos_map[param_name].append(i)
elif isinstance(body, ast.Assign):
for target in body.targets:
assignment_pos_map[target.id].append(i)
# Confirm the order
order_map = {}
for k, idx_list in register_hook_pos_map.items():
for idx in idx_list:
if k not in assignment_pos_map:
order_map[idx] = 1
else:
for assignment_idx in assignment_pos_map[k]:
if idx > assignment_idx:
order_map[idx] = assignment_idx + 1
break
code_order = [*range(len(func_def.body))]
for k, v in sorted(order_map.items(), key=lambda x: x[1], reverse=True):
if k == v:
continue
code_order.remove(k)
code_order.insert(v, k)
# rearrange the code according to the specified order
new_body = [func_def.body[i] for i in code_order]
func_def.body = new_body
def modify_function_code(func):
"""
Modify the function code for the register hook
"""
func_ast = ast.parse(textwrap.dedent(inspect.getsource(func)))
# check if there is register_hook on code after visit the tree.
check_register_hook = next(
(
node
for node in ast.walk(func_ast)
if isinstance(node, ast.Attribute) and node.attr == 'register_hook'
),
None,
)
if check_register_hook is None:
return
visitor = RegisterHookVisitor(func.__name__)
visitor.visit(func_ast)
def pretty_source(source):
return ''.join(source)
new_code = astor.to_source(func_ast, pretty_source=pretty_source)
return new_code
...@@ -38,7 +38,7 @@ from paddle.fluid.layer_helper import LayerHelper ...@@ -38,7 +38,7 @@ from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from paddle.utils import gast from paddle.utils import gast
from .ast_utils import ast_to_source_code from .ast_utils import ast_to_source_code, modify_function_code
from .static_analysis import StaticAnalysisVisitor from .static_analysis import StaticAnalysisVisitor
from .utils_helper import DYGRAPH_MODULE_PREFIX # noqa: F401 from .utils_helper import DYGRAPH_MODULE_PREFIX # noqa: F401
from .utils_helper import DYGRAPH_TO_STATIC_MODULE_PREFIX # noqa: F401 from .utils_helper import DYGRAPH_TO_STATIC_MODULE_PREFIX # noqa: F401
...@@ -643,15 +643,20 @@ def func_to_source_code(function, dedent=True): ...@@ -643,15 +643,20 @@ def func_to_source_code(function, dedent=True):
type(function).__name__ type(function).__name__
) )
) )
source_code_list, _ = inspect.getsourcelines(function) # return modified function source code if there is 'register_hook', otherwise return None
# Replace comments with blank lines so that error messages are not misplaced source_code = modify_function_code(function)
source_code_list = [
line if not line.lstrip().startswith('#') else '\n' if source_code is None:
for line in source_code_list source_code_list, _ = inspect.getsourcelines(function)
] # Replace comments with blank lines so that error messages are not misplaced
source_code = ''.join(source_code_list) source_code_list = [
if dedent: line if not line.lstrip().startswith('#') else '\n'
source_code = textwrap.dedent(source_code) for line in source_code_list
]
source_code = ''.join(source_code_list)
if dedent:
source_code = textwrap.dedent(source_code)
return source_code return source_code
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle import nn
from paddle.jit import to_static
class TestStaticAnalysis(unittest.TestCase):
def test_hook_for_different_parameter(self):
def f(x):
def h(g):
return 2 * g
y = x + 4
f = y + x
z = f**2
y.register_hook(h)
f.register_hook(h)
x.register_hook(h)
return z
x = paddle.to_tensor([2.0])
x.stop_gradient = False
loss = f(x)
loss.backward()
x_jit = paddle.to_tensor([2.0])
x_jit.stop_gradient = False
jit_f = to_static(f)
loss = jit_f(x_jit)
loss.backward()
self.assertTrue(np.allclose(x.grad.numpy(), x_jit.grad.numpy()))
def test_hook_for_reassignment_parameter(self):
def f(x):
def h(g):
return 2 * g
y = x + 4
x = y * 5
z = x**2
x.register_hook(h)
return z
x = paddle.to_tensor([2.0])
x.stop_gradient = False
loss = f(x)
loss.backward()
x_jit = paddle.to_tensor([2.0])
x_jit.stop_gradient = False
jit_f = to_static(f)
loss = jit_f(x_jit)
loss.backward()
self.assertTrue(np.allclose(x.grad.numpy(), x_jit.grad.numpy()))
def test_hook_for_repeat_register(self):
def f(x):
def h(g):
return 2 * g
y = x + 4
z = y**2
x.register_hook(h)
x.register_hook(h)
return z
x = paddle.to_tensor([2.0])
x.stop_gradient = False
loss = f(x)
loss.backward()
x_jit = paddle.to_tensor([2.0])
x_jit.stop_gradient = False
jit_f = to_static(f)
loss = jit_f(x_jit)
loss.backward()
self.assertTrue(np.allclose(x.grad.numpy(), x_jit.grad.numpy()))
def test_hook_in_init_for_layer(self):
def hook(grad):
return grad * 2
IMAGE_SIZE = 784
CLASS_NUM = 10
class LinearNet(nn.Layer):
def __init__(self):
super().__init__()
self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
# register_hook in init
self._linear.parameters()[0].register_hook(hook)
def forward(self, x):
return self._linear(x)
# create network
layer = LinearNet()
jit_layer = to_static(LinearNet())
data = np.random.random([IMAGE_SIZE]).astype('float32')
image = paddle.to_tensor(data)
image_jit = paddle.to_tensor(data)
loss = layer(image)
loss_jit = jit_layer(image_jit)
loss_jit.backward()
loss.backward()
self.assertTrue(
np.allclose(
layer.parameters()[0].grad.numpy(),
jit_layer.parameters()[0].grad.numpy(),
)
)
# def test_hook_in_forward_for_layer(self):
#
# IMAGE_SIZE = 784
# CLASS_NUM = 10
#
# class LinearNet(nn.Layer):
# def __init__(self):
# super().__init__()
# self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
#
# def forward(self, x):
# def hook(grad):
# return grad * 2
#
# res = self._linear(x)
#
# # register_hook in forward
# self._linear.parameters()[0].register_hook(hook)
# return res
#
# # create network
# layer = LinearNet()
# jit_layer = to_static(LinearNet())
# data = np.random.random([IMAGE_SIZE]).astype('float32')
# image = paddle.to_tensor(data)
# image_jit = paddle.to_tensor(data)
# loss = layer(image)
# loss_jit = jit_layer(image_jit)
# loss_jit.backward()
# loss.backward()
# self.assertTrue(
# np.allclose(
# layer.parameters()[0].grad.numpy(),
# jit_layer.parameters()[0].grad.numpy(),
# )
# )
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册