未验证 提交 c5c6026e 编写于 作者: X xiongkun 提交者: GitHub

[ Dy2Static ]Change NameVisitor in while to FunctionScopeAnalysis (#44155)

* change NameVisitor to FunctionScopeAnalysis

* polish the logic of undefined var in while_loop. create vars after body execution

* replace old NameVisitor in while and fix all CI

* Togather with CreateVariableTransformer

* add create_variable_transformer

* fix bugs

* merge

* fix some error, TODO: ForNodePreTransform ahead

* merge for unite PR

* fix conflict with base_transformer PR

* fix ci errors, fix [for i in range()] error

* fix according to code review
上级 8759c78d
......@@ -35,6 +35,7 @@ from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTr
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer
from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransformer
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import ReturnTransformer
from paddle.fluid.dygraph.dygraph_to_static.create_variable_transformer import CreateVariableTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer
......@@ -96,7 +97,7 @@ class DygraphToStaticAst(BaseTransformer):
BreakContinueTransformer, # break/continue in loops
ReturnTransformer, # return in functions
LogicalTransformer, # logical and/or/not
#CreateVariableTransformer, # create undefined var for if / while / for
CreateVariableTransformer, # create undefined var for if / while / for
LoopTransformer, # for/while -> while_op
IfElseTransformer, # if/else -> cond_op
AssertTransformer, # assert statement
......
......@@ -24,6 +24,8 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_TUPLE_INDEX_PR
from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_VAR_LEN_PREFIX
from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_VAR_NAME_PREFIX
from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_ZIP_TO_LIST_PREFIX
from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_TARGET_PREFIX
from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_ITERATOR_PREFIX
class BaseTransformer(gast.NodeTransformer):
......@@ -119,32 +121,20 @@ class NameNodeReplaceTransformer(BaseTransformer):
class ForLoopTuplePreTransformer(BaseTransformer):
"""
ForNodeVisitor parses 3 type statements (Here var is VarBase(Tensor) or python variable):
1). for x in range(var[*]|var.numpy()[*])
2). for x in var|var.numpy()
3). for i, x in enumerate(var|var.numpy())
We chose these 3 types because they are easier (x can be variable name iterating in var).
However, users can write tuples in Python for loop, such as
1). for var1, var2 in var|var.numpy()
2). for t in enumerate(var|var.numpy())
2). for i, (var1, var2, va3) in enumerate(var|var.numpy())
To handle these case, this method will do the rewrite tuple pre-process:
1). Non-enumerate case: for var1, var2 in var|var.numpy() will be re-written as:
for FOR_ITER_TUPLE_PREFIX_x in var | var.numpy():
var1 = FOR_ITER_TUPLE_PREFIX_x[0]
var2 = FOR_ITER_TUPLE_PREFIX_x[1]
2). Enumerate out tuple case: for t in enumerate(var|var.numpy) will be rewritten as:
for FOR_ITER_TUPLE_INDEX_PREFIX_x, FOR_ITER_TUPLE_PREFIX_x in enumerate(var|var.numpy):
t = (FOR_ITER_TUPLE_INDEX_PREFIX_x, FOR_ITER_TUPLE_PREFIX_x)
3). Enumerate inner tuple case: for i, (var1, (var2, va3)) in enumerate(var|var.numpy()) will
be re-written as:
for i, FOR_ITER_TUPLE_PREFIX_x in var | var.numpy():
var1 = FOR_ITER_TUPLE_PREFIX_x[0]
var2 = FOR_ITER_TUPLE_PREFIX_x[1][0]
var3 = FOR_ITER_TUPLE_PREFIX_x[1][1]
""" pre-process of for loop.
>>> for A in B:
>>> C
will be changed into :
>>> UUID_iterator = _jst.Indexable(B) # make iterator-only to indexable list.
>>> for UUID_target in UUID_iterator:
>>> A = _jst.Unpack(UUID_target, structure)
>>> C
make the later loop_transform have unified type:
>>> for target in iter:
>>> body
"""
def __init__(self, wrapper_root):
......@@ -155,104 +145,45 @@ class ForLoopTuplePreTransformer(BaseTransformer):
self.visit(self.root)
def visit_For(self, node):
if self.is_for_enumerate_iter(node):
if isinstance(node.target, (gast.Name, gast.Attribute)):
# Out tuple case
out_tuple_name = ast_to_source_code(node.target).strip()
tuple_iter_name = unique_name.generate(
FOR_ITER_TUPLE_INDEX_PREFIX)
tuple_var_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX)
node.target = gast.Tuple(elts=[
gast.Name(id=tuple_iter_name,
ctx=gast.Store(),
annotation=None,
type_comment=None),
gast.Name(id=tuple_var_name,
ctx=gast.Store(),
self.generic_visit(node)
tuple_target = unique_name.generate(FOR_ITER_TARGET_PREFIX)
tuple_iterator = unique_name.generate(FOR_ITER_ITERATOR_PREFIX)
origin_tuple_node = node.target
assign_iterator_node = gast.parse(
f"{tuple_iterator} = _jst.Indexable({ast_to_source_code(node.iter).strip()})"
).body[0]
node.target = gast.Name(id=tuple_target,
ctx=gast.Store(),
annotation=None,
type_comment=None)
node.iter = gast.Name(id=tuple_iterator,
ctx=gast.Load(),
annotation=None,
type_comment=None)
],
ctx=gast.Store())
node.body.insert(
0,
gast.Assign(targets=[
gast.Name(id=out_tuple_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=gast.Tuple(elts=[
gast.Name(id=tuple_iter_name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
gast.Name(id=tuple_var_name,
ctx=gast.Load(),
annotation=None,
type_comment=None)
],
ctx=gast.Load())))
elif isinstance(node.target, (gast.List, gast.Tuple)) and len(
node.target.elts) >= 2 and isinstance(
node.target.elts[1], (gast.List, gast.Tuple)):
# Inner tuple case
inner_tuple_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX)
origin_inner_tuple_node = node.target.elts[1]
node.target.elts[1] = gast.Name(id=inner_tuple_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
node.body[0:0] = self.tuple_to_stmts(origin_inner_tuple_node,
inner_tuple_name)
elif self.is_for_iter(node) and isinstance(node.target,
(gast.List, gast.Tuple)):
# Non-enumrate case:
tuple_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX)
origin_tuple_node = node.target
node.target = gast.Name(id=tuple_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
node.body[0:0] = self.tuple_to_stmts(origin_tuple_node, tuple_name)
return node
def tuple_to_stmts(self, node, tuple_name, idx=[]):
if not isinstance(node, (gast.Tuple, gast.List)):
value_node_str = tuple_name
for i in idx:
value_node_str = value_node_str + "[{}]".format(i)
node_str = ast_to_source_code(node).strip()
assign_node_str = "{} = {}".format(node_str, value_node_str)
assign_node = gast.parse(assign_node_str).body[0]
return [assign_node]
# isinstance(node, (gast.Tuple, gast.List))
node.body[0:0] = self.tuple_to_stmts(origin_tuple_node, tuple_target)
# return a list will insert a list of node replace the original for node.
return [assign_iterator_node, node]
def tuple_node_to_unpack_structure(self, node):
""" Create a sequence to represents the structure of nest.
For example: `a, (b,c), [d,e,f]` is represented by
`[1, [1,1], [1,1,1]]`. the `1` is just a notation.
Specially, `a` is represented by `1`.
"""
ret = []
for i, element in enumerate(node.elts):
ret += self.tuple_to_stmts(node.elts[i], tuple_name, idx + [i])
if not isinstance(node, (gast.Tuple, gast.List)):
return 1
for element in node.elts:
ret.append(self.tuple_node_to_unpack_structure(element))
return ret
def is_for_iter(self, for_node):
assert isinstance(for_node,
gast.For), "Input node is not gast.For node."
if isinstance(for_node.iter, (gast.Name, gast.Attribute)):
return True
elif isinstance(for_node.iter, gast.Call) and isinstance(
for_node.iter.func,
gast.Attribute) and for_node.iter.func.attr == 'numpy':
return True
elif isinstance(for_node.iter, gast.Subscript):
return True
else:
return False
def is_for_enumerate_iter(self, for_node):
assert isinstance(for_node,
gast.For), "Input node is not gast.For node."
return isinstance(for_node.iter, gast.Call) and isinstance(
for_node.iter.func,
gast.Name) and for_node.iter.func.id == "enumerate"
def tuple_to_stmts(self, node, tuple_name):
structure_str = str(self.tuple_node_to_unpack_structure(node))
node_str = ast_to_source_code(node).strip()
assign_node_str = f"{node_str} = _jst.Unpack({tuple_name}, {structure_str})"
assign_node = gast.parse(assign_node_str).body[0]
return [assign_node]
class SplitAssignTransformer(BaseTransformer):
......
......@@ -40,7 +40,7 @@ class CallTransformer(BaseTransformer):
Determines whether a function needs to be transformed by `convert_call`.
It doesn't need to be transformed when a function satisfies the following conditions:
1. It's a api of paddle
2. It's a python builtin function not include `len` and `zip`
2. It's a python builtin function not include `len`, `zip`, `range` and `enumerate`
"""
assert isinstance(node, gast.Call)
if is_paddle_api(node):
......@@ -48,11 +48,16 @@ class CallTransformer(BaseTransformer):
func_str = ast_to_source_code(node.func).strip()
try:
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin_len, is_builtin, is_builtin_zip
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin
need_convert_builtin_func_list = {
'len',
'zip',
'range',
'enumerate',
}
is_builtin = eval("is_builtin({})".format(func_str))
is_builtin_len = eval("is_builtin_len({})".format(func_str))
is_builtin_zip = eval("is_builtin_zip({})".format(func_str))
return is_builtin and not is_builtin_len and not is_builtin_zip
need_convert = func_str in need_convert_builtin_func_list
return is_builtin and not need_convert
except Exception:
return False
......
......@@ -28,6 +28,7 @@ import six
from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len, convert_zip
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_range, convert_enumerate
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogger
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction
from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static
......@@ -64,25 +65,22 @@ class ConversionOptions(object):
self.not_convert = not_convert
def is_builtin(func):
if isinstance(func, types.BuiltinFunctionType):
def is_builtin(func, name=None):
""" predict whether a function is a builtin function with name={name}.
if name == None, then any builtin function will return True
"""
def name_judge():
return name is None or func.__name__ == name
if isinstance(func, types.BuiltinFunctionType) and name_judge():
return True
elif func in six.moves.builtins.__dict__.values():
elif func in six.moves.builtins.__dict__.values() and name_judge():
return True
else:
return False
def is_builtin_len(func):
if isinstance(func, types.BuiltinFunctionType) and func.__name__ == 'len':
return True
return False
def is_builtin_zip(func):
return is_builtin(func) and func.__name__ == 'zip'
def is_unsupported(func):
"""
Checks whether the func is supported by dygraph to static graph.
......@@ -165,12 +163,18 @@ def convert_call(func):
.format(func))
return func
if is_builtin_len(func):
if is_builtin(func, "len"):
return convert_len
if is_builtin_zip(func):
if is_builtin(func, "zip"):
return convert_zip
if is_builtin(func, "range"):
return convert_range
if is_builtin(func, "enumerate"):
return convert_enumerate
if is_builtin(func) or is_unsupported(func):
return func
......
......@@ -13,11 +13,12 @@
# limitations under the License.
import re
import paddle
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle.fluid.framework import core, Variable
from paddle.fluid.layers import Assert, Print
from paddle.fluid.layers import range as paddle_range
from paddle.fluid.layers import array_length, array_read, array_write, create_array
from paddle.fluid.layers import assign, fill_constant, slice, reduce_all, reduce_any
from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn
......@@ -26,6 +27,45 @@ from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, Dygraph2StaticException
def indexable(x, code=None):
if isinstance(x, Variable): return x
if hasattr(x, '__len__') and hasattr(x, '__getitem__'): return x
if hasattr(x, '__iter__'):
return [i for i in x]
else:
raise RuntimeError("X can't be convert into indexable.")
def unpack_by_structure(target, structure):
""" unified unpack interface for paddle and python.
"""
if isinstance(target, Variable):
return _unpack_by_structure_paddle(target, structure)
else:
return _unpack_by_structure_python(target, structure)
def _unpack_by_structure_python(target, structure):
""" TODO(xiongkun): analysis the differences between python and paddle unpack.
"""
return _unpack_by_structure_paddle(target, structure)
def _unpack_by_structure_paddle(target, structure):
if structure == 1:
return target
ret = []
for idx, ele in enumerate(structure):
if ele == 1:
ret.append(target[idx])
continue
if isinstance(ele, list):
ret.append(unpack_by_structure(target[idx], ele))
continue
assert False, "structure element must be 1 or list"
return ret
def convert_while_loop(cond, body, getter, setter):
"""
A function representation of a Python ``while`` statement.
......@@ -50,12 +90,26 @@ def convert_while_loop(cond, body, getter, setter):
def _run_paddle_while(cond, body, getter, setter):
# NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Tensors.
# UndefinedVar will become data layer not check.
loop_vars = [to_static_variable(var) for var in getter()]
def new_body_fn(*args):
""" wrap the body() and add return value for `while_loop`
"""
body()
return getter()
def new_cond_fn(*args):
""" cond is a zero-args function, which is not
compatible with `while_loop`.
"""
return cond()
# UndefinedVar will become data layer not check variable with value=NO_VALUE_MAGIC.
loop_vars = [
to_static_variable(var) if not isinstance(var, UndefinedVar) else var
for var in getter()
]
setter(loop_vars) # change the non-local var to variable
# variable maybe modified to inner var. change it into
loop_vars = control_flow.while_loop(cond, body, loop_vars)
loop_vars = control_flow.while_loop(new_cond_fn, new_body_fn, loop_vars)
setter(loop_vars) # change the non-local var to variable
return loop_vars
......@@ -368,6 +422,8 @@ def convert_len(var):
'len(var) only supports LoDTensor/LoDTensorArray/SelectedRows, but received %s.'
% type(var))
else:
if isinstance(var, VariableTuple):
return var.__len__()
return len(var)
......@@ -380,6 +436,44 @@ def convert_zip(*args):
return zip(*args)
# TODO(xiongkun): delete when list<variable> is ready.
class VariableTuple:
"""
this class will cause enumerate can't be wrapped by other iterator change function.
this will be fixed when list<Variable> is producted.
VariableTuple can only deal with variables which is fixed.
"""
def __init__(self, var, start=0):
self.var = var
self.len = convert_len(var)
self.rag = paddle_range(start, start + self.len, 1, paddle.int64)
def __getitem__(self, idx):
return self.rag[idx], self.var[idx]
def __len__(self):
return self.len
def convert_enumerate(*args):
has_variable = any(map(lambda x: isinstance(x, Variable), args))
if has_variable:
return VariableTuple(*args)
return enumerate(*args)
def convert_range(*args):
has_variable = any(map(lambda x: isinstance(x, Variable), args))
if has_variable:
if len(args) == 1: return paddle_range(0, args[0], 1, paddle.int64)
if len(args) == 2:
return paddle_range(args[0], args[1], 1, paddle.int64)
if len(args) == 3:
return paddle_range(args[0], args[1], args[2], paddle.int64)
return range(*args)
def convert_shape(x):
"""
A function representation of the shape of variable.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.utils import FunctionNameLivenessAnalysis
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
class CreateVariableTransformer(BaseTransformer):
"""
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Type of input node should be AstNodeWrapper, but received %s ." % type(
wrapper_root)
self.root = wrapper_root.node
FunctionNameLivenessAnalysis(self.root)
def transform(self):
"""
Main function to transform AST.
"""
self.visit(self.root)
def visit_FunctionDef(self, node):
#attributes = set(filter(lambda x: '.' in x, node.pd_scope.modified_vars()))
bodys = node.body
names = sorted(node.pd_scope.created_vars())
for name in names:
bodys[0:0] = [create_undefined_var(name)]
return node
......@@ -34,6 +34,7 @@ from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_un
from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_nodes
from paddle.fluid.dygraph.dygraph_to_static.utils import create_get_args_node, create_set_args_node
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_INDEX_PREFIX, FOR_ITER_TUPLE_PREFIX, FOR_ITER_TUPLE_INDEX_PREFIX, FOR_ITER_VAR_LEN_PREFIX, FOR_ITER_VAR_NAME_PREFIX, FOR_ITER_ZIP_TO_LIST_PREFIX, FOR_ITER_TARGET_PREFIX, FOR_ITER_ITERATOR_PREFIX
TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn'
......@@ -304,7 +305,6 @@ def transform_if_else(node, root):
"""
# TODO(liym27): Consider variable like `self.a` modified in if/else node.
new_vars_to_create = sorted(list(node.pd_scope.created_vars()))
return_name_ids = sorted(list(node.pd_scope.modified_vars()))
# NOTE: Python can create variable only in if body or only in else body, and use it out of if/else.
# E.g.
......@@ -315,10 +315,6 @@ def transform_if_else(node, root):
#
# Create static variable for those variables
create_new_vars_in_parent_stmts = []
for name in new_vars_to_create:
# NOTE: Consider variable like `self.a` modified in if/else node.
if "." not in name:
create_new_vars_in_parent_stmts.append(create_undefined_var(name))
nonlocal_names = list(return_name_ids)
nonlocal_names.sort()
......@@ -326,8 +322,21 @@ def transform_if_else(node, root):
nonlocal_names = _valid_nonlocal_names(return_name_ids, nonlocal_names)
# TODO(dev): Need a better way to deal this.
if ARGS_NAME in nonlocal_names:
nonlocal_names.remove(ARGS_NAME)
# LoopTransformer will create some special vars, which is not visiable by users. so we can sure it's safe to remove them.
filter_names = [
ARGS_NAME, FOR_ITER_INDEX_PREFIX, FOR_ITER_TUPLE_PREFIX,
FOR_ITER_TARGET_PREFIX, FOR_ITER_ITERATOR_PREFIX,
FOR_ITER_TUPLE_INDEX_PREFIX, FOR_ITER_VAR_LEN_PREFIX,
FOR_ITER_VAR_NAME_PREFIX, FOR_ITER_ZIP_TO_LIST_PREFIX
]
def remove_if(x):
for name in filter_names:
if x.startswith(name): return False
return True
nonlocal_names = list(filter(remove_if, nonlocal_names))
return_name_ids = nonlocal_names
nonlocal_stmt_node = create_nonlocal_stmt_nodes(nonlocal_names)
......
......@@ -26,8 +26,8 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_nodes, create_get_args_node, create_set_args_node
from paddle.fluid.dygraph.dygraph_to_static.utils import FunctionNameLivenessAnalysis
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import ARGS_NAME
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import RenameTransformer
......@@ -483,10 +483,10 @@ class LoopTransformer(BaseTransformer):
), "Input non-AstNodeWrapper node for the initialization of LoopTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
FunctionNameLivenessAnalysis(self.root)
def transform(self):
ForLoopTuplePreTransformer(self.wrapper_root).transform()
self.name_visitor = NameVisitor(self.root)
self.visit(self.root)
def visit_While(self, node):
......@@ -537,19 +537,19 @@ class LoopTransformer(BaseTransformer):
return [node]
init_stmts, cond_stmt, body_stmts = stmts_tuple
# 2. get original loop vars
loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
node)
loop_var_names, create_var_names = node.pd_scope.modified_vars(
), node.pd_scope.created_vars()
# TODO: Remove the bunch of code? We have the unique format `for A in B:`
# NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases,
# we need append new loop var & remove useless loop var
# 1. for x in var -> x is no need
# 2. for i, x in enumerate(var) -> x is no need
if current_for_node_parser.is_for_iter(
) or current_for_node_parser.is_for_enumerate_iter():
if current_for_node_parser.is_for_iter():
iter_var_name = current_for_node_parser.iter_var_name
iter_idx_name = current_for_node_parser.iter_idx_name
loop_var_names.add(iter_idx_name)
if iter_var_name not in create_var_names:
loop_var_names.remove(iter_var_name)
if current_for_node_parser.enum_idx_name is not None:
loop_var_names.add(current_for_node_parser.enum_idx_name)
# 3. prepare result statement list
new_stmts = []
......@@ -559,10 +559,8 @@ class LoopTransformer(BaseTransformer):
# y += x
# print(x) # x = 10
#
# We need to create static variable for those variables
for name in create_var_names:
if "." not in name:
new_stmts.append(create_undefined_var(name))
# We don't need to create static variable for them, because
# we do this in CreateUndefinedVarTransformer
# create non-local statement for body and cond.
nonlocal_names = list(loop_var_names | create_var_names)
......@@ -581,10 +579,7 @@ class LoopTransformer(BaseTransformer):
name=unique_name.generate(FOR_CONDITION_PREFIX),
args=gast.arguments(args=[],
posonlyargs=[],
vararg=gast.Name(id=ARGS_NAME,
ctx=gast.Param(),
annotation=None,
type_comment=None),
vararg=None,
kwonlyargs=[],
kw_defaults=None,
kwarg=None,
......@@ -597,17 +592,11 @@ class LoopTransformer(BaseTransformer):
# 6. create & append loop body function node
# append return values for loop body
body_stmts.append(
gast.Return(value=generate_name_node(
nonlocal_names, ctx=gast.Load(), gen_tuple_if_single=True)))
body_func_node = gast.FunctionDef(
name=unique_name.generate(FOR_BODY_PREFIX),
args=gast.arguments(args=[],
posonlyargs=[],
vararg=gast.Name(id=ARGS_NAME,
ctx=gast.Param(),
annotation=None,
type_comment=None),
vararg=None,
kwonlyargs=[],
kw_defaults=None,
kwarg=None,
......@@ -632,8 +621,8 @@ class LoopTransformer(BaseTransformer):
return new_stmts
def get_while_stmt_nodes(self, node):
loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
node)
loop_var_names, create_var_names = node.pd_scope.modified_vars(
), node.pd_scope.created_vars()
new_stmts = []
# create non-local statement for body and cond.
......@@ -652,19 +641,14 @@ class LoopTransformer(BaseTransformer):
# y = x
# z = y
#
# We need to create static variable for those variables
for name in create_var_names:
if "." not in name:
new_stmts.append(create_fill_constant_node(name))
# We don't need to create static variable for those variables, because
# we do this in CreateUndefinedVarTransformer
condition_func_node = gast.FunctionDef(
name=unique_name.generate(WHILE_CONDITION_PREFIX),
args=gast.arguments(args=[],
posonlyargs=[],
vararg=gast.Name(id=ARGS_NAME,
ctx=gast.Param(),
annotation=None,
type_comment=None),
vararg=None,
kwonlyargs=[],
kw_defaults=None,
kwarg=None,
......@@ -677,17 +661,11 @@ class LoopTransformer(BaseTransformer):
new_stmts.append(condition_func_node)
new_body = node.body
new_body.append(
gast.Return(value=generate_name_node(
nonlocal_names, ctx=gast.Load(), gen_tuple_if_single=True)))
body_func_node = gast.FunctionDef(
name=unique_name.generate(WHILE_BODY_PREFIX),
args=gast.arguments(args=[],
posonlyargs=[],
vararg=gast.Name(id=ARGS_NAME,
ctx=gast.Param(),
annotation=None,
type_comment=None),
vararg=None,
kwonlyargs=[],
kw_defaults=None,
kwarg=None,
......
......@@ -82,6 +82,8 @@ dygraph_class_to_static_api = {
FOR_ITER_INDEX_PREFIX = '__for_loop_var_index'
FOR_ITER_TUPLE_PREFIX = '__for_loop_iter_tuple'
FOR_ITER_TARGET_PREFIX = '__for_loop_iter_target'
FOR_ITER_ITERATOR_PREFIX = '__for_loop_iter_iterator'
FOR_ITER_TUPLE_INDEX_PREFIX = '__for_loop_iter_tuple_index'
FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len'
FOR_ITER_VAR_NAME_PREFIX = '__for_loop_iter_var'
......@@ -1099,6 +1101,18 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
if isinstance(node, gast.FunctionDef):
return self._get_name_scope(node)
def visit_ListComp(self, node):
""" [ i for i in range(10) ]
In this case, `i` will not created in FunctionScope.
We don't collect `i` by not calling generic_visit.
"""
pass
def visit_DictComp(self, node):
""" the same as ListComp.
"""
pass
def visit_Name(self, node):
self.generic_visit(node)
write_context = (gast.Store, gast.AugStore, gast.Del)
......@@ -1149,8 +1163,13 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
def post_func():
self._father_name_scope().merge_from(self._current_name_scope())
self._nearest_function_scope().merge_from(
self._current_name_scope())
self._current_name_scope().created = self._nearest_function_scope(
).existed_vars() - node.before_created
# gather created vars into father and used in CreateUndefinedVarTransform
self._nearest_function_scope().created |= self._current_name_scope(
).created
def pre_func():
setattr(node, "before_created",
......
......@@ -108,7 +108,6 @@ def select_input(inputs, mask):
def select_input_with_buildin_type(inputs, mask):
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, create_undefined_var_like
support_ret_buildin_type = (bool, float, six.integer_types)
false_var, true_var = inputs
if isinstance(false_var, UndefinedVar) and isinstance(
......@@ -1182,12 +1181,16 @@ class While(object):
})
support_ret_buildin_type = (bool, float, six.integer_types)
def assign_skip_lod_tensor_array(input, output):
"""
Assign input to output, but skip the process of copying LoDTensorArray unless it's created in while_block.
"""
if not isinstance(input, (Variable, core.VarBase)):
if isinstance(output, Variable):
if isinstance(output, Variable) and isinstance(
input, support_ret_buildin_type):
assign(input, output)
else:
output = input
......@@ -1297,6 +1300,7 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
if not isinstance(output_vars, (list, tuple)):
output_vars = [output_vars]
try:
loop_vars = _deal_with_undefined_var(output_vars, loop_vars)
assert_same_structure(output_vars, loop_vars, check_types=False)
except ValueError as e:
raise ValueError(
......@@ -1308,6 +1312,36 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
return loop_vars
def _deal_with_undefined_var(output_vars, loop_vars):
""" Deal with undefined var cases, We create undefined variable based on the results of body().
In Dy2Static, we use undefined var to represent the var created in control flow. This function
expand the loop_vars and replace original loop_vars.
1. UndefinedVar = Variable # create a variable
2. UndefinedVar = None # create a undefined var with RETURN_NO_VALUE_MAGIC_NUM
3. UndefinedVar = List(int) # create a list of variable
4. UndefinedVar = value # create a variable
"""
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, create_undefined_variable
def create_var_like(o_var):
if isinstance(o_var,
(Variable, ) + support_ret_buildin_type) or o_var is None:
return create_undefined_variable()
if isinstance(o_var, (tuple, list)):
return [create_undefined_variable() for i in range(len(o_var))]
if len(output_vars) != len(loop_vars):
raise ValueError("The length of loop_vars should be the same.")
results = []
for o_var, l_var in zip(output_vars, loop_vars):
if isinstance(l_var, UndefinedVar) or l_var is None:
results.append(create_var_like(o_var))
else:
results.append(l_var)
return results
def lod_rank_table(x, level=0):
"""
LoD Rank Table Operator. Given an input variable **x** and a level number
......@@ -2616,6 +2650,11 @@ def change_none_to_undefinedvar(nest1, nest2):
def expand_undefined_var(nest1, nest2, names):
""" TODO: make this function recursively.
nest1: Var1, (UndefinedVar, [1,2,3])
nest2: Var2, ([1,2,3,4], UndefinedVar)
In this case, we should not expand recursively.
"""
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_VALUE_PREFIX
......
......@@ -385,6 +385,7 @@ class BaseModel(fluid.dygraph.Layer):
dropout_implementation='upscale_in_train')
else:
step_input = new_hidden
cell_outputs = self._split_batch_beams(step_input)
cell_outputs = self.fc(cell_outputs)
......
......@@ -442,13 +442,6 @@ class TestErrorInForLoop(TestTransformForLoop):
def _init_dyfunc(self):
self.dyfunc = for_loop_dyfunc_not_support
def test_ast_to_func(self):
with self.assertRaisesRegexp(
NotImplementedError,
"Dynamic-to-Static only supports the step value is a constant or negative constant "
):
self._run_static()
if __name__ == '__main__':
with fluid.framework._test_eager_guard():
......
......@@ -66,6 +66,9 @@ def get_source_code(func):
class StaticCode1():
def dyfunc_with_if_else(x_v, label=None):
loss = _jst.UndefinedVar('loss')
__return_1 = _jst.UndefinedVar('__return_1')
__return_0 = _jst.UndefinedVar('__return_0')
__return_value_0 = None
def get_args_0():
......@@ -89,9 +92,6 @@ class StaticCode1():
_jst.IfElse(
paddle.mean(x_v)[0] > 5, true_fn_0, false_fn_0, get_args_0,
set_args_0, ('x_v', ))
__return_0 = _jst.UndefinedVar('__return_0')
__return_1 = _jst.UndefinedVar('__return_1')
loss = _jst.UndefinedVar('loss')
def get_args_1():
nonlocal __return_0, __return_1, __return_value_0, loss
......@@ -123,6 +123,9 @@ class StaticCode1():
class StaticCode2():
# TODO: Transform return statement
def dyfunc_with_if_else(x_v, label=None):
loss = _jst.UndefinedVar('loss')
__return_3 = _jst.UndefinedVar('__return_3')
__return_2 = _jst.UndefinedVar('__return_2')
__return_value_1 = None
def get_args_2():
......@@ -146,9 +149,6 @@ class StaticCode2():
_jst.IfElse(
paddle.mean(x_v)[0] > 5, true_fn_2, false_fn_2, get_args_2,
set_args_2, ('x_v', ))
__return_2 = _jst.UndefinedVar('__return_2')
__return_3 = _jst.UndefinedVar('__return_3')
loss = _jst.UndefinedVar('loss')
def get_args_3():
nonlocal __return_2, __return_3, __return_value_1, loss
......
......@@ -578,8 +578,8 @@ class TestOpNumWithTensorShapeInFor1(TestOpNumBasicWithTensorShape):
self.dygraph_func = dyfunc_with_for_1
def _set_expected_op_num(self):
self.expected_op_num = 22
self.expected_shape_op_num = 3
self.expected_op_num = 29
self.expected_shape_op_num = 2
self.expected_slice_op_num = 3
......@@ -589,7 +589,7 @@ class TestOpNumWithTensorShapeInWhile1(TestOpNumBasicWithTensorShape):
self.dygraph_func = dyfunc_with_while_1
def _set_expected_op_num(self):
self.expected_op_num = 22
self.expected_op_num = 21
self.expected_shape_op_num = 3
self.expected_slice_op_num = 3
......
......@@ -21,7 +21,7 @@ import paddle.fluid.layers as layers
from paddle.fluid.dygraph import Embedding, Layer, LayerNorm, Linear, to_variable
from paddle.fluid.dygraph.jit import dygraph_to_static_func
from paddle.fluid.layers.utils import map_structure
from paddle.fluid.layers.tensor import range as pd_range
import paddle
def position_encoding_init(n_position, d_pos_vec):
......@@ -634,7 +634,7 @@ class Transformer(Layer):
value=0),
} for i in range(self.n_layer)]
for i in pd_range(0, max_len, 1, dtype="int32"):
for i in range(paddle.to_tensor(max_len)):
trg_pos = layers.fill_constant(shape=trg_word.shape,
dtype="int64",
value=i)
......
......@@ -26,7 +26,8 @@ from .convert_operators import convert_pop as Pop # noqa: F401
from .convert_operators import convert_print as Print # noqa: F401
from .convert_operators import convert_shape as Shape # noqa: F401
from .convert_operators import convert_while_loop as While # noqa: F401
from .convert_operators import unpack_by_structure as Unpack # noqa: F401
from .convert_operators import indexable as Indexable # noqa: F401
from .variable_trans_func import create_bool_as_type # noqa: F401
from .variable_trans_func import to_static_variable # noqa: F401
from .convert_operators import convert_shape_compare # noqa: F401
......
......@@ -26,5 +26,6 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape_c
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_loop # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import unpack_by_structure, indexable # noqa: F401
__all__ = []
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册