未验证 提交 68a92e46 编写于 作者: L liym27 提交者: GitHub

fix dygraph_to_static_ouput and add a new decorator. (#22785)

* change dygraph_to_static_output to dygraph_to_static_graph. test=develop

* Remove duplicate code. test=develop

* Follow comments from Liujie. test=develop

* change dygraph_to_static_output to dygraph_to_static_graph. test=develop
上级 5dbafe38
...@@ -25,7 +25,7 @@ from .static_analysis import AstNodeWrapper, StaticAnalysisVisitor ...@@ -25,7 +25,7 @@ from .static_analysis import AstNodeWrapper, StaticAnalysisVisitor
__all__ = ['DygraphToStaticAst'] __all__ = ['DygraphToStaticAst']
DECORATOR_NAME = 'dygraph_to_static_output' DECORATOR_NAMES = ['dygraph_to_static_output', 'dygraph_to_static_graph']
class IfElseTransformer(gast.NodeTransformer): class IfElseTransformer(gast.NodeTransformer):
...@@ -100,6 +100,7 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -100,6 +100,7 @@ class DygraphToStaticAst(gast.NodeTransformer):
self.static_analysis_root = StaticAnalysisVisitor( self.static_analysis_root = StaticAnalysisVisitor(
root).get_node_wrapper_root() root).get_node_wrapper_root()
self.decorate_func_name = None self.decorate_func_name = None
self.arg_name_to_idx = {}
self.transfer_from_node_type(self.static_analysis_root) self.transfer_from_node_type(self.static_analysis_root)
return self.static_analysis_root return self.static_analysis_root
...@@ -118,16 +119,14 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -118,16 +119,14 @@ class DygraphToStaticAst(gast.NodeTransformer):
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
if self.decorate_func_name is None: if self.decorate_func_name is None:
self.decorate_func_name = node.name self.decorate_func_name = node.name
for idx, arg in enumerate(node.args.args):
self.arg_name_to_idx = {} self.arg_name_to_idx[arg.id] = idx
for idx, arg in enumerate(node.args.args):
self.arg_name_to_idx[arg.id] = idx
self.generic_visit(node) self.generic_visit(node)
# Remove the decorated name of dygraph_to_static # Remove the decorated name of dygraph_to_static
if hasattr(node, 'decorator_list'): if hasattr(node, 'decorator_list'):
decorator_list = [ decorator_list = [
d for d in node.decorator_list if d.id != DECORATOR_NAME d for d in node.decorator_list if d.id not in DECORATOR_NAMES
] ]
node.decorator_list = decorator_list node.decorator_list = decorator_list
return node return node
...@@ -170,7 +169,7 @@ class BasicApiTransformer(gast.NodeTransformer): ...@@ -170,7 +169,7 @@ class BasicApiTransformer(gast.NodeTransformer):
self.generic_visit(node) self.generic_visit(node)
if hasattr(node, 'decorator_list'): if hasattr(node, 'decorator_list'):
decorator_list = [ decorator_list = [
d for d in node.decorator_list if d.id != DECORATOR_NAME d for d in node.decorator_list if d.id not in DECORATOR_NAMES
] ]
node.decorator_list = decorator_list node.decorator_list = decorator_list
return node return node
......
...@@ -14,11 +14,12 @@ ...@@ -14,11 +14,12 @@
from __future__ import print_function from __future__ import print_function
__all__ = ['TracedLayer', 'dygraph_to_static_output'] __all__ = ['TracedLayer', 'dygraph_to_static_output', 'dygraph_to_static_graph']
import gast import gast
import inspect import inspect
import textwrap import textwrap
import warnings
from ..wrapped_decorator import wrap_decorator from ..wrapped_decorator import wrap_decorator
from .base import program_desc_tracing_guard, switch_to_static_graph from .base import program_desc_tracing_guard, switch_to_static_graph
...@@ -29,7 +30,7 @@ from paddle.fluid import core ...@@ -29,7 +30,7 @@ from paddle.fluid import core
from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode
from paddle.fluid.executor import Executor, scope_guard from paddle.fluid.executor import Executor, scope_guard
from paddle.fluid.compiler import CompiledProgram from paddle.fluid.compiler import CompiledProgram
from paddle.fluid import program_guard, data from paddle.fluid import program_guard, data, default_startup_program, default_main_program
def create_program_from_desc(program_desc): def create_program_from_desc(program_desc):
...@@ -55,43 +56,60 @@ def extract_vars(inputs): ...@@ -55,43 +56,60 @@ def extract_vars(inputs):
return result_list return result_list
def _dygraph_to_static_output_(dygraph_func): def to_static_func(dygraph_func):
def __impl__(*args, **kwargs): # Get AST from dygraph function
dygraph_code = inspect.getsource(dygraph_func)
dygraph_code = textwrap.dedent(dygraph_code)
root = gast.parse(dygraph_code)
# Get AST from dygraph function # Transform AST
dygraph_code = inspect.getsource(dygraph_func) dygraph_to_static = DygraphToStaticAst()
dygraph_code = textwrap.dedent(dygraph_code) root_wrapper = dygraph_to_static.get_static_ast(root)
root = gast.parse(dygraph_code)
# Transform AST # Get static_func from AST
dygraph_to_static = DygraphToStaticAst() func_name = dygraph_to_static.get_module_name()
root_wrapper = dygraph_to_static.get_static_ast(root) static_func, file_name = ast_to_func(root_wrapper.node, func_name)
# Get static_func from AST return static_func, dygraph_to_static
func_name = dygraph_to_static.get_module_name()
static_func, file_name = ast_to_func(root_wrapper.node, func_name)
def _dygraph_to_static_graph_(dygraph_func):
def __impl__(*args, **kwargs):
if in_dygraph_mode():
warnings.warn(
"The decorator 'dygraph_to_static_graph' doesn't work in dygraph mode."
" Please use it in static mode.")
return dygraph_func(*args, **kwargs)
static_func, dygraph_to_static = to_static_func(dygraph_func)
return static_func(*args, **kwargs)
return __impl__
if not in_dygraph_mode():
return static_func(*args, **kwargs)
else:
feed_name_to_idx = dygraph_to_static.get_feed_name_to_idx()
feed_dict = {}
for feed_name, idx in feed_name_to_idx.items():
feed_dict[feed_name] = args[idx]
# Run static_func in static mode
startup_program = Program()
main_program = Program()
static_res = run_static_func(main_program, startup_program,
static_func, args, kwargs, feed_dict,
feed_name_to_idx)
def _dygraph_to_static_output_(dygraph_func):
def __impl__(*args, **kwargs):
if in_dygraph_mode():
warnings.warn(
"The decorator 'dygraph_to_static_output' doesn't work in dygraph mode."
" Please use it in static mode.")
return dygraph_func(*args, **kwargs)
static_func, dygraph_to_static = to_static_func(dygraph_func)
feed_name_to_idx = dygraph_to_static.get_feed_name_to_idx()
feed_dict = {}
for feed_name, idx in feed_name_to_idx.items():
feed_dict[feed_name] = args[idx]
# Run static_func in static mode
startup_program = default_main_program()
main_program = default_startup_program()
static_res = run_static_func(main_program, startup_program, static_func,
args, kwargs, feed_dict, feed_name_to_idx)
return static_res return static_res
return __impl__ return __impl__
@switch_to_static_graph
def run_static_func(main_program, startup_program, static_func, args, kwargs, def run_static_func(main_program, startup_program, static_func, args, kwargs,
feed_dict, feed_name_to_idx): feed_dict, feed_name_to_idx):
...@@ -114,6 +132,7 @@ def run_static_func(main_program, startup_program, static_func, args, kwargs, ...@@ -114,6 +132,7 @@ def run_static_func(main_program, startup_program, static_func, args, kwargs,
dygraph_to_static_output = wrap_decorator(_dygraph_to_static_output_) dygraph_to_static_output = wrap_decorator(_dygraph_to_static_output_)
dygraph_to_static_graph = wrap_decorator(_dygraph_to_static_graph_)
@dygraph_only @dygraph_only
......
...@@ -33,17 +33,19 @@ class Pool2D(fluid.dygraph.Layer): ...@@ -33,17 +33,19 @@ class Pool2D(fluid.dygraph.Layer):
@dygraph_to_static_output @dygraph_to_static_output
def forward(self, x): def forward(self, x):
inputs = fluid.dygraph.to_variable(x) inputs = fluid.dygraph.to_variable(x)
pre = self.pool2d(inputs)
# Add func `get_result` for testing arg_name_to_idx in ast transformation.
def get_result(x):
return self.pool2d(x)
pre = get_result(inputs)
return pre return pre
class Linear(fluid.dygraph.Layer): class Linear(fluid.dygraph.Layer):
def __init__(self): def __init__(self):
super(Linear, self).__init__() super(Linear, self).__init__()
self.fc = fluid.dygraph.Linear(
@dygraph_to_static_output
def forward(self, x):
fc = fluid.dygraph.Linear(
input_dim=10, input_dim=10,
output_dim=5, output_dim=5,
act='relu', act='relu',
...@@ -51,8 +53,11 @@ class Linear(fluid.dygraph.Layer): ...@@ -51,8 +53,11 @@ class Linear(fluid.dygraph.Layer):
value=0.99)), value=0.99)),
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.5))) value=0.5)))
@dygraph_to_static_output
def forward(self, x):
inputs = fluid.dygraph.to_variable(x) inputs = fluid.dygraph.to_variable(x)
pre = fc(inputs) pre = self.fc(inputs)
return pre return pre
...@@ -67,7 +72,7 @@ class TestPool2D(unittest.TestCase): ...@@ -67,7 +72,7 @@ class TestPool2D(unittest.TestCase):
for _ in range(1): for _ in range(1):
prediction = dy_layer(x=self.data) prediction = dy_layer(x=self.data)
return prediction return prediction.numpy()
def run_static_mode(self): def run_static_mode(self):
startup_prog = fluid.Program() startup_prog = fluid.Program()
...@@ -75,22 +80,20 @@ class TestPool2D(unittest.TestCase): ...@@ -75,22 +80,20 @@ class TestPool2D(unittest.TestCase):
with fluid.program_guard(main_prog, startup_prog): with fluid.program_guard(main_prog, startup_prog):
dy_layer = self.dygraph_class() dy_layer = self.dygraph_class()
out = dy_layer(x=self.data) out = dy_layer(x=self.data)
place = fluid.CPUPlace() return out[0]
exe = fluid.Executor(place)
res = exe.run(main_prog, fetch_list=out)
return res
def test_static_output(self): def test_static_output(self):
dygraph_res = self.run_dygraph_mode() dygraph_res = self.run_dygraph_mode()
static_res = self.run_static_mode() static_res = self.run_static_mode()
self.assertTrue( self.assertTrue(
np.allclose(dygraph_res[0], static_res[0]), np.allclose(dygraph_res, static_res),
msg='dygraph is {}\n static_res is \n{}'.format(dygraph_res, msg='dygraph_res is {}\n static_res is \n{}'.format(dygraph_res,
static_res)) static_res))
return return
class TestLinear(unittest.TestCase): class TestLinear(TestPool2D):
def setUp(self): def setUp(self):
self.dygraph_class = Linear self.dygraph_class = Linear
self.data = np.random.random((4, 10)).astype('float32') self.data = np.random.random((4, 10)).astype('float32')
......
...@@ -20,7 +20,7 @@ import paddle.fluid as fluid ...@@ -20,7 +20,7 @@ import paddle.fluid as fluid
from paddle.fluid.optimizer import AdamOptimizer from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.dygraph.jit import dygraph_to_static_output from paddle.fluid.dygraph.jit import dygraph_to_static_graph
import unittest import unittest
...@@ -66,7 +66,7 @@ class SimpleImgConvPool(fluid.dygraph.Layer): ...@@ -66,7 +66,7 @@ class SimpleImgConvPool(fluid.dygraph.Layer):
global_pooling=global_pooling, global_pooling=global_pooling,
use_cudnn=use_cudnn) use_cudnn=use_cudnn)
@dygraph_to_static_output @dygraph_to_static_graph
def forward(self, inputs): def forward(self, inputs):
x = self._conv2d(inputs) x = self._conv2d(inputs)
x = self._pool2d(x) x = self._pool2d(x)
...@@ -94,7 +94,7 @@ class MNIST(fluid.dygraph.Layer): ...@@ -94,7 +94,7 @@ class MNIST(fluid.dygraph.Layer):
loc=0.0, scale=scale)), loc=0.0, scale=scale)),
act="softmax") act="softmax")
@dygraph_to_static_output @dygraph_to_static_graph
def forward(self, inputs, label=None): def forward(self, inputs, label=None):
x = self.inference(inputs) x = self.inference(inputs)
if label is not None: if label is not None:
...@@ -105,7 +105,7 @@ class MNIST(fluid.dygraph.Layer): ...@@ -105,7 +105,7 @@ class MNIST(fluid.dygraph.Layer):
else: else:
return x return x
@dygraph_to_static_output @dygraph_to_static_graph
def inference(self, inputs): def inference(self, inputs):
x = self._simple_img_conv_pool_1(inputs) x = self._simple_img_conv_pool_1(inputs)
x = self._simple_img_conv_pool_2(x) x = self._simple_img_conv_pool_2(x)
...@@ -128,7 +128,7 @@ class TestMNIST(unittest.TestCase): ...@@ -128,7 +128,7 @@ class TestMNIST(unittest.TestCase):
class TestMNISTWithStaticMode(TestMNIST): class TestMNISTWithStaticMode(TestMNIST):
""" """
Tests model when using `dygraph_to_static_output` to convert dygraph into static Tests model when using `dygraph_to_static_graph` to convert dygraph into static
model. It allows user to add customized code to train static model, such as `with` model. It allows user to add customized code to train static model, such as `with`
and `Executor` statement. and `Executor` statement.
""" """
......
...@@ -18,7 +18,7 @@ import numpy as np ...@@ -18,7 +18,7 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import unittest import unittest
from paddle.fluid.dygraph.jit import dygraph_to_static_output from paddle.fluid.dygraph.jit import dygraph_to_static_graph
np.random.seed(1) np.random.seed(1)
...@@ -92,7 +92,7 @@ class TestDygraphIfElse(unittest.TestCase): ...@@ -92,7 +92,7 @@ class TestDygraphIfElse(unittest.TestCase):
with fluid.program_guard(main_program): with fluid.program_guard(main_program):
x_v = fluid.layers.assign(self.x) x_v = fluid.layers.assign(self.x)
# Transform into static graph # Transform into static graph
out = dygraph_to_static_output(self.dyfunc)(x_v) out = dygraph_to_static_graph(self.dyfunc)(x_v)
exe = fluid.Executor(place) exe = fluid.Executor(place)
ret = exe.run(main_program, fetch_list=out) ret = exe.run(main_program, fetch_list=out)
return ret return ret
......
...@@ -20,7 +20,7 @@ import unittest ...@@ -20,7 +20,7 @@ import unittest
import inspect import inspect
import gast import gast
from paddle.fluid.dygraph.jit import dygraph_to_static_output from paddle.fluid.dygraph.jit import dygraph_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api
SEED = 2020 SEED = 2020
...@@ -47,7 +47,7 @@ class TestDygraphBasicApi_ToVariable(unittest.TestCase): ...@@ -47,7 +47,7 @@ class TestDygraphBasicApi_ToVariable(unittest.TestCase):
main_program = fluid.Program() main_program = fluid.Program()
main_program.random_seed = SEED main_program.random_seed = SEED
with fluid.program_guard(main_program): with fluid.program_guard(main_program):
static_out = dygraph_to_static_output(self.dygraph_func)(self.input) static_out = dygraph_to_static_graph(self.dygraph_func)(self.input)
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
static_res = exe.run(main_program, fetch_list=static_out) static_res = exe.run(main_program, fetch_list=static_out)
...@@ -190,7 +190,7 @@ class TestDygraphBasicApi(unittest.TestCase): ...@@ -190,7 +190,7 @@ class TestDygraphBasicApi(unittest.TestCase):
main_program.random_seed = SEED main_program.random_seed = SEED
with fluid.program_guard(main_program, startup_program): with fluid.program_guard(main_program, startup_program):
data = fluid.layers.assign(self.input) data = fluid.layers.assign(self.input)
static_out = dygraph_to_static_output(self.dygraph_func)(data) static_out = dygraph_to_static_graph(self.dygraph_func)(data)
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup_program) exe.run(startup_program)
...@@ -225,8 +225,8 @@ class TestDygraphBasicApi_BilinearTensorProduct(TestDygraphBasicApi): ...@@ -225,8 +225,8 @@ class TestDygraphBasicApi_BilinearTensorProduct(TestDygraphBasicApi):
main_program = fluid.Program() main_program = fluid.Program()
main_program.random_seed = SEED main_program.random_seed = SEED
with fluid.program_guard(main_program, startup_program): with fluid.program_guard(main_program, startup_program):
static_out = dygraph_to_static_output(self.dygraph_func)( static_out = dygraph_to_static_graph(self.dygraph_func)(self.input1,
self.input1, self.input2) self.input2)
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup_program) exe.run(startup_program)
...@@ -352,7 +352,7 @@ class TestDygraphBasicApi_CosineDecay(unittest.TestCase): ...@@ -352,7 +352,7 @@ class TestDygraphBasicApi_CosineDecay(unittest.TestCase):
main_program = fluid.Program() main_program = fluid.Program()
main_program.random_seed = SEED main_program.random_seed = SEED
with fluid.program_guard(main_program, startup_program): with fluid.program_guard(main_program, startup_program):
static_out = dygraph_to_static_output(self.dygraph_func)() static_out = dygraph_to_static_graph(self.dygraph_func)()
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup_program) exe.run(startup_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册