未验证 提交 b056c9cb 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat]Add RollBack into original dygraph function for @to_static (#43284)

* [Dy2stat]Add RollBack into original dygraph function for @to_static

* fix unittest
上级 1fbd4440
......@@ -243,6 +243,7 @@ def convert_call(func):
if hasattr(func, 'forward') and isinstance(func, Layer):
try:
_, forward_func = unwrap_decorators(func.forward)
func._original_funcs['forward'] = forward_func.__func__
forward_func = convert_to_static(forward_func)
# Bound mothod will be convert into plain function after `convert_to_static`.
# So descriptor mechanism is used to bound `self` instance on function to
......
......@@ -255,6 +255,8 @@ class StaticFunction(object):
if inspect.ismethod(function):
self._dygraph_function = getattr(function, '__func__')
self._class_instance = getattr(function, '__self__')
self._class_instance._original_funcs[
function.__name__] = self._dygraph_function
else:
self._dygraph_function = function
self._class_instance = None
......@@ -564,6 +566,60 @@ class StaticFunction(object):
partial_layer) = self._program_cache.last()
return concrete_program
def rollback(self):
"""
Rollback into original dygraph functions for current class instance.
Returns:
Function or Method
Example::
.. code-block:: python
import paddle
class Net(paddle.nn.Layer):
def __init__(self):
super(Net, self).__init__()
def forward(self, x, flag=True):
if flag:
out = x + 1
else:
out = x - 1
return out
x = paddle.randn([10, 1], 'float32')
net = paddle.jit.to_static(Net()) # convert into static mode
out = net(x)
net.forward.rollback() # rollback into dygraph mode
out = net(x)
"""
def rollback_impl(class_instance):
for name, func in class_instance._original_funcs.items():
setattr(class_instance, name, func.__get__(class_instance))
for sublayer in class_instance.sublayers(include_self=False):
rollback_impl(sublayer)
if self._class_instance is None:
return self._dygraph_function
# only rollback sub-functions on path of top _dygraph_function
func_name = self._dygraph_function.__name__
assert func_name in self._class_instance._original_funcs, "Not Found function '{}' in class '{}'.".format(
func_name, self._class_instance.__name__)
func = self._class_instance._original_funcs[func_name]
setattr(self._class_instance, func_name,
func.__get__(self._class_instance))
for sublayer in self._class_instance.sublayers(include_self=False):
rollback_impl(sublayer)
return getattr(self._class_instance, func_name)
@property
def inputs(self):
"""
......
......@@ -127,6 +127,8 @@ class Layer(object):
self._casted_by_pure_fp16 = False
self._state_dict_hooks = collections.OrderedDict()
# Records orignal functions after @to_static to support to rollback
self._original_funcs = collections.OrderedDict()
def train(self):
"""
......
......@@ -51,7 +51,7 @@ class TestConvertCall(unittest.TestCase):
def forward_not_exist():
return net()
with self.assertRaises(TypeError):
with self.assertRaises(AttributeError):
forward_not_exist()
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import numpy as np
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction
class Net(paddle.nn.Layer):
def __init__(self):
super(Net, self).__init__()
self.sub = SubNet()
def forward(self, x):
x = self.sub(x)
x = foo(x)
out = self.sub.bar(x)
return out
def infer(self, x):
x = self.sub.bar(x)
out = foo(x)
return out
class SubNet(paddle.nn.Layer):
def __init__(self):
super(SubNet, self).__init__()
def forward(self, x, flag=True):
if flag:
out = x + 1
else:
out = x - 1
return out
def bar(self, x, flag=True):
if flag:
out = x + 2
else:
out = x - 2
return out
def foo(x, flag=False):
if flag:
out = x * 2.
else:
out = x / 2.
return out
class TestRollBackPlainFunction(unittest.TestCase):
def setUp(self):
paddle.set_device("cpu")
def test_plain_func(self):
st_foo = paddle.jit.to_static(foo)
x = paddle.randn([3, 4])
st_out = st_foo(x)
self.assertTrue(isinstance(st_foo, StaticFunction))
st_foo = st_foo.rollback()
dy_out = st_foo(x)
self.assertTrue(func_to_source_code(foo) == func_to_source_code(st_foo))
self.assertTrue(np.array_equal(st_out.numpy(), dy_out.numpy()))
class TestRollBackNet(unittest.TestCase):
def setUp(self):
paddle.set_device("cpu")
def test_net(self):
net = paddle.jit.to_static(Net())
x = paddle.randn([3, 4])
st_fwd_out = net(x)
# forward function is inplacly converted.
self.assertTrue(isinstance(net.forward, StaticFunction))
self.assertTrue("true_fn" in func_to_source_code(net.sub.forward))
# other non-forward function is not inplacly converted.
self.assertFalse("true_fn" in func_to_source_code(net.sub.bar))
net.infer = paddle.jit.to_static(net.infer)
st_infer_out = net.infer(x)
self.assertTrue(isinstance(net.infer, StaticFunction))
self.assertFalse("true_fn" in func_to_source_code(net.sub.bar))
# rollback forward into original dygraph method
net.forward = net.forward.rollback()
self.assertFalse(isinstance(net.forward, StaticFunction))
self.assertFalse("true_fn" in func_to_source_code(net.sub.forward))
dy_fwd_out = net(x)
self.assertTrue(np.array_equal(st_fwd_out.numpy(), dy_fwd_out.numpy()))
# rollback infer into original dygraph method
net.infer.rollback()
self.assertFalse(isinstance(net.infer, StaticFunction))
self.assertFalse("true_fn" in func_to_source_code(net.sub.forward))
dy_infer_out = net.infer(x)
self.assertTrue(
np.array_equal(st_infer_out.numpy(), dy_infer_out.numpy()))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册