From aca3f5311d37afeb30c40b3e2729337a9f04643f Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Thu, 5 Mar 2020 11:13:33 +0800 Subject: [PATCH] Support "while" in Dygraph to Static (#22841) Add basic support for while in translating dygraph to static 1. Analysis the variable liveness in class NameVisitor 2. Replace while key word using while_loop API --- .../dygraph/dygraph_to_static/__init__.py | 8 + .../dygraph_to_static/ast_transformer.py | 24 +- .../dygraph/dygraph_to_static/ast_utils.py | 2 +- .../dygraph_to_static/loop_transformer.py | 251 ++++++++++++++++++ .../dygraph_to_static/variable_trans_func.py | 46 ++++ .../unittests/test_dygraph_to_static_loop.py | 80 ++++++ 6 files changed, 401 insertions(+), 10 deletions(-) create mode 100644 python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py create mode 100644 python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py create mode 100644 python/paddle/fluid/tests/unittests/test_dygraph_to_static_loop.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/__init__.py b/python/paddle/fluid/dygraph/dygraph_to_static/__init__.py index 9df7cb4e3c4..e39a68f96b8 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/__init__.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/__init__.py @@ -20,10 +20,18 @@ from .ast_transformer import * from . import static_analysis from .static_analysis import * +from . import loop_transformer +from .loop_transformer import * + +from . import variable_trans_func +from .variable_trans_func import * + from . import cache_program from .cache_program import * __all__ = [] __all__ += ast_transformer.__all__ +__all__ += loop_transformer.__all__ __all__ += static_analysis.__all__ +__all__ += variable_trans_func.__all__ __all__ += cache_program.__all__ diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index 86837e993d7..c641c0a8c13 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -13,17 +13,21 @@ # limitations under the License. from __future__ import print_function -from .utils import * -import gast -import textwrap -import inspect + +import astor # gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST). # It provides a compatibility layer between the AST of various Python versions, # as produced by ast.parse from the standard ast module. # See details in https://github.com/serge-sans-paille/gast/ -from .ast_utils import is_control_flow_if, create_cond_node, transform_if_else, ast_to_func +import gast +import textwrap +import inspect + from paddle.fluid import unique_name +from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer +from .ast_utils import is_control_flow_if, create_cond_node, transform_if_else, ast_to_func from .static_analysis import AstNodeWrapper, StaticAnalysisVisitor +from .utils import * __all__ = ['DygraphToStaticAst', 'convert_to_static'] @@ -124,17 +128,19 @@ class DygraphToStaticAst(gast.NodeTransformer): self.transfer_from_node_type(self.static_analysis_root) return self.static_analysis_root - def transfer_from_node_type(self, node): + def transfer_from_node_type(self, node_wrapper): # Generic transformation - self.visit(node.node) + self.visit(node_wrapper.node) # Transform basic api of dygraph to static graph - basic_api_trans = BasicApiTransformer(node) + basic_api_trans = BasicApiTransformer(node_wrapper) basic_api_trans.ast_visit() self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id() # Transform all if/else statement of Dygraph into Static Graph. - IfElseTransformer(node).ast_visit() + IfElseTransformer(node_wrapper).ast_visit() + + LoopTransformer(node_wrapper).transform() def visit_FunctionDef(self, node): if self.decorate_func_name is None: diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py index 357f746eb5a..283b5b1a9d5 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py @@ -14,8 +14,8 @@ from __future__ import print_function -import astor import ast +import astor import gast import six import copy diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py new file mode 100644 index 00000000000..f3f297c6a63 --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -0,0 +1,251 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import copy +import gast + +from collections import defaultdict +from paddle.fluid import unique_name +from paddle.fluid.dygraph.dygraph_to_static.ast_utils import create_funcDef_node +from paddle.fluid.dygraph.dygraph_to_static.ast_utils import generate_name_node +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper +from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node +from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node + +__all__ = ['LoopTransformer', 'NameVisitor'] + +WHILE_CONDITION_PREFIX = 'while_condition' +WHILE_BODY_PREFIX = 'while_body' + + +def create_while_node(condition_name, body_name, loop_var_names): + while_args = [] + while_args.append( + gast.Name( + id=condition_name, + ctx=gast.Param(), + annotation=None, + type_comment=None)) + while_args.append( + gast.Name( + id=body_name, ctx=gast.Param(), annotation=None, type_comment=None)) + assign_targets = [ + gast.Name( + id=var_name, ctx=gast.Param(), annotation=None, type_comment=None) + for var_name in loop_var_names + ] + while_args.append(gast.List(elts=assign_targets, ctx=gast.Param())) + + while_func_id = gast.parse('fluid.layers.while_loop').body[0].value + while_node = gast.Call(func=while_func_id, args=while_args, keywords=[]) + assign_node = gast.Assign( + targets=[gast.Tuple( + elts=assign_targets, ctx=gast.Store())], + value=while_node) + return assign_node + + +class NameVisitor(gast.NodeVisitor): + ''' + Analysis name liveness for loop transformer + ''' + + def __init__(self, root_node): + # Set of gast.Name + self.current_seen_vars = set() + # List of gast.While/gast.For nodes + self.current_loop = [] + + # Mapping from gast.While/gast.For to string name of vars + self.before_loop_vars = defaultdict(set) + self.in_loop_vars = defaultdict(set) + + self.visit(root_node) + + def is_control_flow_loop(self, node): + # TODO: make a better condition + return True + + def get_loop_var_names(self, node): + assert isinstance(node, gast.While) or isinstance( + while_node, gast.For), "Input node is not gast loop node" + loop_var_names = set() + create_var_names = set() + read_context = {type(gast.Load), type(gast.AugLoad)} + + in_loop_vars = self.in_loop_vars[node] + in_loop_name_strs = set(name.id for name in in_loop_vars) + before_loop_vars = self.before_loop_vars[node] + before_loop_name_strs = set(name.id for name in before_loop_vars) + after_loop_vars = self.current_seen_vars - before_loop_vars - in_loop_vars + after_loop_name_strs = set( + name.id for name in after_loop_vars + if type(name.ctx) in read_context) + for name in in_loop_name_strs: + if name in before_loop_name_strs: + # If a variable is used in loop and created before loop, it + # should be in loop_var as input + loop_var_names.add(name) + elif name in after_loop_name_strs: + # If a variable is created in the while loop and read after + # loop, it should be in loop_var and we should create it + loop_var_names.add(name) + create_var_names.add(name) + return loop_var_names, create_var_names + + def visit_Name(self, node): + self.current_seen_vars.add(node) + for loop_node in self.current_loop: + self.in_loop_vars[loop_node].add(node) + self.generic_visit(node) + + def visit_For(self, node): + self.current_loop.append(node) + self.before_loop_vars[node] = copy.deepcopy(self.current_seen_vars) + self.generic_visit(node) + self.current_loop.pop() + + def visit_While(self, node): + self.current_loop.append(node) + self.before_loop_vars[node] = copy.deepcopy(self.current_seen_vars) + self.generic_visit(node) + self.current_loop.pop() + + +class LoopTransformer(gast.NodeTransformer): + """ + This class transforms python while/for statement into Static Graph Ast + """ + + def __init__(self, wrapper_root): + assert isinstance( + wrapper_root, AstNodeWrapper + ), "Input non-AstNodeWrapper node for the initialization of WhileTransformer." + self.wrapper_root = wrapper_root + self.root = wrapper_root.node + self.name_visitor = NameVisitor(self.root) + + def transform(self): + self.visit(self.root) + + def get_for_stmt_nodes(self, node): + self.generic_visit(node) + # TODO + return node + + def visit(self, node): + self.generic_visit(node) + # All parent nodes that may contain gast.While/gast.For + if hasattr(node, 'body'): + self.replace_stmt_list(node.body) + if hasattr(node, 'orelse'): + self.replace_stmt_list(node.orelse) + return node + + def replace_stmt_list(self, body_list): + if not isinstance(body_list, list): + return + + i = 0 + while i < len(body_list): + if isinstance(body_list[i], gast.While): + new_stmts = self.get_while_stmt_nodes(body_list[i]) + body_list[i:i + 1] = new_stmts + i += len(new_stmts) + elif isinstance(body_list[i], gast.For): + # TODO + i += 1 + else: + i += 1 + + def get_while_stmt_nodes(self, node): + # TODO: consider while - else in python + # self.generic_visit(node) + + if not self.name_visitor.is_control_flow_loop(node): + return [node] + + loop_var_names, create_var_names = self.name_visitor.get_loop_var_names( + node) + new_stmts = [] + + # Python can create variable in loop and use it out of loop, E.g. + # + # while x < 10: + # x += 1 + # y = x + # z = y + # + # We need to create static variable for those variables + for name in create_var_names: + new_stmts.append(create_static_variable_gast_node(name)) + + # while x < 10 in dygraph should be convert into static tensor < 10 + for name in loop_var_names: + new_stmts.append(to_static_variable_gast_node(name)) + + condition_func_node = gast.FunctionDef( + name=unique_name.generate(WHILE_CONDITION_PREFIX), + args=gast.arguments( + args=[ + gast.Name( + id=name, + ctx=gast.Param(), + annotation=None, + type_comment=None) for name in loop_var_names + ], + posonlyargs=[], + vararg=None, + kwonlyargs=[], + kw_defaults=None, + kwarg=None, + defaults=[]), + body=[gast.Return(value=node.test)], + decorator_list=[], + returns=None, + type_comment=None) + new_stmts.append(condition_func_node) + + new_body = node.body + new_body.append( + gast.Return(value=generate_name_node( + loop_var_names, ctx=gast.Load()))) + body_func_node = gast.FunctionDef( + name=unique_name.generate(WHILE_BODY_PREFIX), + args=gast.arguments( + args=[ + gast.Name( + id=name, + ctx=gast.Param(), + annotation=None, + type_comment=None) for name in loop_var_names + ], + posonlyargs=[], + vararg=None, + kwonlyargs=[], + kw_defaults=None, + kwarg=None, + defaults=[]), + body=new_body, + decorator_list=[], + returns=None, + type_comment=None) + new_stmts.append(body_func_node) + + while_loop_node = create_while_node(condition_func_node.name, + body_func_node.name, loop_var_names) + new_stmts.append(while_loop_node) + return new_stmts diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py new file mode 100644 index 00000000000..621299ddda2 --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py @@ -0,0 +1,46 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import gast + +from paddle.fluid.layers import fill_constant + +__all__ = ['to_static_variable_gast_node', 'create_static_variable_gast_node'] + + +def to_static_variable_gast_node(name): + func_code = "{} = fluid.dygraph.dygraph_to_static.variable_trans_func.to_static_variable({})".format( + name, name) + return gast.parse(func_code) + + +def create_static_variable_gast_node(name): + func_code = "{} = fluid.layers.data(name='{}', shape=[-1], dtype='float32')".format( + name, name) + return gast.parse(func_code) + + +def to_static_variable(x): + ''' + Translate a Python variable to PaddlePaddle static graph variable + ''' + if isinstance(x, bool): + return fill_constant(shape=[1], dtype='bool', value=x) + if isinstance(x, int): + return fill_constant(shape=[1], dtype='int64', value=x) + if isinstance(x, float): + return fill_constant(shape=[1], dtype='float64', value=x) + return x diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_to_static_loop.py b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_loop.py new file mode 100644 index 00000000000..9ca551f91b8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_loop.py @@ -0,0 +1,80 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import gast +import inspect +import numpy as np +import paddle.fluid as fluid +import unittest + +from paddle.fluid.dygraph.jit import dygraph_to_static_graph +#from paddle.fluid.dygraph.dygraph_to_static import NameVistor + +SEED = 2020 +np.random.seed(SEED) + + +def while_loop_dyfunc(x): + i = fluid.dygraph.to_variable(x) + while x < 10: + i = i + x + x = x + 1 + return i + + +class TestNameVisitor(unittest.TestCase): + def test_loop_vars(self): + #TODO + pass + + +class TestTransformWhile(unittest.TestCase): + def setUp(self): + self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( + ) else fluid.CPUPlace() + self.x = np.zeros(shape=(1), dtype=np.int32) + + def _run_static(self): + main_program = fluid.Program() + with fluid.program_guard(main_program): + x_var = fluid.layers.assign(self.x) + static_func = dygraph_to_static_graph(while_loop_dyfunc) + + out = static_func(x_var) + exe = fluid.Executor(self.place) + ret = exe.run(main_program, fetch_list=out) + return ret + + def _run_dygraph(self): + with fluid.dygraph.guard(self.place): + ret = while_loop_dyfunc(fluid.dygraph.to_variable(self.x)) + return ret.numpy() + + def test_ast_to_func(self): + static_numpy = self._run_static() + self.assertTrue( + np.allclose( + np.full( + shape=(1), fill_value=45, dtype=np.int32), static_numpy)) + + # Enable next lines after Paddle dygraph supports while x < 10 + # + # self._run_dygraph() + # self.assertTrue(np.allclose(self._run_dygraph(), self._run_static())) + + +if __name__ == '__main__': + unittest.main() -- GitLab