未验证 提交 68a92e46 编写于 作者: L liym27 提交者: GitHub

fix dygraph_to_static_ouput and add a new decorator. (#22785)

* change dygraph_to_static_output to dygraph_to_static_graph. test=develop

* Remove duplicate code. test=develop

* Follow comments from Liujie. test=develop

* change dygraph_to_static_output to dygraph_to_static_graph. test=develop
上级 5dbafe38
......@@ -25,7 +25,7 @@ from .static_analysis import AstNodeWrapper, StaticAnalysisVisitor
__all__ = ['DygraphToStaticAst']
DECORATOR_NAME = 'dygraph_to_static_output'
DECORATOR_NAMES = ['dygraph_to_static_output', 'dygraph_to_static_graph']
class IfElseTransformer(gast.NodeTransformer):
......@@ -100,6 +100,7 @@ class DygraphToStaticAst(gast.NodeTransformer):
self.static_analysis_root = StaticAnalysisVisitor(
root).get_node_wrapper_root()
self.decorate_func_name = None
self.arg_name_to_idx = {}
self.transfer_from_node_type(self.static_analysis_root)
return self.static_analysis_root
......@@ -118,16 +119,14 @@ class DygraphToStaticAst(gast.NodeTransformer):
def visit_FunctionDef(self, node):
if self.decorate_func_name is None:
self.decorate_func_name = node.name
self.arg_name_to_idx = {}
for idx, arg in enumerate(node.args.args):
self.arg_name_to_idx[arg.id] = idx
for idx, arg in enumerate(node.args.args):
self.arg_name_to_idx[arg.id] = idx
self.generic_visit(node)
# Remove the decorated name of dygraph_to_static
if hasattr(node, 'decorator_list'):
decorator_list = [
d for d in node.decorator_list if d.id != DECORATOR_NAME
d for d in node.decorator_list if d.id not in DECORATOR_NAMES
]
node.decorator_list = decorator_list
return node
......@@ -170,7 +169,7 @@ class BasicApiTransformer(gast.NodeTransformer):
self.generic_visit(node)
if hasattr(node, 'decorator_list'):
decorator_list = [
d for d in node.decorator_list if d.id != DECORATOR_NAME
d for d in node.decorator_list if d.id not in DECORATOR_NAMES
]
node.decorator_list = decorator_list
return node
......
......@@ -14,11 +14,12 @@
from __future__ import print_function
__all__ = ['TracedLayer', 'dygraph_to_static_output']
__all__ = ['TracedLayer', 'dygraph_to_static_output', 'dygraph_to_static_graph']
import gast
import inspect
import textwrap
import warnings
from ..wrapped_decorator import wrap_decorator
from .base import program_desc_tracing_guard, switch_to_static_graph
......@@ -29,7 +30,7 @@ from paddle.fluid import core
from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode
from paddle.fluid.executor import Executor, scope_guard
from paddle.fluid.compiler import CompiledProgram
from paddle.fluid import program_guard, data
from paddle.fluid import program_guard, data, default_startup_program, default_main_program
def create_program_from_desc(program_desc):
......@@ -55,43 +56,60 @@ def extract_vars(inputs):
return result_list
def _dygraph_to_static_output_(dygraph_func):
def __impl__(*args, **kwargs):
def to_static_func(dygraph_func):
# Get AST from dygraph function
dygraph_code = inspect.getsource(dygraph_func)
dygraph_code = textwrap.dedent(dygraph_code)
root = gast.parse(dygraph_code)
# Get AST from dygraph function
dygraph_code = inspect.getsource(dygraph_func)
dygraph_code = textwrap.dedent(dygraph_code)
root = gast.parse(dygraph_code)
# Transform AST
dygraph_to_static = DygraphToStaticAst()
root_wrapper = dygraph_to_static.get_static_ast(root)
# Transform AST
dygraph_to_static = DygraphToStaticAst()
root_wrapper = dygraph_to_static.get_static_ast(root)
# Get static_func from AST
func_name = dygraph_to_static.get_module_name()
static_func, file_name = ast_to_func(root_wrapper.node, func_name)
# Get static_func from AST
func_name = dygraph_to_static.get_module_name()
static_func, file_name = ast_to_func(root_wrapper.node, func_name)
return static_func, dygraph_to_static
def _dygraph_to_static_graph_(dygraph_func):
def __impl__(*args, **kwargs):
if in_dygraph_mode():
warnings.warn(
"The decorator 'dygraph_to_static_graph' doesn't work in dygraph mode."
" Please use it in static mode.")
return dygraph_func(*args, **kwargs)
static_func, dygraph_to_static = to_static_func(dygraph_func)
return static_func(*args, **kwargs)
return __impl__
if not in_dygraph_mode():
return static_func(*args, **kwargs)
else:
feed_name_to_idx = dygraph_to_static.get_feed_name_to_idx()
feed_dict = {}
for feed_name, idx in feed_name_to_idx.items():
feed_dict[feed_name] = args[idx]
# Run static_func in static mode
startup_program = Program()
main_program = Program()
static_res = run_static_func(main_program, startup_program,
static_func, args, kwargs, feed_dict,
feed_name_to_idx)
def _dygraph_to_static_output_(dygraph_func):
def __impl__(*args, **kwargs):
if in_dygraph_mode():
warnings.warn(
"The decorator 'dygraph_to_static_output' doesn't work in dygraph mode."
" Please use it in static mode.")
return dygraph_func(*args, **kwargs)
static_func, dygraph_to_static = to_static_func(dygraph_func)
feed_name_to_idx = dygraph_to_static.get_feed_name_to_idx()
feed_dict = {}
for feed_name, idx in feed_name_to_idx.items():
feed_dict[feed_name] = args[idx]
# Run static_func in static mode
startup_program = default_main_program()
main_program = default_startup_program()
static_res = run_static_func(main_program, startup_program, static_func,
args, kwargs, feed_dict, feed_name_to_idx)
return static_res
return __impl__
@switch_to_static_graph
def run_static_func(main_program, startup_program, static_func, args, kwargs,
feed_dict, feed_name_to_idx):
......@@ -114,6 +132,7 @@ def run_static_func(main_program, startup_program, static_func, args, kwargs,
dygraph_to_static_output = wrap_decorator(_dygraph_to_static_output_)
dygraph_to_static_graph = wrap_decorator(_dygraph_to_static_graph_)
@dygraph_only
......
......@@ -33,17 +33,19 @@ class Pool2D(fluid.dygraph.Layer):
@dygraph_to_static_output
def forward(self, x):
inputs = fluid.dygraph.to_variable(x)
pre = self.pool2d(inputs)
# Add func `get_result` for testing arg_name_to_idx in ast transformation.
def get_result(x):
return self.pool2d(x)
pre = get_result(inputs)
return pre
class Linear(fluid.dygraph.Layer):
def __init__(self):
super(Linear, self).__init__()
@dygraph_to_static_output
def forward(self, x):
fc = fluid.dygraph.Linear(
self.fc = fluid.dygraph.Linear(
input_dim=10,
output_dim=5,
act='relu',
......@@ -51,8 +53,11 @@ class Linear(fluid.dygraph.Layer):
value=0.99)),
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.5)))
@dygraph_to_static_output
def forward(self, x):
inputs = fluid.dygraph.to_variable(x)
pre = fc(inputs)
pre = self.fc(inputs)
return pre
......@@ -67,7 +72,7 @@ class TestPool2D(unittest.TestCase):
for _ in range(1):
prediction = dy_layer(x=self.data)
return prediction
return prediction.numpy()
def run_static_mode(self):
startup_prog = fluid.Program()
......@@ -75,22 +80,20 @@ class TestPool2D(unittest.TestCase):
with fluid.program_guard(main_prog, startup_prog):
dy_layer = self.dygraph_class()
out = dy_layer(x=self.data)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
res = exe.run(main_prog, fetch_list=out)
return res
return out[0]
def test_static_output(self):
dygraph_res = self.run_dygraph_mode()
static_res = self.run_static_mode()
self.assertTrue(
np.allclose(dygraph_res[0], static_res[0]),
msg='dygraph is {}\n static_res is \n{}'.format(dygraph_res,
static_res))
np.allclose(dygraph_res, static_res),
msg='dygraph_res is {}\n static_res is \n{}'.format(dygraph_res,
static_res))
return
class TestLinear(unittest.TestCase):
class TestLinear(TestPool2D):
def setUp(self):
self.dygraph_class = Linear
self.data = np.random.random((4, 10)).astype('float32')
......
......@@ -20,7 +20,7 @@ import paddle.fluid as fluid
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.dygraph.jit import dygraph_to_static_output
from paddle.fluid.dygraph.jit import dygraph_to_static_graph
import unittest
......@@ -66,7 +66,7 @@ class SimpleImgConvPool(fluid.dygraph.Layer):
global_pooling=global_pooling,
use_cudnn=use_cudnn)
@dygraph_to_static_output
@dygraph_to_static_graph
def forward(self, inputs):
x = self._conv2d(inputs)
x = self._pool2d(x)
......@@ -94,7 +94,7 @@ class MNIST(fluid.dygraph.Layer):
loc=0.0, scale=scale)),
act="softmax")
@dygraph_to_static_output
@dygraph_to_static_graph
def forward(self, inputs, label=None):
x = self.inference(inputs)
if label is not None:
......@@ -105,7 +105,7 @@ class MNIST(fluid.dygraph.Layer):
else:
return x
@dygraph_to_static_output
@dygraph_to_static_graph
def inference(self, inputs):
x = self._simple_img_conv_pool_1(inputs)
x = self._simple_img_conv_pool_2(x)
......@@ -128,7 +128,7 @@ class TestMNIST(unittest.TestCase):
class TestMNISTWithStaticMode(TestMNIST):
"""
Tests model when using `dygraph_to_static_output` to convert dygraph into static
Tests model when using `dygraph_to_static_graph` to convert dygraph into static
model. It allows user to add customized code to train static model, such as `with`
and `Executor` statement.
"""
......
......@@ -18,7 +18,7 @@ import numpy as np
import paddle.fluid as fluid
import unittest
from paddle.fluid.dygraph.jit import dygraph_to_static_output
from paddle.fluid.dygraph.jit import dygraph_to_static_graph
np.random.seed(1)
......@@ -92,7 +92,7 @@ class TestDygraphIfElse(unittest.TestCase):
with fluid.program_guard(main_program):
x_v = fluid.layers.assign(self.x)
# Transform into static graph
out = dygraph_to_static_output(self.dyfunc)(x_v)
out = dygraph_to_static_graph(self.dyfunc)(x_v)
exe = fluid.Executor(place)
ret = exe.run(main_program, fetch_list=out)
return ret
......
......@@ -20,7 +20,7 @@ import unittest
import inspect
import gast
from paddle.fluid.dygraph.jit import dygraph_to_static_output
from paddle.fluid.dygraph.jit import dygraph_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api
SEED = 2020
......@@ -47,7 +47,7 @@ class TestDygraphBasicApi_ToVariable(unittest.TestCase):
main_program = fluid.Program()
main_program.random_seed = SEED
with fluid.program_guard(main_program):
static_out = dygraph_to_static_output(self.dygraph_func)(self.input)
static_out = dygraph_to_static_graph(self.dygraph_func)(self.input)
exe = fluid.Executor(fluid.CPUPlace())
static_res = exe.run(main_program, fetch_list=static_out)
......@@ -190,7 +190,7 @@ class TestDygraphBasicApi(unittest.TestCase):
main_program.random_seed = SEED
with fluid.program_guard(main_program, startup_program):
data = fluid.layers.assign(self.input)
static_out = dygraph_to_static_output(self.dygraph_func)(data)
static_out = dygraph_to_static_graph(self.dygraph_func)(data)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup_program)
......@@ -225,8 +225,8 @@ class TestDygraphBasicApi_BilinearTensorProduct(TestDygraphBasicApi):
main_program = fluid.Program()
main_program.random_seed = SEED
with fluid.program_guard(main_program, startup_program):
static_out = dygraph_to_static_output(self.dygraph_func)(
self.input1, self.input2)
static_out = dygraph_to_static_graph(self.dygraph_func)(self.input1,
self.input2)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup_program)
......@@ -352,7 +352,7 @@ class TestDygraphBasicApi_CosineDecay(unittest.TestCase):
main_program = fluid.Program()
main_program.random_seed = SEED
with fluid.program_guard(main_program, startup_program):
static_out = dygraph_to_static_output(self.dygraph_func)()
static_out = dygraph_to_static_graph(self.dygraph_func)()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册