From c5c6026e128017af59a8a908c1ee10fc6f37240d Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 12 Jul 2022 19:32:33 +0800 Subject: [PATCH] [ Dy2Static ]Change NameVisitor in while to FunctionScopeAnalysis (#44155) * change NameVisitor to FunctionScopeAnalysis * polish the logic of undefined var in while_loop. create vars after body execution * replace old NameVisitor in while and fix all CI * Togather with CreateVariableTransformer * add create_variable_transformer * fix bugs * merge * fix some error, TODO: ForNodePreTransform ahead * merge for unite PR * fix conflict with base_transformer PR * fix ci errors, fix [for i in range()] error * fix according to code review --- .../dygraph_to_static/ast_transformer.py | 3 +- .../dygraph_to_static/base_transformer.py | 169 ++++++------------ .../dygraph_to_static/call_transformer.py | 15 +- .../dygraph_to_static/convert_call_func.py | 34 ++-- .../dygraph_to_static/convert_operators.py | 104 ++++++++++- .../create_variable_transformer.py | 48 +++++ .../dygraph_to_static/ifelse_transformer.py | 23 ++- .../dygraph_to_static/loop_transformer.py | 58 ++---- .../fluid/dygraph/dygraph_to_static/utils.py | 19 ++ python/paddle/fluid/layers/control_flow.py | 43 ++++- .../seq2seq_dygraph_model.py | 1 + .../unittests/dygraph_to_static/test_loop.py | 7 - .../test_program_translator.py | 12 +- .../dygraph_to_static/test_tensor_shape.py | 6 +- .../transformer_dygraph_model.py | 4 +- python/paddle/jit/dy2static/__init__.py | 3 +- .../paddle/jit/dy2static/convert_operators.py | 1 + 17 files changed, 337 insertions(+), 213 deletions(-) create mode 100644 python/paddle/fluid/dygraph/dygraph_to_static/create_variable_transformer.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index f1ab097758b..a9e8f447e99 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -35,6 +35,7 @@ from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTr from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransformer from paddle.fluid.dygraph.dygraph_to_static.return_transformer import ReturnTransformer +from paddle.fluid.dygraph.dygraph_to_static.create_variable_transformer import CreateVariableTransformer from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer @@ -96,7 +97,7 @@ class DygraphToStaticAst(BaseTransformer): BreakContinueTransformer, # break/continue in loops ReturnTransformer, # return in functions LogicalTransformer, # logical and/or/not - #CreateVariableTransformer, # create undefined var for if / while / for + CreateVariableTransformer, # create undefined var for if / while / for LoopTransformer, # for/while -> while_op IfElseTransformer, # if/else -> cond_op AssertTransformer, # assert statement diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/base_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/base_transformer.py index a3c2c0c69ef..9df7e8d9b4f 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/base_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/base_transformer.py @@ -24,6 +24,8 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_TUPLE_INDEX_PR from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_VAR_LEN_PREFIX from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_VAR_NAME_PREFIX from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_ZIP_TO_LIST_PREFIX +from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_TARGET_PREFIX +from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_ITERATOR_PREFIX class BaseTransformer(gast.NodeTransformer): @@ -119,32 +121,20 @@ class NameNodeReplaceTransformer(BaseTransformer): class ForLoopTuplePreTransformer(BaseTransformer): - """ - ForNodeVisitor parses 3 type statements (Here var is VarBase(Tensor) or python variable): - 1). for x in range(var[*]|var.numpy()[*]) - 2). for x in var|var.numpy() - 3). for i, x in enumerate(var|var.numpy()) - - We chose these 3 types because they are easier (x can be variable name iterating in var). - However, users can write tuples in Python for loop, such as - 1). for var1, var2 in var|var.numpy() - 2). for t in enumerate(var|var.numpy()) - 2). for i, (var1, var2, va3) in enumerate(var|var.numpy()) - - To handle these case, this method will do the rewrite tuple pre-process: - 1). Non-enumerate case: for var1, var2 in var|var.numpy() will be re-written as: - for FOR_ITER_TUPLE_PREFIX_x in var | var.numpy(): - var1 = FOR_ITER_TUPLE_PREFIX_x[0] - var2 = FOR_ITER_TUPLE_PREFIX_x[1] - 2). Enumerate out tuple case: for t in enumerate(var|var.numpy) will be rewritten as: - for FOR_ITER_TUPLE_INDEX_PREFIX_x, FOR_ITER_TUPLE_PREFIX_x in enumerate(var|var.numpy): - t = (FOR_ITER_TUPLE_INDEX_PREFIX_x, FOR_ITER_TUPLE_PREFIX_x) - 3). Enumerate inner tuple case: for i, (var1, (var2, va3)) in enumerate(var|var.numpy()) will - be re-written as: - for i, FOR_ITER_TUPLE_PREFIX_x in var | var.numpy(): - var1 = FOR_ITER_TUPLE_PREFIX_x[0] - var2 = FOR_ITER_TUPLE_PREFIX_x[1][0] - var3 = FOR_ITER_TUPLE_PREFIX_x[1][1] + """ pre-process of for loop. + >>> for A in B: + >>> C + + will be changed into : + + >>> UUID_iterator = _jst.Indexable(B) # make iterator-only to indexable list. + >>> for UUID_target in UUID_iterator: + >>> A = _jst.Unpack(UUID_target, structure) + >>> C + + make the later loop_transform have unified type: + >>> for target in iter: + >>> body """ def __init__(self, wrapper_root): @@ -155,104 +145,45 @@ class ForLoopTuplePreTransformer(BaseTransformer): self.visit(self.root) def visit_For(self, node): - if self.is_for_enumerate_iter(node): - if isinstance(node.target, (gast.Name, gast.Attribute)): - # Out tuple case - out_tuple_name = ast_to_source_code(node.target).strip() - tuple_iter_name = unique_name.generate( - FOR_ITER_TUPLE_INDEX_PREFIX) - tuple_var_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX) - node.target = gast.Tuple(elts=[ - gast.Name(id=tuple_iter_name, - ctx=gast.Store(), - annotation=None, - type_comment=None), - gast.Name(id=tuple_var_name, - ctx=gast.Store(), + self.generic_visit(node) + tuple_target = unique_name.generate(FOR_ITER_TARGET_PREFIX) + tuple_iterator = unique_name.generate(FOR_ITER_ITERATOR_PREFIX) + origin_tuple_node = node.target + assign_iterator_node = gast.parse( + f"{tuple_iterator} = _jst.Indexable({ast_to_source_code(node.iter).strip()})" + ).body[0] + node.target = gast.Name(id=tuple_target, + ctx=gast.Store(), + annotation=None, + type_comment=None) + node.iter = gast.Name(id=tuple_iterator, + ctx=gast.Load(), annotation=None, type_comment=None) - ], - ctx=gast.Store()) - node.body.insert( - 0, - gast.Assign(targets=[ - gast.Name(id=out_tuple_name, - ctx=gast.Store(), - annotation=None, - type_comment=None) - ], - value=gast.Tuple(elts=[ - gast.Name(id=tuple_iter_name, - ctx=gast.Load(), - annotation=None, - type_comment=None), - gast.Name(id=tuple_var_name, - ctx=gast.Load(), - annotation=None, - type_comment=None) - ], - ctx=gast.Load()))) - elif isinstance(node.target, (gast.List, gast.Tuple)) and len( - node.target.elts) >= 2 and isinstance( - node.target.elts[1], (gast.List, gast.Tuple)): - # Inner tuple case - inner_tuple_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX) - origin_inner_tuple_node = node.target.elts[1] - node.target.elts[1] = gast.Name(id=inner_tuple_name, - ctx=gast.Store(), - annotation=None, - type_comment=None) - node.body[0:0] = self.tuple_to_stmts(origin_inner_tuple_node, - inner_tuple_name) - elif self.is_for_iter(node) and isinstance(node.target, - (gast.List, gast.Tuple)): - # Non-enumrate case: - tuple_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX) - origin_tuple_node = node.target - node.target = gast.Name(id=tuple_name, - ctx=gast.Store(), - annotation=None, - type_comment=None) - node.body[0:0] = self.tuple_to_stmts(origin_tuple_node, tuple_name) - return node - - def tuple_to_stmts(self, node, tuple_name, idx=[]): - if not isinstance(node, (gast.Tuple, gast.List)): - value_node_str = tuple_name - for i in idx: - value_node_str = value_node_str + "[{}]".format(i) - - node_str = ast_to_source_code(node).strip() - assign_node_str = "{} = {}".format(node_str, value_node_str) - assign_node = gast.parse(assign_node_str).body[0] - return [assign_node] - - # isinstance(node, (gast.Tuple, gast.List)) + node.body[0:0] = self.tuple_to_stmts(origin_tuple_node, tuple_target) + # return a list will insert a list of node replace the original for node. + return [assign_iterator_node, node] + + def tuple_node_to_unpack_structure(self, node): + """ Create a sequence to represents the structure of nest. + For example: `a, (b,c), [d,e,f]` is represented by + `[1, [1,1], [1,1,1]]`. the `1` is just a notation. + + Specially, `a` is represented by `1`. + """ ret = [] - for i, element in enumerate(node.elts): - ret += self.tuple_to_stmts(node.elts[i], tuple_name, idx + [i]) + if not isinstance(node, (gast.Tuple, gast.List)): + return 1 + for element in node.elts: + ret.append(self.tuple_node_to_unpack_structure(element)) return ret - def is_for_iter(self, for_node): - assert isinstance(for_node, - gast.For), "Input node is not gast.For node." - if isinstance(for_node.iter, (gast.Name, gast.Attribute)): - return True - elif isinstance(for_node.iter, gast.Call) and isinstance( - for_node.iter.func, - gast.Attribute) and for_node.iter.func.attr == 'numpy': - return True - elif isinstance(for_node.iter, gast.Subscript): - return True - else: - return False - - def is_for_enumerate_iter(self, for_node): - assert isinstance(for_node, - gast.For), "Input node is not gast.For node." - return isinstance(for_node.iter, gast.Call) and isinstance( - for_node.iter.func, - gast.Name) and for_node.iter.func.id == "enumerate" + def tuple_to_stmts(self, node, tuple_name): + structure_str = str(self.tuple_node_to_unpack_structure(node)) + node_str = ast_to_source_code(node).strip() + assign_node_str = f"{node_str} = _jst.Unpack({tuple_name}, {structure_str})" + assign_node = gast.parse(assign_node_str).body[0] + return [assign_node] class SplitAssignTransformer(BaseTransformer): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py index c9f56287ed3..15b909f3d3d 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py @@ -40,7 +40,7 @@ class CallTransformer(BaseTransformer): Determines whether a function needs to be transformed by `convert_call`. It doesn't need to be transformed when a function satisfies the following conditions: 1. It's a api of paddle - 2. It's a python builtin function not include `len` and `zip` + 2. It's a python builtin function not include `len`, `zip`, `range` and `enumerate` """ assert isinstance(node, gast.Call) if is_paddle_api(node): @@ -48,11 +48,16 @@ class CallTransformer(BaseTransformer): func_str = ast_to_source_code(node.func).strip() try: - from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin_len, is_builtin, is_builtin_zip + from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin + need_convert_builtin_func_list = { + 'len', + 'zip', + 'range', + 'enumerate', + } is_builtin = eval("is_builtin({})".format(func_str)) - is_builtin_len = eval("is_builtin_len({})".format(func_str)) - is_builtin_zip = eval("is_builtin_zip({})".format(func_str)) - return is_builtin and not is_builtin_len and not is_builtin_zip + need_convert = func_str in need_convert_builtin_func_list + return is_builtin and not need_convert except Exception: return False diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py index e660a64ab36..5bb75bda8de 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py @@ -28,6 +28,7 @@ import six from paddle.fluid.dygraph.container import Sequential from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len, convert_zip +from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_range, convert_enumerate from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogger from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static @@ -64,25 +65,22 @@ class ConversionOptions(object): self.not_convert = not_convert -def is_builtin(func): - if isinstance(func, types.BuiltinFunctionType): +def is_builtin(func, name=None): + """ predict whether a function is a builtin function with name={name}. + if name == None, then any builtin function will return True + """ + + def name_judge(): + return name is None or func.__name__ == name + + if isinstance(func, types.BuiltinFunctionType) and name_judge(): return True - elif func in six.moves.builtins.__dict__.values(): + elif func in six.moves.builtins.__dict__.values() and name_judge(): return True else: return False -def is_builtin_len(func): - if isinstance(func, types.BuiltinFunctionType) and func.__name__ == 'len': - return True - return False - - -def is_builtin_zip(func): - return is_builtin(func) and func.__name__ == 'zip' - - def is_unsupported(func): """ Checks whether the func is supported by dygraph to static graph. @@ -165,12 +163,18 @@ def convert_call(func): .format(func)) return func - if is_builtin_len(func): + if is_builtin(func, "len"): return convert_len - if is_builtin_zip(func): + if is_builtin(func, "zip"): return convert_zip + if is_builtin(func, "range"): + return convert_range + + if is_builtin(func, "enumerate"): + return convert_enumerate + if is_builtin(func) or is_unsupported(func): return func diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index 583db5c0dcd..e0b46fe2341 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -13,11 +13,12 @@ # limitations under the License. import re - +import paddle from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable from paddle.fluid.framework import core, Variable from paddle.fluid.layers import Assert, Print +from paddle.fluid.layers import range as paddle_range from paddle.fluid.layers import array_length, array_read, array_write, create_array from paddle.fluid.layers import assign, fill_constant, slice, reduce_all, reduce_any from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn @@ -26,6 +27,45 @@ from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_ from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, Dygraph2StaticException +def indexable(x, code=None): + if isinstance(x, Variable): return x + if hasattr(x, '__len__') and hasattr(x, '__getitem__'): return x + if hasattr(x, '__iter__'): + return [i for i in x] + else: + raise RuntimeError("X can't be convert into indexable.") + + +def unpack_by_structure(target, structure): + """ unified unpack interface for paddle and python. + """ + if isinstance(target, Variable): + return _unpack_by_structure_paddle(target, structure) + else: + return _unpack_by_structure_python(target, structure) + + +def _unpack_by_structure_python(target, structure): + """ TODO(xiongkun): analysis the differences between python and paddle unpack. + """ + return _unpack_by_structure_paddle(target, structure) + + +def _unpack_by_structure_paddle(target, structure): + if structure == 1: + return target + ret = [] + for idx, ele in enumerate(structure): + if ele == 1: + ret.append(target[idx]) + continue + if isinstance(ele, list): + ret.append(unpack_by_structure(target[idx], ele)) + continue + assert False, "structure element must be 1 or list" + return ret + + def convert_while_loop(cond, body, getter, setter): """ A function representation of a Python ``while`` statement. @@ -50,12 +90,26 @@ def convert_while_loop(cond, body, getter, setter): def _run_paddle_while(cond, body, getter, setter): # NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Tensors. - - # UndefinedVar will become data layer not check. - loop_vars = [to_static_variable(var) for var in getter()] + def new_body_fn(*args): + """ wrap the body() and add return value for `while_loop` + """ + body() + return getter() + + def new_cond_fn(*args): + """ cond is a zero-args function, which is not + compatible with `while_loop`. + """ + return cond() + + # UndefinedVar will become data layer not check variable with value=NO_VALUE_MAGIC. + loop_vars = [ + to_static_variable(var) if not isinstance(var, UndefinedVar) else var + for var in getter() + ] setter(loop_vars) # change the non-local var to variable # variable maybe modified to inner var. change it into - loop_vars = control_flow.while_loop(cond, body, loop_vars) + loop_vars = control_flow.while_loop(new_cond_fn, new_body_fn, loop_vars) setter(loop_vars) # change the non-local var to variable return loop_vars @@ -368,6 +422,8 @@ def convert_len(var): 'len(var) only supports LoDTensor/LoDTensorArray/SelectedRows, but received %s.' % type(var)) else: + if isinstance(var, VariableTuple): + return var.__len__() return len(var) @@ -380,6 +436,44 @@ def convert_zip(*args): return zip(*args) +# TODO(xiongkun): delete when list is ready. +class VariableTuple: + """ + this class will cause enumerate can't be wrapped by other iterator change function. + this will be fixed when list is producted. + VariableTuple can only deal with variables which is fixed. + """ + + def __init__(self, var, start=0): + self.var = var + self.len = convert_len(var) + self.rag = paddle_range(start, start + self.len, 1, paddle.int64) + + def __getitem__(self, idx): + return self.rag[idx], self.var[idx] + + def __len__(self): + return self.len + + +def convert_enumerate(*args): + has_variable = any(map(lambda x: isinstance(x, Variable), args)) + if has_variable: + return VariableTuple(*args) + return enumerate(*args) + + +def convert_range(*args): + has_variable = any(map(lambda x: isinstance(x, Variable), args)) + if has_variable: + if len(args) == 1: return paddle_range(0, args[0], 1, paddle.int64) + if len(args) == 2: + return paddle_range(args[0], args[1], 1, paddle.int64) + if len(args) == 3: + return paddle_range(args[0], args[1], args[2], paddle.int64) + return range(*args) + + def convert_shape(x): """ A function representation of the shape of variable. diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/create_variable_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/create_variable_transformer.py new file mode 100644 index 00000000000..8ae4c12eb8e --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/create_variable_transformer.py @@ -0,0 +1,48 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from paddle.utils import gast +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper +from paddle.fluid.dygraph.dygraph_to_static.utils import FunctionNameLivenessAnalysis +from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var +from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer + + +class CreateVariableTransformer(BaseTransformer): + """ + """ + + def __init__(self, wrapper_root): + assert isinstance( + wrapper_root, AstNodeWrapper + ), "Type of input node should be AstNodeWrapper, but received %s ." % type( + wrapper_root) + self.root = wrapper_root.node + FunctionNameLivenessAnalysis(self.root) + + def transform(self): + """ + Main function to transform AST. + """ + self.visit(self.root) + + def visit_FunctionDef(self, node): + #attributes = set(filter(lambda x: '.' in x, node.pd_scope.modified_vars())) + bodys = node.body + names = sorted(node.pd_scope.created_vars()) + for name in names: + bodys[0:0] = [create_undefined_var(name)] + return node diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index a65e86f8e82..07d4920d433 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -34,6 +34,7 @@ from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_un from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_nodes from paddle.fluid.dygraph.dygraph_to_static.utils import create_get_args_node, create_set_args_node from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer +from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_INDEX_PREFIX, FOR_ITER_TUPLE_PREFIX, FOR_ITER_TUPLE_INDEX_PREFIX, FOR_ITER_VAR_LEN_PREFIX, FOR_ITER_VAR_NAME_PREFIX, FOR_ITER_ZIP_TO_LIST_PREFIX, FOR_ITER_TARGET_PREFIX, FOR_ITER_ITERATOR_PREFIX TRUE_FUNC_PREFIX = 'true_fn' FALSE_FUNC_PREFIX = 'false_fn' @@ -304,7 +305,6 @@ def transform_if_else(node, root): """ # TODO(liym27): Consider variable like `self.a` modified in if/else node. - new_vars_to_create = sorted(list(node.pd_scope.created_vars())) return_name_ids = sorted(list(node.pd_scope.modified_vars())) # NOTE: Python can create variable only in if body or only in else body, and use it out of if/else. # E.g. @@ -315,10 +315,6 @@ def transform_if_else(node, root): # # Create static variable for those variables create_new_vars_in_parent_stmts = [] - for name in new_vars_to_create: - # NOTE: Consider variable like `self.a` modified in if/else node. - if "." not in name: - create_new_vars_in_parent_stmts.append(create_undefined_var(name)) nonlocal_names = list(return_name_ids) nonlocal_names.sort() @@ -326,8 +322,21 @@ def transform_if_else(node, root): nonlocal_names = _valid_nonlocal_names(return_name_ids, nonlocal_names) # TODO(dev): Need a better way to deal this. - if ARGS_NAME in nonlocal_names: - nonlocal_names.remove(ARGS_NAME) + # LoopTransformer will create some special vars, which is not visiable by users. so we can sure it's safe to remove them. + filter_names = [ + ARGS_NAME, FOR_ITER_INDEX_PREFIX, FOR_ITER_TUPLE_PREFIX, + FOR_ITER_TARGET_PREFIX, FOR_ITER_ITERATOR_PREFIX, + FOR_ITER_TUPLE_INDEX_PREFIX, FOR_ITER_VAR_LEN_PREFIX, + FOR_ITER_VAR_NAME_PREFIX, FOR_ITER_ZIP_TO_LIST_PREFIX + ] + + def remove_if(x): + for name in filter_names: + if x.startswith(name): return False + return True + + nonlocal_names = list(filter(remove_if, nonlocal_names)) + return_name_ids = nonlocal_names nonlocal_stmt_node = create_nonlocal_stmt_nodes(nonlocal_names) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index 29ac905074e..099f6697480 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -26,8 +26,8 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var -from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_nodes, create_get_args_node, create_set_args_node +from paddle.fluid.dygraph.dygraph_to_static.utils import FunctionNameLivenessAnalysis from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import ARGS_NAME from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer from paddle.fluid.dygraph.dygraph_to_static.base_transformer import RenameTransformer @@ -483,10 +483,10 @@ class LoopTransformer(BaseTransformer): ), "Input non-AstNodeWrapper node for the initialization of LoopTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node + FunctionNameLivenessAnalysis(self.root) def transform(self): ForLoopTuplePreTransformer(self.wrapper_root).transform() - self.name_visitor = NameVisitor(self.root) self.visit(self.root) def visit_While(self, node): @@ -537,19 +537,19 @@ class LoopTransformer(BaseTransformer): return [node] init_stmts, cond_stmt, body_stmts = stmts_tuple # 2. get original loop vars - loop_var_names, create_var_names = self.name_visitor.get_loop_var_names( - node) + loop_var_names, create_var_names = node.pd_scope.modified_vars( + ), node.pd_scope.created_vars() + # TODO: Remove the bunch of code? We have the unique format `for A in B:` # NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases, # we need append new loop var & remove useless loop var # 1. for x in var -> x is no need # 2. for i, x in enumerate(var) -> x is no need - if current_for_node_parser.is_for_iter( - ) or current_for_node_parser.is_for_enumerate_iter(): + if current_for_node_parser.is_for_iter(): iter_var_name = current_for_node_parser.iter_var_name iter_idx_name = current_for_node_parser.iter_idx_name loop_var_names.add(iter_idx_name) - if iter_var_name not in create_var_names: - loop_var_names.remove(iter_var_name) + if current_for_node_parser.enum_idx_name is not None: + loop_var_names.add(current_for_node_parser.enum_idx_name) # 3. prepare result statement list new_stmts = [] @@ -559,10 +559,8 @@ class LoopTransformer(BaseTransformer): # y += x # print(x) # x = 10 # - # We need to create static variable for those variables - for name in create_var_names: - if "." not in name: - new_stmts.append(create_undefined_var(name)) + # We don't need to create static variable for them, because + # we do this in CreateUndefinedVarTransformer # create non-local statement for body and cond. nonlocal_names = list(loop_var_names | create_var_names) @@ -581,10 +579,7 @@ class LoopTransformer(BaseTransformer): name=unique_name.generate(FOR_CONDITION_PREFIX), args=gast.arguments(args=[], posonlyargs=[], - vararg=gast.Name(id=ARGS_NAME, - ctx=gast.Param(), - annotation=None, - type_comment=None), + vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, @@ -597,17 +592,11 @@ class LoopTransformer(BaseTransformer): # 6. create & append loop body function node # append return values for loop body - body_stmts.append( - gast.Return(value=generate_name_node( - nonlocal_names, ctx=gast.Load(), gen_tuple_if_single=True))) body_func_node = gast.FunctionDef( name=unique_name.generate(FOR_BODY_PREFIX), args=gast.arguments(args=[], posonlyargs=[], - vararg=gast.Name(id=ARGS_NAME, - ctx=gast.Param(), - annotation=None, - type_comment=None), + vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, @@ -632,8 +621,8 @@ class LoopTransformer(BaseTransformer): return new_stmts def get_while_stmt_nodes(self, node): - loop_var_names, create_var_names = self.name_visitor.get_loop_var_names( - node) + loop_var_names, create_var_names = node.pd_scope.modified_vars( + ), node.pd_scope.created_vars() new_stmts = [] # create non-local statement for body and cond. @@ -652,19 +641,14 @@ class LoopTransformer(BaseTransformer): # y = x # z = y # - # We need to create static variable for those variables - for name in create_var_names: - if "." not in name: - new_stmts.append(create_fill_constant_node(name)) + # We don't need to create static variable for those variables, because + # we do this in CreateUndefinedVarTransformer condition_func_node = gast.FunctionDef( name=unique_name.generate(WHILE_CONDITION_PREFIX), args=gast.arguments(args=[], posonlyargs=[], - vararg=gast.Name(id=ARGS_NAME, - ctx=gast.Param(), - annotation=None, - type_comment=None), + vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, @@ -677,17 +661,11 @@ class LoopTransformer(BaseTransformer): new_stmts.append(condition_func_node) new_body = node.body - new_body.append( - gast.Return(value=generate_name_node( - nonlocal_names, ctx=gast.Load(), gen_tuple_if_single=True))) body_func_node = gast.FunctionDef( name=unique_name.generate(WHILE_BODY_PREFIX), args=gast.arguments(args=[], posonlyargs=[], - vararg=gast.Name(id=ARGS_NAME, - ctx=gast.Param(), - annotation=None, - type_comment=None), + vararg=None, kwonlyargs=[], kw_defaults=None, kwarg=None, diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 9f390252f3a..ed7faf83cef 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -82,6 +82,8 @@ dygraph_class_to_static_api = { FOR_ITER_INDEX_PREFIX = '__for_loop_var_index' FOR_ITER_TUPLE_PREFIX = '__for_loop_iter_tuple' +FOR_ITER_TARGET_PREFIX = '__for_loop_iter_target' +FOR_ITER_ITERATOR_PREFIX = '__for_loop_iter_iterator' FOR_ITER_TUPLE_INDEX_PREFIX = '__for_loop_iter_tuple_index' FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len' FOR_ITER_VAR_NAME_PREFIX = '__for_loop_iter_var' @@ -1099,6 +1101,18 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): if isinstance(node, gast.FunctionDef): return self._get_name_scope(node) + def visit_ListComp(self, node): + """ [ i for i in range(10) ] + In this case, `i` will not created in FunctionScope. + We don't collect `i` by not calling generic_visit. + """ + pass + + def visit_DictComp(self, node): + """ the same as ListComp. + """ + pass + def visit_Name(self, node): self.generic_visit(node) write_context = (gast.Store, gast.AugStore, gast.Del) @@ -1149,8 +1163,13 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): def post_func(): self._father_name_scope().merge_from(self._current_name_scope()) + self._nearest_function_scope().merge_from( + self._current_name_scope()) self._current_name_scope().created = self._nearest_function_scope( ).existed_vars() - node.before_created + # gather created vars into father and used in CreateUndefinedVarTransform + self._nearest_function_scope().created |= self._current_name_scope( + ).created def pre_func(): setattr(node, "before_created", diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index bc1a2c15dd3..d7b85961247 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -108,7 +108,6 @@ def select_input(inputs, mask): def select_input_with_buildin_type(inputs, mask): from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, create_undefined_var_like - support_ret_buildin_type = (bool, float, six.integer_types) false_var, true_var = inputs if isinstance(false_var, UndefinedVar) and isinstance( @@ -1182,12 +1181,16 @@ class While(object): }) +support_ret_buildin_type = (bool, float, six.integer_types) + + def assign_skip_lod_tensor_array(input, output): """ Assign input to output, but skip the process of copying LoDTensorArray unless it's created in while_block. """ if not isinstance(input, (Variable, core.VarBase)): - if isinstance(output, Variable): + if isinstance(output, Variable) and isinstance( + input, support_ret_buildin_type): assign(input, output) else: output = input @@ -1297,6 +1300,7 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): if not isinstance(output_vars, (list, tuple)): output_vars = [output_vars] try: + loop_vars = _deal_with_undefined_var(output_vars, loop_vars) assert_same_structure(output_vars, loop_vars, check_types=False) except ValueError as e: raise ValueError( @@ -1308,6 +1312,36 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): return loop_vars +def _deal_with_undefined_var(output_vars, loop_vars): + """ Deal with undefined var cases, We create undefined variable based on the results of body(). + In Dy2Static, we use undefined var to represent the var created in control flow. This function + expand the loop_vars and replace original loop_vars. + 1. UndefinedVar = Variable # create a variable + 2. UndefinedVar = None # create a undefined var with RETURN_NO_VALUE_MAGIC_NUM + 3. UndefinedVar = List(int) # create a list of variable + 4. UndefinedVar = value # create a variable + """ + from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, create_undefined_variable + + def create_var_like(o_var): + if isinstance(o_var, + (Variable, ) + support_ret_buildin_type) or o_var is None: + return create_undefined_variable() + if isinstance(o_var, (tuple, list)): + return [create_undefined_variable() for i in range(len(o_var))] + + if len(output_vars) != len(loop_vars): + raise ValueError("The length of loop_vars should be the same.") + + results = [] + for o_var, l_var in zip(output_vars, loop_vars): + if isinstance(l_var, UndefinedVar) or l_var is None: + results.append(create_var_like(o_var)) + else: + results.append(l_var) + return results + + def lod_rank_table(x, level=0): """ LoD Rank Table Operator. Given an input variable **x** and a level number @@ -2616,6 +2650,11 @@ def change_none_to_undefinedvar(nest1, nest2): def expand_undefined_var(nest1, nest2, names): + """ TODO: make this function recursively. + nest1: Var1, (UndefinedVar, [1,2,3]) + nest2: Var2, ([1,2,3,4], UndefinedVar) + In this case, we should not expand recursively. + """ from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_VALUE_PREFIX diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/seq2seq_dygraph_model.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/seq2seq_dygraph_model.py index b544ca9bd83..ce322db06cf 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/seq2seq_dygraph_model.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/seq2seq_dygraph_model.py @@ -385,6 +385,7 @@ class BaseModel(fluid.dygraph.Layer): dropout_implementation='upscale_in_train') else: step_input = new_hidden + cell_outputs = self._split_batch_beams(step_input) cell_outputs = self.fc(cell_outputs) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py index 683135b9078..ff3e0da6fea 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py @@ -442,13 +442,6 @@ class TestErrorInForLoop(TestTransformForLoop): def _init_dyfunc(self): self.dyfunc = for_loop_dyfunc_not_support - def test_ast_to_func(self): - with self.assertRaisesRegexp( - NotImplementedError, - "Dynamic-to-Static only supports the step value is a constant or negative constant " - ): - self._run_static() - if __name__ == '__main__': with fluid.framework._test_eager_guard(): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py index c7cecab04f5..27debe00af1 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py @@ -66,6 +66,9 @@ def get_source_code(func): class StaticCode1(): def dyfunc_with_if_else(x_v, label=None): + loss = _jst.UndefinedVar('loss') + __return_1 = _jst.UndefinedVar('__return_1') + __return_0 = _jst.UndefinedVar('__return_0') __return_value_0 = None def get_args_0(): @@ -89,9 +92,6 @@ class StaticCode1(): _jst.IfElse( paddle.mean(x_v)[0] > 5, true_fn_0, false_fn_0, get_args_0, set_args_0, ('x_v', )) - __return_0 = _jst.UndefinedVar('__return_0') - __return_1 = _jst.UndefinedVar('__return_1') - loss = _jst.UndefinedVar('loss') def get_args_1(): nonlocal __return_0, __return_1, __return_value_0, loss @@ -123,6 +123,9 @@ class StaticCode1(): class StaticCode2(): # TODO: Transform return statement def dyfunc_with_if_else(x_v, label=None): + loss = _jst.UndefinedVar('loss') + __return_3 = _jst.UndefinedVar('__return_3') + __return_2 = _jst.UndefinedVar('__return_2') __return_value_1 = None def get_args_2(): @@ -146,9 +149,6 @@ class StaticCode2(): _jst.IfElse( paddle.mean(x_v)[0] > 5, true_fn_2, false_fn_2, get_args_2, set_args_2, ('x_v', )) - __return_2 = _jst.UndefinedVar('__return_2') - __return_3 = _jst.UndefinedVar('__return_3') - loss = _jst.UndefinedVar('loss') def get_args_3(): nonlocal __return_2, __return_3, __return_value_1, loss diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py index 9b1cde6dcc5..0d1dc69823a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py @@ -578,8 +578,8 @@ class TestOpNumWithTensorShapeInFor1(TestOpNumBasicWithTensorShape): self.dygraph_func = dyfunc_with_for_1 def _set_expected_op_num(self): - self.expected_op_num = 22 - self.expected_shape_op_num = 3 + self.expected_op_num = 29 + self.expected_shape_op_num = 2 self.expected_slice_op_num = 3 @@ -589,7 +589,7 @@ class TestOpNumWithTensorShapeInWhile1(TestOpNumBasicWithTensorShape): self.dygraph_func = dyfunc_with_while_1 def _set_expected_op_num(self): - self.expected_op_num = 22 + self.expected_op_num = 21 self.expected_shape_op_num = 3 self.expected_slice_op_num = 3 diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/transformer_dygraph_model.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/transformer_dygraph_model.py index 2239c6544f2..57b6fc55efb 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/transformer_dygraph_model.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/transformer_dygraph_model.py @@ -21,7 +21,7 @@ import paddle.fluid.layers as layers from paddle.fluid.dygraph import Embedding, Layer, LayerNorm, Linear, to_variable from paddle.fluid.dygraph.jit import dygraph_to_static_func from paddle.fluid.layers.utils import map_structure -from paddle.fluid.layers.tensor import range as pd_range +import paddle def position_encoding_init(n_position, d_pos_vec): @@ -634,7 +634,7 @@ class Transformer(Layer): value=0), } for i in range(self.n_layer)] - for i in pd_range(0, max_len, 1, dtype="int32"): + for i in range(paddle.to_tensor(max_len)): trg_pos = layers.fill_constant(shape=trg_word.shape, dtype="int64", value=i) diff --git a/python/paddle/jit/dy2static/__init__.py b/python/paddle/jit/dy2static/__init__.py index 0a51a3e265e..ebb4d30a412 100644 --- a/python/paddle/jit/dy2static/__init__.py +++ b/python/paddle/jit/dy2static/__init__.py @@ -26,7 +26,8 @@ from .convert_operators import convert_pop as Pop # noqa: F401 from .convert_operators import convert_print as Print # noqa: F401 from .convert_operators import convert_shape as Shape # noqa: F401 from .convert_operators import convert_while_loop as While # noqa: F401 - +from .convert_operators import unpack_by_structure as Unpack # noqa: F401 +from .convert_operators import indexable as Indexable # noqa: F401 from .variable_trans_func import create_bool_as_type # noqa: F401 from .variable_trans_func import to_static_variable # noqa: F401 from .convert_operators import convert_shape_compare # noqa: F401 diff --git a/python/paddle/jit/dy2static/convert_operators.py b/python/paddle/jit/dy2static/convert_operators.py index 59ffedef0a9..691c8c0cfbe 100644 --- a/python/paddle/jit/dy2static/convert_operators.py +++ b/python/paddle/jit/dy2static/convert_operators.py @@ -26,5 +26,6 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape_c from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype # noqa: F401 from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape # noqa: F401 from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_loop # noqa: F401 +from ...fluid.dygraph.dygraph_to_static.convert_operators import unpack_by_structure, indexable # noqa: F401 __all__ = [] -- GitLab