diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index c641c0a8c1349723da88080b1b10f6ff31213c2b..6305dcee70cca80915ad0ee5c113d82aa40cae04 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -14,19 +14,21 @@ from __future__ import print_function +import copy +import inspect +import textwrap + import astor # gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST). # It provides a compatibility layer between the AST of various Python versions, # as produced by ast.parse from the standard ast module. # See details in https://github.com/serge-sans-paille/gast/ import gast -import textwrap -import inspect from paddle.fluid import unique_name from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer from .ast_utils import is_control_flow_if, create_cond_node, transform_if_else, ast_to_func -from .static_analysis import AstNodeWrapper, StaticAnalysisVisitor +from .static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor from .utils import * __all__ = ['DygraphToStaticAst', 'convert_to_static'] @@ -121,8 +123,10 @@ class DygraphToStaticAst(gast.NodeTransformer): def get_static_ast(self, root): # save root for some analysis may need global AST self.root = root - self.static_analysis_root = StaticAnalysisVisitor( - root).get_node_wrapper_root() + self.static_analysis_visitor = StaticAnalysisVisitor(root) + self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root( + ) + self.decorate_func_name = None self.arg_name_to_idx = {} self.transfer_from_node_type(self.static_analysis_root) @@ -133,7 +137,8 @@ class DygraphToStaticAst(gast.NodeTransformer): self.visit(node_wrapper.node) # Transform basic api of dygraph to static graph - basic_api_trans = BasicApiTransformer(node_wrapper) + basic_api_trans = BasicApiTransformer(node_wrapper, + self.static_analysis_visitor) basic_api_trans.ast_visit() self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id() @@ -178,14 +183,31 @@ class BasicApiTransformer(gast.NodeTransformer): Class to transform basic API from dygraph to static graph. """ - def __init__(self, wrapper_root): + def __init__(self, wrapper_root, static_analysis_visitor): assert isinstance( wrapper_root, AstNodeWrapper ), "Input non-AstNodeWrapper node for the initialization of BasicApiTransformer." + self.wrapper_root = wrapper_root self.root = wrapper_root.node self.class_node_dict = {} + + # Used for transformation of data feed self.feed_name_to_arg_id = {} + self.name_to_tensor_shape = {} + + # Used for transformation of Tensor.shape + self.static_analysis_visitor = static_analysis_visitor + self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( + ) + self.scope_var_type_dict = {} + self._run_static_visitor() + + def _run_static_visitor(self): + var_env = copy.deepcopy(self.static_analysis_visitor.get_var_env()) + # TODO: Consider that Tensor.shape is used in sub function and sub_scopes is empty + var_env.cur_scope = var_env.cur_scope.sub_scopes[0] + self.scope_var_type_dict = var_env.get_scope_var_type() def ast_visit(self): self.visit(self.root) @@ -204,11 +226,12 @@ class BasicApiTransformer(gast.NodeTransformer): if self._update_class_node_dict(node): return None - value_node = node.value - for child_node in gast.walk(value_node): + if self._update_name_to_tensor_shape(node): + return node + + for child_node in gast.walk(node.value): if isinstance(child_node, gast.Call): self._visit_Call(child_node) - return node def visit_Expr(self, node): @@ -219,19 +242,41 @@ class BasicApiTransformer(gast.NodeTransformer): return else: self._visit_Call(child_node) + return node + + def visit_Attribute(self, node): + if self._used_by_paddle_api(node): + if self.is_tensor_shape(node): + return create_api_shape_node(node) + return node + def visit_Name(self, node): + if node.id in self.name_to_tensor_shape: + if self._used_by_paddle_api(node): + tensor_shape_node = self.name_to_tensor_shape[node.id] + if isinstance(tensor_shape_node, gast.Attribute): + return create_api_shape_node(tensor_shape_node) + elif isinstance(tensor_shape_node, gast.Subscript): + result_node = copy.deepcopy(tensor_shape_node) + result_node.value = create_api_shape_node( + tensor_shape_node.value) + return result_node return node def _visit_Call(self, node): assert isinstance(node, gast.Call) - # Replace API `to_variable` with `fluid.layers.assign` if is_to_variable(node): self._update_feed_dict(node) node = to_assign_node(node) return node + if is_paddle_api(node): + # Visit gast.Attribute and gast.Name to replace tensor.shape if necessary + self.generic_visit(node) + func_name = astor.to_source(gast.gast_to_ast(node.func)) + if self._is_dygraph_forward(func_name): class_node = self._get_class_node(func_name) static_node = to_static_ast(node, class_node) @@ -239,6 +284,53 @@ class BasicApiTransformer(gast.NodeTransformer): else: return node + def is_tensor_shape(self, node): + """ + Return True if node is like `x.shape` and x is Tensor, return False otherwise. + """ + assert isinstance(node, gast.Attribute) + if node.attr != 'shape': + return False + + try: + value_id = node.value.id + except AttributeError: + return False + + if value_id in self.name_to_tensor_shape: + return True + + # TODO: `value_id` may be not in scope_var_type_dict if `value_id` is the arg of decorated function + # Need a better way to confirm whether `value_id` is a Tensor. + try: + var_type_set = self.scope_var_type_dict[value_id] + except KeyError: + return False + + if NodeVarType.NUMPY_NDARRAY in var_type_set: + return False + if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set: + return False + + return True + + def _used_by_paddle_api(self, node): + assert isinstance(node, (gast.Attribute, gast.Name)) + wrapper_node = self.node_to_wrapper_map.get(node) + if not wrapper_node: + # Transformed node is not in node_to_wrapper_map + return False + while wrapper_node.parent: + parent_node = wrapper_node.parent.node + if isinstance(parent_node, gast.Call): + if is_paddle_api(parent_node): + return True + else: + return False + wrapper_node = wrapper_node.parent + + return False + def _is_dygraph_forward(self, func_id): return func_id in self.class_node_dict @@ -280,6 +372,32 @@ class BasicApiTransformer(gast.NodeTransformer): def get_feed_name_to_arg_id(self): return self.feed_name_to_arg_id + def _update_name_to_tensor_shape(self, node): + assert isinstance(node, gast.Assign) + # TODO: Consider node has more than one target. eg: x, y = a, Tensor.shape[1] + target_node = node.targets[0] + try: + target_id = target_node.id + except AttributeError: + return False + value_node = node.value + + if isinstance(value_node, gast.Name): + if value_node.id in self.name_to_tensor_shape: + self.name_to_tensor_shape[ + target_id] = self.name_to_tensor_shape[value_node.id] + return True + if isinstance(value_node, gast.Attribute): + if self.is_tensor_shape(value_node): # eg: x.shape + self.name_to_tensor_shape[target_id] = value_node + return True + if isinstance(value_node, gast.Subscript): + if isinstance(value_node.value, gast.Attribute): + if self.is_tensor_shape(value_node.value): # eg: x.shape[0] + self.name_to_tensor_shape[target_id] = value_node + return True + return False + def convert_to_static(dyfunc): """ diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py index 283b5b1a9d5d60675a7979f27ed90e5e56a0ab2e..3f8a1699739b1bf4655271dc6b404f4e3d0870d5 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py @@ -360,7 +360,9 @@ def ast_to_func(ast_root, func_name, delete_on_exit=True): # TODO(Aurelius84): more elegant way to transform ast into callable object import_str = "import paddle\n" \ "import paddle.fluid as fluid\n" \ - "import paddle.fluid.layers as layers\n" + "import paddle.fluid.layers as layers\n" \ + "import numpy as np\n" \ + "import numpy\n" with f: module_name = os.path.basename(f.name[:-3]) f.write(import_str) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index d2d275041213630c1f438190e1e9e439fc297134..cf54043dd37a8bff017fdc4e559a4892601f7994 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -181,3 +181,12 @@ def update_args_of_func(node, dygraph_node, method_name): node.args = [] node.keywords = added_keywords + node.keywords + + +def create_api_shape_node(tensor_shape_node): + assert isinstance(tensor_shape_node, gast.Attribute) + api_shape_node = gast.Call( + func=gast.parse('fluid.layers.shape').body[0].value, + args=[tensor_shape_node.value], + keywords=[]) + return api_shape_node diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py index 13f24f5aa56cb68b956a9875c395b4725a1b9936..cab2e7e78e1ce7e68d5362c1b978061b17a3486f 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py new file mode 100644 index 0000000000000000000000000000000000000000..30f46e216f5ef2b0956fcfef87b33f3437cfcd32 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py @@ -0,0 +1,100 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy + +import unittest +import paddle.fluid as fluid +from paddle.fluid.dygraph.jit import dygraph_to_static_graph + + +def dyfunc_tensor_shape_1(x): + x = fluid.dygraph.to_variable(x) + res = fluid.layers.reshape(x, shape=x.shape) + return res + + +def dyfunc_tensor_shape_2(x): + x = fluid.dygraph.to_variable(x) + shape = x.shape + shape2 = shape + res = fluid.layers.reshape(x, shape2) + return res + + +def dyfunc_tensor_shape_3(x): + # Don't transform y.shape because y is numpy.ndarray + x = fluid.dygraph.to_variable(x) + y = numpy.ones(5) + res = fluid.layers.reshape(x, shape=y.shape) + return res + + +def dyfunc_tensor_shape_4(x): + x = fluid.dygraph.to_variable(x) + res = fluid.layers.reshape(x, shape=(-1, x.shape[0], len(x.shape))) + return res + + +def dyfunc_tensor_shape_5(x): + # `res = fluid.layers.reshape(x, shape=(-1, s))` to + # `res = fluid.layers.reshape(x, shape=(-1, fluid.layers.shape(x)[0]))` + x = fluid.dygraph.to_variable(x) + s = x.shape[0] + res = fluid.layers.reshape(x, shape=(-1, s)) + return res + + +test_funcs = [ + dyfunc_tensor_shape_1, dyfunc_tensor_shape_2, dyfunc_tensor_shape_3, + dyfunc_tensor_shape_4, dyfunc_tensor_shape_5 +] + + +class TestTensorShape(unittest.TestCase): + def setUp(self): + self.input = numpy.ones(5).astype("int32") + self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( + ) else fluid.CPUPlace() + + def get_dygraph_output(self): + with fluid.dygraph.guard(): + res = self.dygraph_func(self.input).numpy() + return res + + def get_static_output(self): + main_program = fluid.Program() + with fluid.program_guard(main_program): + static_out = dygraph_to_static_graph(self.dygraph_func)(self.input) + + exe = fluid.Executor(self.place) + static_res = exe.run(main_program, fetch_list=static_out) + + return static_res[0] + + def test_transformed_static_result(self): + for func in test_funcs: + self.dygraph_func = func + static_res = self.get_static_output() + dygraph_res = self.get_dygraph_output() + self.assertTrue( + numpy.allclose(dygraph_res, static_res), + msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res, + static_res)) + + +if __name__ == '__main__': + unittest.main()