未验证 提交 fc9d80bc 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]rename StaticLayer into StaticFunction (#27487)

* rename StaticLayer

* rename
上级 dc713116
...@@ -29,7 +29,7 @@ import six ...@@ -29,7 +29,7 @@ import six
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogger from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogger
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticLayer from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction
from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_decorators from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_decorators
from paddle.fluid.dygraph.layers import Layer from paddle.fluid.dygraph.layers import Layer
...@@ -143,14 +143,14 @@ def convert_call(func): ...@@ -143,14 +143,14 @@ def convert_call(func):
# def foo(x): # def foo(x):
# return x # return x
# #
# `foo` will be converted into a wrapper class, suppose as `StaticLayer`. # `foo` will be converted into a wrapper class, suppose as `StaticFunction`.
# And `foo.__globals__['foo']` will still return this `StaticLayer` instead of # And `foo.__globals__['foo']` will still return this `StaticFunction` instead of
# `foo` function. So `isinstance(fn, StaticLayer)` is added here. # `foo` function. So `isinstance(fn, StaticFunction)` is added here.
global_functions = set() global_functions = set()
for fn in func.__globals__.values(): for fn in func.__globals__.values():
if inspect.isfunction(fn): if inspect.isfunction(fn):
global_functions.add(fn) global_functions.add(fn)
elif isinstance(fn, StaticLayer): elif isinstance(fn, StaticFunction):
_, fn = unwrap_decorators(fn) _, fn = unwrap_decorators(fn)
global_functions.add(fn) global_functions.add(fn)
......
...@@ -205,7 +205,7 @@ def unwrap_decorators(func): ...@@ -205,7 +205,7 @@ def unwrap_decorators(func):
decorators = [] decorators = []
cur = func cur = func
while True: while True:
if isinstance(cur, StaticLayer): if isinstance(cur, StaticFunction):
decorators.append(cur) decorators.append(cur)
# Note: if `cur` is a method, keep it as bound method of class. # Note: if `cur` is a method, keep it as bound method of class.
instance = cur._class_instance instance = cur._class_instance
...@@ -218,7 +218,7 @@ def unwrap_decorators(func): ...@@ -218,7 +218,7 @@ def unwrap_decorators(func):
return decorators, cur return decorators, cur
class StaticLayer(object): class StaticFunction(object):
""" """
Wrapper class to Manage program conversion of decorated function. Wrapper class to Manage program conversion of decorated function.
...@@ -226,7 +226,7 @@ class StaticLayer(object): ...@@ -226,7 +226,7 @@ class StaticLayer(object):
def __init__(self, function, input_spec=None): def __init__(self, function, input_spec=None):
""" """
Initializes a `StaticLayer`. Initializes a `StaticFunction`.
Args: Args:
function(callable): A function or method that will be converted into static program. function(callable): A function or method that will be converted into static program.
...@@ -268,12 +268,12 @@ class StaticLayer(object): ...@@ -268,12 +268,12 @@ class StaticLayer(object):
In above case, `net(x, y)` will call `net.forward(x, y)` firstly that is a bound method In above case, `net(x, y)` will call `net.forward(x, y)` firstly that is a bound method
of `Net` instance. After decorated by `@paddle.jit.to_static`, it will firstly to call `__get__` of `Net` instance. After decorated by `@paddle.jit.to_static`, it will firstly to call `__get__`
to parse the class instance correctly instead of the `StaticLayer` instance. to parse the class instance correctly instead of the `StaticFunction` instance.
""" """
if instance not in self._descriptor_cache: if instance not in self._descriptor_cache:
if instance is None: if instance is None:
return self return self
# Note(Aurelius84): To construct new instance of StaticLayer when we # Note(Aurelius84): To construct new instance of StaticFunction when we
# first encouter the bound function of layer and cache it. # first encouter the bound function of layer and cache it.
new_static_layer = self._clone() new_static_layer = self._clone()
new_static_layer._class_instance = instance new_static_layer._class_instance = instance
......
...@@ -28,7 +28,7 @@ from paddle.fluid.data_feeder import check_type ...@@ -28,7 +28,7 @@ from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import set_code_level, set_verbosity from paddle.fluid.dygraph.dygraph_to_static.logging_utils import set_code_level, set_verbosity
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator, StaticLayer, unwrap_decorators from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator, StaticFunction, unwrap_decorators
from paddle.fluid.dygraph.io import EXTRA_VAR_INFO_FILENAME, VARIABLE_FILENAME, TranslatedLayer from paddle.fluid.dygraph.io import EXTRA_VAR_INFO_FILENAME, VARIABLE_FILENAME, TranslatedLayer
from paddle.fluid.dygraph.layers import Layer from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.executor import Executor, scope_guard from paddle.fluid.executor import Executor, scope_guard
...@@ -141,7 +141,7 @@ def copy_decorator_attrs(original_func, decorated_obj): ...@@ -141,7 +141,7 @@ def copy_decorator_attrs(original_func, decorated_obj):
Args: Args:
original_func(callable): the original decorated function. original_func(callable): the original decorated function.
decorated_obj(StaticLayer): the target decorated StaticLayer object. decorated_obj(StaticFunction): the target decorated StaticFunction object.
""" """
decorator_name = "declarative" decorator_name = "declarative"
...@@ -198,7 +198,7 @@ def declarative(function=None, input_spec=None): ...@@ -198,7 +198,7 @@ def declarative(function=None, input_spec=None):
def decorated(python_func): def decorated(python_func):
""" """
Decorates a python function into a StaticLayer object. Decorates a python function into a StaticFunction object.
""" """
# Step 1. unwrap the function if it is already decorated. # Step 1. unwrap the function if it is already decorated.
_, python_func = unwrap_decorators(python_func) _, python_func = unwrap_decorators(python_func)
...@@ -206,7 +206,7 @@ def declarative(function=None, input_spec=None): ...@@ -206,7 +206,7 @@ def declarative(function=None, input_spec=None):
# Step 2. copy some attributes from original python function. # Step 2. copy some attributes from original python function.
static_layer = copy_decorator_attrs( static_layer = copy_decorator_attrs(
original_func=python_func, original_func=python_func,
decorated_obj=StaticLayer( decorated_obj=StaticFunction(
function=python_func, input_spec=input_spec)) function=python_func, input_spec=input_spec))
return static_layer return static_layer
...@@ -214,7 +214,7 @@ def declarative(function=None, input_spec=None): ...@@ -214,7 +214,7 @@ def declarative(function=None, input_spec=None):
# for usage: `declarative(foo, ...)` # for usage: `declarative(foo, ...)`
if function is not None: if function is not None:
if isinstance(function, Layer): if isinstance(function, Layer):
if isinstance(function.forward, StaticLayer): if isinstance(function.forward, StaticFunction):
class_name = function.__class__.__name__ class_name = function.__class__.__name__
logging_utils.warn( logging_utils.warn(
"`{}.forward` has already been decorated somewhere. It will be redecorated to replace previous one.". "`{}.forward` has already been decorated somewhere. It will be redecorated to replace previous one.".
...@@ -868,7 +868,7 @@ def save(layer, model_path, input_spec=None, config=None): ...@@ -868,7 +868,7 @@ def save(layer, model_path, input_spec=None, config=None):
# 2. get program from Layer # 2. get program from Layer
# TODO(chenweihang): add support for other method, not only forward # TODO(chenweihang): add support for other method, not only forward
if isinstance(layer.forward, StaticLayer): if isinstance(layer.forward, StaticFunction):
concrete_program = layer.forward.concrete_program concrete_program = layer.forward.concrete_program
else: else:
# transform in jit.save, if input_spec is incomplete, declarative will throw error # transform in jit.save, if input_spec is incomplete, declarative will throw error
......
...@@ -19,7 +19,7 @@ import paddle ...@@ -19,7 +19,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.static import InputSpec from paddle.static import InputSpec
from paddle.fluid.dygraph import to_variable, declarative, ProgramTranslator, Layer, jit from paddle.fluid.dygraph import to_variable, declarative, ProgramTranslator, Layer, jit
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ConcreteProgram, StaticLayer from paddle.fluid.dygraph.dygraph_to_static.program_translator import ConcreteProgram, StaticFunction
from test_basic_api_transformation import dyfunc_to_variable from test_basic_api_transformation import dyfunc_to_variable
...@@ -81,14 +81,14 @@ class SimpleNet(Layer): ...@@ -81,14 +81,14 @@ class SimpleNet(Layer):
return z return z
class TestStaticLayerInstance(unittest.TestCase): class TestStaticFunctionInstance(unittest.TestCase):
def test_instance_same_class(self): def test_instance_same_class(self):
with fluid.dygraph.guard(fluid.CPUPlace()): with fluid.dygraph.guard(fluid.CPUPlace()):
net_1 = SimpleNet() net_1 = SimpleNet()
net_2 = SimpleNet() net_2 = SimpleNet()
self.assertTrue(isinstance(net_1.forward, StaticLayer)) self.assertTrue(isinstance(net_1.forward, StaticFunction))
self.assertTrue(isinstance(net_2.forward, StaticLayer)) self.assertTrue(isinstance(net_2.forward, StaticFunction))
self.assertNotEqual(net_1.forward, net_2.forward) self.assertNotEqual(net_1.forward, net_2.forward)
# convert layer into static progam of net_1 # convert layer into static progam of net_1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册