diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py index cf3383f5d0638015a685dad42cadd6a8dacccbf5..e660a64ab363c10581520a3d4223994e009f769d 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py @@ -243,6 +243,7 @@ def convert_call(func): if hasattr(func, 'forward') and isinstance(func, Layer): try: _, forward_func = unwrap_decorators(func.forward) + func._original_funcs['forward'] = forward_func.__func__ forward_func = convert_to_static(forward_func) # Bound mothod will be convert into plain function after `convert_to_static`. # So descriptor mechanism is used to bound `self` instance on function to diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 54c2b2216cd1c1add124ebbb920c5c1b6c486d69..c5a39158024019549806a1dc751f811f931813ce 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -255,6 +255,8 @@ class StaticFunction(object): if inspect.ismethod(function): self._dygraph_function = getattr(function, '__func__') self._class_instance = getattr(function, '__self__') + self._class_instance._original_funcs[ + function.__name__] = self._dygraph_function else: self._dygraph_function = function self._class_instance = None @@ -564,6 +566,60 @@ class StaticFunction(object): partial_layer) = self._program_cache.last() return concrete_program + def rollback(self): + """ + Rollback into original dygraph functions for current class instance. + + Returns: + Function or Method + + Example:: + .. code-block:: python + + import paddle + + class Net(paddle.nn.Layer): + def __init__(self): + super(Net, self).__init__() + + def forward(self, x, flag=True): + if flag: + out = x + 1 + else: + out = x - 1 + return out + + x = paddle.randn([10, 1], 'float32') + net = paddle.jit.to_static(Net()) # convert into static mode + out = net(x) + + net.forward.rollback() # rollback into dygraph mode + out = net(x) + """ + + def rollback_impl(class_instance): + for name, func in class_instance._original_funcs.items(): + setattr(class_instance, name, func.__get__(class_instance)) + + for sublayer in class_instance.sublayers(include_self=False): + rollback_impl(sublayer) + + if self._class_instance is None: + return self._dygraph_function + + # only rollback sub-functions on path of top _dygraph_function + func_name = self._dygraph_function.__name__ + assert func_name in self._class_instance._original_funcs, "Not Found function '{}' in class '{}'.".format( + func_name, self._class_instance.__name__) + func = self._class_instance._original_funcs[func_name] + setattr(self._class_instance, func_name, + func.__get__(self._class_instance)) + + for sublayer in self._class_instance.sublayers(include_self=False): + rollback_impl(sublayer) + + return getattr(self._class_instance, func_name) + @property def inputs(self): """ diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index b67f7d0a91fee235d63773738763d32ce0268773..4a4bdf6e18e36060924087345e7924bcefdf5244 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -127,6 +127,8 @@ class Layer(object): self._casted_by_pure_fp16 = False self._state_dict_hooks = collections.OrderedDict() + # Records orignal functions after @to_static to support to rollback + self._original_funcs = collections.OrderedDict() def train(self): """ diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py index 6188d6a786b2e2f533029df8ae2b50423afcd62f..375873aa14fddef62d987915ea91ecd02f743db7 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py @@ -51,7 +51,7 @@ class TestConvertCall(unittest.TestCase): def forward_not_exist(): return net() - with self.assertRaises(TypeError): + with self.assertRaises(AttributeError): forward_not_exist() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_rollback.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_rollback.py new file mode 100644 index 0000000000000000000000000000000000000000..5277a50c299ea7ecb5db79c02aba984195978d45 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_rollback.py @@ -0,0 +1,126 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import numpy as np +from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code +from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction + + +class Net(paddle.nn.Layer): + + def __init__(self): + super(Net, self).__init__() + self.sub = SubNet() + + def forward(self, x): + x = self.sub(x) + x = foo(x) + out = self.sub.bar(x) + return out + + def infer(self, x): + x = self.sub.bar(x) + out = foo(x) + return out + + +class SubNet(paddle.nn.Layer): + + def __init__(self): + super(SubNet, self).__init__() + + def forward(self, x, flag=True): + if flag: + out = x + 1 + else: + out = x - 1 + return out + + def bar(self, x, flag=True): + if flag: + out = x + 2 + else: + out = x - 2 + return out + + +def foo(x, flag=False): + if flag: + out = x * 2. + else: + out = x / 2. + + return out + + +class TestRollBackPlainFunction(unittest.TestCase): + + def setUp(self): + paddle.set_device("cpu") + + def test_plain_func(self): + st_foo = paddle.jit.to_static(foo) + x = paddle.randn([3, 4]) + st_out = st_foo(x) + + self.assertTrue(isinstance(st_foo, StaticFunction)) + + st_foo = st_foo.rollback() + dy_out = st_foo(x) + + self.assertTrue(func_to_source_code(foo) == func_to_source_code(st_foo)) + self.assertTrue(np.array_equal(st_out.numpy(), dy_out.numpy())) + + +class TestRollBackNet(unittest.TestCase): + + def setUp(self): + paddle.set_device("cpu") + + def test_net(self): + net = paddle.jit.to_static(Net()) + x = paddle.randn([3, 4]) + st_fwd_out = net(x) + + # forward function is inplacly converted. + self.assertTrue(isinstance(net.forward, StaticFunction)) + self.assertTrue("true_fn" in func_to_source_code(net.sub.forward)) + # other non-forward function is not inplacly converted. + self.assertFalse("true_fn" in func_to_source_code(net.sub.bar)) + + net.infer = paddle.jit.to_static(net.infer) + st_infer_out = net.infer(x) + self.assertTrue(isinstance(net.infer, StaticFunction)) + self.assertFalse("true_fn" in func_to_source_code(net.sub.bar)) + + # rollback forward into original dygraph method + net.forward = net.forward.rollback() + self.assertFalse(isinstance(net.forward, StaticFunction)) + self.assertFalse("true_fn" in func_to_source_code(net.sub.forward)) + dy_fwd_out = net(x) + self.assertTrue(np.array_equal(st_fwd_out.numpy(), dy_fwd_out.numpy())) + + # rollback infer into original dygraph method + net.infer.rollback() + self.assertFalse(isinstance(net.infer, StaticFunction)) + self.assertFalse("true_fn" in func_to_source_code(net.sub.forward)) + dy_infer_out = net.infer(x) + self.assertTrue( + np.array_equal(st_infer_out.numpy(), dy_infer_out.numpy())) + + +if __name__ == "__main__": + unittest.main()