diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index cbb4655f354a5a9f9e3e851e0c1dd81b18936ef4..a6cab0db513803c678d89f501d5a2a512ecb6b2a 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -24,7 +24,7 @@ from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_ from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar -def convert_while_loop(cond, body, loop_vars): +def convert_while_loop(cond, body, getter, setter): """ A function representation of a Python ``while`` statement. @@ -39,25 +39,36 @@ def convert_while_loop(cond, body, loop_vars): # NOTE: It may be slower if cond is very expensive, but usually cond is just O(1). # If loop_vars is changed during cond callable, then it causes bug, but current logical_and/logical_not/... doesn't change the loop_vars. - pred = cond(*loop_vars) + pred = cond() if isinstance(pred, Variable): - loop_vars = _run_paddle_while_loop(cond, body, loop_vars) + loop_vars = _run_paddle_while(cond, body, getter, setter) else: - loop_vars = _run_py_while(cond, body, loop_vars) + loop_vars = _run_py_while(cond, body, getter, setter) return loop_vars -def _run_paddle_while_loop(cond, body, loop_vars): +def _run_paddle_while(cond, body, getter, setter): # NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Tensors. - loop_vars = [to_static_variable(var) for var in loop_vars] + def to_list(x): + if isinstance(x, (tuple, list)): return x + return [x] + + # UndefinedVar will become data layer not check. + loop_vars = [to_static_variable(var) for var in to_list(getter())] + setter(loop_vars if len(loop_vars) > 1 else + loop_vars[0]) # change the non-local var to variable + # variable maybe modified to inner var. change it into loop_vars = control_flow.while_loop(cond, body, loop_vars) + setter(loop_vars if len(loop_vars) > 1 else + loop_vars[0]) # change the non-local var to variable return loop_vars -def _run_py_while(cond, body, loop_vars): - while cond(*loop_vars): - loop_vars = body(*loop_vars) +def _run_py_while(cond, body, getter, setter): + loop_vars = getter() + while cond(): + loop_vars = body() return loop_vars diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index 1935629f54e865040900cb8eb934bfc3e86988d5..d4449f6dfc24ef6887f9df66427cfa4d6e5c7c36 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -31,7 +31,8 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import create_assign_node from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var -from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_nonlocal_stmt_node +from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_node +from paddle.fluid.dygraph.dygraph_to_static.utils import create_get_args_node, create_set_args_node TRUE_FUNC_PREFIX = 'true_fn' FALSE_FUNC_PREFIX = 'false_fn' @@ -415,17 +416,22 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict, # modified vars body_modified_vars = _modified_vars(if_vars_dict, parent_vars_dict) + body_modified_vars = set( + filter(lambda x: x != ARGS_NAME, body_modified_vars)) orelse_modified_vars = _modified_vars(else_vars_dict, parent_vars_dict) + orelse_modified_vars = set( + filter(lambda x: x != ARGS_NAME, orelse_modified_vars)) modified_vars = body_modified_vars | orelse_modified_vars # new vars + # TODO(remove __args when new FunctionScopeAnalysis has been used.) body_new_vars = set([ var for var in _vars_with_store(if_vars_dict) - if var not in parent_vars_dict + if var not in parent_vars_dict and var != ARGS_NAME ]) orelse_new_vars = set([ var for var in _vars_with_store(else_vars_dict) - if var not in parent_vars_dict + if var not in parent_vars_dict and var != ARGS_NAME ]) new_vars_in_body_or_orelse = body_new_vars | orelse_new_vars new_vars_in_one_of_body_or_orelse = body_new_vars ^ orelse_new_vars @@ -511,11 +517,11 @@ def transform_if_else(node, root): if any([not isinstance(ctx, gast.Load) for ctx in ctxs]): parent_ids_set.add(k) - trun_args = parse_cond_args(parent_ids_set, body_name_ids, + true_args = parse_cond_args(parent_ids_set, body_name_ids, modified_name_ids_from_parent) false_args = parse_cond_args(parent_ids_set, orelse_name_ids, modified_name_ids_from_parent) - nonlocal_names = list(trun_args | false_args | new_vars_to_create) + nonlocal_names = list(true_args | false_args | new_vars_to_create) nonlocal_names.sort() # NOTE: All var in return_name_ids should be in nonlocal_names. nonlocal_names = _valid_nonlocal_names(return_name_ids, nonlocal_names) @@ -552,70 +558,6 @@ def transform_if_else(node, root): return create_new_vars_in_parent_stmts, true_func_node, false_func_node, get_args_node, set_args_node, return_name_ids -def create_get_args_node(names): - """ - Create get_args function as follows: - - def get_args_0(): - nonlocal x, y - return x, y - """ - - def empty_node(): - func_def = """ - def {func_name}(): - return - """.format(func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX)) - return gast.parse(textwrap.dedent(func_def)).body[0] - - assert isinstance(names, (list, tuple)) - if not names: - return empty_node() - - template = """ - def {func_name}(): - nonlocal {vars} - return {vars} - """ - func_def = template.format( - func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX), - vars=",".join(names)) - return gast.parse(textwrap.dedent(func_def)).body[0] - - -def create_set_args_node(names): - """ - Create set_args function as follows: - - def set_args_0(__args): - nonlocal x, y - x, y = __args - """ - - def empty_node(): - func_def = """ - def {func_name}({args}): - pass - """.format(func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX), - args=ARGS_NAME) - return gast.parse(textwrap.dedent(func_def)).body[0] - - assert isinstance(names, (list, tuple)) - if not names: - return empty_node() - - template = """ - def {func_name}({args}): - nonlocal {vars} - {vars} = {args} - """ - func_def = template.format( - func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX), - args=ARGS_NAME, - vars=",".join(names)) - return gast.parse(textwrap.dedent(func_def)).body[0] - - def create_convert_ifelse_node(return_name_ids, pred, true_func, diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index 832c502c0aa5cd582250dcf161ca3c243d931402..63fc4f0489acba5d90e7d5f58147eaf019efaded 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -28,7 +28,10 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name from paddle.fluid.dygraph.dygraph_to_static.utils import ForLoopTuplePreTransformer from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer +from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node +from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_node, create_get_args_node, create_set_args_node +from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import ARGS_NAME __all__ = ['LoopTransformer', 'NameVisitor'] @@ -37,12 +40,10 @@ WHILE_BODY_PREFIX = 'while_body' FOR_CONDITION_PREFIX = 'for_loop_condition' FOR_BODY_PREFIX = 'for_loop_body' -GENERATE_VARIABLE_PREFIX = 'generate_variable' -ATTRIBUTE_VARIABLE_PREFIX = '__attribute_variable' - -def create_while_nodes(condition_name, body_name, loop_var_names): +def create_while_nodes(condition_name, body_name, loop_var_names, getter_name, + setter_name): """ Returns a list of gast.Node which represents the calling of Paddle controlflow while_loop. @@ -74,37 +75,20 @@ def create_while_nodes(condition_name, body_name, loop_var_names): # # For example: loop_var_names = [a, b, foo.x], the type of `a` or `b` is gast.Name, # but the type of `foo.x` gast.Attribute. - - unique_name_to_origin = {} # We have to make loop_var_names and assign_loop_var_names with same order # set doesn't have order so we convert it to list loop_var_names = list(loop_var_names) assign_loop_var_names = [] for name in (loop_var_names): - if "." in name: - # name is an attribute variable such as foo.x - tmp_attr_name = unique_name.generate(ATTRIBUTE_VARIABLE_PREFIX) - unique_name_to_origin[tmp_attr_name] = name - assign_loop_var_names.append(tmp_attr_name) - else: - assign_loop_var_names.append(name) + assign_loop_var_names.append(name) while_func_name = "_jst.While" - while_node_str = "[{}] = {}({}, {}, [{}])".format( - ",".join(assign_loop_var_names), while_func_name, condition_name, - body_name, ",".join(loop_var_names)) + while_node_str = "{}({}, {}, {}, {})".format(while_func_name, + condition_name, body_name, + getter_name, setter_name) while_node = gast.parse(while_node_str).body[0] ret = [while_node] - for tmp_attr_name in unique_name_to_origin: - origin_attr_var = unique_name_to_origin[tmp_attr_name] - dot_pos = origin_attr_var.rindex(".") - obj_name = origin_attr_var[0:dot_pos] - attr_name = origin_attr_var[dot_pos + 1:] - assign_if_not_prop_str = "if not isinstance(getattr(type({}), '{}', None), property): {} = {}".format( - obj_name, attr_name, origin_attr_var, tmp_attr_name) - assign_if_not_prop_node = gast.parse(assign_if_not_prop_str).body[0] - ret.append(assign_if_not_prop_node) return ret @@ -117,8 +101,10 @@ class NameScope: self.globals = set() self.nonlocals = set() self.args = set() - self.w_vars = set() # all vars been stored, + # all vars been stored, # may be globals or non-locals + self.w_vars = set() + def created_vars(self): return self.w_vars - self.globals - self.nonlocals - self.args @@ -282,9 +268,7 @@ class NameVisitor(gast.NodeVisitor): # If this var is a basic variable and read-only and not # condition var, it may not be loop_var else it should # be in loop_var as input - if (not name in condition_names) and ( - not name in write_names - ) and self._node_var_type_is_basic(name_to_type[name]): + if (not name in condition_names) and (not name in write_names): continue loop_var_names.add(name) @@ -645,7 +629,6 @@ class LoopTransformer(gast.NodeTransformer): if stmts_tuple is None: return [node] init_stmts, cond_stmt, body_stmts = stmts_tuple - # 2. get original loop vars loop_var_names, create_var_names = self.name_visitor.get_loop_var_names( node) @@ -672,7 +655,16 @@ class LoopTransformer(gast.NodeTransformer): # We need to create static variable for those variables for name in create_var_names: if "." not in name: - new_stmts.append(create_fill_constant_node(name)) + new_stmts.append(create_undefined_var(name)) + + # create non-local statement for body and cond. + nonlocal_names = list(loop_var_names | create_var_names) + nonlocal_names.sort() + # TODO(dev): Need a better way to deal this. + if ARGS_NAME in nonlocal_names: + nonlocal_names.remove(ARGS_NAME) + + nonlocal_stmt_node = [create_nonlocal_stmt_node(nonlocal_names)] # 4. append init statements new_stmts.extend(init_stmts) @@ -680,63 +672,54 @@ class LoopTransformer(gast.NodeTransformer): # 5. create & append condition function node condition_func_node = gast.FunctionDef( name=unique_name.generate(FOR_CONDITION_PREFIX), - args=gast.arguments(args=[ - gast.Name(id=name, - ctx=gast.Param(), - annotation=None, - type_comment=None) for name in loop_var_names - ], + args=gast.arguments(args=[], posonlyargs=[], - vararg=None, + vararg=gast.Name(id=ARGS_NAME, + ctx=gast.Param(), + annotation=None, + type_comment=None), kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), - body=[gast.Return(value=cond_stmt)], + body=nonlocal_stmt_node + [gast.Return(value=cond_stmt)], decorator_list=[], returns=None, type_comment=None) - for name in loop_var_names: - if "." in name: - rename_transformer = RenameTransformer(condition_func_node) - rename_transformer.rename( - name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(condition_func_node) # 6. create & append loop body function node # append return values for loop body body_stmts.append( gast.Return(value=generate_name_node( - loop_var_names, ctx=gast.Load(), gen_tuple_if_single=True))) + nonlocal_names, ctx=gast.Load(), gen_tuple_if_single=True))) body_func_node = gast.FunctionDef( name=unique_name.generate(FOR_BODY_PREFIX), - args=gast.arguments(args=[ - gast.Name(id=name, - ctx=gast.Param(), - annotation=None, - type_comment=None) for name in loop_var_names - ], + args=gast.arguments(args=[], posonlyargs=[], - vararg=None, + vararg=gast.Name(id=ARGS_NAME, + ctx=gast.Param(), + annotation=None, + type_comment=None), kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), - body=body_stmts, + body=nonlocal_stmt_node + body_stmts, decorator_list=[], returns=None, type_comment=None) - for name in loop_var_names: - if "." in name: - rename_transformer = RenameTransformer(body_func_node) - rename_transformer.rename( - name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(body_func_node) + get_args_node = create_get_args_node(nonlocal_names) + set_args_node = create_set_args_node(nonlocal_names) # 7. create & append while loop node while_loop_nodes = create_while_nodes(condition_func_node.name, body_func_node.name, - loop_var_names) + nonlocal_names, + get_args_node.name, + set_args_node.name) + new_stmts.extend([get_args_node, set_args_node]) new_stmts.extend(while_loop_nodes) return new_stmts @@ -746,6 +729,15 @@ class LoopTransformer(gast.NodeTransformer): node) new_stmts = [] + # create non-local statement for body and cond. + nonlocal_names = list(loop_var_names | create_var_names) + nonlocal_names.sort() + # TODO(dev): Need a better way to deal this. + if ARGS_NAME in nonlocal_names: + nonlocal_names.remove(ARGS_NAME) + + nonlocal_stmt_node = [create_nonlocal_stmt_node(nonlocal_names)] + # Python can create variable in loop and use it out of loop, E.g. # # while x < 10: @@ -760,61 +752,52 @@ class LoopTransformer(gast.NodeTransformer): condition_func_node = gast.FunctionDef( name=unique_name.generate(WHILE_CONDITION_PREFIX), - args=gast.arguments(args=[ - gast.Name(id=name, - ctx=gast.Param(), - annotation=None, - type_comment=None) for name in loop_var_names - ], + args=gast.arguments(args=[], posonlyargs=[], - vararg=None, + vararg=gast.Name(id=ARGS_NAME, + ctx=gast.Param(), + annotation=None, + type_comment=None), kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), - body=[gast.Return(value=node.test)], + body=nonlocal_stmt_node + [gast.Return(value=node.test)], decorator_list=[], returns=None, type_comment=None) - for name in loop_var_names: - if "." in name: - rename_transformer = RenameTransformer(condition_func_node) - rename_transformer.rename( - name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(condition_func_node) new_body = node.body new_body.append( gast.Return(value=generate_name_node( - loop_var_names, ctx=gast.Load(), gen_tuple_if_single=True))) + nonlocal_names, ctx=gast.Load(), gen_tuple_if_single=True))) body_func_node = gast.FunctionDef( name=unique_name.generate(WHILE_BODY_PREFIX), - args=gast.arguments(args=[ - gast.Name(id=name, - ctx=gast.Param(), - annotation=None, - type_comment=None) for name in loop_var_names - ], + args=gast.arguments(args=[], posonlyargs=[], - vararg=None, + vararg=gast.Name(id=ARGS_NAME, + ctx=gast.Param(), + annotation=None, + type_comment=None), kwonlyargs=[], kw_defaults=None, kwarg=None, defaults=[]), - body=new_body, + body=nonlocal_stmt_node + new_body, decorator_list=[], returns=None, type_comment=None) - for name in loop_var_names: - if "." in name: - rename_transformer = RenameTransformer(body_func_node) - rename_transformer.rename( - name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(body_func_node) + get_args_node = create_get_args_node(nonlocal_names) + set_args_node = create_set_args_node(nonlocal_names) while_loop_nodes = create_while_nodes(condition_func_node.name, body_func_node.name, - loop_var_names) + nonlocal_names, + get_args_node.name, + set_args_node.name) + new_stmts.extend([get_args_node, set_args_node]) new_stmts.extend(while_loop_nodes) return new_stmts diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 8dd11c06e463fe74aa3ee0bd7404eb7136aaca48..466e9ee4d34c1db39b915ba6476237422bcafd0b 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -30,6 +30,8 @@ import numpy as np import paddle from paddle.fluid import unique_name from paddle.fluid.data_feeder import convert_dtype +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid import core # Note(Aurelius): Do not forget the dot `.` to distinguish other # module such as paddlenlp. @@ -59,6 +61,51 @@ class BaseNodeVisitor(gast.NodeVisitor): return ret +def data_layer_not_check(name, shape, dtype='float32', lod_level=0): + """ + This function creates a Tensor on the global block. The created Tensor + doesn't check the dtype and the shape of feed data because dygraph input + data can be various-length. This API is used in translating dygraph into + static graph. + + Note: + The default :code:`stop_gradient` attribute of the Tensor created by + this API is true, which means the gradient won't be passed backward + through the data Tensor. Set :code:`var.stop_gradient = False` If + user would like to pass backward gradient. + + Args: + name (str): The name/alias of the Tensor, see :ref:`api_guide_Name` + for more details. + shape (list|tuple): List|Tuple of integers declaring the shape. You can + set "None" at a dimension to indicate the dimension can be of any + size. For example, it is useful to set changeable batch size as "None" + dtype (np.dtype|VarType|str, optional): The type of the data. Supported + dtype: bool, float16, float32, float64, int8, int16, int32, int64, + uint8. Default: float32 + lod_level (int, optional): The LoD level of the LoDTensor. Usually users + don't have to set this value. For more details about when and how to + use LoD level, see :ref:`user_guide_lod_tensor` . Default: 0 + + Returns: + Tensor: The global Tensor that gives access to the data. + """ + helper = LayerHelper('data', **locals()) + shape = list(shape) + for i in six.moves.range(len(shape)): + if shape[i] is None: + shape[i] = -1 + + return helper.create_variable(name=name, + shape=shape, + dtype=dtype, + type=core.VarDesc.VarType.LOD_TENSOR, + stop_gradient=True, + lod_level=lod_level, + is_data=True, + need_check_feed=False) + + # imp is deprecated in python3 from importlib.machinery import SourceFileLoader @@ -412,10 +459,16 @@ def generate_name_node(name_ids, ctx=gast.Load(), gen_tuple_if_single=False): raise TypeError( 'name_ids must be list or tuple or set, but received %s' % type(type(name_ids))) - gast_names = [ - gast.Name(id=name_id, ctx=ctx, annotation=None, type_comment=None) - for name_id in name_ids - ] + + def create_node_for_name(name): + if '.' not in name: + return gast.Name(id=name, + ctx=ctx, + annotation=None, + type_comment=None) + return gast.parse(name).body[0].value + + gast_names = [create_node_for_name(name_id) for name_id in name_ids] if len(gast_names) == 1 and not gen_tuple_if_single: name_node = gast_names[0] else: @@ -842,6 +895,16 @@ class NameNodeReplaceTransformer(gast.NodeTransformer): return self.replace_node return node + def visit_Nonlocal(self, node): + names = node.names + + def replace(s): + if s == self.target_name: return self.replace_node.id + return s + + node.names = list(map(replace, names)) + return node + class ForLoopTuplePreTransformer(gast.NodeTransformer): """ @@ -1527,3 +1590,93 @@ def slice_is_num(slice_node): return True return False + + +def create_get_args_node(names): + """ + Create get_args function as follows: + + def get_args_0(): + nonlocal x, y + return x, y + """ + + def empty_node(): + func_def = """ + def {func_name}(): + return + """.format(func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX)) + return gast.parse(textwrap.dedent(func_def)).body[0] + + assert isinstance(names, (list, tuple)) + if not names: + return empty_node() + + mapped = list(filter(lambda n: '.' not in n, names)) + nonlocal_names = sorted( + mapped, + key=mapped.index) # to keep the order, we can't use set() to unique + template = """ + def {func_name}(): + nonlocal {nonlocal_vars} + return {vars} + """ + func_def = template.format( + func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX), + nonlocal_vars=','.join(nonlocal_names), + vars=",".join(names)) + return gast.parse(textwrap.dedent(func_def)).body[0] + + +GET_ARGS_FUNC_PREFIX = 'get_args' +SET_ARGS_FUNC_PREFIX = 'set_args' +ARGS_NAME = '__args' + + +def create_set_args_node(names): + """ + Create set_args function as follows: + + def set_args_0(__args): + nonlocal x, y + x, y = __args + """ + + def empty_node(): + func_def = """ + def {func_name}({args}): + pass + """.format(func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX), + args=ARGS_NAME) + return gast.parse(textwrap.dedent(func_def)).body[0] + + assert isinstance(names, (list, tuple)) + if not names: + return empty_node() + + mapped = list(filter(lambda n: '.' not in n, names)) + nonlocal_names = sorted( + mapped, + key=mapped.index) # to keep the order, we can't use set() to unique + template = """ + def {func_name}({args}): + nonlocal {nonlocal_vars} + {vars} = {args} + """ + func_def = template.format( + func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX), + args=ARGS_NAME, + nonlocal_vars=','.join(nonlocal_names), + vars=",".join(names)) + return gast.parse(textwrap.dedent(func_def)).body[0] + + +def create_nonlocal_stmt_node(names): + assert isinstance(names, (list, tuple)) + + mapped = list(filter(lambda n: '.' not in n, names)) + names = sorted( + mapped, + key=mapped.index) # to keep the order, we can't use set() to unique + func_code = "nonlocal {}".format(','.join(names)) + return gast.parse(func_code).body[0] diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py index 92ef7a3f13d9bb86024299b871800a266680a420..9bbce59fc54cefdf5d7cb2e74140e7f7b073afe9 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py @@ -16,15 +16,17 @@ from __future__ import print_function import six import paddle +import textwrap from paddle.utils import gast -from paddle.fluid import core from paddle.fluid import unique_name from paddle.fluid.framework import Variable -from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, data_layer_not_check __all__ = [ - 'create_bool_as_type', 'create_fill_constant_node', 'to_static_variable', - 'create_undefined_var' + 'create_bool_as_type', + 'create_fill_constant_node', + 'to_static_variable', + 'create_undefined_var', ] @@ -33,12 +35,6 @@ def create_undefined_var(name): return gast.parse(func_code).body[0] -def create_nonlocal_stmt_node(names): - assert isinstance(names, (list, tuple)) - func_code = "nonlocal {}".format(','.join(names)) - return gast.parse(func_code).body[0] - - def create_fill_constant_node(name, value=0): func_code = "{} = paddle.full(shape=[1], ".format(name) if isinstance(value, bool): @@ -66,7 +62,9 @@ def to_static_variable(x): return paddle.full(shape=[1], dtype='float64', fill_value=x) if isinstance(x, six.integer_types): return paddle.full(shape=[1], dtype='int64', fill_value=x) - + if isinstance(x, UndefinedVar): + return data_layer_not_check(unique_name.generator("loop_undefined_var"), + [-1]) return x diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py index f573960b5dba0a7c207bb8e427e812e85815352f..1d64e7b81849f3a08ee35784292ee02e9ff4b7cb 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py @@ -177,7 +177,6 @@ def test_list_pop_in_for_loop(x, iter_num): one = fluid.layers.ones(shape=[1], dtype="int32") for i in range(one.numpy()[0]): item = a.pop() - return a[0], item, b[1] diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py index 78d97a3884aedf79dccaa40099a25c202d38fcd4..683135b9078dc8fc6ccde2e34d399bf8adce9e74 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py @@ -270,7 +270,7 @@ class TestNameVisitor(unittest.TestCase): self.loop_var_names = [ set(["j", "two"]), set(["i", "three", "b"]), - set(["i", "j"]) + set(["i"]) ] self.create_var_names = [set(), set(["b"]), set()]