未验证 提交 81c4def9 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Rename Dygraph To Static Decorators (#23880)

1. Rename Dygraph To Static Decorators to declarative
2. dygraph_to_static_func is still using in some training tests, I cannot delete it now.
3. Add some API docs
上级 752636f9
...@@ -43,10 +43,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static ...@@ -43,10 +43,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static
__all__ = ['DygraphToStaticAst', 'convert_to_static'] __all__ = ['DygraphToStaticAst', 'convert_to_static']
DECORATOR_NAMES = [ DECORATOR_NAMES = ['declarative', 'dygraph_to_static_func']
'dygraph_to_static_code', 'dygraph_to_static_program',
'dygraph_to_static_func', 'dygraph_to_static_output'
]
class DygraphToStaticAst(gast.NodeTransformer): class DygraphToStaticAst(gast.NodeTransformer):
......
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
from __future__ import print_function from __future__ import print_function
import gast import gast
import inspect import inspect
import logging
import numpy import numpy
import textwrap import textwrap
import threading import threading
import warnings
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid import core, executor from paddle.fluid import core, executor
...@@ -32,6 +32,8 @@ from paddle.fluid.data_feeder import check_type ...@@ -32,6 +32,8 @@ from paddle.fluid.data_feeder import check_type
__all__ = ['ProgramTranslator', 'convert_function_with_cache'] __all__ = ['ProgramTranslator', 'convert_function_with_cache']
logger = logging.getLogger("fluid")
class FunctionCache(object): class FunctionCache(object):
""" """
...@@ -235,6 +237,10 @@ class ProgramCache(object): ...@@ -235,6 +237,10 @@ class ProgramCache(object):
class ProgramTranslator(object): class ProgramTranslator(object):
"""
Class to translate dygraph function into static graph function.
"""
_singleton_lock = threading.Lock() _singleton_lock = threading.Lock()
_instance = None _instance = None
...@@ -274,16 +280,37 @@ class ProgramTranslator(object): ...@@ -274,16 +280,37 @@ class ProgramTranslator(object):
self._loss_name = None self._loss_name = None
# Once startup_program is changed, should run startup_program. # Once startup_program is changed, should run startup_program.
self._prev_startup = None self._prev_startup = None
self.enable_declarative = True
def enable_declarative_function(self, enable_declarative):
"""
Enable or disable the converting from imperative to declarative by
ProgramTranslator globally.
Args:
enable_declarative (bool): True or False to enable or disable declarative
"""
self.enable_declarative = enable_declarative
def get_output(self, dygraph_func, *args, **kwargs): def get_output(self, dygraph_func, *args, **kwargs):
""" """
Returns the output tensors for dygraph function and its arguments Returns the output dygraph VarBase for dygraph function. The dygraph
function will be translated into static graph function so the under
beneath numerical result will be calculated by declarative mode.
Args:
dygraph_func (callable): the dygraph function.
*args, **kwargs : the input argument of dygraph_func.
Returns:
VarBase or tuple of VarBase: the dygraph VarBase containing digital
result.
""" """
if in_dygraph_mode(): if in_dygraph_mode() or not self.enable_declarative:
warnings.warn( logger.info(
"The ProgramTranslator.get_output doesn't work in dygraph " "The ProgramTranslator.get_output doesn't work in dygraph "
"mode. We will just return dygraph output. Use it in " "mode or set enable_declarative_function to False. We will "
"static mode if you would like to translate to static graph.") "just return dygraph output.")
return dygraph_func(*args, **kwargs) return dygraph_func(*args, **kwargs)
program_cache = self.get_program_cache() program_cache = self.get_program_cache()
...@@ -292,33 +319,60 @@ class ProgramTranslator(object): ...@@ -292,33 +319,60 @@ class ProgramTranslator(object):
if not program_cache.in_build_process: if not program_cache.in_build_process:
outputs = self.run(*args, **kwargs) outputs = self.run(*args, **kwargs)
with guard(): with guard():
outputs = [to_variable(x) for x in outputs] if len(outputs) == 1:
outputs = to_variable(outputs[0])
else:
outputs = tuple(to_variable(x) for x in outputs)
return outputs return outputs
def get_func(self, dygraph_func): def get_func(self, dygraph_func):
""" """
Returns the translated static function from dygraph function Returns a callable function which converts imperative dygraph APIs of
the input dygraph_func into declarative net-building APIs, which means
it doesn't return immediate digital result as get_output does.
Users should handle Program and Executor by themselves.
Args:
dygraph_func (callable): the dygraph function.
Returns:
callable: converting imperative dygraph APIs into declarative
net-building APIs.
""" """
if in_dygraph_mode(): if in_dygraph_mode() or not self.enable_declarative:
warnings.warn( logger.info(
"The ProgramTranslator.get_func doesn't work in dygraph " "The ProgramTranslator.get_func doesn't work in dygraph "
"mode. We will just return dygraph function. Use it in " "mode or set enable_declarative_function to False. We will "
"static mode if you would like to translate to static graph.") "just return dygraph output.")
return dygraph_func return dygraph_func
static_func = convert_function_with_cache(dygraph_func) static_func = convert_function_with_cache(dygraph_func)
return static_func return static_func
def get_program(self, dygraph_func, *args, **kwargs): def get_program(self, dygraph_func, *args, **kwargs):
""" """
Returns the translated static program and input/output variables from Returns the translated static program and input/output variables from
dygraph function. dygraph function. The users can use the program to run by executor.
"""
if in_dygraph_mode(): Args:
warnings.warn( dygraph_func (callable): the dygraph function.
*args, **kwargs : the input argument of dygraph_func.
Returns:
tuple of (main_program, startup_program, inputs, outputs) whose
types are (Program, Program, list of Variable, list of Variable).
main_program: the converted main program.
startup_program: the converted startup program.
inputs: list of input Variables which need to be fed.
outputs: list of output Variables which users can fetch.
"""
if in_dygraph_mode() or not self.enable_declarative:
logger.info(
"The ProgramTranslator.get_program doesn't work in dygraph " "The ProgramTranslator.get_program doesn't work in dygraph "
"mode. We will just return dygraph output. Use it in static " "mode or set enable_declarative_function to False. We will "
"mode if you would like to translate to static graph.") "just return dygraph output.")
return dygraph_func(*args, **kwargs) return dygraph_func(*args, **kwargs)
program_cache = self.get_program_cache() program_cache = self.get_program_cache()
outputs = program_cache.build_program_and_return_output(dygraph_func, outputs = program_cache.build_program_and_return_output(dygraph_func,
*args, **kwargs) *args, **kwargs)
...@@ -326,7 +380,13 @@ class ProgramTranslator(object): ...@@ -326,7 +380,13 @@ class ProgramTranslator(object):
def get_code(self, dygraph_func): def get_code(self, dygraph_func):
""" """
Returns the translated static function code from dygraph code Returns the translated static function string code from dygraph function.
Args:
dygraph_func (callable): the dygraph function.
Returns:
str: the string code of translated static function
""" """
# Gets AST from dygraph function # Gets AST from dygraph function
raw_code = inspect.getsource(dygraph_func) raw_code = inspect.getsource(dygraph_func)
......
...@@ -14,12 +14,9 @@ ...@@ -14,12 +14,9 @@
from __future__ import print_function from __future__ import print_function
__all__ = [ __all__ = ['TracedLayer', 'declarative', 'dygraph_to_static_func']
'TracedLayer', 'dygraph_to_static_code', 'dygraph_to_static_func',
'dygraph_to_static_output', 'dygraph_to_static_program'
]
import warnings import logging
from ..wrapped_decorator import wrap_decorator from ..wrapped_decorator import wrap_decorator
from .base import program_desc_tracing_guard, switch_to_static_graph from .base import program_desc_tracing_guard, switch_to_static_graph
...@@ -30,6 +27,8 @@ from paddle.fluid.executor import Executor, scope_guard ...@@ -30,6 +27,8 @@ from paddle.fluid.executor import Executor, scope_guard
from paddle.fluid.compiler import CompiledProgram from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator
logger = logging.getLogger("fluid")
def create_program_from_desc(program_desc): def create_program_from_desc(program_desc):
program = Program() program = Program()
...@@ -54,45 +53,63 @@ def extract_vars(inputs): ...@@ -54,45 +53,63 @@ def extract_vars(inputs):
return result_list return result_list
def _dygraph_to_static_code_(dygraph_func): def _dygraph_to_static_func_(dygraph_func):
def __impl__(*args, **kwargs): """
program_translator = ProgramTranslator() Converts imperative dygraph APIs into declarative function APIs. Decorator
return program_translator.get_code(dygraph_func) @dygraph_to_static_func only converts imperative dygraph APIs into
declarative net-building APIs, which means it doesn't return immediate
digital result as imperative mode. Users should handle Program and Executor
by themselves.
return __impl__ Note:
This decorator is NOT our recommended way to transform imperative function
to declarative function. We will remove this decorator after we finalize
cleaning up code.
Args:
dygraph_func (callable): callable imperative function.
dygraph_to_static_code = wrap_decorator(_dygraph_to_static_code_) Returns:
Callable: converting imperative dygraph APIs into declarative
net-building APIs.
Examples:
.. code-block:: python
def _dygraph_to_static_program_(dygraph_func): import paddle.fluid as fluid
def __impl__(*args, **kwargs): import numpy as np
if in_dygraph_mode(): from paddle.fluid.dygraph.jit import dygraph_to_static_func
warnings.warn(
"The decorator 'dygraph_to_static_program' doesn't work in "
"dygraph mode. We will just return dygraph output. Use the "
"decorator in static mode if you would like to translate to "
"static graph.")
return dygraph_func(*args, **kwargs)
program_translator = ProgramTranslator()
return program_translator.get_program(dygraph_func, *args, **kwargs)
return __impl__ @dygraph_to_static_func
def func(x):
if fluid.layers.mean(x) < 0:
x_v = x - 1
else:
x_v = x + 1
return x_v
dygraph_to_static_program = wrap_decorator(_dygraph_to_static_program_) x = fluid.layers.fill_constant(shape=[3, 3], value=0, dtype='float64')
x_v = func(x)
exe = fluid.Executor(fluid.CPUPlace())
out = exe.run(fetch_list=[x_v])
print(out[0])
# [[1. 1. 1.]
# [1. 1. 1.]
# [1. 1. 1.]]
def _dygraph_to_static_func_(dygraph_func): """
# TODO: remove this decorator after we finalize training API
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
if in_dygraph_mode(): program_translator = ProgramTranslator()
warnings.warn( if in_dygraph_mode() or not program_translator.enable_declarative:
logger.info(
"The decorator 'dygraph_to_static_func' doesn't work in " "The decorator 'dygraph_to_static_func' doesn't work in "
"dygraph mode. We will just return dygraph output. Use the " "dygraph mode or set enable_declarative_function to False. "
"decorator in static mode if you would like to translate to " "We will just return dygraph output.")
"static graph.")
return dygraph_func(*args, **kwargs) return dygraph_func(*args, **kwargs)
program_translator = ProgramTranslator()
static_func = program_translator.get_func(dygraph_func) static_func = program_translator.get_func(dygraph_func)
return static_func(*args, **kwargs) return static_func(*args, **kwargs)
...@@ -102,14 +119,48 @@ def _dygraph_to_static_func_(dygraph_func): ...@@ -102,14 +119,48 @@ def _dygraph_to_static_func_(dygraph_func):
dygraph_to_static_func = wrap_decorator(_dygraph_to_static_func_) dygraph_to_static_func = wrap_decorator(_dygraph_to_static_func_)
def _dygraph_to_static_output_(dygraph_func): def _declarative_(dygraph_func):
"""
Converts imperative dygraph APIs into declarative function APIs. Decorator
@declarative handles the Program and Executor of static mode and returns
the result as a dygraph VarBase.
Args:
dygraph_func (callable): callable imperative function.
Returns:
VarBase: containing the numerical result.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
from paddle.fluid.dygraph.jit import declarative
@declarative
def func(x):
x = fluid.dygraph.to_variable(x)
if fluid.layers.mean(x) < 0:
x_v = x - 1
else:
x_v = x + 1
return x_v
x = np.ones([1, 2])
x_v = func(x)
print(x_v.numpy()) # [[2. 2.]]
"""
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
if in_dygraph_mode(): program_translator = ProgramTranslator()
warnings.warn( if in_dygraph_mode() or not program_translator.enable_declarative:
"The decorator 'dygraph_to_static_output' doesn't work in " logger.info(
"dygraph mode. We will just return dygraph output. Use the " "The decorator 'declarative' doesn't work in dygraph "
"decorator in static mode if you would like to translate to " "mode or set enable_declarative_function to False. We will "
"static graph.") "just return dygraph output.")
return dygraph_func(*args, **kwargs) return dygraph_func(*args, **kwargs)
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
return program_translator.get_output(dygraph_func, *args, **kwargs) return program_translator.get_output(dygraph_func, *args, **kwargs)
...@@ -117,7 +168,7 @@ def _dygraph_to_static_output_(dygraph_func): ...@@ -117,7 +168,7 @@ def _dygraph_to_static_output_(dygraph_func):
return __impl__ return __impl__
dygraph_to_static_output = wrap_decorator(_dygraph_to_static_output_) declarative = wrap_decorator(_declarative_)
@dygraph_only @dygraph_only
......
...@@ -20,7 +20,7 @@ from collections import Counter ...@@ -20,7 +20,7 @@ from collections import Counter
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import dygraph_to_static_output from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static import convert_function_with_cache from paddle.fluid.dygraph.dygraph_to_static import convert_function_with_cache
...@@ -50,10 +50,14 @@ class TestCacheProgram(unittest.TestCase): ...@@ -50,10 +50,14 @@ class TestCacheProgram(unittest.TestCase):
op.type for op in fluid.default_main_program().block(0).ops op.type for op in fluid.default_main_program().block(0).ops
]) ])
if batch_id > 0: if batch_id > 0:
prev_out_numpy = prev_out[0].numpy() if isinstance(
prev_out, tuple) else prev_out.numpy()
cur_out_numpy = cur_out[0].numpy() if isinstance(
cur_out, tuple) else cur_out.numpy()
self.assertTrue( self.assertTrue(
np.allclose(prev_out[0].numpy(), cur_out[0].numpy()), np.allclose(prev_out_numpy, cur_out_numpy),
msg='Output in previous batch is {}\n Output in current batch is \n{}' msg='Output in previous batch is {}\n Output in current batch is \n{}'
.format(prev_out, cur_out)) .format(prev_out_numpy, cur_out_numpy))
self.assertEqual(prev_ops, cur_ops) self.assertEqual(prev_ops, cur_ops)
...@@ -139,7 +143,7 @@ class TestConvertWithCache(unittest.TestCase): ...@@ -139,7 +143,7 @@ class TestConvertWithCache(unittest.TestCase):
self.assertTrue(id(static_func), id(cached_func)) self.assertTrue(id(static_func), id(cached_func))
@dygraph_to_static_output @declarative
def sum_even_util_limit(max_len, limit): def sum_even_util_limit(max_len, limit):
ret_sum = fluid.dygraph.to_variable(np.zeros((1)).astype('int32')) ret_sum = fluid.dygraph.to_variable(np.zeros((1)).astype('int32'))
for i in range(max_len): for i in range(max_len):
...@@ -152,7 +156,7 @@ def sum_even_util_limit(max_len, limit): ...@@ -152,7 +156,7 @@ def sum_even_util_limit(max_len, limit):
return ret_sum return ret_sum
@dygraph_to_static_output @declarative
def sum_under_while(limit): def sum_under_while(limit):
i = fluid.dygraph.to_variable(np.zeros((1)).astype('int32')) i = fluid.dygraph.to_variable(np.zeros((1)).astype('int32'))
ret_sum = fluid.dygraph.to_variable(np.zeros((1)).astype('int32')) ret_sum = fluid.dygraph.to_variable(np.zeros((1)).astype('int32'))
...@@ -165,10 +169,10 @@ def sum_under_while(limit): ...@@ -165,10 +169,10 @@ def sum_under_while(limit):
class TestToOutputWithCache(unittest.TestCase): class TestToOutputWithCache(unittest.TestCase):
def test_output(self): def test_output(self):
ret = sum_even_util_limit(80, 10) ret = sum_even_util_limit(80, 10)
self.assertEqual(ret[0].numpy(), 30) self.assertEqual(ret.numpy(), 30)
ret = sum_under_while(100) ret = sum_under_while(100)
self.assertEqual(ret[0].numpy(), 5050) self.assertEqual(ret.numpy(), 5050)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from __future__ import print_function from __future__ import print_function
from paddle.fluid.dygraph.jit import dygraph_to_static_output from paddle.fluid.dygraph.jit import declarative
import numpy as np import numpy as np
import unittest import unittest
...@@ -30,7 +30,7 @@ class Pool2D(fluid.dygraph.Layer): ...@@ -30,7 +30,7 @@ class Pool2D(fluid.dygraph.Layer):
self.pool2d = fluid.dygraph.Pool2D( self.pool2d = fluid.dygraph.Pool2D(
pool_size=2, pool_type='avg', pool_stride=1, global_pooling=False) pool_size=2, pool_type='avg', pool_stride=1, global_pooling=False)
@dygraph_to_static_output @declarative
def forward(self, x): def forward(self, x):
inputs = fluid.dygraph.to_variable(x) inputs = fluid.dygraph.to_variable(x)
...@@ -54,7 +54,7 @@ class Linear(fluid.dygraph.Layer): ...@@ -54,7 +54,7 @@ class Linear(fluid.dygraph.Layer):
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.5))) value=0.5)))
@dygraph_to_static_output @declarative
def forward(self, x): def forward(self, x):
inputs = fluid.dygraph.to_variable(x) inputs = fluid.dygraph.to_variable(x)
pre = self.fc(inputs) pre = self.fc(inputs)
...@@ -82,7 +82,9 @@ class TestPool2D(unittest.TestCase): ...@@ -82,7 +82,9 @@ class TestPool2D(unittest.TestCase):
with fluid.program_guard(main_prog, startup_prog): with fluid.program_guard(main_prog, startup_prog):
dy_layer = self.dygraph_class() dy_layer = self.dygraph_class()
out = dy_layer(x=self.data) out = dy_layer(x=self.data)
if isinstance(out, tuple):
return out[0].numpy() return out[0].numpy()
return out.numpy()
def test_static_output(self): def test_static_output(self):
dygraph_res = self.run_dygraph_mode() dygraph_res = self.run_dygraph_mode()
......
...@@ -17,15 +17,38 @@ from __future__ import print_function ...@@ -17,15 +17,38 @@ from __future__ import print_function
import astor import astor
import gast import gast
import inspect import inspect
import numpy as np
import textwrap import textwrap
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.jit import dygraph_to_static_code from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.nn import Linear
from ifelse_simple_func import dyfunc_with_if_else from ifelse_simple_func import dyfunc_with_if_else
np.random.seed(0)
def simple_func(x, weight_numpy):
weight_initalizer = fluid.initializer.NumpyArrayInitializer(weight_numpy)
linear = Linear(32, 64, param_attr=weight_initalizer, bias_attr=False)
x = fluid.dygraph.to_variable(x)
y = linear(x)
z = linear(x)
return z
@declarative
def decorated_simple_func(x, weight_numpy):
weight_initalizer = fluid.initializer.NumpyArrayInitializer(weight_numpy)
linear = Linear(32, 64, param_attr=weight_initalizer, bias_attr=False)
x = fluid.dygraph.to_variable(x)
y = linear(x)
z = linear(x)
return z
def get_source_code(func): def get_source_code(func):
raw_code = inspect.getsource(func) raw_code = inspect.getsource(func)
...@@ -81,8 +104,9 @@ class TestDygraphToStaticCode(unittest.TestCase): ...@@ -81,8 +104,9 @@ class TestDygraphToStaticCode(unittest.TestCase):
def test_decorator(self): def test_decorator(self):
x_v = None x_v = None
program_translator = ProgramTranslator()
code = program_translator.get_code(dyfunc_with_if_else)
answer = get_source_code(StaticCode1.dyfunc_with_if_else) answer = get_source_code(StaticCode1.dyfunc_with_if_else)
code = dygraph_to_static_code(dyfunc_with_if_else)(x_v)
self.assertEqual(answer, code) self.assertEqual(answer, code)
def test_program_translator(self): def test_program_translator(self):
...@@ -92,5 +116,80 @@ class TestDygraphToStaticCode(unittest.TestCase): ...@@ -92,5 +116,80 @@ class TestDygraphToStaticCode(unittest.TestCase):
self.assertEqual(answer, code) self.assertEqual(answer, code)
class TestEnableDeclarative(unittest.TestCase):
def test_enable_disable_get_output(self):
x = np.random.randn(30, 10, 32).astype('float32')
weight = np.random.randn(32, 64).astype('float32')
program_translator = ProgramTranslator()
with fluid.program_guard(fluid.Program(), fluid.Program()):
program_translator.enable_declarative_function(True)
static_output = program_translator.get_output(simple_func, x,
weight)
program_translator.enable_declarative_function(False)
with fluid.dygraph.guard():
dygraph_output = program_translator.get_output(simple_func, x,
weight)
self.assertTrue(
np.allclose(
static_output.numpy(), dygraph_output.numpy(), atol=1e-4))
def test_enable_disable_get_func(self):
x = np.random.randn(30, 10, 32).astype('float32')
weight = np.random.randn(32, 64).astype('float32')
program_translator = ProgramTranslator()
with fluid.program_guard(fluid.Program(), fluid.Program()):
program_translator.enable_declarative_function(True)
static_func = program_translator.get_func(simple_func)
self.assertTrue(callable(static_func))
static_output = static_func(x, weight)
self.assertTrue(isinstance(static_output, fluid.Variable))
program_translator.enable_declarative_function(False)
with fluid.dygraph.guard():
dygraph_func = program_translator.get_func(simple_func)
self.assertTrue(callable(dygraph_func))
dygraph_output = dygraph_func(x, weight)
self.assertTrue(isinstance(dygraph_output, fluid.core.VarBase))
def test_enable_disable_get_program(self):
x = np.random.randn(30, 10, 32).astype('float32')
weight = np.random.randn(32, 64).astype('float32')
program_translator = ProgramTranslator()
with fluid.program_guard(fluid.Program(), fluid.Program()):
program_translator.enable_declarative_function(True)
static_output = program_translator.get_program(simple_func, x,
weight)
self.assertTrue(isinstance(static_output, tuple))
self.assertEqual(len(static_output), 4)
self.assertTrue(isinstance(static_output[0], fluid.Program))
self.assertTrue(isinstance(static_output[1], fluid.Program))
program_translator.enable_declarative_function(False)
with fluid.dygraph.guard():
dygraph_output = program_translator.get_program(simple_func, x,
weight)
self.assertTrue(isinstance(dygraph_output, fluid.core.VarBase))
def test_enable_disable_declarative(self):
x = np.random.randn(30, 10, 32).astype('float32')
weight = np.random.randn(32, 64).astype('float32')
program_translator = ProgramTranslator()
with fluid.program_guard(fluid.Program(), fluid.Program()):
program_translator.enable_declarative_function(True)
static_output = decorated_simple_func(x, weight)
program_translator.enable_declarative_function(False)
with fluid.dygraph.guard():
dygraph_output = decorated_simple_func(x, weight)
self.assertTrue(
np.allclose(
static_output.numpy(), dygraph_output.numpy(), atol=1e-4))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -21,7 +21,7 @@ import numpy as np ...@@ -21,7 +21,7 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.jit import dygraph_to_static_output from paddle.fluid.dygraph.jit import declarative
np.random.seed(2020) np.random.seed(2020)
...@@ -34,7 +34,7 @@ class SimpleFcLayer(fluid.dygraph.Layer): ...@@ -34,7 +34,7 @@ class SimpleFcLayer(fluid.dygraph.Layer):
super(SimpleFcLayer, self).__init__() super(SimpleFcLayer, self).__init__()
self._linear = fluid.dygraph.Linear(fc_size, fc_size) self._linear = fluid.dygraph.Linear(fc_size, fc_size)
@dygraph_to_static_output @declarative
def forward(self, x): def forward(self, x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
y = self._linear(x) y = self._linear(x)
......
...@@ -20,7 +20,7 @@ import numpy as np ...@@ -20,7 +20,7 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
from paddle.fluid.dygraph.jit import dygraph_to_static_program from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.nn import Linear from paddle.fluid.dygraph.nn import Linear
np.random.seed(2020) np.random.seed(2020)
...@@ -38,7 +38,6 @@ def simple_func(x, weight_numpy): ...@@ -38,7 +38,6 @@ def simple_func(x, weight_numpy):
return z return z
@dygraph_to_static_program
def decorated_simple_func(x, weight_numpy): def decorated_simple_func(x, weight_numpy):
weight_initalizer = fluid.initializer.NumpyArrayInitializer(weight_numpy) weight_initalizer = fluid.initializer.NumpyArrayInitializer(weight_numpy)
linear = Linear(32, 64, param_attr=weight_initalizer) linear = Linear(32, 64, param_attr=weight_initalizer)
...@@ -55,8 +54,8 @@ class TestDyToStaticSaveLoad(unittest.TestCase): ...@@ -55,8 +54,8 @@ class TestDyToStaticSaveLoad(unittest.TestCase):
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
dygraph_result = simple_func(x, weight) dygraph_result = simple_func(x, weight)
main_program, startup_program, inputs, outputs = decorated_simple_func( main_program, startup_program, inputs, outputs = ProgramTranslator(
x, weight) ).get_program(decorated_simple_func, x, weight)
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_program) exe.run(startup_program)
fluid.save(main_program, "./test_dy2stat_save_load") fluid.save(main_program, "./test_dy2stat_save_load")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册