From 903039a3c6a74bf75bf34f9cfdbfe2e460ea97a0 Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Mon, 10 Feb 2020 11:55:14 +0800 Subject: [PATCH] Add Simple Framework for Transforming Dygraph to Static Graph (#22491) This PR provides very basic and simple framework for transforming Dygraph to Static Graph. API names, final outputs are not determined yet. Feel free to modify or add class/function/type when you think the framework is not extendable for you. --- .../dygraph/dygraph_to_static/__init__.py | 21 +++++ .../dygraph_to_static/ast_transformer.py | 76 +++++++++++++++++++ python/paddle/fluid/dygraph/jit.py | 27 ++++++- .../unittests/test_dygraph_to_static_basic.py | 62 +++++++++++++++ python/setup.py.in | 1 + 5 files changed, 186 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/dygraph/dygraph_to_static/__init__.py create mode 100644 python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py create mode 100644 python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/__init__.py b/python/paddle/fluid/dygraph/dygraph_to_static/__init__.py new file mode 100644 index 0000000000..a36d2c220f --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from . import ast_transformer +from .ast_transformer import * + +__all__ = [] +__all__ += ast_transformer.__all__ diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py new file mode 100644 index 0000000000..2c5387f052 --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -0,0 +1,76 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import ast + +__all__ = ['DygraphToStaticAst'] + + +class NodeVarType(object): + """ + Enum class of python variable types. We have to know some variable types + during compile time to transfer AST. For example, a string variable and a + tensor variable in if clause may lead to different conversion from dygraph + to static graph. + """ + UNKNOWN = 0 # Reserve for AST nodes have not known the type + STATEMENT = 1 # For nodes representing statement (non-variable type) + PADDLE_DYGRAPH_API = 2 + PADDLE_CONTROL_IF = 3 + PADDLE_CONTROL_WHILE = 4 + PADDLE_CONTROL_FOR = 5 + + NONE = 100 + INT = 101 + FLOAT = 102 + STRING = 103 + TENSOR = 104 + + +class AstNodeWrapper(object): + """ + Wrapper for python ast.node. We need a node wrapper because ast.node + doesn't store all required information when we are transforming AST. + We should collect additional information which the actual transformation + needs. + """ + + def __init__(self, node): + self.node = node + self.parent = None + self.node_var_type = NodeVarType.UNKNOWN + + +class DygraphToStaticAst(ast.NodeTransformer): + """ + Main class to transform Dygraph to Static Graph + """ + + def get_static_ast(self, root): + # save root for some analysis may need global AST + self.root = root + self.static_analysis_root = AstNodeWrapper(root) + self.visit(root) + self.transfer_from_node_type(self.static_analysis_root) + return self.static_analysis_root + + def visit(self, node): + # TODO construct a tree whose nodes are AstNodeWrapper + # This step also does static node type analysis + print("Not implemented") + + def transfer_from_node_type(self, node): + print("Not implemented") diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index aab02d63ab..e067e70443 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -12,9 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ['TracedLayer'] +__all__ = ['TracedLayer', 'dygraph_to_static_output'] +import ast +import inspect + +from ..wrapped_decorator import wrap_decorator from .base import program_desc_tracing_guard, switch_to_static_graph +from .dygraph_to_static import DygraphToStaticAst from .layers import Layer from paddle.fluid import core from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode @@ -45,6 +50,26 @@ def extract_vars(inputs): return result_list +def _dygraph_to_static_output_(dygraph_func): + def __impl__(*args, **kwargs): + # Get AST from dygraph function + dygraph_code = inspect.getsource(dygraph_func) + root = ast.parse(dygraph_code) + + root = DygraphToStaticAst().get_static_ast(root) + + # TODO static_func should a callable from AST, like + # static_func = ast_to_func(root) + # currently just use dygraph_func + static_func = dygraph_func + return static_func(*args, **kwargs) + + return __impl__ + + +dygraph_to_static_output = wrap_decorator(_dygraph_to_static_output_) + + @dygraph_only def _trace(layer, inputs, diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic.py b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic.py new file mode 100644 index 0000000000..39c8fbe6fc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic.py @@ -0,0 +1,62 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.fluid.core as core +import unittest + +from paddle.fluid.dygraph.jit import dygraph_to_static_output + +np.random.seed(1) + + +def dyfunc(a, b): + with fluid.dygraph.guard(): + x = fluid.dygraph.to_variable(a) + y = fluid.dygraph.to_variable(b) + x.stop_gradient = False + y.stop_gradient = False + + inputs = {'X': [x], 'Y': [y]} + loss = core.ops.elementwise_mul(inputs)['Out'][0] + + loss.backward() + x_grad = x.gradient() + y_grad = y.gradient() + return x_grad, y_grad + + +@dygraph_to_static_output +def dyfunc_to_static(a, b): + return dyfunc(a, b) + + +class TestBasicModel(unittest.TestCase): + def test_dygraph_static_same_output(self): + a = np.random.uniform( + low=0.1, high=1, size=(3, 4, 5)).astype(np.float32) + b = np.random.uniform( + low=0.1, high=1, size=(3, 4, 5)).astype(np.float32) + dy_output = dyfunc(a, b) + static_output = dyfunc_to_static(a, b) + self.assertTrue(np.array_equal(dy_output[0], static_output[0])) + self.assertTrue(np.array_equal(dy_output[1], static_output[1])) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index 99d1ac6a7e..d2f4571b11 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -112,6 +112,7 @@ packages=['paddle', 'paddle.distributed', 'paddle.fluid', 'paddle.fluid.dygraph', + 'paddle.fluid.dygraph.dygraph_to_static', 'paddle.fluid.proto', 'paddle.fluid.proto.profiler', 'paddle.fluid.distributed', -- GitLab