From 03ba5b748d569309f29da8577c51237ab1105eda Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Mon, 18 May 2020 14:09:12 +0800 Subject: [PATCH] [Dy2static] Add for enumerate Variable support (#24398) * initial test * for enumerate basic implement, test=develop * update unittests, test=develop * refine unittests to adapt new training mode, test=develop * refactor for node stmts parsing code, test=develop * self-review & polish details, test=develop --- .../break_continue_transformer.py | 86 +---- .../dygraph_to_static/loop_transformer.py | 119 +++---- .../fluid/dygraph/dygraph_to_static/utils.py | 330 +++++++++++++++++- .../dygraph_to_static/test_for_enumerate.py | 278 +++++++++++++++ 4 files changed, 657 insertions(+), 156 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py index f83cd6d302..e9280954bf 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py @@ -19,6 +19,7 @@ import gast from paddle.fluid import unique_name from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list +from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeParser from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node __all__ = ['BreakContinueTransformer'] @@ -61,87 +62,26 @@ class ForToWhileTransformer(gast.NodeTransformer): raise ValueError( "parent_node doesn't contain the loop_node in ForToWhileTransformer") - def get_for_range_node(self, node): - if not isinstance(node.iter, gast.Call): - return None - if not isinstance(node.iter.func, gast.Name): - return None - if node.iter.func.id != "range": - return None - return node.iter - - def get_for_args_stmts(self, iter_name, args_list): - ''' - Returns 3 gast stmt nodes for argument. - 1. Initailize of iterate variable - 2. Condition for the loop - 3. Statement for changing of iterate variable during the loop - ''' - len_range_args = len(args_list) - assert len_range_args >= 1 and len_range_args <= 3, "range() function takes 1 to 3 arguments" - if len_range_args == 1: - init_stmt = get_constant_variable_node(iter_name, 0) - else: - init_stmt = gast.Assign( - targets=[ - gast.Name( - id=iter_name, - ctx=gast.Store(), - annotation=None, - type_comment=None) - ], - value=args_list[0]) - - range_max_node = args_list[0] if len_range_args == 1 else args_list[1] - step_node = args_list[2] if len_range_args == 3 else gast.Constant( - value=1, kind=None) - - old_cond_stmt = gast.Compare( - left=gast.BinOp( - left=gast.Name( - id=iter_name, - ctx=gast.Load(), - annotation=None, - type_comment=None), - op=gast.Add(), - right=step_node), - ops=[gast.LtE()], - comparators=[range_max_node]) - cond_stmt = gast.BoolOp( - op=gast.And(), values=[old_cond_stmt, self.condition_node]) - - change_stmt = gast.AugAssign( - target=gast.Name( - id=iter_name, - ctx=gast.Store(), - annotation=None, - type_comment=None), - op=gast.Add(), - value=step_node) - - return init_stmt, cond_stmt, change_stmt - def get_for_stmt_nodes(self, node): assert isinstance( node, gast.For), "Input node is NOT gast.For in get_for_stmt_nodes" - # TODO: support non-range case - range_call_node = self.get_for_range_node(node) - if range_call_node is None: - return [node] - - if not isinstance(node.target, gast.Name): + # 1. parse current gast.For node + current_for_node_parser = ForNodeParser(node) + stmts_tuple = current_for_node_parser.parse() + if stmts_tuple is None: return [node] - iter_var_name = node.target.id + init_stmts, cond_stmt, body_stmts = stmts_tuple - init_stmt, cond_stmt, change_stmt = self.get_for_args_stmts( - iter_var_name, range_call_node.args) + # 2. append break statement + new_cond_stmt = gast.BoolOp( + op=gast.And(), values=[cond_stmt, self.condition_node]) - new_body = node.body - new_body.append(change_stmt) + # 3. construct gast.While node while_node = gast.While( - test=cond_stmt, body=new_body, orelse=node.orelse) - return [init_stmt, while_node] + test=new_cond_stmt, body=body_stmts, orelse=node.orelse) + init_stmts.append(while_node) + return init_stmts class BreakContinueTransformer(gast.NodeTransformer): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index 900e48269d..6c8b27625b 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -23,10 +23,12 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code +from paddle.fluid.dygraph.dygraph_to_static.utils import create_api_shape_node from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform +from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeParser from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node @@ -321,7 +323,7 @@ class LoopTransformer(gast.NodeTransformer): def __init__(self, wrapper_root): assert isinstance( wrapper_root, AstNodeWrapper - ), "Input non-AstNodeWrapper node for the initialization of WhileTransformer." + ), "Input non-AstNodeWrapper node for the initialization of LoopTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node self.name_visitor = NameVisitor(self.root) @@ -355,86 +357,45 @@ class LoopTransformer(gast.NodeTransformer): else: i += 1 - def get_for_range_node(self, node): - if not isinstance(node.iter, gast.Call): - return None - if not isinstance(node.iter.func, gast.Name): - return None - if node.iter.func.id != "range": - return None - return node.iter - - def get_for_args_stmts(self, iter_name, args_list): - ''' - Returns 3 gast stmt nodes for argument. - 1. Initailize of iterate variable - 2. Condition for the loop - 3. Statement for changing of iterate variable during the loop - NOTE(TODO): Python allows to access iteration variable after loop, such - as "for i in range(10)" will create i = 9 after the loop. But using - current conversion will make i = 10. We should find a way to change it - ''' - len_range_args = len(args_list) - assert len_range_args >= 1 and len_range_args <= 3, "range() function takes 1 to 3 arguments" - if len_range_args == 1: - init_stmt = get_constant_variable_node(iter_name, 0) - else: - init_stmt = gast.Assign( - targets=[ - gast.Name( - id=iter_name, - ctx=gast.Store(), - annotation=None, - type_comment=None) - ], - value=args_list[0]) - - range_max_node = args_list[0] if len_range_args == 1 else args_list[1] - step_node = args_list[2] if len_range_args == 3 else gast.Constant( - value=1, kind=None) - - cond_stmt = gast.Compare( - left=gast.BinOp( - left=gast.Name( - id=iter_name, - ctx=gast.Load(), - annotation=None, - type_comment=None), - op=gast.Add(), - right=step_node), - ops=[gast.LtE()], - comparators=[range_max_node]) - - change_stmt = gast.AugAssign( - target=gast.Name( - id=iter_name, - ctx=gast.Store(), - annotation=None, - type_comment=None), - op=gast.Add(), - value=step_node) - - return init_stmt, cond_stmt, change_stmt - def get_for_stmt_nodes(self, node): # TODO: consider for - else in python - if not self.name_visitor.is_control_flow_loop(node): - return [node] - # TODO: support non-range case - range_call_node = self.get_for_range_node(node) - if range_call_node is None: + # 1. check whether need to transform + # NOTE: Current need transform cases: + # 1). for x in range(VarBase.numpy()[0]) + # 2). for x in VarBase.numpy() + # 3). for i, x in enumerate(VarBase.numpy()) + if not self.name_visitor.is_control_flow_loop(node): return [node] - if not isinstance(node.target, gast.Name): + # 2. get key statements for different cases + # NOTE: three key statements: + # 1). init_stmts: list[node], prepare nodes of for loop, may not only one + # 2). cond_stmt: node, condition node to judge whether continue loop + # 3). body_stmts: list[node], updated loop body, sometimes we should change + # the original statement in body, not just append new statement + current_for_node_parser = ForNodeParser(node) + stmts_tuple = current_for_node_parser.parse() + if stmts_tuple is None: return [node] - iter_var_name = node.target.id - - init_stmt, cond_stmt, change_stmt = self.get_for_args_stmts( - iter_var_name, range_call_node.args) + init_stmts, cond_stmt, body_stmts = stmts_tuple + # 3. get original loop vars loop_var_names, create_var_names = self.name_visitor.get_loop_var_names( node) + # NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases, + # we need append new loop var & remove useless loop var + # 1. for x in var -> x is no need + # 2. for i, x in enumerate(var) -> x is no need + if current_for_node_parser.is_for_iter( + ) or current_for_node_parser.is_for_enumerate_iter(): + iter_var_name = current_for_node_parser.iter_var_name + iter_idx_name = current_for_node_parser.iter_idx_name + loop_var_names.add(iter_idx_name) + if iter_var_name not in create_var_names: + loop_var_names.remove(iter_var_name) + + # 4. prepare result statement list new_stmts = [] # Python can create variable in loop and use it out of loop, E.g. # @@ -447,12 +408,13 @@ class LoopTransformer(gast.NodeTransformer): if "." not in name: new_stmts.append(create_static_variable_gast_node(name)) - new_stmts.append(init_stmt) - + # 5. append init statements + new_stmts.extend(init_stmts) # for x in range(10) in dygraph should be convert into static tensor + 1 <= 10 for name in loop_var_names: new_stmts.append(to_static_variable_gast_node(name)) + # 6. create & append condition function node condition_func_node = gast.FunctionDef( name=unique_name.generate(FOR_CONDITION_PREFIX), args=gast.arguments( @@ -480,9 +442,9 @@ class LoopTransformer(gast.NodeTransformer): name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(condition_func_node) - new_body = node.body - new_body.append(change_stmt) - new_body.append( + # 7. create & append loop body function node + # append return values for loop body + body_stmts.append( gast.Return(value=generate_name_node( loop_var_names, ctx=gast.Load()))) body_func_node = gast.FunctionDef( @@ -501,7 +463,7 @@ class LoopTransformer(gast.NodeTransformer): kw_defaults=None, kwarg=None, defaults=[]), - body=new_body, + body=body_stmts, decorator_list=[], returns=None, type_comment=None) @@ -512,6 +474,7 @@ class LoopTransformer(gast.NodeTransformer): name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(body_func_node) + # 8. create & append while loop node while_loop_node = create_while_node(condition_func_node.name, body_func_node.name, loop_var_names) new_stmts.append(while_loop_node) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 29eb767320..0372dab799 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -25,6 +25,8 @@ import os import six import tempfile +from paddle.fluid import unique_name + dygraph_class_to_static_api = { "CosineDecay": "cosine_decay", "ExponentialDecay": "exponential_decay", @@ -35,6 +37,9 @@ dygraph_class_to_static_api = { "PolynomialDecay": "polynomial_decay", } +FOR_ITER_INDEX_PREFIX = '__for_loop_var_index' +FOR_ITER_VAR_SHAPE_PREFIX = '__for_loop_var_shape' + def _is_api_in_module_helper(obj, module_prefix): m = inspect.getmodule(obj) @@ -504,12 +509,22 @@ class IsControlFlowVisitor(gast.NodeVisitor): assert isinstance(node, gast.For) if not isinstance(node.iter, gast.Call): return - if not isinstance(node.iter.func, gast.Name): - return - if node.iter.func.id != "range": + + # for in range(v.numpy()) or for in enumerate(v.numpy()) + if isinstance(node.iter.func, gast.Name): + if node.iter.func.id == "range" or node.iter.func.id == "enumerate": + for arg in node.iter.args: + self.visit(arg) + else: + return + # for in v.numpy() + elif isinstance(node.iter.func, gast.Attribute): + if node.iter.func.attr == 'numpy': + self._visit_Call(node.iter) + else: + return + else: return - for arg in node.iter.args: - self.visit(arg) for child_node in gast.walk(node): if isinstance(child_node, (gast.Continue, gast.Break)): @@ -609,3 +624,308 @@ class IsControlFlowVisitor(gast.NodeVisitor): def get_compare_nodes_with_tensor(self): return self._compare_node_tenor_set + + +class NameNodeReplaceTransformer(gast.NodeTransformer): + """ + This class transform specfice gast.Name node to replace node + """ + + def __init__(self, root_node, target_name, replace_node): + assert isinstance(target_name, str) + self.target_name = target_name + self.replace_node = replace_node + + self.visit(root_node) + + def visit_Name(self, node): + if node.id == self.target_name: + return self.replace_node + return node + + +class ForNodeParser(object): + """ + This class parse python for statement, get transformed 3 statement components of for node + three key statements: + 1). init_stmts: list[node], prepare nodes of for loop, may not only one + 2). cond_stmt: node, condition node to judge whether continue loop + 3). body_stmts: list[node], updated loop body, sometimes we should change + the original statement in body, not just append new statement + + In this process, the semantics of for does not change. + + Now only can parse 3 type statements: + 1). for x in range(***) + 2). for x in var.numpy() + 3). for i, x enumerate(var.numpy()) + """ + + def __init__(self, for_node): + assert isinstance( + for_node, gast.For + ), "Input node for the initialization of ForNodeParser is not gast.For node." + # 1. original for node + self.node = for_node + + # 2. gast.For node main parts + self.target = for_node.target + # NOTE: type may be Node or list[Node] + self.iter_args = for_node.iter if self.is_for_iter( + ) else for_node.iter.args + self.body = for_node.body + + # 3. key shared node or names + # - x: + # - for x in range(***) + # - for x in var.numpy() + # - for i, x enumerate(var.numpy()) + self.iter_var_name = self._get_iter_var_name() + + # - created index var to slice Variable: __for_loop_var_index_0 + # - for x in var.numpy() + # - for i, x enumerate(var.numpy()) + self.iter_idx_name = unique_name.generate(FOR_ITER_INDEX_PREFIX) + + # - created shape var to build loop condition: __for_loop_var_shape_0 + # - for x in var.numpy() + # - for i, x enumerate(var.numpy()) + self.iter_var_shape_name = unique_name.generate( + FOR_ITER_VAR_SHAPE_PREFIX) + + # - var.numpy() + # - for x in var.numpy() + # - for i, x enumerate(var.numpy()) + self.iter_node = self._get_iter_node() + + # - enumeate i: + # - for i, x enumerate(var.numpy()) + self.enum_idx_name = self._get_enum_idx_name() + + # - range/enumerate args length + self.args_length = None + + def parse(self): + self._args_check() + if self.is_for_range_iter(): + return self._parse_for_range_stmts() + elif self.is_for_iter(): + return self._parse_for_stmts() + elif self.is_for_enumerate_iter(): + return self._parse_for_enumerate_stmts() + else: + raise None + + def is_for_range_iter(self): + return isinstance(self.node.iter.func, + gast.Name) and self.node.iter.func.id == "range" + + def is_for_iter(self): + return isinstance( + self.node.iter.func, + gast.Attribute) and self.node.iter.func.attr == 'numpy' + + def is_for_enumerate_iter(self): + return isinstance(self.node.iter.func, + gast.Name) and self.node.iter.func.id == "enumerate" + + def _args_check(self): + if self.is_for_range_iter(): + self.args_length = len(self.iter_args) + assert self.args_length >= 1 and self.args_length <= 3, "range() function takes 1 to 3 arguments" + elif self.is_for_enumerate_iter(): + self.args_length = len(self.iter_args) + assert self.args_length >= 1 and self.args_length <= 2, "enumerate() function takes 1 to 2 arguments" + else: + self.args_length = None + + def _parse_for_range_stmts(self): + init_stmts = [] + init_stmts.append(self._build_index_init_node()) + + compare_node = self._build_compare_node() + step_node = self._build_step_node() + cond_stmt = self._build_cond_stmt(step_node, compare_node) + + body_stmts = self.body + body_stmts.append(self._build_index_increase_node(step_node)) + + return init_stmts, cond_stmt, body_stmts + + def _parse_for_stmts(self): + init_stmts = [] + init_stmts.append(self._build_index_init_node()) + init_stmts.append(self._build_var_shape_assign_node()) + + compare_node = self._build_compare_node() + step_node = self._build_step_node() + cond_stmt = self._build_cond_stmt(step_node, compare_node) + + body_stmts = self.body + var_slice_node = self._build_var_slice_node() + for body_node in body_stmts: + NameNodeReplaceTransformer(body_node, self.iter_var_name, + var_slice_node) + body_stmts.append(self._build_index_increase_node(step_node)) + + return init_stmts, cond_stmt, body_stmts + + def _parse_for_enumerate_stmts(self): + init_stmts = [] + init_stmts.append(self._build_index_init_node()) + init_stmts.append(self._build_var_shape_assign_node()) + init_stmts.append(self._build_enum_init_node()) + + compare_node = self._build_compare_node() + step_node = self._build_step_node() + cond_stmt = self._build_cond_stmt(step_node, compare_node) + + body_stmts = self.body + var_slice_node = self._build_var_slice_node() + for body_node in body_stmts: + NameNodeReplaceTransformer(body_node, self.iter_var_name, + var_slice_node) + body_stmts.append(self._build_index_increase_node(step_node)) + body_stmts.append(self._build_enum_increase_node()) + + return init_stmts, cond_stmt, body_stmts + + def _build_index_init_node(self): + if self.is_for_range_iter(): + if self.args_length == 1: + index_init_node = get_constant_variable_node(self.iter_var_name, + 0) + else: + index_init_node = gast.Assign( + targets=[ + gast.Name( + id=self.iter_var_name, + ctx=gast.Store(), + annotation=None, + type_comment=None) + ], + value=self.iter_args[0]) + else: + # TODO: slice bug, only support int32 index + index_init_node = get_constant_variable_node( + self.iter_idx_name, 0, dtype='int32') + return index_init_node + + def _build_var_shape_assign_node(self): + # get variable shape as iter length + return gast.Assign( + targets=[ + gast.Name( + id=self.iter_var_shape_name, + ctx=gast.Load(), + annotation=None, + type_comment=None) + ], + value=create_api_shape_node(self.iter_node.func)) + + def _build_enum_init_node(self): + enum_init_node = get_constant_variable_node( + name=self.enum_idx_name, value=0) + if self.is_for_enumerate_iter() and self.args_length != 1: + enum_init_node = gast.Assign( + targets=[ + gast.Name( + id=self.enum_idx_name, + ctx=gast.Store(), + annotation=None, + type_comment=None) + ], + value=self.iter_args[1]) + return enum_init_node + + def _build_compare_node(self): + if self.is_for_range_iter(): + compare_node = self.iter_args[ + 0] if self.args_length == 1 else self.iter_args[1] + else: + compare_node = gast.Subscript( + value=gast.Name( + id=self.iter_var_shape_name, + ctx=gast.Load(), + annotation=None, + type_comment=None), + slice=gast.Index(value=gast.Constant( + value=0, kind=None)), + ctx=gast.Load()) + return compare_node + + def _build_step_node(self): + if self.is_for_range_iter(): + step_node = self.iter_args[ + 2] if self.args_length == 3 else gast.Constant( + value=1, kind=None) + else: + step_node = gast.Constant(value=1, kind=None) + return step_node + + def _build_cond_stmt(self, step_node, compare_node): + return gast.Compare( + left=gast.BinOp( + left=gast.Name( + id=self.iter_var_name + if self.is_for_range_iter() else self.iter_idx_name, + ctx=gast.Load(), + annotation=None, + type_comment=None), + op=gast.Add(), + right=step_node), + ops=[gast.LtE()], + comparators=[compare_node]) + + def _build_index_increase_node(self, step_node): + return gast.AugAssign( + target=gast.Name( + id=self.iter_var_name + if self.is_for_range_iter() else self.iter_idx_name, + ctx=gast.Store(), + annotation=None, + type_comment=None), + op=gast.Add(), + value=step_node) + + def _build_var_slice_node(self): + return gast.Subscript( + value=self.iter_node, + slice=gast.Index(value=gast.Name( + id=self.iter_idx_name, + ctx=gast.Load(), + annotation=None, + type_comment=None)), + ctx=gast.Load()) + + def _build_enum_increase_node(self): + return gast.AugAssign( + target=gast.Name( + id=self.enum_idx_name, + ctx=gast.Store(), + annotation=None, + type_comment=None), + op=gast.Add(), + value=gast.Constant( + value=1, kind=None)) + + def _get_iter_var_name(self): + if self.is_for_range_iter(): + return self.target.id + elif self.is_for_iter(): + return self.target.id + elif self.is_for_enumerate_iter(): + return self.target.elts[1].id + return None + + def _get_iter_node(self): + if self.is_for_iter(): + return self.iter_args + elif self.is_for_enumerate_iter(): + return self.iter_args[0] + return None + + def _get_enum_idx_name(self): + if self.is_for_enumerate_iter(): + return self.target.elts[0].id + return None diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py new file mode 100644 index 0000000000..dc160cdbe5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py @@ -0,0 +1,278 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest + +import paddle.fluid as fluid +from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator +from paddle.fluid.dygraph.jit import declarative + +program_translator = ProgramTranslator() + + +# 0. for in range with var case +@declarative +def dygraph_for_in_range(x): + z = fluid.layers.fill_constant([1], 'int32', 0) + x = fluid.dygraph.to_variable(x) + for i in range(x.numpy()[0]): + z = z + i + return z + + +# 1. for iter list +@declarative +def dygraph_for_iter_list(x_array): + z = fluid.layers.fill_constant([1], 'int32', 0) + for x in x_array: + z = z + x + return z + + +# 2. for enumerate list +@declarative +def dygraph_for_enumerate_list(x_array): + z = fluid.layers.fill_constant([1], 'int32', 0) + for i, x in enumerate(x_array): + z = z + x + i + return z + + +# 3. for iter var.numpy() +@declarative +def dygraph_for_iter_var_numpy(x_array): + z = fluid.layers.fill_constant([1], 'int32', 0) + x_array = fluid.dygraph.to_variable(x_array) + for x in x_array.numpy(): + z = z + x + return z + + +# 4. for enumerate var.numpy() +@declarative +def dygraph_for_enumerate_var_numpy(x_array): + y = fluid.layers.fill_constant([1], 'int32', 0) + z = fluid.layers.fill_constant([1], 'int32', 0) + x_array = fluid.dygraph.to_variable(x_array) + for i, x in enumerate(x_array.numpy()): + y = y + i + z = z + x + return y, z + + +# 5. for enumerate var.numpy() with start +@declarative +def dygraph_for_enumerate_var_numpy_with_start(x_array): + y = fluid.layers.fill_constant([1], 'int32', 0) + z = fluid.layers.fill_constant([1], 'int32', 0) + x_array = fluid.dygraph.to_variable(x_array) + for i, x in enumerate(x_array.numpy(), 1): + y = y + i + z = z + x + return y, z + + +# 6. for in range with break +@declarative +def dygraph_for_in_range_with_break(x): + z = fluid.layers.fill_constant([1], 'int32', 0) + x = fluid.dygraph.to_variable(x) + for i in range(x.numpy()[0]): + z = z + i + if i > 2: + break + return z + + +# 7. for enumerate var.numpy() with break +@declarative +def dygraph_for_enumerate_var_numpy_with_break(x_array): + y = fluid.layers.fill_constant([1], 'int32', 0) + z = fluid.layers.fill_constant([1], 'int32', 0) + x_array = fluid.dygraph.to_variable(x_array) + for i, x in enumerate(x_array.numpy()): + y = y + i + z = z + x + if i > 2: + break + return y, z + + +# 8. for enumerate var.numpy() with continue +@declarative +def dygraph_for_enumerate_var_numpy_with_continue(x_array): + y = fluid.layers.fill_constant([1], 'int32', 0) + z = fluid.layers.fill_constant([1], 'int32', 0) + x_array = fluid.dygraph.to_variable(x_array) + for i, x in enumerate(x_array.numpy()): + y = y + i + if i > 2: + continue + z = z + x + return y, z + + +# 9. for enumerate var.numpy() with start & break +@declarative +def dygraph_for_enumerate_var_numpy_with_start_break(x_array): + y = fluid.layers.fill_constant([1], 'int32', 0) + z = fluid.layers.fill_constant([1], 'int32', 0) + x_array = fluid.dygraph.to_variable(x_array) + for i, x in enumerate(x_array.numpy(), 1): + y = y + i + z = z + x + if i > 2: + break + return y, z + + +# 10. for enumerate var.numpy() with start & continue +@declarative +def dygraph_for_enumerate_var_numpy_with_start_continue(x_array): + y = fluid.layers.fill_constant([1], 'int32', 0) + z = fluid.layers.fill_constant([1], 'int32', 0) + x_array = fluid.dygraph.to_variable(x_array) + for i, x in enumerate(x_array.numpy(), 1): + y = y + i + if i > 2: + continue + z = z + x + return y, z + + +class TestTransformBase(unittest.TestCase): + def setUp(self): + self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( + ) else fluid.CPUPlace() + self.set_input() + self.set_test_func() + + def set_input(self): + self.input = [1, 2, 3] + + def set_test_func(self): + raise NotImplementedError( + "For Enumerate test should implement set_test_func") + + def _run(self, to_static): + program_translator.enable(to_static) + with fluid.dygraph.guard(): + return self.dygraph_func(self.input) + + def get_dygraph_output(self): + return self._run(to_static=False) + + def get_static_output(self): + return self._run(to_static=True) + + +class TestTransform(TestTransformBase): + def transformed_result_compare(self): + dy_outs = self.get_dygraph_output() + if not isinstance(dy_outs, tuple): + dy_outs = (dy_outs, ) + + # NOTE: return type is difference + st_outs = self.get_static_output() + if not isinstance(st_outs, list): + st_outs = (st_outs, ) + else: + st_outs = tuple(st_outs) + + for x, y in zip(dy_outs, st_outs): + self.assertTrue(np.allclose(x.numpy(), y.numpy())) + + +class TestTransformError(TestTransformBase): + def transformed_error(self, etype): + with self.assertRaises(etype): + dy_out = self.get_dygraph_output() + st_out = self.get_static_output() + + +class TestForInRange(TestTransform): + def set_input(self): + self.input = np.array([5]) + + def set_test_func(self): + self.dygraph_func = dygraph_for_in_range + + def test_transformed_result_compare(self): + self.transformed_result_compare() + + +class TestForIterList(TestTransform): + def set_test_func(self): + self.dygraph_func = dygraph_for_iter_list + + def test_transformed_result_compare(self): + self.transformed_result_compare() + + +class TestForEnumerateSimple(TestForIterList): + def set_test_func(self): + self.dygraph_func = dygraph_for_enumerate_list + + +class TestForInRangeWithBreak(TestForInRange): + def set_test_func(self): + self.dygraph_func = dygraph_for_in_range_with_break + + +class TestForIterVarNumpy(TestTransform): + def set_input(self): + self.input = np.array([1, 2, 3, 4, 5]) + + def set_test_func(self): + self.dygraph_func = dygraph_for_iter_var_numpy + + def test_transformed_result_compare(self): + self.transformed_result_compare() + + +class TestForEnumerateVarNumpy(TestForIterVarNumpy): + def set_test_func(self): + self.dygraph_func = dygraph_for_enumerate_var_numpy + + +class TestForEnumerateVarNumpyWithStart(TestForIterVarNumpy): + def set_test_func(self): + self.dygraph_func = dygraph_for_enumerate_var_numpy_with_start + + +class TestForEnumerateVarNumpyWithBreak(TestForIterVarNumpy): + def set_test_func(self): + self.dygraph_func = dygraph_for_enumerate_var_numpy_with_break + + +class TestForEnumerateVarNumpyWithBreak(TestForIterVarNumpy): + def set_test_func(self): + self.dygraph_func = dygraph_for_enumerate_var_numpy_with_continue + + +class TestForEnumerateVarNumpyWithStartAndBreak(TestForIterVarNumpy): + def set_test_func(self): + self.dygraph_func = dygraph_for_enumerate_var_numpy_with_start_break + + +class TestForEnumerateVarNumpyWithStartAndBreak(TestForIterVarNumpy): + def set_test_func(self): + self.dygraph_func = dygraph_for_enumerate_var_numpy_with_start_continue + + +if __name__ == '__main__': + unittest.main() -- GitLab