From 8c381cd95716cfefedb9c8d7e7f32169b2520e14 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Thu, 27 Feb 2020 17:25:13 +0800 Subject: [PATCH] support fetch feed in dygraph to static graph (#22767) * Support fetch and run program in the process of dygraph_to_static_output. test=develop * fix to_source(gast) and remove dygraph API such as Conv2D, Linear. test=develop --- .../dygraph_to_static/ast_transformer.py | 44 +++++++- .../dygraph_to_static/static_analysis.py | 46 +------- .../fluid/dygraph/dygraph_to_static/utils.py | 80 +++++++------- python/paddle/fluid/dygraph/jit.py | 43 +++++++- .../fluid/tests/unittests/CMakeLists.txt | 1 + .../dygraph_to_static/CMakeLists.txt | 6 ++ .../dygraph_to_static/test_fetch_feed.py | 100 ++++++++++++++++++ 7 files changed, 232 insertions(+), 88 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/CMakeLists.txt create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index a2e3841291..bb2e1ee0f3 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import print_function -import astor from .utils import * import gast # gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST). @@ -21,7 +20,7 @@ import gast # as produced by ast.parse from the standard ast module. # See details in https://github.com/serge-sans-paille/gast/ from .ast_utils import is_control_flow_if, create_cond_node, transform_if_else - +from paddle.fluid import unique_name from .static_analysis import AstNodeWrapper, StaticAnalysisVisitor __all__ = ['DygraphToStaticAst'] @@ -109,7 +108,9 @@ class DygraphToStaticAst(gast.NodeTransformer): self.visit(node.node) # Transform basic api of dygraph to static graph - BasicApiTransformer(node).ast_visit() + basic_api_trans = BasicApiTransformer(node) + basic_api_trans.ast_visit() + self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id() # Transform all if/else statement of Dygraph into Static Graph. IfElseTransformer(node).ast_visit() @@ -117,6 +118,11 @@ class DygraphToStaticAst(gast.NodeTransformer): def visit_FunctionDef(self, node): if self.decorate_func_name is None: self.decorate_func_name = node.name + + self.arg_name_to_idx = {} + for idx, arg in enumerate(node.args.args): + self.arg_name_to_idx[arg.id] = idx + self.generic_visit(node) # Remove the decorated name of dygraph_to_static if hasattr(node, 'decorator_list'): @@ -135,6 +141,12 @@ class DygraphToStaticAst(gast.NodeTransformer): assert self.decorate_func_name, "decorate_func_name shall not be None." return self.decorate_func_name + def get_feed_name_to_idx(self): + feed_name_to_idx = {} + for feed_name, arg_name in self.feed_name_to_arg_name.items(): + feed_name_to_idx[feed_name] = self.arg_name_to_idx.get(arg_name) + return feed_name_to_idx + class BasicApiTransformer(gast.NodeTransformer): """ @@ -148,6 +160,7 @@ class BasicApiTransformer(gast.NodeTransformer): self.wrapper_root = wrapper_root self.root = wrapper_root.node self.class_node_dict = {} + self.feed_name_to_arg_id = {} def ast_visit(self): self.visit(self.root) @@ -189,10 +202,11 @@ class BasicApiTransformer(gast.NodeTransformer): # Replace API `to_variable` with `fluid.layers.assign` if is_to_variable(node): + self._update_feed_dict(node) node = to_assign_node(node) return node - func_name = astor.to_source(node.func) + func_name = astor.to_source(gast.gast_to_ast(node.func)) if self._is_dygraph_forward(func_name): class_node = self._get_class_node(func_name) static_node = to_static_ast(node, class_node) @@ -214,9 +228,29 @@ class BasicApiTransformer(gast.NodeTransformer): return False if is_dygraph_api(node_value): + dygraph_api = node_value.func.attr + if not dygraph_class_to_static_api.get(dygraph_api): + return False + update_args_of_func(node_value, node_value, "__init__") - target_str = astor.to_source(node.targets[0]) + target_str = astor.to_source(gast.gast_to_ast(node.targets[0])) self.class_node_dict[target_str] = node_value return True # TODO: node.value is not dygraph class return False + + def _update_feed_dict(self, node): + assert isinstance(node, gast.Call) + + var_name = None + for kw in node.keywords: + if kw.arg == 'value': + var_name = kw.value.id # eg: 'a' for "value=a " + if not var_name: + var_name = node.args[0].id + + feed_var_name = unique_name.generate(var_name) # eg: "a_0" + self.feed_name_to_arg_id[feed_var_name] = var_name # eg: "a_0" : "a" + + def get_feed_name_to_arg_id(self): + return self.feed_name_to_arg_id diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py index c8ab1a4586..34d05bed32 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py @@ -19,55 +19,11 @@ import gast import inspect import six import warnings +from .utils import is_paddle_api, is_dygraph_api, is_numpy_api __all__ = ['AstNodeWrapper', 'NodeVarType', 'StaticAnalysisVisitor'] -# TODO: _is_paddle_dygraph_api is duplicated in Yamei's utils.py. Merge the two -# function code together when Yamei finish her PR. -def _is_api_in_module_helper(obj, module_prefix): - m = inspect.getmodule(obj) - return m is not None and m.__name__.startswith(module_prefix) - - -# TODO: is_dygraph_api is duplicated in Yamei's utils.py. Merge the two -# function code together when Yamei finish her PR. -def is_api_in_module(node, module_prefix): - assert isinstance(node, gast.Call), "Input non-Call node for is_dygraph_api" - func_str = astor.to_source(gast.gast_to_ast(node.func)) - try: - import paddle.fluid as fluid - import paddle - return eval("_is_api_in_module_helper({}, '{}')".format(func_str, - module_prefix)) - except NameError: - return False - - -def is_dygraph_api(node): - return is_api_in_module(node, "paddle.fluid.dygraph") - - -def is_paddle_api(node): - return is_api_in_module(node, "paddle.fluid") - - -# Is numpy_api cannot reuse is_api_in_module because of numpy module problem -def is_numpy_api(node): - assert isinstance(node, gast.Call), "Input non-Call node for is_numpy_api" - func_str = astor.to_source(gast.gast_to_ast(node.func)) - try: - import numpy as np - module_result = eval("_is_api_in_module_helper({}, '{}')".format( - func_str, "numpy")) - # BUG: np.random.uniform doesn't have module and cannot be analyzed - # TODO: find a better way - if not module_result: - return func_str.startswith("numpy.") or func_str.startswith("np.") - except NameError: - return False - - class NodeVarType(object): """ Enum class of python variable types. We have to know some variable types diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 4ebe58da57..d2d2750412 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -17,41 +17,62 @@ from __future__ import print_function import inspect import gast import astor -import atexit -import os -import tempfile -import six -import imp dygraph_class_to_static_api = { - "BatchNorm": "batch_norm", - "BilinearTensorProduct": "bilinear_tensor_product", - "Conv2D": "conv2d", - "Conv3D": "conv3d", - "Conv2DTranspose": "conv2d_transpose", - "Conv3DTranspose": "conv3d_transpose", "CosineDecay": "cosine_decay", - "Embedding": "embedding", "ExponentialDecay": "exponential_decay", - "GroupNorm": "group_norm", - "GRUUnit": "gru_unit", "InverseTimeDecay": "inverse_time_decay", - "LayerNorm": "layer_norm", - "Linear": "fc", "NaturalExpDecay": "natural_exp_decay", - "NCE": "nce", "NoamDecay": "noam_decay", "PiecewiseDecay": "piecewise_decay", "PolynomialDecay": "polynomial_decay", - "Pool2D": "pool2d", - "PRelu": "prelu", - "SpectralNorm": "spectral_norm", } +def _is_api_in_module_helper(obj, module_prefix): + m = inspect.getmodule(obj) + return m is not None and m.__name__.startswith(module_prefix) + + +def is_api_in_module(node, module_prefix): + assert isinstance(node, gast.Call), "Input non-Call node for is_dygraph_api" + func_str = astor.to_source(gast.gast_to_ast(node.func)) + try: + import paddle.fluid as fluid + import paddle + return eval("_is_api_in_module_helper({}, '{}')".format(func_str, + module_prefix)) + except NameError: + return False + + +def is_dygraph_api(node): + return is_api_in_module(node, "paddle.fluid.dygraph") + + +def is_paddle_api(node): + return is_api_in_module(node, "paddle.fluid") + + +# Is numpy_api cannot reuse is_api_in_module because of numpy module problem +def is_numpy_api(node): + assert isinstance(node, gast.Call), "Input non-Call node for is_numpy_api" + func_str = astor.to_source(gast.gast_to_ast(node.func)) + try: + import numpy as np + module_result = eval("_is_api_in_module_helper({}, '{}')".format( + func_str, "numpy")) + # BUG: np.random.uniform doesn't have module and cannot be analyzed + # TODO: find a better way + if not module_result: + return func_str.startswith("numpy.") or func_str.startswith("np.") + except NameError: + return False + + def _delete_keywords_from(node): assert isinstance(node, gast.Call) - func_src = astor.to_source(node.func) + func_src = astor.to_source(gast.gast_to_ast(node.func)) import paddle.fluid as fluid full_args = eval("inspect.getargspec({})".format(func_src)) full_args_name = full_args[0] @@ -94,21 +115,6 @@ def _add_keywords_to(node, dygraph_api_name): return -def _is_paddle_dygraph_api(obj): - m = inspect.getmodule(obj) - return m is not None and m.__name__.startswith("paddle.fluid.dygraph") - - -def is_dygraph_api(node): - assert isinstance(node, gast.Call) - func_src = astor.to_source(node.func) - try: - import paddle.fluid as fluid - return eval("_is_paddle_dygraph_api({})".format(func_src)) - except NameError: - return False - - def is_to_variable(node): assert isinstance(node, gast.Call) if is_dygraph_api(node): @@ -158,7 +164,7 @@ def update_args_of_func(node, dygraph_node, method_name): "The method name of class to update args should be '__init__' or 'forward'" ) - class_src = astor.to_source(dygraph_node.func) + class_src = astor.to_source(gast.gast_to_ast(dygraph_node.func)) import paddle.fluid as fluid if method_name == "__init__" or eval( "issubclass({}, fluid.dygraph.Layer)".format(class_src)): diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index f84173b8c7..db0fe62ee6 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -29,6 +29,7 @@ from paddle.fluid import core from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode from paddle.fluid.executor import Executor, scope_guard from paddle.fluid.compiler import CompiledProgram +from paddle.fluid import program_guard, data def create_program_from_desc(program_desc): @@ -56,6 +57,7 @@ def extract_vars(inputs): def _dygraph_to_static_output_(dygraph_func): def __impl__(*args, **kwargs): + # Get AST from dygraph function dygraph_code = inspect.getsource(dygraph_func) dygraph_code = textwrap.dedent(dygraph_code) @@ -64,14 +66,53 @@ def _dygraph_to_static_output_(dygraph_func): # Transform AST dygraph_to_static = DygraphToStaticAst() root_wrapper = dygraph_to_static.get_static_ast(root) + + # Get static_func from AST func_name = dygraph_to_static.get_module_name() static_func, file_name = ast_to_func(root_wrapper.node, func_name) - return static_func(*args, **kwargs) + if not in_dygraph_mode(): + return static_func(*args, **kwargs) + else: + feed_name_to_idx = dygraph_to_static.get_feed_name_to_idx() + feed_dict = {} + for feed_name, idx in feed_name_to_idx.items(): + feed_dict[feed_name] = args[idx] + + # Run static_func in static mode + startup_program = Program() + main_program = Program() + static_res = run_static_func(main_program, startup_program, + static_func, args, kwargs, feed_dict, + feed_name_to_idx) + + return static_res return __impl__ +@switch_to_static_graph +def run_static_func(main_program, startup_program, static_func, args, kwargs, + feed_dict, feed_name_to_idx): + + with program_guard(main_program, startup_program): + args_list = list(args) + for var_name, value in feed_dict.items(): + idx = feed_name_to_idx[var_name] + args_list[idx] = data( + name=var_name, shape=value.shape, dtype=str(value.dtype)) + args = tuple(args_list) + static_out = static_func(*args, **kwargs) + if not isinstance(static_out, (list, tuple)): + static_out = [static_out] + exe = Executor(core.CPUPlace()) + exe.run(startup_program) + static_res = exe.run(main_program, + fetch_list=static_out, + feed=feed_dict) + return static_res + + dygraph_to_static_output = wrap_decorator(_dygraph_to_static_output_) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index f972627cde..3a503396ea 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -337,6 +337,7 @@ set_tests_properties(test_parallel_executor_seresnext_with_reduce_cpu PROPERTIES set_tests_properties(test_parallel_executor_seresnext_with_fuse_all_reduce_cpu PROPERTIES TIMEOUT 750) add_subdirectory(sequence) +add_subdirectory(dygraph_to_static) if (WITH_NGRAPH) add_subdirectory(ngraph) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/CMakeLists.txt b/python/paddle/fluid/tests/unittests/dygraph_to_static/CMakeLists.txt new file mode 100644 index 0000000000..f71e04c09a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +foreach(TEST_OP ${TEST_OPS}) + py_test_modules(${TEST_OP} MODULES ${TEST_OP}) +endforeach(TEST_OP) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py new file mode 100644 index 0000000000..b4f96d81c4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py @@ -0,0 +1,100 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from paddle.fluid.dygraph.jit import dygraph_to_static_output + +import numpy as np +import unittest + +import paddle.fluid as fluid + +SEED = 2020 + + +class Pool2D(fluid.dygraph.Layer): + def __init__(self): + super(Pool2D, self).__init__() + self.pool2d = fluid.dygraph.Pool2D( + pool_size=2, pool_type='avg', pool_stride=1, global_pooling=False) + + @dygraph_to_static_output + def forward(self, x): + inputs = fluid.dygraph.to_variable(x) + pre = self.pool2d(inputs) + return pre + + +class Linear(fluid.dygraph.Layer): + def __init__(self): + super(Linear, self).__init__() + + @dygraph_to_static_output + def forward(self, x): + fc = fluid.dygraph.Linear( + input_dim=10, + output_dim=5, + act='relu', + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.99)), + bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.5))) + inputs = fluid.dygraph.to_variable(x) + pre = fc(inputs) + return pre + + +class TestPool2D(unittest.TestCase): + def setUp(self): + self.dygraph_class = Pool2D + self.data = np.random.random((1, 2, 4, 4)).astype('float32') + + def run_dygraph_mode(self): + with fluid.dygraph.guard(): + dy_layer = self.dygraph_class() + for _ in range(1): + + prediction = dy_layer(x=self.data) + return prediction + + def run_static_mode(self): + startup_prog = fluid.Program() + main_prog = fluid.Program() + with fluid.program_guard(main_prog, startup_prog): + dy_layer = self.dygraph_class() + out = dy_layer(x=self.data) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + res = exe.run(main_prog, fetch_list=out) + return res + + def test_static_output(self): + dygraph_res = self.run_dygraph_mode() + static_res = self.run_static_mode() + self.assertTrue( + np.allclose(dygraph_res[0], static_res[0]), + msg='dygraph is {}\n static_res is \n{}'.format(dygraph_res, + static_res)) + return + + +class TestLinear(unittest.TestCase): + def setUp(self): + self.dygraph_class = Linear + self.data = np.random.random((4, 10)).astype('float32') + + +if __name__ == '__main__': + unittest.main() -- GitLab