diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py index 4630cfcdabfd307ea03a7fd0c885c73ce4a4ea0b..c837c8eb123c2707d89a75a7489607f43a2e7501 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py @@ -31,6 +31,7 @@ from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogger from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticLayer from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static +from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_decorators from paddle.fluid.dygraph.layers import Layer # TODO(liym27): A better way to do this. @@ -118,14 +119,9 @@ def convert_call(func): func_self = None converted_call = None - # Function in convert_call may be decorated by another `@declarative`, + # Function in convert_call may be decorated by another `@to_static`, # in this case, unwraps it into a raw method or function. - if isinstance(func, StaticLayer): - instance = func._class_instance - if instance is not None: - func = func.dygraph_function.__get__(instance) - else: - func = func.dygraph_function + _, func = unwrap_decorators(func) if is_builtin_len(func): return convert_len @@ -155,7 +151,8 @@ def convert_call(func): if inspect.isfunction(fn): global_functions.add(fn) elif isinstance(fn, StaticLayer): - global_functions.add(fn.dygraph_function) + _, fn = unwrap_decorators(fn) + global_functions.add(fn) if func in global_functions: converted_call = convert_to_static(func) @@ -189,7 +186,8 @@ def convert_call(func): elif hasattr(func, '__class__') and hasattr(func.__class__, '__call__'): if hasattr(func, 'forward') and isinstance(func, Layer): try: - forward_func = convert_to_static(func.forward) + _, forward_func = unwrap_decorators(func.forward) + forward_func = convert_to_static(forward_func) setattr(func, 'forward', forward_func) func_self = func except Exception: diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index cb489af44d0adc7da377f73a3205c3c264769b4d..3d27810f1db94c4f6c273399ec93b9335f5bb03a 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -21,6 +21,7 @@ import six import textwrap import threading import warnings +import weakref import gast from paddle.fluid import framework @@ -245,6 +246,7 @@ class StaticLayer(object): self._input_spec = input_spec self._function_spec = FunctionSpec(function, input_spec) self._program_cache = ProgramCache() + self._descriptor_cache = weakref.WeakKeyDictionary() # Note: Hold a reference to ProgramTranslator for switching `enable_declarative`. self._program_trans = ProgramTranslator() @@ -271,8 +273,19 @@ class StaticLayer(object): of `Net` instance. After decorated by `@paddle.jit.to_static`, it will firstly to call `__get__` to parse the class instance correctly instead of the `StaticLayer` instance. """ - self._class_instance = instance - return self + if instance not in self._descriptor_cache: + if instance is None: + return self + # Note(Aurelius84): To construct new instance of StaticLayer when we + # first encouter the bound function of layer and cache it. + new_static_layer = self._clone() + new_static_layer._class_instance = instance + self._descriptor_cache[instance] = new_static_layer + + return self._descriptor_cache[instance] + + def _clone(self): + return self.__class__(self._dygraph_function, self._input_spec) def __call__(self, *args, **kwargs): """ diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py index 949286f63efb3357325f25b02f60e938eebd28e8..0b8df63d666b6547d5dccfc2ce0b420d653cc544 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py @@ -19,7 +19,7 @@ import paddle import paddle.fluid as fluid from paddle.static import InputSpec from paddle.fluid.dygraph import to_variable, declarative, ProgramTranslator, Layer, jit -from paddle.fluid.dygraph.dygraph_to_static.program_translator import ConcreteProgram +from paddle.fluid.dygraph.dygraph_to_static.program_translator import ConcreteProgram, StaticLayer from test_basic_api_transformation import dyfunc_to_variable @@ -84,6 +84,23 @@ class SimpleNet(Layer): return z +class TestStaticLayerInstance(unittest.TestCase): + def test_instance_same_class(self): + with fluid.dygraph.guard(fluid.CPUPlace()): + net_1 = SimpleNet() + net_2 = SimpleNet() + + self.assertTrue(isinstance(net_1.forward, StaticLayer)) + self.assertTrue(isinstance(net_2.forward, StaticLayer)) + self.assertNotEqual(net_1.forward, net_2.forward) + + # convert layer into static progam of net_1 + net_1.forward.concrete_program + self.assertTrue(len(net_1.forward.program_cache) == 1) + # check no conversion applid with net_2 + self.assertTrue(len(net_2.forward.program_cache) == 0) + + class TestInputSpec(unittest.TestCase): def setUp(self): pass @@ -224,7 +241,6 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase): # 1. specific InputSpec for `x`/`y` concrete_program_1 = foo.get_concrete_program( InputSpec([None, 10]), InputSpec([10])) - print(concrete_program_1) self.assertTrue(len(foo.program_cache) == 1) # 2. specific `c`/`d` explicitly with same default value diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py index 6cf59c030c00384b225d5d13160f68a3558084b9..cf7708c675aa9c1fb8faf5f8585b458be88b6c83 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py @@ -133,7 +133,7 @@ class TestPartialProgramRaiseError(unittest.TestCase): x = fluid.dygraph.to_variable(x_data) out = net(x) - program_cache = SimpleFcLayer.forward.program_cache + program_cache = net.forward.program_cache _, (concrete_program, _) = program_cache.last() params = concrete_program.parameters