提交 903039a3 编写于 作者: H Huihuang Zheng 提交者: GitHub

Add Simple Framework for Transforming Dygraph to Static Graph (#22491)

This PR provides very basic and simple framework for transforming Dygraph to Static Graph.

API names, final outputs are not determined yet. Feel free to modify or add class/function/type when you think the framework is not extendable for you.
上级 a61d0952
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from . import ast_transformer
from .ast_transformer import *
__all__ = []
__all__ += ast_transformer.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import ast
__all__ = ['DygraphToStaticAst']
class NodeVarType(object):
"""
Enum class of python variable types. We have to know some variable types
during compile time to transfer AST. For example, a string variable and a
tensor variable in if clause may lead to different conversion from dygraph
to static graph.
"""
UNKNOWN = 0 # Reserve for AST nodes have not known the type
STATEMENT = 1 # For nodes representing statement (non-variable type)
PADDLE_DYGRAPH_API = 2
PADDLE_CONTROL_IF = 3
PADDLE_CONTROL_WHILE = 4
PADDLE_CONTROL_FOR = 5
NONE = 100
INT = 101
FLOAT = 102
STRING = 103
TENSOR = 104
class AstNodeWrapper(object):
"""
Wrapper for python ast.node. We need a node wrapper because ast.node
doesn't store all required information when we are transforming AST.
We should collect additional information which the actual transformation
needs.
"""
def __init__(self, node):
self.node = node
self.parent = None
self.node_var_type = NodeVarType.UNKNOWN
class DygraphToStaticAst(ast.NodeTransformer):
"""
Main class to transform Dygraph to Static Graph
"""
def get_static_ast(self, root):
# save root for some analysis may need global AST
self.root = root
self.static_analysis_root = AstNodeWrapper(root)
self.visit(root)
self.transfer_from_node_type(self.static_analysis_root)
return self.static_analysis_root
def visit(self, node):
# TODO construct a tree whose nodes are AstNodeWrapper
# This step also does static node type analysis
print("Not implemented")
def transfer_from_node_type(self, node):
print("Not implemented")
......@@ -12,9 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['TracedLayer']
__all__ = ['TracedLayer', 'dygraph_to_static_output']
import ast
import inspect
from ..wrapped_decorator import wrap_decorator
from .base import program_desc_tracing_guard, switch_to_static_graph
from .dygraph_to_static import DygraphToStaticAst
from .layers import Layer
from paddle.fluid import core
from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode
......@@ -45,6 +50,26 @@ def extract_vars(inputs):
return result_list
def _dygraph_to_static_output_(dygraph_func):
def __impl__(*args, **kwargs):
# Get AST from dygraph function
dygraph_code = inspect.getsource(dygraph_func)
root = ast.parse(dygraph_code)
root = DygraphToStaticAst().get_static_ast(root)
# TODO static_func should a callable from AST, like
# static_func = ast_to_func(root)
# currently just use dygraph_func
static_func = dygraph_func
return static_func(*args, **kwargs)
return __impl__
dygraph_to_static_output = wrap_decorator(_dygraph_to_static_output_)
@dygraph_only
def _trace(layer,
inputs,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core
import unittest
from paddle.fluid.dygraph.jit import dygraph_to_static_output
np.random.seed(1)
def dyfunc(a, b):
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(a)
y = fluid.dygraph.to_variable(b)
x.stop_gradient = False
y.stop_gradient = False
inputs = {'X': [x], 'Y': [y]}
loss = core.ops.elementwise_mul(inputs)['Out'][0]
loss.backward()
x_grad = x.gradient()
y_grad = y.gradient()
return x_grad, y_grad
@dygraph_to_static_output
def dyfunc_to_static(a, b):
return dyfunc(a, b)
class TestBasicModel(unittest.TestCase):
def test_dygraph_static_same_output(self):
a = np.random.uniform(
low=0.1, high=1, size=(3, 4, 5)).astype(np.float32)
b = np.random.uniform(
low=0.1, high=1, size=(3, 4, 5)).astype(np.float32)
dy_output = dyfunc(a, b)
static_output = dyfunc_to_static(a, b)
self.assertTrue(np.array_equal(dy_output[0], static_output[0]))
self.assertTrue(np.array_equal(dy_output[1], static_output[1]))
if __name__ == '__main__':
unittest.main()
......@@ -112,6 +112,7 @@ packages=['paddle',
'paddle.distributed',
'paddle.fluid',
'paddle.fluid.dygraph',
'paddle.fluid.dygraph.dygraph_to_static',
'paddle.fluid.proto',
'paddle.fluid.proto.profiler',
'paddle.fluid.distributed',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册