未验证 提交 8c381cd9 编写于 作者: L liym27 提交者: GitHub

support fetch feed in dygraph to static graph (#22767)

* Support fetch and run program in the process of dygraph_to_static_output. test=develop

* fix to_source(gast) and remove dygraph API such as Conv2D, Linear. test=develop
上级 88776e40
......@@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import print_function
import astor
from .utils import *
import gast
# gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST).
......@@ -21,7 +20,7 @@ import gast
# as produced by ast.parse from the standard ast module.
# See details in https://github.com/serge-sans-paille/gast/
from .ast_utils import is_control_flow_if, create_cond_node, transform_if_else
from paddle.fluid import unique_name
from .static_analysis import AstNodeWrapper, StaticAnalysisVisitor
__all__ = ['DygraphToStaticAst']
......@@ -109,7 +108,9 @@ class DygraphToStaticAst(gast.NodeTransformer):
self.visit(node.node)
# Transform basic api of dygraph to static graph
BasicApiTransformer(node).ast_visit()
basic_api_trans = BasicApiTransformer(node)
basic_api_trans.ast_visit()
self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id()
# Transform all if/else statement of Dygraph into Static Graph.
IfElseTransformer(node).ast_visit()
......@@ -117,6 +118,11 @@ class DygraphToStaticAst(gast.NodeTransformer):
def visit_FunctionDef(self, node):
if self.decorate_func_name is None:
self.decorate_func_name = node.name
self.arg_name_to_idx = {}
for idx, arg in enumerate(node.args.args):
self.arg_name_to_idx[arg.id] = idx
self.generic_visit(node)
# Remove the decorated name of dygraph_to_static
if hasattr(node, 'decorator_list'):
......@@ -135,6 +141,12 @@ class DygraphToStaticAst(gast.NodeTransformer):
assert self.decorate_func_name, "decorate_func_name shall not be None."
return self.decorate_func_name
def get_feed_name_to_idx(self):
feed_name_to_idx = {}
for feed_name, arg_name in self.feed_name_to_arg_name.items():
feed_name_to_idx[feed_name] = self.arg_name_to_idx.get(arg_name)
return feed_name_to_idx
class BasicApiTransformer(gast.NodeTransformer):
"""
......@@ -148,6 +160,7 @@ class BasicApiTransformer(gast.NodeTransformer):
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.class_node_dict = {}
self.feed_name_to_arg_id = {}
def ast_visit(self):
self.visit(self.root)
......@@ -189,10 +202,11 @@ class BasicApiTransformer(gast.NodeTransformer):
# Replace API `to_variable` with `fluid.layers.assign`
if is_to_variable(node):
self._update_feed_dict(node)
node = to_assign_node(node)
return node
func_name = astor.to_source(node.func)
func_name = astor.to_source(gast.gast_to_ast(node.func))
if self._is_dygraph_forward(func_name):
class_node = self._get_class_node(func_name)
static_node = to_static_ast(node, class_node)
......@@ -214,9 +228,29 @@ class BasicApiTransformer(gast.NodeTransformer):
return False
if is_dygraph_api(node_value):
dygraph_api = node_value.func.attr
if not dygraph_class_to_static_api.get(dygraph_api):
return False
update_args_of_func(node_value, node_value, "__init__")
target_str = astor.to_source(node.targets[0])
target_str = astor.to_source(gast.gast_to_ast(node.targets[0]))
self.class_node_dict[target_str] = node_value
return True
# TODO: node.value is not dygraph class
return False
def _update_feed_dict(self, node):
assert isinstance(node, gast.Call)
var_name = None
for kw in node.keywords:
if kw.arg == 'value':
var_name = kw.value.id # eg: 'a' for "value=a "
if not var_name:
var_name = node.args[0].id
feed_var_name = unique_name.generate(var_name) # eg: "a_0"
self.feed_name_to_arg_id[feed_var_name] = var_name # eg: "a_0" : "a"
def get_feed_name_to_arg_id(self):
return self.feed_name_to_arg_id
......@@ -19,55 +19,11 @@ import gast
import inspect
import six
import warnings
from .utils import is_paddle_api, is_dygraph_api, is_numpy_api
__all__ = ['AstNodeWrapper', 'NodeVarType', 'StaticAnalysisVisitor']
# TODO: _is_paddle_dygraph_api is duplicated in Yamei's utils.py. Merge the two
# function code together when Yamei finish her PR.
def _is_api_in_module_helper(obj, module_prefix):
m = inspect.getmodule(obj)
return m is not None and m.__name__.startswith(module_prefix)
# TODO: is_dygraph_api is duplicated in Yamei's utils.py. Merge the two
# function code together when Yamei finish her PR.
def is_api_in_module(node, module_prefix):
assert isinstance(node, gast.Call), "Input non-Call node for is_dygraph_api"
func_str = astor.to_source(gast.gast_to_ast(node.func))
try:
import paddle.fluid as fluid
import paddle
return eval("_is_api_in_module_helper({}, '{}')".format(func_str,
module_prefix))
except NameError:
return False
def is_dygraph_api(node):
return is_api_in_module(node, "paddle.fluid.dygraph")
def is_paddle_api(node):
return is_api_in_module(node, "paddle.fluid")
# Is numpy_api cannot reuse is_api_in_module because of numpy module problem
def is_numpy_api(node):
assert isinstance(node, gast.Call), "Input non-Call node for is_numpy_api"
func_str = astor.to_source(gast.gast_to_ast(node.func))
try:
import numpy as np
module_result = eval("_is_api_in_module_helper({}, '{}')".format(
func_str, "numpy"))
# BUG: np.random.uniform doesn't have module and cannot be analyzed
# TODO: find a better way
if not module_result:
return func_str.startswith("numpy.") or func_str.startswith("np.")
except NameError:
return False
class NodeVarType(object):
"""
Enum class of python variable types. We have to know some variable types
......
......@@ -17,41 +17,62 @@ from __future__ import print_function
import inspect
import gast
import astor
import atexit
import os
import tempfile
import six
import imp
dygraph_class_to_static_api = {
"BatchNorm": "batch_norm",
"BilinearTensorProduct": "bilinear_tensor_product",
"Conv2D": "conv2d",
"Conv3D": "conv3d",
"Conv2DTranspose": "conv2d_transpose",
"Conv3DTranspose": "conv3d_transpose",
"CosineDecay": "cosine_decay",
"Embedding": "embedding",
"ExponentialDecay": "exponential_decay",
"GroupNorm": "group_norm",
"GRUUnit": "gru_unit",
"InverseTimeDecay": "inverse_time_decay",
"LayerNorm": "layer_norm",
"Linear": "fc",
"NaturalExpDecay": "natural_exp_decay",
"NCE": "nce",
"NoamDecay": "noam_decay",
"PiecewiseDecay": "piecewise_decay",
"PolynomialDecay": "polynomial_decay",
"Pool2D": "pool2d",
"PRelu": "prelu",
"SpectralNorm": "spectral_norm",
}
def _is_api_in_module_helper(obj, module_prefix):
m = inspect.getmodule(obj)
return m is not None and m.__name__.startswith(module_prefix)
def is_api_in_module(node, module_prefix):
assert isinstance(node, gast.Call), "Input non-Call node for is_dygraph_api"
func_str = astor.to_source(gast.gast_to_ast(node.func))
try:
import paddle.fluid as fluid
import paddle
return eval("_is_api_in_module_helper({}, '{}')".format(func_str,
module_prefix))
except NameError:
return False
def is_dygraph_api(node):
return is_api_in_module(node, "paddle.fluid.dygraph")
def is_paddle_api(node):
return is_api_in_module(node, "paddle.fluid")
# Is numpy_api cannot reuse is_api_in_module because of numpy module problem
def is_numpy_api(node):
assert isinstance(node, gast.Call), "Input non-Call node for is_numpy_api"
func_str = astor.to_source(gast.gast_to_ast(node.func))
try:
import numpy as np
module_result = eval("_is_api_in_module_helper({}, '{}')".format(
func_str, "numpy"))
# BUG: np.random.uniform doesn't have module and cannot be analyzed
# TODO: find a better way
if not module_result:
return func_str.startswith("numpy.") or func_str.startswith("np.")
except NameError:
return False
def _delete_keywords_from(node):
assert isinstance(node, gast.Call)
func_src = astor.to_source(node.func)
func_src = astor.to_source(gast.gast_to_ast(node.func))
import paddle.fluid as fluid
full_args = eval("inspect.getargspec({})".format(func_src))
full_args_name = full_args[0]
......@@ -94,21 +115,6 @@ def _add_keywords_to(node, dygraph_api_name):
return
def _is_paddle_dygraph_api(obj):
m = inspect.getmodule(obj)
return m is not None and m.__name__.startswith("paddle.fluid.dygraph")
def is_dygraph_api(node):
assert isinstance(node, gast.Call)
func_src = astor.to_source(node.func)
try:
import paddle.fluid as fluid
return eval("_is_paddle_dygraph_api({})".format(func_src))
except NameError:
return False
def is_to_variable(node):
assert isinstance(node, gast.Call)
if is_dygraph_api(node):
......@@ -158,7 +164,7 @@ def update_args_of_func(node, dygraph_node, method_name):
"The method name of class to update args should be '__init__' or 'forward'"
)
class_src = astor.to_source(dygraph_node.func)
class_src = astor.to_source(gast.gast_to_ast(dygraph_node.func))
import paddle.fluid as fluid
if method_name == "__init__" or eval(
"issubclass({}, fluid.dygraph.Layer)".format(class_src)):
......
......@@ -29,6 +29,7 @@ from paddle.fluid import core
from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode
from paddle.fluid.executor import Executor, scope_guard
from paddle.fluid.compiler import CompiledProgram
from paddle.fluid import program_guard, data
def create_program_from_desc(program_desc):
......@@ -56,6 +57,7 @@ def extract_vars(inputs):
def _dygraph_to_static_output_(dygraph_func):
def __impl__(*args, **kwargs):
# Get AST from dygraph function
dygraph_code = inspect.getsource(dygraph_func)
dygraph_code = textwrap.dedent(dygraph_code)
......@@ -64,14 +66,53 @@ def _dygraph_to_static_output_(dygraph_func):
# Transform AST
dygraph_to_static = DygraphToStaticAst()
root_wrapper = dygraph_to_static.get_static_ast(root)
# Get static_func from AST
func_name = dygraph_to_static.get_module_name()
static_func, file_name = ast_to_func(root_wrapper.node, func_name)
return static_func(*args, **kwargs)
if not in_dygraph_mode():
return static_func(*args, **kwargs)
else:
feed_name_to_idx = dygraph_to_static.get_feed_name_to_idx()
feed_dict = {}
for feed_name, idx in feed_name_to_idx.items():
feed_dict[feed_name] = args[idx]
# Run static_func in static mode
startup_program = Program()
main_program = Program()
static_res = run_static_func(main_program, startup_program,
static_func, args, kwargs, feed_dict,
feed_name_to_idx)
return static_res
return __impl__
@switch_to_static_graph
def run_static_func(main_program, startup_program, static_func, args, kwargs,
feed_dict, feed_name_to_idx):
with program_guard(main_program, startup_program):
args_list = list(args)
for var_name, value in feed_dict.items():
idx = feed_name_to_idx[var_name]
args_list[idx] = data(
name=var_name, shape=value.shape, dtype=str(value.dtype))
args = tuple(args_list)
static_out = static_func(*args, **kwargs)
if not isinstance(static_out, (list, tuple)):
static_out = [static_out]
exe = Executor(core.CPUPlace())
exe.run(startup_program)
static_res = exe.run(main_program,
fetch_list=static_out,
feed=feed_dict)
return static_res
dygraph_to_static_output = wrap_decorator(_dygraph_to_static_output_)
......
......@@ -337,6 +337,7 @@ set_tests_properties(test_parallel_executor_seresnext_with_reduce_cpu PROPERTIES
set_tests_properties(test_parallel_executor_seresnext_with_fuse_all_reduce_cpu PROPERTIES TIMEOUT 750)
add_subdirectory(sequence)
add_subdirectory(dygraph_to_static)
if (WITH_NGRAPH)
add_subdirectory(ngraph)
......
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddle.fluid.dygraph.jit import dygraph_to_static_output
import numpy as np
import unittest
import paddle.fluid as fluid
SEED = 2020
class Pool2D(fluid.dygraph.Layer):
def __init__(self):
super(Pool2D, self).__init__()
self.pool2d = fluid.dygraph.Pool2D(
pool_size=2, pool_type='avg', pool_stride=1, global_pooling=False)
@dygraph_to_static_output
def forward(self, x):
inputs = fluid.dygraph.to_variable(x)
pre = self.pool2d(inputs)
return pre
class Linear(fluid.dygraph.Layer):
def __init__(self):
super(Linear, self).__init__()
@dygraph_to_static_output
def forward(self, x):
fc = fluid.dygraph.Linear(
input_dim=10,
output_dim=5,
act='relu',
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.99)),
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.5)))
inputs = fluid.dygraph.to_variable(x)
pre = fc(inputs)
return pre
class TestPool2D(unittest.TestCase):
def setUp(self):
self.dygraph_class = Pool2D
self.data = np.random.random((1, 2, 4, 4)).astype('float32')
def run_dygraph_mode(self):
with fluid.dygraph.guard():
dy_layer = self.dygraph_class()
for _ in range(1):
prediction = dy_layer(x=self.data)
return prediction
def run_static_mode(self):
startup_prog = fluid.Program()
main_prog = fluid.Program()
with fluid.program_guard(main_prog, startup_prog):
dy_layer = self.dygraph_class()
out = dy_layer(x=self.data)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
res = exe.run(main_prog, fetch_list=out)
return res
def test_static_output(self):
dygraph_res = self.run_dygraph_mode()
static_res = self.run_static_mode()
self.assertTrue(
np.allclose(dygraph_res[0], static_res[0]),
msg='dygraph is {}\n static_res is \n{}'.format(dygraph_res,
static_res))
return
class TestLinear(unittest.TestCase):
def setUp(self):
self.dygraph_class = Linear
self.data = np.random.random((4, 10)).astype('float32')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册