diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index bb2e1ee0f3f3d4668f5b56b16d26ef336fe00d71..73eb8ec3a987a8fb8f33c389da6d2c4e3da410ea 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -25,7 +25,7 @@ from .static_analysis import AstNodeWrapper, StaticAnalysisVisitor __all__ = ['DygraphToStaticAst'] -DECORATOR_NAME = 'dygraph_to_static_output' +DECORATOR_NAMES = ['dygraph_to_static_output', 'dygraph_to_static_graph'] class IfElseTransformer(gast.NodeTransformer): @@ -100,6 +100,7 @@ class DygraphToStaticAst(gast.NodeTransformer): self.static_analysis_root = StaticAnalysisVisitor( root).get_node_wrapper_root() self.decorate_func_name = None + self.arg_name_to_idx = {} self.transfer_from_node_type(self.static_analysis_root) return self.static_analysis_root @@ -118,16 +119,14 @@ class DygraphToStaticAst(gast.NodeTransformer): def visit_FunctionDef(self, node): if self.decorate_func_name is None: self.decorate_func_name = node.name - - self.arg_name_to_idx = {} - for idx, arg in enumerate(node.args.args): - self.arg_name_to_idx[arg.id] = idx + for idx, arg in enumerate(node.args.args): + self.arg_name_to_idx[arg.id] = idx self.generic_visit(node) # Remove the decorated name of dygraph_to_static if hasattr(node, 'decorator_list'): decorator_list = [ - d for d in node.decorator_list if d.id != DECORATOR_NAME + d for d in node.decorator_list if d.id not in DECORATOR_NAMES ] node.decorator_list = decorator_list return node @@ -170,7 +169,7 @@ class BasicApiTransformer(gast.NodeTransformer): self.generic_visit(node) if hasattr(node, 'decorator_list'): decorator_list = [ - d for d in node.decorator_list if d.id != DECORATOR_NAME + d for d in node.decorator_list if d.id not in DECORATOR_NAMES ] node.decorator_list = decorator_list return node diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index db0fe62ee606d878353ff2454991ffb19e67d4f5..423c6702881f251ff825ec6290fd2841b700ab88 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -14,11 +14,12 @@ from __future__ import print_function -__all__ = ['TracedLayer', 'dygraph_to_static_output'] +__all__ = ['TracedLayer', 'dygraph_to_static_output', 'dygraph_to_static_graph'] import gast import inspect import textwrap +import warnings from ..wrapped_decorator import wrap_decorator from .base import program_desc_tracing_guard, switch_to_static_graph @@ -29,7 +30,7 @@ from paddle.fluid import core from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode from paddle.fluid.executor import Executor, scope_guard from paddle.fluid.compiler import CompiledProgram -from paddle.fluid import program_guard, data +from paddle.fluid import program_guard, data, default_startup_program, default_main_program def create_program_from_desc(program_desc): @@ -55,43 +56,60 @@ def extract_vars(inputs): return result_list -def _dygraph_to_static_output_(dygraph_func): - def __impl__(*args, **kwargs): +def to_static_func(dygraph_func): + # Get AST from dygraph function + dygraph_code = inspect.getsource(dygraph_func) + dygraph_code = textwrap.dedent(dygraph_code) + root = gast.parse(dygraph_code) - # Get AST from dygraph function - dygraph_code = inspect.getsource(dygraph_func) - dygraph_code = textwrap.dedent(dygraph_code) - root = gast.parse(dygraph_code) + # Transform AST + dygraph_to_static = DygraphToStaticAst() + root_wrapper = dygraph_to_static.get_static_ast(root) - # Transform AST - dygraph_to_static = DygraphToStaticAst() - root_wrapper = dygraph_to_static.get_static_ast(root) + # Get static_func from AST + func_name = dygraph_to_static.get_module_name() + static_func, file_name = ast_to_func(root_wrapper.node, func_name) - # Get static_func from AST - func_name = dygraph_to_static.get_module_name() - static_func, file_name = ast_to_func(root_wrapper.node, func_name) + return static_func, dygraph_to_static + + +def _dygraph_to_static_graph_(dygraph_func): + def __impl__(*args, **kwargs): + if in_dygraph_mode(): + warnings.warn( + "The decorator 'dygraph_to_static_graph' doesn't work in dygraph mode." + " Please use it in static mode.") + return dygraph_func(*args, **kwargs) + static_func, dygraph_to_static = to_static_func(dygraph_func) + return static_func(*args, **kwargs) + + return __impl__ - if not in_dygraph_mode(): - return static_func(*args, **kwargs) - else: - feed_name_to_idx = dygraph_to_static.get_feed_name_to_idx() - feed_dict = {} - for feed_name, idx in feed_name_to_idx.items(): - feed_dict[feed_name] = args[idx] - - # Run static_func in static mode - startup_program = Program() - main_program = Program() - static_res = run_static_func(main_program, startup_program, - static_func, args, kwargs, feed_dict, - feed_name_to_idx) +def _dygraph_to_static_output_(dygraph_func): + def __impl__(*args, **kwargs): + if in_dygraph_mode(): + warnings.warn( + "The decorator 'dygraph_to_static_output' doesn't work in dygraph mode." + " Please use it in static mode.") + return dygraph_func(*args, **kwargs) + + static_func, dygraph_to_static = to_static_func(dygraph_func) + feed_name_to_idx = dygraph_to_static.get_feed_name_to_idx() + feed_dict = {} + for feed_name, idx in feed_name_to_idx.items(): + feed_dict[feed_name] = args[idx] + + # Run static_func in static mode + startup_program = default_main_program() + main_program = default_startup_program() + static_res = run_static_func(main_program, startup_program, static_func, + args, kwargs, feed_dict, feed_name_to_idx) return static_res return __impl__ -@switch_to_static_graph def run_static_func(main_program, startup_program, static_func, args, kwargs, feed_dict, feed_name_to_idx): @@ -114,6 +132,7 @@ def run_static_func(main_program, startup_program, static_func, args, kwargs, dygraph_to_static_output = wrap_decorator(_dygraph_to_static_output_) +dygraph_to_static_graph = wrap_decorator(_dygraph_to_static_graph_) @dygraph_only diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py index b4f96d81c4b37b8fc52ee248ad82b472e4568495..0c5ffba32a53f57e5f35c6d295bcee465a016eb8 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py @@ -33,17 +33,19 @@ class Pool2D(fluid.dygraph.Layer): @dygraph_to_static_output def forward(self, x): inputs = fluid.dygraph.to_variable(x) - pre = self.pool2d(inputs) + + # Add func `get_result` for testing arg_name_to_idx in ast transformation. + def get_result(x): + return self.pool2d(x) + + pre = get_result(inputs) return pre class Linear(fluid.dygraph.Layer): def __init__(self): super(Linear, self).__init__() - - @dygraph_to_static_output - def forward(self, x): - fc = fluid.dygraph.Linear( + self.fc = fluid.dygraph.Linear( input_dim=10, output_dim=5, act='relu', @@ -51,8 +53,11 @@ class Linear(fluid.dygraph.Layer): value=0.99)), bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( value=0.5))) + + @dygraph_to_static_output + def forward(self, x): inputs = fluid.dygraph.to_variable(x) - pre = fc(inputs) + pre = self.fc(inputs) return pre @@ -67,7 +72,7 @@ class TestPool2D(unittest.TestCase): for _ in range(1): prediction = dy_layer(x=self.data) - return prediction + return prediction.numpy() def run_static_mode(self): startup_prog = fluid.Program() @@ -75,22 +80,20 @@ class TestPool2D(unittest.TestCase): with fluid.program_guard(main_prog, startup_prog): dy_layer = self.dygraph_class() out = dy_layer(x=self.data) - place = fluid.CPUPlace() - exe = fluid.Executor(place) - res = exe.run(main_prog, fetch_list=out) - return res + return out[0] def test_static_output(self): dygraph_res = self.run_dygraph_mode() static_res = self.run_static_mode() + self.assertTrue( - np.allclose(dygraph_res[0], static_res[0]), - msg='dygraph is {}\n static_res is \n{}'.format(dygraph_res, - static_res)) + np.allclose(dygraph_res, static_res), + msg='dygraph_res is {}\n static_res is \n{}'.format(dygraph_res, + static_res)) return -class TestLinear(unittest.TestCase): +class TestLinear(TestPool2D): def setUp(self): self.dygraph_class = Linear self.data = np.random.random((4, 10)).astype('float32') diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mnist.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mnist.py index 5dc9ce45bcf5810feadb000bb0b8ac7c91f44bcf..4c1609a0dd45c98a56323960ad624d9f0771d051 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mnist.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mnist.py @@ -20,7 +20,7 @@ import paddle.fluid as fluid from paddle.fluid.optimizer import AdamOptimizer from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear -from paddle.fluid.dygraph.jit import dygraph_to_static_output +from paddle.fluid.dygraph.jit import dygraph_to_static_graph import unittest @@ -66,7 +66,7 @@ class SimpleImgConvPool(fluid.dygraph.Layer): global_pooling=global_pooling, use_cudnn=use_cudnn) - @dygraph_to_static_output + @dygraph_to_static_graph def forward(self, inputs): x = self._conv2d(inputs) x = self._pool2d(x) @@ -94,7 +94,7 @@ class MNIST(fluid.dygraph.Layer): loc=0.0, scale=scale)), act="softmax") - @dygraph_to_static_output + @dygraph_to_static_graph def forward(self, inputs, label=None): x = self.inference(inputs) if label is not None: @@ -105,7 +105,7 @@ class MNIST(fluid.dygraph.Layer): else: return x - @dygraph_to_static_output + @dygraph_to_static_graph def inference(self, inputs): x = self._simple_img_conv_pool_1(inputs) x = self._simple_img_conv_pool_2(x) @@ -128,7 +128,7 @@ class TestMNIST(unittest.TestCase): class TestMNISTWithStaticMode(TestMNIST): """ - Tests model when using `dygraph_to_static_output` to convert dygraph into static + Tests model when using `dygraph_to_static_graph` to convert dygraph into static model. It allows user to add customized code to train static model, such as `with` and `Executor` statement. """ diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic.py b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic.py index a74cd10be7d3da2e2952e203d0356066896d701f..0ef58d10c62344ec1f68fc5970daab6a82855c93 100644 --- a/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic.py @@ -18,7 +18,7 @@ import numpy as np import paddle.fluid as fluid import unittest -from paddle.fluid.dygraph.jit import dygraph_to_static_output +from paddle.fluid.dygraph.jit import dygraph_to_static_graph np.random.seed(1) @@ -92,7 +92,7 @@ class TestDygraphIfElse(unittest.TestCase): with fluid.program_guard(main_program): x_v = fluid.layers.assign(self.x) # Transform into static graph - out = dygraph_to_static_output(self.dyfunc)(x_v) + out = dygraph_to_static_graph(self.dyfunc)(x_v) exe = fluid.Executor(place) ret = exe.run(main_program, fetch_list=out) return ret diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic_api_transformation.py b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic_api_transformation.py index 55895507c2c39779b8f2ea3feee7a27186801cde..e60cca70b479443e20bfe35b1bcbeaf7cc3ed9ab 100644 --- a/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic_api_transformation.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic_api_transformation.py @@ -20,7 +20,7 @@ import unittest import inspect import gast -from paddle.fluid.dygraph.jit import dygraph_to_static_output +from paddle.fluid.dygraph.jit import dygraph_to_static_graph from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api SEED = 2020 @@ -47,7 +47,7 @@ class TestDygraphBasicApi_ToVariable(unittest.TestCase): main_program = fluid.Program() main_program.random_seed = SEED with fluid.program_guard(main_program): - static_out = dygraph_to_static_output(self.dygraph_func)(self.input) + static_out = dygraph_to_static_graph(self.dygraph_func)(self.input) exe = fluid.Executor(fluid.CPUPlace()) static_res = exe.run(main_program, fetch_list=static_out) @@ -190,7 +190,7 @@ class TestDygraphBasicApi(unittest.TestCase): main_program.random_seed = SEED with fluid.program_guard(main_program, startup_program): data = fluid.layers.assign(self.input) - static_out = dygraph_to_static_output(self.dygraph_func)(data) + static_out = dygraph_to_static_graph(self.dygraph_func)(data) exe = fluid.Executor(fluid.CPUPlace()) exe.run(startup_program) @@ -225,8 +225,8 @@ class TestDygraphBasicApi_BilinearTensorProduct(TestDygraphBasicApi): main_program = fluid.Program() main_program.random_seed = SEED with fluid.program_guard(main_program, startup_program): - static_out = dygraph_to_static_output(self.dygraph_func)( - self.input1, self.input2) + static_out = dygraph_to_static_graph(self.dygraph_func)(self.input1, + self.input2) exe = fluid.Executor(fluid.CPUPlace()) exe.run(startup_program) @@ -352,7 +352,7 @@ class TestDygraphBasicApi_CosineDecay(unittest.TestCase): main_program = fluid.Program() main_program.random_seed = SEED with fluid.program_guard(main_program, startup_program): - static_out = dygraph_to_static_output(self.dygraph_func)() + static_out = dygraph_to_static_graph(self.dygraph_func)() exe = fluid.Executor(fluid.CPUPlace()) exe.run(startup_program)