From e61b25f930032c9bfa7fcc5ce3b18a190158f305 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 9 Jun 2022 11:59:30 +0800 Subject: [PATCH] [Dy2Stat]Support deepcopy Net instance after @to_static (#43317) * [Dy2stat]Add RollBack into original dygraph function for @to_static * fix unittest * [Dy2Stat]Support deepcopy Net instance after @to_static --- .../dygraph_to_static/program_translator.py | 45 +++++++++++++++ .../dygraph_to_static/test_deepcopy.py | 55 +++++++++++++++++++ 2 files changed, 100 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_deepcopy.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index c5a3915802..49a218412c 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -620,6 +620,51 @@ class StaticFunction(object): return getattr(self._class_instance, func_name) + def __deepcopy__(self, memo): + """ + Customized behavior for copy.deepcopy, return original decorated function instead + of a new StaticFunction Object. StaticFunction itself is not copyable becuase it's + associated with class_instance. + + We add __deepcopy__ here only for the following usage: + + Example:: + .. code-block:: python + + import copy + import paddle + + class Net(paddle.nn.Layer): + def __init__(self): + super(Net, self).__init__() + + def forward(self, x, flag=True): + if flag: + out = x + 1 + else: + out = x - 1 + return out + + x = paddle.randn([10, 1], 'float32') + net = paddle.jit.to_static(Net()) # convert into static mode + + copy_net = copy.deepcopy(net) # deepcopy a new net without @to_static + + Please attention that original 'net' will unwrap @to_static and rollback into simple Layer. + """ + if self._class_instance is not None: + net_name = type(self._class_instance).__name__ + logging_utils.log( + level=-1, + msg="Not recommend to deepcopy '{}' decorated with @to_static, it has side effect that will" \ + " rollback into original state before @to_static. Please deepcopy '{}' before applying @to_static." + .format(net_name, net_name)) + self.rollback() + return self._dygraph_function.__get__(memo[id( + self._class_instance)]) + else: + return self._dygraph_function + @property def inputs(self): """ diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_deepcopy.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_deepcopy.py new file mode 100644 index 0000000000..dcc12e120d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_deepcopy.py @@ -0,0 +1,55 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import numpy as np +from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction + +from test_rollback import Net, foo +from copy import deepcopy + + +class TestDeepCopy(unittest.TestCase): + + def test_net(self): + net = Net() + net = paddle.jit.to_static(net) + + x = paddle.randn([3, 4]) + src_out = net(x) + self.assertTrue(isinstance(net.forward, StaticFunction)) + + copy_net = deepcopy(net) + copy_out = copy_net(x) + + self.assertFalse(isinstance(net.forward, StaticFunction)) + self.assertTrue(id(copy_net), id(copy_net.forward.__self__)) + self.assertTrue(np.array_equal(src_out.numpy(), copy_out.numpy())) + + def test_func(self): + st_foo = paddle.jit.to_static(foo) + x = paddle.randn([3, 4]) + st_out = st_foo(x) + + self.assertTrue(isinstance(st_foo, StaticFunction)) + + new_foo = deepcopy(st_foo) + self.assertFalse(isinstance(new_foo, StaticFunction)) + new_out = new_foo(x) + self.assertTrue(np.array_equal(st_out.numpy(), new_out.numpy())) + + +if __name__ == "__main__": + unittest.main() -- GitLab