From fb7b008acc9f35e60928d49d94b1c145fb718b3d Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Thu, 19 Mar 2020 10:01:15 +0800 Subject: [PATCH] Add Support for Break and Continue in Dygraph to Static (#23067) 1. Add support for Break and Continue in Dygraph to Static 2. Also add support for gast.Not in NodeTestTransformer 3. Also add support for logical op transformation in LoopTransformer --- .../dygraph_to_static/ast_transformer.py | 30 +- .../break_continue_transformer.py | 353 ++++++++++++++++++ .../dygraph_to_static/cache_program.py | 4 +- .../dygraph_to_static/ifelse_transformer.py | 18 +- .../dygraph_to_static/loop_transformer.py | 65 +++- .../fluid/dygraph/dygraph_to_static/utils.py | 8 + .../dygraph_to_static/variable_trans_func.py | 26 +- .../dygraph_to_static/test_break_continue.py | 211 +++++++++++ .../unittests/dygraph_to_static/test_loop.py | 21 +- .../unittests/dygraph_to_static/test_utils.py | 33 ++ .../test_variable_trans_func.py | 53 +++ 11 files changed, 799 insertions(+), 23 deletions(-) create mode 100644 python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_utils.py create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_variable_trans_func.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index ea25c3d715f..eaa16575c66 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -14,28 +14,31 @@ from __future__ import print_function -import copy -import inspect -import textwrap - import astor +import copy # gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST). # It provides a compatibility layer between the AST of various Python versions, # as produced by ast.parse from the standard ast module. # See details in https://github.com/serge-sans-paille/gast/ import gast +import inspect +import textwrap from paddle.fluid import unique_name -from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func -from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api, is_to_variable -from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func -from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api -from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer + +from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakContinueTransformer +from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer -from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer +from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer + from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func +from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api, is_dygraph_api, is_to_variable +from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func +from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api __all__ = ['DygraphToStaticAst', 'convert_to_static'] @@ -74,12 +77,15 @@ class DygraphToStaticAst(gast.NodeTransformer): # Transform list used in control flow ListTransformer(node_wrapper).transform() - # Transform all if/else statement of Dygraph into Static Graph. - IfElseTransformer(node_wrapper).transform() + # Transform break/continue in loops + BreakContinueTransformer(node_wrapper).transform() # Transform for loop and while loop LoopTransformer(node_wrapper).transform() + # Transform all if/else statement of Dygraph into Static Graph. + IfElseTransformer(node_wrapper).transform() + def visit_FunctionDef(self, node): if self.decorate_func_name is None: self.decorate_func_name = node.name diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py new file mode 100644 index 00000000000..b8c59b9976f --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py @@ -0,0 +1,353 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import gast + +from paddle.fluid import unique_name +from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import NodeTestTransformer +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code +from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node +from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list +from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node + +__all__ = ['BreakContinueTransformer'] + +BREAK_NAME_PREFIX = '__break' +CONTINUE_NAME_PREFIX = '__continue' + + +class ForToWhileTransformer(gast.NodeTransformer): + """ + Transform python for loop into while loop and add condition node in the + loop test + """ + + def __init__(self, parent_node, loop_node, condition_node): + assert isinstance( + loop_node, + gast.For), "loop_node is not gast.For in ForToWhileTransformer" + self.parent_node = parent_node + self.loop_node = loop_node + self.condition_node = condition_node + + def transform(self): + if hasattr(self.parent_node, 'body'): + body_list = self.parent_node.body + i = index_in_list(body_list, self.loop_node) + if i != -1: + new_stmts = self.get_for_stmt_nodes(body_list[i]) + body_list[i:i + 1] = new_stmts + i += len(new_stmts) + return + if hasattr(self.parent_node, 'orelse'): + body_list = self.parent_node.orelse + i = index_in_list(body_list, self.loop_node) + if i != -1: + new_stmts = self.get_for_stmt_nodes(body_list[i]) + body_list[i:i + 1] = new_stmts + i += len(new_stmts) + return + raise ValueError( + "parent_node doesn't contain the loop_node in ForToWhileTransformer") + + def get_for_range_node(self, node): + if not isinstance(node.iter, gast.Call): + return None + if not isinstance(node.iter.func, gast.Name): + return None + if node.iter.func.id != "range": + return None + return node.iter + + def get_for_args_stmts(self, iter_name, args_list): + ''' + Returns 3 gast stmt nodes for argument. + 1. Initailize of iterate variable + 2. Condition for the loop + 3. Statement for changing of iterate variable during the loop + ''' + len_range_args = len(args_list) + assert len_range_args >= 1 and len_range_args <= 3, "range() function takes 1 to 3 arguments" + if len_range_args == 1: + init_stmt = get_constant_variable_node(iter_name, 0) + else: + init_stmt = gast.Assign( + targets=[ + gast.Name( + id=iter_name, + ctx=gast.Store(), + annotation=None, + type_comment=None) + ], + value=args_list[0]) + + range_max_node = args_list[0] if len_range_args == 1 else args_list[1] + step_node = args_list[2] if len_range_args == 3 else gast.Constant( + value=1, kind=None) + + old_cond_stmt = gast.Compare( + left=gast.BinOp( + left=gast.Name( + id=iter_name, + ctx=gast.Load(), + annotation=None, + type_comment=None), + op=gast.Add(), + right=step_node), + ops=[gast.LtE()], + comparators=[range_max_node]) + cond_stmt = gast.BoolOp( + op=gast.And(), values=[old_cond_stmt, self.condition_node]) + + change_stmt = gast.AugAssign( + target=gast.Name( + id=iter_name, + ctx=gast.Store(), + annotation=None, + type_comment=None), + op=gast.Add(), + value=step_node) + + return init_stmt, cond_stmt, change_stmt + + def get_for_stmt_nodes(self, node): + assert isinstance( + node, gast.For), "Input node is NOT gast.For in get_for_stmt_nodes" + + # TODO: support non-range case + range_call_node = self.get_for_range_node(node) + if range_call_node is None: + return [node] + + if not isinstance(node.target, gast.Name): + return [node] + iter_var_name = node.target.id + + init_stmt, cond_stmt, change_stmt = self.get_for_args_stmts( + iter_var_name, range_call_node.args) + + new_body = node.body + new_body.append(change_stmt) + while_node = gast.While( + test=cond_stmt, body=new_body, orelse=node.orelse) + return [init_stmt, while_node] + + +class BreakContinueTransformer(gast.NodeTransformer): + """ + Rewrite 'break' and 'continue' key words in a if-else python way to make + it equivalent to original control flow + + The main idea of this class is: + + 1. Map the 'break/continue' stmt with an unique boolean variable V. + + 2. Find the first ancestor block containing this 'break/continue', a + block can be a node containing stmt list. We should remove all stmts + after the 'break/continue' and set the V to True here. + + 3. Add 'if V' for stmts in ancestor blocks between the first one + (exclusive) and the ancestor loop (inclusive) + + 4. For 'break' add break into condition of the loop. For 'continue', + set continue to False at the beginning of each loop + + TODO: more details should be summarized as design document + """ + + def __init__(self, wrapper_root): + self.wrapper_root = wrapper_root + self.root = wrapper_root.node + + self.ancestor_nodes = [] + + def transform(self): + self.visit(self.root) + + def generic_visit(self, node): + # TODO: because we change ancestor nodes during visit_Break/Continue, + # not current node, so generic_visit of NodeTransformer will visit node + # which may be deleted. To prevent that node being added into + # transformed AST, I have to self-write a generic_visit, but this is + # NOT a good thing. Considering refactorying this whole class. + for field, value in gast.iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, gast.AST): + self.visit(item) + elif isinstance(value, gast.AST): + self.visit(value) + + def visit(self, node): + self.ancestor_nodes.append(node) + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + ret = visitor(node) + self.ancestor_nodes.pop() + return ret + + def visit_Break(self, node): + loop_node_index = self._find_ancestor_loop_index(node) + assert loop_node_index != -1, "SyntaxError: 'break' outside loop" + loop_node = self.ancestor_nodes[loop_node_index] + + # 1. Map the 'break/continue' stmt with an unique boolean variable V. + variable_name = unique_name.generate(BREAK_NAME_PREFIX) + + # 2. Find the first ancestor block containing this 'break/continue', a + # block can be a node containing stmt list. We should remove all stmts + # after the 'break/continue' and set the V to True here. + first_block_index = self._remove_stmts_after_break_continue( + node, variable_name, loop_node_index) + + # 3. Add 'if V' for stmts in ancestor blocks between the first one + # (exclusive) and the ancestor loop (inclusive) + self._replace_if_stmt(loop_node_index, first_block_index, variable_name) + + # 4. For 'break' add break into condition of the loop. + assign_false_node = create_fill_constant_node(variable_name, False) + self._add_stmt_before_cur_node(loop_node_index, assign_false_node) + + cond_var_node = gast.UnaryOp( + op=gast.Not(), + operand=gast.Name( + id=variable_name, + ctx=gast.Load(), + annotation=None, + type_comment=None)) + if isinstance(loop_node, gast.While): + loop_node.test = gast.BoolOp( + op=gast.And(), values=[loop_node.test, cond_var_node]) + elif isinstance(loop_node, gast.For): + parent_node = self.ancestor_nodes[loop_node_index - 1] + for_to_while = ForToWhileTransformer(parent_node, loop_node, + cond_var_node) + for_to_while.transform() + + def visit_Continue(self, node): + loop_node_index = self._find_ancestor_loop_index(node) + assert loop_node_index != -1, "SyntaxError: 'continue' outside loop" + loop_node = self.ancestor_nodes[loop_node_index] + + # 1. Map the 'break/continue' stmt with an unique boolean variable V. + variable_name = unique_name.generate(CONTINUE_NAME_PREFIX) + + # 2. Find the first ancestor block containing this 'break/continue', a + # block can be a node containing stmt list. We should remove all stmts + # after the 'break/continue' and set the V to True here. + first_block_index = self._remove_stmts_after_break_continue( + node, variable_name, loop_node_index) + + # 3. Add 'if V' for stmts in ancestor blocks between the first one + # (exclusive) and the ancestor loop (inclusive) + self._replace_if_stmt(loop_node_index, first_block_index, variable_name) + + # 4. For 'continue', set continue to False at the beginning of each loop + assign_false_node = create_fill_constant_node(variable_name, False) + loop_node.body.insert(0, assign_false_node) + + def _remove_stmts_after_break_continue( + self, break_continue_node, break_continue_name, loop_node_index): + for first_block_index in range( + len(self.ancestor_nodes) - 1, loop_node_index - 1, -1): + first_block = self.ancestor_nodes[first_block_index] + if hasattr(first_block, + "body") and self._replace_break_continue_in_stmt_list( + first_block.body, break_continue_node, + break_continue_name): + return first_block_index + + if hasattr(first_block, + "orelse") and self._replace_break_continue_in_stmt_list( + first_block.orelse, break_continue_node, + break_continue_name): + return first_block_index + + return first_block_index + + def _replace_break_continue_in_stmt_list( + self, stmt_list, break_continue_node, break_continue_name): + i = index_in_list(stmt_list, break_continue_node) + if i == -1: + return False + assign_true_node = create_fill_constant_node(break_continue_name, True) + stmt_list[i:] = [assign_true_node] + return True + + def _replace_if_stmt(self, loop_node_index, first_block_index, + break_continue_name): + for i in range(first_block_index - 1, loop_node_index - 1, -1): + cur_node = self.ancestor_nodes[i] + son_node = self.ancestor_nodes[i + 1] + if hasattr(cur_node, + 'body') and self._replace_after_node_to_if_in_stmt_list( + cur_node.body, son_node, break_continue_name): + continue + if hasattr( + cur_node, + 'orelse') and self._replace_after_node_to_if_in_stmt_list( + cur_node.orelse, son_node, break_continue_name): + continue + + def _replace_after_node_to_if_in_stmt_list(self, stmt_list, node, + break_continue_name): + i = index_in_list(stmt_list, node) + if i == -1: + return False + + if i == len(stmt_list) - 1: + # No need to add, we consider this as added successfully + return True + + if_stmt = gast.If(test=gast.UnaryOp( + op=gast.Not(), + operand=gast.Name( + id=break_continue_name, + ctx=gast.Store(), + annotation=None, + type_comment=None)), + body=stmt_list[i + 1:], + orelse=[]) + stmt_list[i + 1:] = [] + stmt_list.append(if_stmt) + return True + + def _add_stmt_before_cur_node(self, cur_node_index, stmt_node): + cur_node = self.ancestor_nodes[cur_node_index] + parent_node = self.ancestor_nodes[cur_node_index - 1] + if hasattr(parent_node, + "body") and self._add_stmt_into_list_before_node( + parent_node.body, cur_node, stmt_node): + return True + if hasattr(parent_node, + "orelse") and self._add_stmt_into_list_before_node( + parent_node.orelse, cur_node, stmt_node): + return True + return False + + def _add_stmt_into_list_before_node(self, stmt_list, node, stmt_node): + i = index_in_list(stmt_list, node) + if i == -1: + return False + stmt_list.insert(i, stmt_node) + return True + + def _find_ancestor_loop_index(self, node): + for i in range(len(self.ancestor_nodes) - 1, -1, -1): + if isinstance(self.ancestor_nodes[i], (gast.For, gast.While)): + return i + return -1 diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/cache_program.py b/python/paddle/fluid/dygraph/dygraph_to_static/cache_program.py index 60d66b99c85..1024adf3c11 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/cache_program.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/cache_program.py @@ -310,8 +310,8 @@ class AutoTracer(object): if not isinstance(loss_name, six.string_types): raise ValueError( - "Type of input loss_name should type(str), but received {}." - .format(type(loss_name))) + "Type of input loss_name should type(str), but received {}.". + format(type(loss_name))) self._loss_name = loss_name def _add_optimizer(self): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index 82ee4ac123c..aa3edb0a789 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -36,6 +36,7 @@ TRUE_FUNC_PREFIX = 'true_fn' FALSE_FUNC_PREFIX = 'false_fn' LOGIC_AND_PREFIX = 'logic_and' LOGIC_OR_PREFIX = 'logic_or' +LOGIC_NOT_PREFIX = 'logic_not' PLAIN_TENSOR_PREFIX = 'bool_tensor' @@ -129,7 +130,7 @@ def is_candidate_node(node): """ Nodes with specified type will be dependent on tensor. """ - return isinstance(node, (gast.Compare, gast.BoolOp)) + return isinstance(node, (gast.Compare, gast.BoolOp, gast.UnaryOp)) def compare_with_none(node): @@ -268,6 +269,21 @@ class NodeTestTransformer(gast.NodeTransformer): def transform(self): return self.visit(self.ast_root) + def visit_UnaryOp(self, node): + self.generic_visit(node) + if isinstance(node.op, gast.Not): + arg = ast_to_source_code(node.operand) + new_node_str = "fluid.layers.logical_not({})".format(arg) + # gast.parse returns Module(body=[expr(value=...)]) + new_node = gast.parse(new_node_str).body[0].value + logic_tensor_name = unique_name.generate(LOGIC_NOT_PREFIX) + assign_name, assign_node = create_assign_node(logic_tensor_name, + new_node) + self._new_assign_nodes.append(assign_node) + return assign_name + + return node + def visit_BoolOp(self, node): for i, child in enumerate(node.values): if not is_candidate_node(child): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index fd471f0431c..f7097b11985 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -19,8 +19,10 @@ import gast from collections import defaultdict from paddle.fluid import unique_name +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node -from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node @@ -62,6 +64,55 @@ def create_while_node(condition_name, body_name, loop_var_names): return assign_node +class LogicalOpTransformer(gast.NodeTransformer): + """ + Transform python boolean op into Paddle logical op + """ + + def __init__(self, node): + self.root = node + + def transform(self): + return self.visit(self.root) + + def visit_UnaryOp(self, node): + self.generic_visit(node) + if isinstance(node.op, gast.Not): + arg = ast_to_source_code(node.operand) + new_node_str = "fluid.layers.logical_not({})".format(arg) + # gast.parse returns Module(body=[expr(value=...)]) + new_node = gast.parse(new_node_str).body[0].value + return new_node + return node + + def visit_BoolOp(self, node): + self.generic_visit(node) + if isinstance(node.op, gast.And): + new_node = self._create_bool_op_node(node.values, 'and') + elif isinstance(node.op, gast.Or): + new_node = self._create_bool_op_node(node.values, 'or') + else: + raise TypeError( + "Only supports and/or syntax in control flow if statement.") + return new_node + + def _create_bool_op_node(self, nodes, api_type): + assert len( + nodes + ) > 1, "The length of BoolOp should be at least 2, but received {}.".format( + len(nodes)) + if len(nodes) > 2: + # Creates logic_and/logic_or node recursively. + pre_assign_node = self._create_bool_op_node(nodes[:2], api_type) + nodes = [pre_assign_node] + nodes[2:] + args = [ast_to_source_code(child) for child in nodes] + new_node_str = "fluid.layers.logical_{}(x={}, y={})".format( + api_type, args[0], args[1]) + # gast.parse return Module(body=[expr(...)]) + new_node = gast.parse(new_node_str).body[0].value + return new_node + + class NameVisitor(gast.NodeVisitor): ''' Analysis name liveness for loop transformer @@ -89,8 +140,8 @@ class NameVisitor(gast.NodeVisitor): return True def get_loop_var_names(self, node): - assert isinstance(node, (gast.While, - gast.For)), "Input node is not gast loop node" + assert isinstance( + node, (gast.While, gast.For)), "Input node is not gast loop node" loop_var_names = set() create_var_names = set() read_context = {type(gast.Load()), type(gast.AugLoad())} @@ -118,6 +169,9 @@ class NameVisitor(gast.NodeVisitor): if self._is_call_func_name_node(node): self.generic_visit(node) return + if node.id == "False" or node.id == "True": + self.generic_visit(node) + return self.current_seen_vars.add(node) for loop_node in self.current_loop: @@ -390,6 +444,9 @@ class LoopTransformer(gast.NodeTransformer): for name in loop_var_names: new_stmts.append(to_static_variable_gast_node(name)) + logical_op_transformer = LogicalOpTransformer(node.test) + cond_value_node = logical_op_transformer.transform() + condition_func_node = gast.FunctionDef( name=unique_name.generate(WHILE_CONDITION_PREFIX), args=gast.arguments( @@ -406,7 +463,7 @@ class LoopTransformer(gast.NodeTransformer): kw_defaults=None, kwarg=None, defaults=[]), - body=[gast.Return(value=node.test)], + body=[gast.Return(value=cond_value_node)], decorator_list=[], returns=None, type_comment=None) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index ed14bafb17d..59820291848 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -299,6 +299,14 @@ def create_funcDef_node(nodes, name, input_args, return_name_ids): return func_def_node +def index_in_list(array_list, item): + try: + return array_list.index(item) + except ValueError: + # Item not in array_list + return -1 + + def ast_to_func(ast_root, func_name, delete_on_exit=True): """ Transform modified AST of decorated function into python callable object. diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py index c2979e86e77..ccf8991545d 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py @@ -25,13 +25,35 @@ __all__ = ['to_static_variable_gast_node', 'create_static_variable_gast_node'] def to_static_variable_gast_node(name): func_code = "{} = fluid.dygraph.dygraph_to_static.variable_trans_func.to_static_variable({})".format( name, name) - return gast.parse(func_code) + return gast.parse(func_code).body[0] def create_static_variable_gast_node(name): func_code = "{} = fluid.layers.data(name='{}', shape=[-1], dtype='float32')".format( name, name) - return gast.parse(func_code) + return gast.parse(func_code).body[0] + + +def create_fill_constant_node(name, value): + func_code = "{} = fluid.layers.fill_constant(shape=[1], ".format(name) + if isinstance(value, bool): + func_code += "dtype='bool', value={})".format(value) + return gast.parse(func_code).body[0] + if isinstance(value, float): + func_code += "dtype='float64', value={})".format(value) + return gast.parse(func_code).body[0] + + if six.PY2: + if isinstance(value, int): + func_code += "dtype='int32', value={})".format(value) + return gast.parse(func_code).body[0] + if isinstance(value, long): + func_code += "dtype='int64', value={})".format(value) + return gast.parse(func_code).body[0] + else: + if isinstance(value, int): + func_code += "dtype='int64', value={})".format(value) + return gast.parse(func_code).body[0] def to_static_variable(x): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py new file mode 100644 index 00000000000..31870146149 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py @@ -0,0 +1,211 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.dygraph.jit import dygraph_to_static_graph + +SEED = 2020 +np.random.seed(SEED) + + +def test_continue_in_for(x): + x = fluid.dygraph.to_variable(x) + for i in range(10): + x += 1 + if i > 5: + continue + x += 10086 + x += i + return x + + +def test_continue_in_for_at_end(x): + x = fluid.dygraph.to_variable(x) + for i in range(10): + x += 1 + if i > 5: + continue + return x + + +def test_continue_in_while(x): + x = fluid.dygraph.to_variable(x) + i = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0) + while i < 10: + i += 1 + if i > 5: + continue + x += 10086 + x += i + return x + + +def test_break_in_for(x): + x = fluid.dygraph.to_variable(x) + for i in range(10): + x += 1 + if i > 5: + break + x += 10086 + x += i + return x + + +def test_break_in_for_at_end(x): + x = fluid.dygraph.to_variable(x) + for i in range(10): + x += 1 + if i > 5: + break + return x + + +def test_break_in_while(x): + x = fluid.dygraph.to_variable(x) + i = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0) + while i < 10: + i += 1 + if i > 5: + break + x += 10086 + x += i + return x + + +def test_break_continue_in_for(x): + x = fluid.dygraph.to_variable(x) + for i in range(1, 10, 1): + if i <= 4: + x += 1 + continue + else: + x += 10010 + break + x += 10086 + return x + + +def test_for_in_else(x): + x = fluid.dygraph.to_variable(x) + # + # TODO: Huihuang founds that if we put the for range in else body + # the testcase will fail. Enable this test case after fixing it. + # + #if False: + # pass + #else: + # for i in range(0, 10): + # if i > 5: + # x += 1 + # break + # x += i + # + if False: + pass + else: + for i in range(0, 10): + x += 1 + break + x += i + return x + + +class TestContinueInFor(unittest.TestCase): + def setUp(self): + self.input = np.zeros((1)).astype('int32') + self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( + ) else fluid.CPUPlace() + self.init_dygraph_func() + + def init_dygraph_func(self): + self.dygraph_func = test_continue_in_for + + def run_dygraph_mode(self): + with fluid.dygraph.guard(): + res = self.dygraph_func(self.input) + return res.numpy() + + def run_static_mode(self): + main_program = fluid.Program() + with fluid.program_guard(main_program): + res = dygraph_to_static_graph(self.dygraph_func)(self.input) + exe = fluid.Executor(self.place) + static_res = exe.run(main_program, fetch_list=[res]) + + return static_res[0] + + def test_transformed_static_result(self): + static_res = self.run_static_mode() + dygraph_res = self.run_dygraph_mode() + self.assertTrue( + np.allclose(dygraph_res, static_res), + msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res, + static_res)) + + +class TestContinueInForAtEnd(TestContinueInFor): + def init_dygraph_func(self): + self.dygraph_func = test_continue_in_for_at_end + + +class TestBreakInFor(TestContinueInFor): + def init_dygraph_func(self): + self.dygraph_func = test_break_in_for + + +class TestBreakInForAtEnd(TestContinueInFor): + def init_dygraph_func(self): + self.dygraph_func = test_break_in_for_at_end + + +class TestBreakContinueInFor(TestContinueInFor): + def init_dygraph_func(self): + self.dygraph_func = test_break_continue_in_for + + +class TestForInElse(TestContinueInFor): + def init_dygraph_func(self): + self.dygraph_func = test_for_in_else + + +class TestContinueInWhile(TestContinueInFor): + def init_dygraph_func(self): + self.dygraph_func = test_continue_in_while + + def test_transformed_static_result(self): + # TODO: while i < 10 in dygraph will be supported after PR22892 + # so currently we just assert static result. + # remove this overrided function after PR22892 is merged + static_res = self.run_static_mode() + self.assertEqual(15, static_res[0]) + + +class TestBreakInWhile(TestContinueInWhile): + def init_dygraph_func(self): + self.dygraph_func = test_break_in_while + + def test_transformed_static_result(self): + # TODO: while i < 10 in dygraph will be supported after PR22892 + # so currently we just assert static result. + # remove this overrided function after PR22892 is merged + static_res = self.run_static_mode() + self.assertEqual(15, static_res[0]) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py index 3bb7a356288..920161db53f 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py @@ -42,6 +42,14 @@ def for_loop_dyfunc(max_len): return ret +def while_loop_bool_op(x): + i = fluid.dygraph.to_variable(x) + while (x >= 0 and x < 10) or x <= -1 or x < -3 or (x < -7 or x < -5): + i = i + x + x = x + 1 + return i + + class TestNameVisitor(unittest.TestCase): def setUp(self): self.loop_funcs = [while_loop_dyfunc, for_loop_dyfunc] @@ -67,12 +75,16 @@ class TestTransformWhileLoop(unittest.TestCase): self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( ) else fluid.CPUPlace() self.x = np.zeros(shape=(1), dtype=np.int32) + self._init_dyfunc() + + def _init_dyfunc(self): + self.dyfunc = while_loop_dyfunc def _run_static(self): main_program = fluid.Program() with fluid.program_guard(main_program): x_var = fluid.layers.assign(self.x) - static_func = dygraph_to_static_graph(while_loop_dyfunc) + static_func = dygraph_to_static_graph(self.dyfunc) out = static_func(x_var) exe = fluid.Executor(self.place) @@ -81,7 +93,7 @@ class TestTransformWhileLoop(unittest.TestCase): def _run_dygraph(self): with fluid.dygraph.guard(self.place): - ret = while_loop_dyfunc(fluid.dygraph.to_variable(self.x)) + ret = self.dyfunc(fluid.dygraph.to_variable(self.x)) return ret.numpy() def test_ast_to_func(self): @@ -97,6 +109,11 @@ class TestTransformWhileLoop(unittest.TestCase): # self.assertTrue(np.allclose(self._run_dygraph(), self._run_static())) +class TestWhileLoopBoolOp(TestTransformWhileLoop): + def _init_dyfunc(self): + self.dyfunc = while_loop_bool_op + + class TestTransformForLoop(unittest.TestCase): def setUp(self): self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_utils.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_utils.py new file mode 100644 index 00000000000..0c8ebb163ba --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_utils.py @@ -0,0 +1,33 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list + + +class TestIndexInList(unittest.TestCase): + def test_index_in_list(self): + list_to_test = [1, 2, 3, 4, 5] + self.assertEqual(index_in_list(list_to_test, 4), 3) + self.assertEqual(index_in_list(list_to_test, 1), 0) + self.assertEqual(index_in_list(list_to_test, 5), 4) + self.assertEqual(index_in_list(list_to_test, 0), -1) + self.assertEqual(index_in_list(list_to_test, 6), -1) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_variable_trans_func.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_variable_trans_func.py new file mode 100644 index 00000000000..7f2220c1111 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_variable_trans_func.py @@ -0,0 +1,53 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import gast +import six +import unittest + +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code +from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node + + +class TestVariableTransFunc(unittest.TestCase): + def test_create_fill_constant_node(self): + node = create_fill_constant_node("a", 1.0) + source = "a = fluid.layers.fill_constant(shape=[1], dtype='float64', value=1.0)" + self.assertEqual(ast_to_source_code(node).strip(), source) + + node = create_fill_constant_node("b", True) + source = "b = fluid.layers.fill_constant(shape=[1], dtype='bool', value=True)" + self.assertEqual(ast_to_source_code(node).strip(), source) + + if six.PY2: + node = create_fill_constant_node("c", 214) + source = "c = fluid.layers.fill_constant(shape=[1], dtype='int32', value=214)" + self.assertEqual(ast_to_source_code(node).strip(), source) + + node = create_fill_constant_node("d", long(10086)) + source = "d = fluid.layers.fill_constant(shape=[1], dtype='int64', value=10086)" + self.assertEqual(ast_to_source_code(node).strip(), source) + else: + node = create_fill_constant_node("c", 4293) + source = "c = fluid.layers.fill_constant(shape=[1], dtype='int64', value=4293)" + self.assertEqual(ast_to_source_code(node).strip(), source) + + self.assertIsNone(create_fill_constant_node("e", None)) + self.assertIsNone(create_fill_constant_node("e", [])) + + +if __name__ == '__main__': + unittest.main() -- GitLab