未验证 提交 3fb9ee6f 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat] Add descriptor cache for StaticLayer (#26987) (#27098)

* add descriptor cache

* fix self arugments

* deal case if instance is None

* clean code

* fix usage
上级 940e673b
...@@ -31,6 +31,7 @@ from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len ...@@ -31,6 +31,7 @@ from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogger from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogger
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticLayer from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticLayer
from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_decorators
from paddle.fluid.dygraph.layers import Layer from paddle.fluid.dygraph.layers import Layer
# TODO(liym27): A better way to do this. # TODO(liym27): A better way to do this.
...@@ -118,14 +119,9 @@ def convert_call(func): ...@@ -118,14 +119,9 @@ def convert_call(func):
func_self = None func_self = None
converted_call = None converted_call = None
# Function in convert_call may be decorated by another `@declarative`, # Function in convert_call may be decorated by another `@to_static`,
# in this case, unwraps it into a raw method or function. # in this case, unwraps it into a raw method or function.
if isinstance(func, StaticLayer): _, func = unwrap_decorators(func)
instance = func._class_instance
if instance is not None:
func = func.dygraph_function.__get__(instance)
else:
func = func.dygraph_function
if is_builtin_len(func): if is_builtin_len(func):
return convert_len return convert_len
...@@ -155,7 +151,8 @@ def convert_call(func): ...@@ -155,7 +151,8 @@ def convert_call(func):
if inspect.isfunction(fn): if inspect.isfunction(fn):
global_functions.add(fn) global_functions.add(fn)
elif isinstance(fn, StaticLayer): elif isinstance(fn, StaticLayer):
global_functions.add(fn.dygraph_function) _, fn = unwrap_decorators(fn)
global_functions.add(fn)
if func in global_functions: if func in global_functions:
converted_call = convert_to_static(func) converted_call = convert_to_static(func)
...@@ -189,7 +186,8 @@ def convert_call(func): ...@@ -189,7 +186,8 @@ def convert_call(func):
elif hasattr(func, '__class__') and hasattr(func.__class__, '__call__'): elif hasattr(func, '__class__') and hasattr(func.__class__, '__call__'):
if hasattr(func, 'forward') and isinstance(func, Layer): if hasattr(func, 'forward') and isinstance(func, Layer):
try: try:
forward_func = convert_to_static(func.forward) _, forward_func = unwrap_decorators(func.forward)
forward_func = convert_to_static(forward_func)
setattr(func, 'forward', forward_func) setattr(func, 'forward', forward_func)
func_self = func func_self = func
except Exception: except Exception:
......
...@@ -21,6 +21,7 @@ import six ...@@ -21,6 +21,7 @@ import six
import textwrap import textwrap
import threading import threading
import warnings import warnings
import weakref
import gast import gast
from paddle.fluid import framework from paddle.fluid import framework
...@@ -245,6 +246,7 @@ class StaticLayer(object): ...@@ -245,6 +246,7 @@ class StaticLayer(object):
self._input_spec = input_spec self._input_spec = input_spec
self._function_spec = FunctionSpec(function, input_spec) self._function_spec = FunctionSpec(function, input_spec)
self._program_cache = ProgramCache() self._program_cache = ProgramCache()
self._descriptor_cache = weakref.WeakKeyDictionary()
# Note: Hold a reference to ProgramTranslator for switching `enable_declarative`. # Note: Hold a reference to ProgramTranslator for switching `enable_declarative`.
self._program_trans = ProgramTranslator() self._program_trans = ProgramTranslator()
...@@ -271,8 +273,19 @@ class StaticLayer(object): ...@@ -271,8 +273,19 @@ class StaticLayer(object):
of `Net` instance. After decorated by `@paddle.jit.to_static`, it will firstly to call `__get__` of `Net` instance. After decorated by `@paddle.jit.to_static`, it will firstly to call `__get__`
to parse the class instance correctly instead of the `StaticLayer` instance. to parse the class instance correctly instead of the `StaticLayer` instance.
""" """
self._class_instance = instance if instance not in self._descriptor_cache:
return self if instance is None:
return self
# Note(Aurelius84): To construct new instance of StaticLayer when we
# first encouter the bound function of layer and cache it.
new_static_layer = self._clone()
new_static_layer._class_instance = instance
self._descriptor_cache[instance] = new_static_layer
return self._descriptor_cache[instance]
def _clone(self):
return self.__class__(self._dygraph_function, self._input_spec)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
""" """
......
...@@ -19,7 +19,7 @@ import paddle ...@@ -19,7 +19,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.static import InputSpec from paddle.static import InputSpec
from paddle.fluid.dygraph import to_variable, declarative, ProgramTranslator, Layer, jit from paddle.fluid.dygraph import to_variable, declarative, ProgramTranslator, Layer, jit
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ConcreteProgram from paddle.fluid.dygraph.dygraph_to_static.program_translator import ConcreteProgram, StaticLayer
from test_basic_api_transformation import dyfunc_to_variable from test_basic_api_transformation import dyfunc_to_variable
...@@ -84,6 +84,23 @@ class SimpleNet(Layer): ...@@ -84,6 +84,23 @@ class SimpleNet(Layer):
return z return z
class TestStaticLayerInstance(unittest.TestCase):
def test_instance_same_class(self):
with fluid.dygraph.guard(fluid.CPUPlace()):
net_1 = SimpleNet()
net_2 = SimpleNet()
self.assertTrue(isinstance(net_1.forward, StaticLayer))
self.assertTrue(isinstance(net_2.forward, StaticLayer))
self.assertNotEqual(net_1.forward, net_2.forward)
# convert layer into static progam of net_1
net_1.forward.concrete_program
self.assertTrue(len(net_1.forward.program_cache) == 1)
# check no conversion applid with net_2
self.assertTrue(len(net_2.forward.program_cache) == 0)
class TestInputSpec(unittest.TestCase): class TestInputSpec(unittest.TestCase):
def setUp(self): def setUp(self):
pass pass
...@@ -224,7 +241,6 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase): ...@@ -224,7 +241,6 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase):
# 1. specific InputSpec for `x`/`y` # 1. specific InputSpec for `x`/`y`
concrete_program_1 = foo.get_concrete_program( concrete_program_1 = foo.get_concrete_program(
InputSpec([None, 10]), InputSpec([10])) InputSpec([None, 10]), InputSpec([10]))
print(concrete_program_1)
self.assertTrue(len(foo.program_cache) == 1) self.assertTrue(len(foo.program_cache) == 1)
# 2. specific `c`/`d` explicitly with same default value # 2. specific `c`/`d` explicitly with same default value
......
...@@ -133,7 +133,7 @@ class TestPartialProgramRaiseError(unittest.TestCase): ...@@ -133,7 +133,7 @@ class TestPartialProgramRaiseError(unittest.TestCase):
x = fluid.dygraph.to_variable(x_data) x = fluid.dygraph.to_variable(x_data)
out = net(x) out = net(x)
program_cache = SimpleFcLayer.forward.program_cache program_cache = net.forward.program_cache
_, (concrete_program, _) = program_cache.last() _, (concrete_program, _) = program_cache.last()
params = concrete_program.parameters params = concrete_program.parameters
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册