未验证 提交 35f53ecd 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat] Add descriptor cache for StaticLayer (#26987)

* add descriptor cache

* fix self arugments

* deal case if instance is None

* clean code

* fix usage
上级 9373cf5a
......@@ -31,6 +31,7 @@ from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogger
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticLayer
from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_decorators
from paddle.fluid.dygraph.layers import Layer
# TODO(liym27): A better way to do this.
......@@ -118,14 +119,9 @@ def convert_call(func):
func_self = None
converted_call = None
# Function in convert_call may be decorated by another `@declarative`,
# Function in convert_call may be decorated by another `@to_static`,
# in this case, unwraps it into a raw method or function.
if isinstance(func, StaticLayer):
instance = func._class_instance
if instance is not None:
func = func.dygraph_function.__get__(instance)
else:
func = func.dygraph_function
_, func = unwrap_decorators(func)
if is_builtin_len(func):
return convert_len
......@@ -155,7 +151,8 @@ def convert_call(func):
if inspect.isfunction(fn):
global_functions.add(fn)
elif isinstance(fn, StaticLayer):
global_functions.add(fn.dygraph_function)
_, fn = unwrap_decorators(fn)
global_functions.add(fn)
if func in global_functions:
converted_call = convert_to_static(func)
......@@ -189,7 +186,8 @@ def convert_call(func):
elif hasattr(func, '__class__') and hasattr(func.__class__, '__call__'):
if hasattr(func, 'forward') and isinstance(func, Layer):
try:
forward_func = convert_to_static(func.forward)
_, forward_func = unwrap_decorators(func.forward)
forward_func = convert_to_static(forward_func)
setattr(func, 'forward', forward_func)
func_self = func
except Exception:
......
......@@ -21,6 +21,7 @@ import six
import textwrap
import threading
import warnings
import weakref
import gast
from paddle.fluid import framework
......@@ -245,6 +246,7 @@ class StaticLayer(object):
self._input_spec = input_spec
self._function_spec = FunctionSpec(function, input_spec)
self._program_cache = ProgramCache()
self._descriptor_cache = weakref.WeakKeyDictionary()
# Note: Hold a reference to ProgramTranslator for switching `enable_declarative`.
self._program_trans = ProgramTranslator()
......@@ -271,8 +273,19 @@ class StaticLayer(object):
of `Net` instance. After decorated by `@paddle.jit.to_static`, it will firstly to call `__get__`
to parse the class instance correctly instead of the `StaticLayer` instance.
"""
self._class_instance = instance
if instance not in self._descriptor_cache:
if instance is None:
return self
# Note(Aurelius84): To construct new instance of StaticLayer when we
# first encouter the bound function of layer and cache it.
new_static_layer = self._clone()
new_static_layer._class_instance = instance
self._descriptor_cache[instance] = new_static_layer
return self._descriptor_cache[instance]
def _clone(self):
return self.__class__(self._dygraph_function, self._input_spec)
def __call__(self, *args, **kwargs):
"""
......
......@@ -19,7 +19,7 @@ import paddle
import paddle.fluid as fluid
from paddle.static import InputSpec
from paddle.fluid.dygraph import to_variable, declarative, ProgramTranslator, Layer, jit
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ConcreteProgram
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ConcreteProgram, StaticLayer
from test_basic_api_transformation import dyfunc_to_variable
......@@ -84,6 +84,23 @@ class SimpleNet(Layer):
return z
class TestStaticLayerInstance(unittest.TestCase):
def test_instance_same_class(self):
with fluid.dygraph.guard(fluid.CPUPlace()):
net_1 = SimpleNet()
net_2 = SimpleNet()
self.assertTrue(isinstance(net_1.forward, StaticLayer))
self.assertTrue(isinstance(net_2.forward, StaticLayer))
self.assertNotEqual(net_1.forward, net_2.forward)
# convert layer into static progam of net_1
net_1.forward.concrete_program
self.assertTrue(len(net_1.forward.program_cache) == 1)
# check no conversion applid with net_2
self.assertTrue(len(net_2.forward.program_cache) == 0)
class TestInputSpec(unittest.TestCase):
def setUp(self):
pass
......@@ -224,7 +241,6 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase):
# 1. specific InputSpec for `x`/`y`
concrete_program_1 = foo.get_concrete_program(
InputSpec([None, 10]), InputSpec([10]))
print(concrete_program_1)
self.assertTrue(len(foo.program_cache) == 1)
# 2. specific `c`/`d` explicitly with same default value
......
......@@ -133,7 +133,7 @@ class TestPartialProgramRaiseError(unittest.TestCase):
x = fluid.dygraph.to_variable(x_data)
out = net(x)
program_cache = SimpleFcLayer.forward.program_cache
program_cache = net.forward.program_cache
_, (concrete_program, _) = program_cache.last()
params = concrete_program.parameters
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册