未验证 提交 c5c6026e 编写于 作者: X xiongkun 提交者: GitHub

[ Dy2Static ]Change NameVisitor in while to FunctionScopeAnalysis (#44155)

* change NameVisitor to FunctionScopeAnalysis

* polish the logic of undefined var in while_loop. create vars after body execution

* replace old NameVisitor in while and fix all CI

* Togather with CreateVariableTransformer

* add create_variable_transformer

* fix bugs

* merge

* fix some error, TODO: ForNodePreTransform ahead

* merge for unite PR

* fix conflict with base_transformer PR

* fix ci errors, fix [for i in range()] error

* fix according to code review
上级 8759c78d
...@@ -35,6 +35,7 @@ from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTr ...@@ -35,6 +35,7 @@ from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTr
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer
from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransformer from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransformer
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import ReturnTransformer from paddle.fluid.dygraph.dygraph_to_static.return_transformer import ReturnTransformer
from paddle.fluid.dygraph.dygraph_to_static.create_variable_transformer import CreateVariableTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer
...@@ -96,7 +97,7 @@ class DygraphToStaticAst(BaseTransformer): ...@@ -96,7 +97,7 @@ class DygraphToStaticAst(BaseTransformer):
BreakContinueTransformer, # break/continue in loops BreakContinueTransformer, # break/continue in loops
ReturnTransformer, # return in functions ReturnTransformer, # return in functions
LogicalTransformer, # logical and/or/not LogicalTransformer, # logical and/or/not
#CreateVariableTransformer, # create undefined var for if / while / for CreateVariableTransformer, # create undefined var for if / while / for
LoopTransformer, # for/while -> while_op LoopTransformer, # for/while -> while_op
IfElseTransformer, # if/else -> cond_op IfElseTransformer, # if/else -> cond_op
AssertTransformer, # assert statement AssertTransformer, # assert statement
......
...@@ -24,6 +24,8 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_TUPLE_INDEX_PR ...@@ -24,6 +24,8 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_TUPLE_INDEX_PR
from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_VAR_LEN_PREFIX from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_VAR_LEN_PREFIX
from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_VAR_NAME_PREFIX from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_VAR_NAME_PREFIX
from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_ZIP_TO_LIST_PREFIX from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_ZIP_TO_LIST_PREFIX
from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_TARGET_PREFIX
from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_ITERATOR_PREFIX
class BaseTransformer(gast.NodeTransformer): class BaseTransformer(gast.NodeTransformer):
...@@ -119,32 +121,20 @@ class NameNodeReplaceTransformer(BaseTransformer): ...@@ -119,32 +121,20 @@ class NameNodeReplaceTransformer(BaseTransformer):
class ForLoopTuplePreTransformer(BaseTransformer): class ForLoopTuplePreTransformer(BaseTransformer):
""" """ pre-process of for loop.
ForNodeVisitor parses 3 type statements (Here var is VarBase(Tensor) or python variable): >>> for A in B:
1). for x in range(var[*]|var.numpy()[*]) >>> C
2). for x in var|var.numpy()
3). for i, x in enumerate(var|var.numpy()) will be changed into :
We chose these 3 types because they are easier (x can be variable name iterating in var). >>> UUID_iterator = _jst.Indexable(B) # make iterator-only to indexable list.
However, users can write tuples in Python for loop, such as >>> for UUID_target in UUID_iterator:
1). for var1, var2 in var|var.numpy() >>> A = _jst.Unpack(UUID_target, structure)
2). for t in enumerate(var|var.numpy()) >>> C
2). for i, (var1, var2, va3) in enumerate(var|var.numpy())
make the later loop_transform have unified type:
To handle these case, this method will do the rewrite tuple pre-process: >>> for target in iter:
1). Non-enumerate case: for var1, var2 in var|var.numpy() will be re-written as: >>> body
for FOR_ITER_TUPLE_PREFIX_x in var | var.numpy():
var1 = FOR_ITER_TUPLE_PREFIX_x[0]
var2 = FOR_ITER_TUPLE_PREFIX_x[1]
2). Enumerate out tuple case: for t in enumerate(var|var.numpy) will be rewritten as:
for FOR_ITER_TUPLE_INDEX_PREFIX_x, FOR_ITER_TUPLE_PREFIX_x in enumerate(var|var.numpy):
t = (FOR_ITER_TUPLE_INDEX_PREFIX_x, FOR_ITER_TUPLE_PREFIX_x)
3). Enumerate inner tuple case: for i, (var1, (var2, va3)) in enumerate(var|var.numpy()) will
be re-written as:
for i, FOR_ITER_TUPLE_PREFIX_x in var | var.numpy():
var1 = FOR_ITER_TUPLE_PREFIX_x[0]
var2 = FOR_ITER_TUPLE_PREFIX_x[1][0]
var3 = FOR_ITER_TUPLE_PREFIX_x[1][1]
""" """
def __init__(self, wrapper_root): def __init__(self, wrapper_root):
...@@ -155,104 +145,45 @@ class ForLoopTuplePreTransformer(BaseTransformer): ...@@ -155,104 +145,45 @@ class ForLoopTuplePreTransformer(BaseTransformer):
self.visit(self.root) self.visit(self.root)
def visit_For(self, node): def visit_For(self, node):
if self.is_for_enumerate_iter(node): self.generic_visit(node)
if isinstance(node.target, (gast.Name, gast.Attribute)): tuple_target = unique_name.generate(FOR_ITER_TARGET_PREFIX)
# Out tuple case tuple_iterator = unique_name.generate(FOR_ITER_ITERATOR_PREFIX)
out_tuple_name = ast_to_source_code(node.target).strip() origin_tuple_node = node.target
tuple_iter_name = unique_name.generate( assign_iterator_node = gast.parse(
FOR_ITER_TUPLE_INDEX_PREFIX) f"{tuple_iterator} = _jst.Indexable({ast_to_source_code(node.iter).strip()})"
tuple_var_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX) ).body[0]
node.target = gast.Tuple(elts=[ node.target = gast.Name(id=tuple_target,
gast.Name(id=tuple_iter_name, ctx=gast.Store(),
ctx=gast.Store(), annotation=None,
annotation=None, type_comment=None)
type_comment=None), node.iter = gast.Name(id=tuple_iterator,
gast.Name(id=tuple_var_name, ctx=gast.Load(),
ctx=gast.Store(),
annotation=None, annotation=None,
type_comment=None) type_comment=None)
], node.body[0:0] = self.tuple_to_stmts(origin_tuple_node, tuple_target)
ctx=gast.Store()) # return a list will insert a list of node replace the original for node.
node.body.insert( return [assign_iterator_node, node]
0,
gast.Assign(targets=[ def tuple_node_to_unpack_structure(self, node):
gast.Name(id=out_tuple_name, """ Create a sequence to represents the structure of nest.
ctx=gast.Store(), For example: `a, (b,c), [d,e,f]` is represented by
annotation=None, `[1, [1,1], [1,1,1]]`. the `1` is just a notation.
type_comment=None)
], Specially, `a` is represented by `1`.
value=gast.Tuple(elts=[ """
gast.Name(id=tuple_iter_name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
gast.Name(id=tuple_var_name,
ctx=gast.Load(),
annotation=None,
type_comment=None)
],
ctx=gast.Load())))
elif isinstance(node.target, (gast.List, gast.Tuple)) and len(
node.target.elts) >= 2 and isinstance(
node.target.elts[1], (gast.List, gast.Tuple)):
# Inner tuple case
inner_tuple_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX)
origin_inner_tuple_node = node.target.elts[1]
node.target.elts[1] = gast.Name(id=inner_tuple_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
node.body[0:0] = self.tuple_to_stmts(origin_inner_tuple_node,
inner_tuple_name)
elif self.is_for_iter(node) and isinstance(node.target,
(gast.List, gast.Tuple)):
# Non-enumrate case:
tuple_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX)
origin_tuple_node = node.target
node.target = gast.Name(id=tuple_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
node.body[0:0] = self.tuple_to_stmts(origin_tuple_node, tuple_name)
return node
def tuple_to_stmts(self, node, tuple_name, idx=[]):
if not isinstance(node, (gast.Tuple, gast.List)):
value_node_str = tuple_name
for i in idx:
value_node_str = value_node_str + "[{}]".format(i)
node_str = ast_to_source_code(node).strip()
assign_node_str = "{} = {}".format(node_str, value_node_str)
assign_node = gast.parse(assign_node_str).body[0]
return [assign_node]
# isinstance(node, (gast.Tuple, gast.List))
ret = [] ret = []
for i, element in enumerate(node.elts): if not isinstance(node, (gast.Tuple, gast.List)):
ret += self.tuple_to_stmts(node.elts[i], tuple_name, idx + [i]) return 1
for element in node.elts:
ret.append(self.tuple_node_to_unpack_structure(element))
return ret return ret
def is_for_iter(self, for_node): def tuple_to_stmts(self, node, tuple_name):
assert isinstance(for_node, structure_str = str(self.tuple_node_to_unpack_structure(node))
gast.For), "Input node is not gast.For node." node_str = ast_to_source_code(node).strip()
if isinstance(for_node.iter, (gast.Name, gast.Attribute)): assign_node_str = f"{node_str} = _jst.Unpack({tuple_name}, {structure_str})"
return True assign_node = gast.parse(assign_node_str).body[0]
elif isinstance(for_node.iter, gast.Call) and isinstance( return [assign_node]
for_node.iter.func,
gast.Attribute) and for_node.iter.func.attr == 'numpy':
return True
elif isinstance(for_node.iter, gast.Subscript):
return True
else:
return False
def is_for_enumerate_iter(self, for_node):
assert isinstance(for_node,
gast.For), "Input node is not gast.For node."
return isinstance(for_node.iter, gast.Call) and isinstance(
for_node.iter.func,
gast.Name) and for_node.iter.func.id == "enumerate"
class SplitAssignTransformer(BaseTransformer): class SplitAssignTransformer(BaseTransformer):
......
...@@ -40,7 +40,7 @@ class CallTransformer(BaseTransformer): ...@@ -40,7 +40,7 @@ class CallTransformer(BaseTransformer):
Determines whether a function needs to be transformed by `convert_call`. Determines whether a function needs to be transformed by `convert_call`.
It doesn't need to be transformed when a function satisfies the following conditions: It doesn't need to be transformed when a function satisfies the following conditions:
1. It's a api of paddle 1. It's a api of paddle
2. It's a python builtin function not include `len` and `zip` 2. It's a python builtin function not include `len`, `zip`, `range` and `enumerate`
""" """
assert isinstance(node, gast.Call) assert isinstance(node, gast.Call)
if is_paddle_api(node): if is_paddle_api(node):
...@@ -48,11 +48,16 @@ class CallTransformer(BaseTransformer): ...@@ -48,11 +48,16 @@ class CallTransformer(BaseTransformer):
func_str = ast_to_source_code(node.func).strip() func_str = ast_to_source_code(node.func).strip()
try: try:
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin_len, is_builtin, is_builtin_zip from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin
need_convert_builtin_func_list = {
'len',
'zip',
'range',
'enumerate',
}
is_builtin = eval("is_builtin({})".format(func_str)) is_builtin = eval("is_builtin({})".format(func_str))
is_builtin_len = eval("is_builtin_len({})".format(func_str)) need_convert = func_str in need_convert_builtin_func_list
is_builtin_zip = eval("is_builtin_zip({})".format(func_str)) return is_builtin and not need_convert
return is_builtin and not is_builtin_len and not is_builtin_zip
except Exception: except Exception:
return False return False
......
...@@ -28,6 +28,7 @@ import six ...@@ -28,6 +28,7 @@ import six
from paddle.fluid.dygraph.container import Sequential from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len, convert_zip from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len, convert_zip
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_range, convert_enumerate
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogger from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogger
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction
from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static
...@@ -64,25 +65,22 @@ class ConversionOptions(object): ...@@ -64,25 +65,22 @@ class ConversionOptions(object):
self.not_convert = not_convert self.not_convert = not_convert
def is_builtin(func): def is_builtin(func, name=None):
if isinstance(func, types.BuiltinFunctionType): """ predict whether a function is a builtin function with name={name}.
if name == None, then any builtin function will return True
"""
def name_judge():
return name is None or func.__name__ == name
if isinstance(func, types.BuiltinFunctionType) and name_judge():
return True return True
elif func in six.moves.builtins.__dict__.values(): elif func in six.moves.builtins.__dict__.values() and name_judge():
return True return True
else: else:
return False return False
def is_builtin_len(func):
if isinstance(func, types.BuiltinFunctionType) and func.__name__ == 'len':
return True
return False
def is_builtin_zip(func):
return is_builtin(func) and func.__name__ == 'zip'
def is_unsupported(func): def is_unsupported(func):
""" """
Checks whether the func is supported by dygraph to static graph. Checks whether the func is supported by dygraph to static graph.
...@@ -165,12 +163,18 @@ def convert_call(func): ...@@ -165,12 +163,18 @@ def convert_call(func):
.format(func)) .format(func))
return func return func
if is_builtin_len(func): if is_builtin(func, "len"):
return convert_len return convert_len
if is_builtin_zip(func): if is_builtin(func, "zip"):
return convert_zip return convert_zip
if is_builtin(func, "range"):
return convert_range
if is_builtin(func, "enumerate"):
return convert_enumerate
if is_builtin(func) or is_unsupported(func): if is_builtin(func) or is_unsupported(func):
return func return func
......
...@@ -13,11 +13,12 @@ ...@@ -13,11 +13,12 @@
# limitations under the License. # limitations under the License.
import re import re
import paddle
from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle.fluid.framework import core, Variable from paddle.fluid.framework import core, Variable
from paddle.fluid.layers import Assert, Print from paddle.fluid.layers import Assert, Print
from paddle.fluid.layers import range as paddle_range
from paddle.fluid.layers import array_length, array_read, array_write, create_array from paddle.fluid.layers import array_length, array_read, array_write, create_array
from paddle.fluid.layers import assign, fill_constant, slice, reduce_all, reduce_any from paddle.fluid.layers import assign, fill_constant, slice, reduce_all, reduce_any
from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn
...@@ -26,6 +27,45 @@ from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_ ...@@ -26,6 +27,45 @@ from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, Dygraph2StaticException from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, Dygraph2StaticException
def indexable(x, code=None):
if isinstance(x, Variable): return x
if hasattr(x, '__len__') and hasattr(x, '__getitem__'): return x
if hasattr(x, '__iter__'):
return [i for i in x]
else:
raise RuntimeError("X can't be convert into indexable.")
def unpack_by_structure(target, structure):
""" unified unpack interface for paddle and python.
"""
if isinstance(target, Variable):
return _unpack_by_structure_paddle(target, structure)
else:
return _unpack_by_structure_python(target, structure)
def _unpack_by_structure_python(target, structure):
""" TODO(xiongkun): analysis the differences between python and paddle unpack.
"""
return _unpack_by_structure_paddle(target, structure)
def _unpack_by_structure_paddle(target, structure):
if structure == 1:
return target
ret = []
for idx, ele in enumerate(structure):
if ele == 1:
ret.append(target[idx])
continue
if isinstance(ele, list):
ret.append(unpack_by_structure(target[idx], ele))
continue
assert False, "structure element must be 1 or list"
return ret
def convert_while_loop(cond, body, getter, setter): def convert_while_loop(cond, body, getter, setter):
""" """
A function representation of a Python ``while`` statement. A function representation of a Python ``while`` statement.
...@@ -50,12 +90,26 @@ def convert_while_loop(cond, body, getter, setter): ...@@ -50,12 +90,26 @@ def convert_while_loop(cond, body, getter, setter):
def _run_paddle_while(cond, body, getter, setter): def _run_paddle_while(cond, body, getter, setter):
# NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Tensors. # NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Tensors.
def new_body_fn(*args):
# UndefinedVar will become data layer not check. """ wrap the body() and add return value for `while_loop`
loop_vars = [to_static_variable(var) for var in getter()] """
body()
return getter()
def new_cond_fn(*args):
""" cond is a zero-args function, which is not
compatible with `while_loop`.
"""
return cond()
# UndefinedVar will become data layer not check variable with value=NO_VALUE_MAGIC.
loop_vars = [
to_static_variable(var) if not isinstance(var, UndefinedVar) else var
for var in getter()
]
setter(loop_vars) # change the non-local var to variable setter(loop_vars) # change the non-local var to variable
# variable maybe modified to inner var. change it into # variable maybe modified to inner var. change it into
loop_vars = control_flow.while_loop(cond, body, loop_vars) loop_vars = control_flow.while_loop(new_cond_fn, new_body_fn, loop_vars)
setter(loop_vars) # change the non-local var to variable setter(loop_vars) # change the non-local var to variable
return loop_vars return loop_vars
...@@ -368,6 +422,8 @@ def convert_len(var): ...@@ -368,6 +422,8 @@ def convert_len(var):
'len(var) only supports LoDTensor/LoDTensorArray/SelectedRows, but received %s.' 'len(var) only supports LoDTensor/LoDTensorArray/SelectedRows, but received %s.'
% type(var)) % type(var))
else: else:
if isinstance(var, VariableTuple):
return var.__len__()
return len(var) return len(var)
...@@ -380,6 +436,44 @@ def convert_zip(*args): ...@@ -380,6 +436,44 @@ def convert_zip(*args):
return zip(*args) return zip(*args)
# TODO(xiongkun): delete when list<variable> is ready.
class VariableTuple:
"""
this class will cause enumerate can't be wrapped by other iterator change function.
this will be fixed when list<Variable> is producted.
VariableTuple can only deal with variables which is fixed.
"""
def __init__(self, var, start=0):
self.var = var
self.len = convert_len(var)
self.rag = paddle_range(start, start + self.len, 1, paddle.int64)
def __getitem__(self, idx):
return self.rag[idx], self.var[idx]
def __len__(self):
return self.len
def convert_enumerate(*args):
has_variable = any(map(lambda x: isinstance(x, Variable), args))
if has_variable:
return VariableTuple(*args)
return enumerate(*args)
def convert_range(*args):
has_variable = any(map(lambda x: isinstance(x, Variable), args))
if has_variable:
if len(args) == 1: return paddle_range(0, args[0], 1, paddle.int64)
if len(args) == 2:
return paddle_range(args[0], args[1], 1, paddle.int64)
if len(args) == 3:
return paddle_range(args[0], args[1], args[2], paddle.int64)
return range(*args)
def convert_shape(x): def convert_shape(x):
""" """
A function representation of the shape of variable. A function representation of the shape of variable.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.utils import FunctionNameLivenessAnalysis
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
class CreateVariableTransformer(BaseTransformer):
"""
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Type of input node should be AstNodeWrapper, but received %s ." % type(
wrapper_root)
self.root = wrapper_root.node
FunctionNameLivenessAnalysis(self.root)
def transform(self):
"""
Main function to transform AST.
"""
self.visit(self.root)
def visit_FunctionDef(self, node):
#attributes = set(filter(lambda x: '.' in x, node.pd_scope.modified_vars()))
bodys = node.body
names = sorted(node.pd_scope.created_vars())
for name in names:
bodys[0:0] = [create_undefined_var(name)]
return node
...@@ -34,6 +34,7 @@ from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_un ...@@ -34,6 +34,7 @@ from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_un
from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_nodes from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_nodes
from paddle.fluid.dygraph.dygraph_to_static.utils import create_get_args_node, create_set_args_node from paddle.fluid.dygraph.dygraph_to_static.utils import create_get_args_node, create_set_args_node
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_INDEX_PREFIX, FOR_ITER_TUPLE_PREFIX, FOR_ITER_TUPLE_INDEX_PREFIX, FOR_ITER_VAR_LEN_PREFIX, FOR_ITER_VAR_NAME_PREFIX, FOR_ITER_ZIP_TO_LIST_PREFIX, FOR_ITER_TARGET_PREFIX, FOR_ITER_ITERATOR_PREFIX
TRUE_FUNC_PREFIX = 'true_fn' TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn' FALSE_FUNC_PREFIX = 'false_fn'
...@@ -304,7 +305,6 @@ def transform_if_else(node, root): ...@@ -304,7 +305,6 @@ def transform_if_else(node, root):
""" """
# TODO(liym27): Consider variable like `self.a` modified in if/else node. # TODO(liym27): Consider variable like `self.a` modified in if/else node.
new_vars_to_create = sorted(list(node.pd_scope.created_vars()))
return_name_ids = sorted(list(node.pd_scope.modified_vars())) return_name_ids = sorted(list(node.pd_scope.modified_vars()))
# NOTE: Python can create variable only in if body or only in else body, and use it out of if/else. # NOTE: Python can create variable only in if body or only in else body, and use it out of if/else.
# E.g. # E.g.
...@@ -315,10 +315,6 @@ def transform_if_else(node, root): ...@@ -315,10 +315,6 @@ def transform_if_else(node, root):
# #
# Create static variable for those variables # Create static variable for those variables
create_new_vars_in_parent_stmts = [] create_new_vars_in_parent_stmts = []
for name in new_vars_to_create:
# NOTE: Consider variable like `self.a` modified in if/else node.
if "." not in name:
create_new_vars_in_parent_stmts.append(create_undefined_var(name))
nonlocal_names = list(return_name_ids) nonlocal_names = list(return_name_ids)
nonlocal_names.sort() nonlocal_names.sort()
...@@ -326,8 +322,21 @@ def transform_if_else(node, root): ...@@ -326,8 +322,21 @@ def transform_if_else(node, root):
nonlocal_names = _valid_nonlocal_names(return_name_ids, nonlocal_names) nonlocal_names = _valid_nonlocal_names(return_name_ids, nonlocal_names)
# TODO(dev): Need a better way to deal this. # TODO(dev): Need a better way to deal this.
if ARGS_NAME in nonlocal_names: # LoopTransformer will create some special vars, which is not visiable by users. so we can sure it's safe to remove them.
nonlocal_names.remove(ARGS_NAME) filter_names = [
ARGS_NAME, FOR_ITER_INDEX_PREFIX, FOR_ITER_TUPLE_PREFIX,
FOR_ITER_TARGET_PREFIX, FOR_ITER_ITERATOR_PREFIX,
FOR_ITER_TUPLE_INDEX_PREFIX, FOR_ITER_VAR_LEN_PREFIX,
FOR_ITER_VAR_NAME_PREFIX, FOR_ITER_ZIP_TO_LIST_PREFIX
]
def remove_if(x):
for name in filter_names:
if x.startswith(name): return False
return True
nonlocal_names = list(filter(remove_if, nonlocal_names))
return_name_ids = nonlocal_names
nonlocal_stmt_node = create_nonlocal_stmt_nodes(nonlocal_names) nonlocal_stmt_node = create_nonlocal_stmt_nodes(nonlocal_names)
......
...@@ -26,8 +26,8 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code ...@@ -26,8 +26,8 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_nodes, create_get_args_node, create_set_args_node from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_nodes, create_get_args_node, create_set_args_node
from paddle.fluid.dygraph.dygraph_to_static.utils import FunctionNameLivenessAnalysis
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import ARGS_NAME from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import ARGS_NAME
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import RenameTransformer from paddle.fluid.dygraph.dygraph_to_static.base_transformer import RenameTransformer
...@@ -483,10 +483,10 @@ class LoopTransformer(BaseTransformer): ...@@ -483,10 +483,10 @@ class LoopTransformer(BaseTransformer):
), "Input non-AstNodeWrapper node for the initialization of LoopTransformer." ), "Input non-AstNodeWrapper node for the initialization of LoopTransformer."
self.wrapper_root = wrapper_root self.wrapper_root = wrapper_root
self.root = wrapper_root.node self.root = wrapper_root.node
FunctionNameLivenessAnalysis(self.root)
def transform(self): def transform(self):
ForLoopTuplePreTransformer(self.wrapper_root).transform() ForLoopTuplePreTransformer(self.wrapper_root).transform()
self.name_visitor = NameVisitor(self.root)
self.visit(self.root) self.visit(self.root)
def visit_While(self, node): def visit_While(self, node):
...@@ -537,19 +537,19 @@ class LoopTransformer(BaseTransformer): ...@@ -537,19 +537,19 @@ class LoopTransformer(BaseTransformer):
return [node] return [node]
init_stmts, cond_stmt, body_stmts = stmts_tuple init_stmts, cond_stmt, body_stmts = stmts_tuple
# 2. get original loop vars # 2. get original loop vars
loop_var_names, create_var_names = self.name_visitor.get_loop_var_names( loop_var_names, create_var_names = node.pd_scope.modified_vars(
node) ), node.pd_scope.created_vars()
# TODO: Remove the bunch of code? We have the unique format `for A in B:`
# NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases, # NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases,
# we need append new loop var & remove useless loop var # we need append new loop var & remove useless loop var
# 1. for x in var -> x is no need # 1. for x in var -> x is no need
# 2. for i, x in enumerate(var) -> x is no need # 2. for i, x in enumerate(var) -> x is no need
if current_for_node_parser.is_for_iter( if current_for_node_parser.is_for_iter():
) or current_for_node_parser.is_for_enumerate_iter():
iter_var_name = current_for_node_parser.iter_var_name iter_var_name = current_for_node_parser.iter_var_name
iter_idx_name = current_for_node_parser.iter_idx_name iter_idx_name = current_for_node_parser.iter_idx_name
loop_var_names.add(iter_idx_name) loop_var_names.add(iter_idx_name)
if iter_var_name not in create_var_names: if current_for_node_parser.enum_idx_name is not None:
loop_var_names.remove(iter_var_name) loop_var_names.add(current_for_node_parser.enum_idx_name)
# 3. prepare result statement list # 3. prepare result statement list
new_stmts = [] new_stmts = []
...@@ -559,10 +559,8 @@ class LoopTransformer(BaseTransformer): ...@@ -559,10 +559,8 @@ class LoopTransformer(BaseTransformer):
# y += x # y += x
# print(x) # x = 10 # print(x) # x = 10
# #
# We need to create static variable for those variables # We don't need to create static variable for them, because
for name in create_var_names: # we do this in CreateUndefinedVarTransformer
if "." not in name:
new_stmts.append(create_undefined_var(name))
# create non-local statement for body and cond. # create non-local statement for body and cond.
nonlocal_names = list(loop_var_names | create_var_names) nonlocal_names = list(loop_var_names | create_var_names)
...@@ -581,10 +579,7 @@ class LoopTransformer(BaseTransformer): ...@@ -581,10 +579,7 @@ class LoopTransformer(BaseTransformer):
name=unique_name.generate(FOR_CONDITION_PREFIX), name=unique_name.generate(FOR_CONDITION_PREFIX),
args=gast.arguments(args=[], args=gast.arguments(args=[],
posonlyargs=[], posonlyargs=[],
vararg=gast.Name(id=ARGS_NAME, vararg=None,
ctx=gast.Param(),
annotation=None,
type_comment=None),
kwonlyargs=[], kwonlyargs=[],
kw_defaults=None, kw_defaults=None,
kwarg=None, kwarg=None,
...@@ -597,17 +592,11 @@ class LoopTransformer(BaseTransformer): ...@@ -597,17 +592,11 @@ class LoopTransformer(BaseTransformer):
# 6. create & append loop body function node # 6. create & append loop body function node
# append return values for loop body # append return values for loop body
body_stmts.append(
gast.Return(value=generate_name_node(
nonlocal_names, ctx=gast.Load(), gen_tuple_if_single=True)))
body_func_node = gast.FunctionDef( body_func_node = gast.FunctionDef(
name=unique_name.generate(FOR_BODY_PREFIX), name=unique_name.generate(FOR_BODY_PREFIX),
args=gast.arguments(args=[], args=gast.arguments(args=[],
posonlyargs=[], posonlyargs=[],
vararg=gast.Name(id=ARGS_NAME, vararg=None,
ctx=gast.Param(),
annotation=None,
type_comment=None),
kwonlyargs=[], kwonlyargs=[],
kw_defaults=None, kw_defaults=None,
kwarg=None, kwarg=None,
...@@ -632,8 +621,8 @@ class LoopTransformer(BaseTransformer): ...@@ -632,8 +621,8 @@ class LoopTransformer(BaseTransformer):
return new_stmts return new_stmts
def get_while_stmt_nodes(self, node): def get_while_stmt_nodes(self, node):
loop_var_names, create_var_names = self.name_visitor.get_loop_var_names( loop_var_names, create_var_names = node.pd_scope.modified_vars(
node) ), node.pd_scope.created_vars()
new_stmts = [] new_stmts = []
# create non-local statement for body and cond. # create non-local statement for body and cond.
...@@ -652,19 +641,14 @@ class LoopTransformer(BaseTransformer): ...@@ -652,19 +641,14 @@ class LoopTransformer(BaseTransformer):
# y = x # y = x
# z = y # z = y
# #
# We need to create static variable for those variables # We don't need to create static variable for those variables, because
for name in create_var_names: # we do this in CreateUndefinedVarTransformer
if "." not in name:
new_stmts.append(create_fill_constant_node(name))
condition_func_node = gast.FunctionDef( condition_func_node = gast.FunctionDef(
name=unique_name.generate(WHILE_CONDITION_PREFIX), name=unique_name.generate(WHILE_CONDITION_PREFIX),
args=gast.arguments(args=[], args=gast.arguments(args=[],
posonlyargs=[], posonlyargs=[],
vararg=gast.Name(id=ARGS_NAME, vararg=None,
ctx=gast.Param(),
annotation=None,
type_comment=None),
kwonlyargs=[], kwonlyargs=[],
kw_defaults=None, kw_defaults=None,
kwarg=None, kwarg=None,
...@@ -677,17 +661,11 @@ class LoopTransformer(BaseTransformer): ...@@ -677,17 +661,11 @@ class LoopTransformer(BaseTransformer):
new_stmts.append(condition_func_node) new_stmts.append(condition_func_node)
new_body = node.body new_body = node.body
new_body.append(
gast.Return(value=generate_name_node(
nonlocal_names, ctx=gast.Load(), gen_tuple_if_single=True)))
body_func_node = gast.FunctionDef( body_func_node = gast.FunctionDef(
name=unique_name.generate(WHILE_BODY_PREFIX), name=unique_name.generate(WHILE_BODY_PREFIX),
args=gast.arguments(args=[], args=gast.arguments(args=[],
posonlyargs=[], posonlyargs=[],
vararg=gast.Name(id=ARGS_NAME, vararg=None,
ctx=gast.Param(),
annotation=None,
type_comment=None),
kwonlyargs=[], kwonlyargs=[],
kw_defaults=None, kw_defaults=None,
kwarg=None, kwarg=None,
......
...@@ -82,6 +82,8 @@ dygraph_class_to_static_api = { ...@@ -82,6 +82,8 @@ dygraph_class_to_static_api = {
FOR_ITER_INDEX_PREFIX = '__for_loop_var_index' FOR_ITER_INDEX_PREFIX = '__for_loop_var_index'
FOR_ITER_TUPLE_PREFIX = '__for_loop_iter_tuple' FOR_ITER_TUPLE_PREFIX = '__for_loop_iter_tuple'
FOR_ITER_TARGET_PREFIX = '__for_loop_iter_target'
FOR_ITER_ITERATOR_PREFIX = '__for_loop_iter_iterator'
FOR_ITER_TUPLE_INDEX_PREFIX = '__for_loop_iter_tuple_index' FOR_ITER_TUPLE_INDEX_PREFIX = '__for_loop_iter_tuple_index'
FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len' FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len'
FOR_ITER_VAR_NAME_PREFIX = '__for_loop_iter_var' FOR_ITER_VAR_NAME_PREFIX = '__for_loop_iter_var'
...@@ -1099,6 +1101,18 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): ...@@ -1099,6 +1101,18 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
if isinstance(node, gast.FunctionDef): if isinstance(node, gast.FunctionDef):
return self._get_name_scope(node) return self._get_name_scope(node)
def visit_ListComp(self, node):
""" [ i for i in range(10) ]
In this case, `i` will not created in FunctionScope.
We don't collect `i` by not calling generic_visit.
"""
pass
def visit_DictComp(self, node):
""" the same as ListComp.
"""
pass
def visit_Name(self, node): def visit_Name(self, node):
self.generic_visit(node) self.generic_visit(node)
write_context = (gast.Store, gast.AugStore, gast.Del) write_context = (gast.Store, gast.AugStore, gast.Del)
...@@ -1149,8 +1163,13 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): ...@@ -1149,8 +1163,13 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
def post_func(): def post_func():
self._father_name_scope().merge_from(self._current_name_scope()) self._father_name_scope().merge_from(self._current_name_scope())
self._nearest_function_scope().merge_from(
self._current_name_scope())
self._current_name_scope().created = self._nearest_function_scope( self._current_name_scope().created = self._nearest_function_scope(
).existed_vars() - node.before_created ).existed_vars() - node.before_created
# gather created vars into father and used in CreateUndefinedVarTransform
self._nearest_function_scope().created |= self._current_name_scope(
).created
def pre_func(): def pre_func():
setattr(node, "before_created", setattr(node, "before_created",
......
...@@ -108,7 +108,6 @@ def select_input(inputs, mask): ...@@ -108,7 +108,6 @@ def select_input(inputs, mask):
def select_input_with_buildin_type(inputs, mask): def select_input_with_buildin_type(inputs, mask):
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, create_undefined_var_like from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, create_undefined_var_like
support_ret_buildin_type = (bool, float, six.integer_types)
false_var, true_var = inputs false_var, true_var = inputs
if isinstance(false_var, UndefinedVar) and isinstance( if isinstance(false_var, UndefinedVar) and isinstance(
...@@ -1182,12 +1181,16 @@ class While(object): ...@@ -1182,12 +1181,16 @@ class While(object):
}) })
support_ret_buildin_type = (bool, float, six.integer_types)
def assign_skip_lod_tensor_array(input, output): def assign_skip_lod_tensor_array(input, output):
""" """
Assign input to output, but skip the process of copying LoDTensorArray unless it's created in while_block. Assign input to output, but skip the process of copying LoDTensorArray unless it's created in while_block.
""" """
if not isinstance(input, (Variable, core.VarBase)): if not isinstance(input, (Variable, core.VarBase)):
if isinstance(output, Variable): if isinstance(output, Variable) and isinstance(
input, support_ret_buildin_type):
assign(input, output) assign(input, output)
else: else:
output = input output = input
...@@ -1297,6 +1300,7 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): ...@@ -1297,6 +1300,7 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
if not isinstance(output_vars, (list, tuple)): if not isinstance(output_vars, (list, tuple)):
output_vars = [output_vars] output_vars = [output_vars]
try: try:
loop_vars = _deal_with_undefined_var(output_vars, loop_vars)
assert_same_structure(output_vars, loop_vars, check_types=False) assert_same_structure(output_vars, loop_vars, check_types=False)
except ValueError as e: except ValueError as e:
raise ValueError( raise ValueError(
...@@ -1308,6 +1312,36 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): ...@@ -1308,6 +1312,36 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
return loop_vars return loop_vars
def _deal_with_undefined_var(output_vars, loop_vars):
""" Deal with undefined var cases, We create undefined variable based on the results of body().
In Dy2Static, we use undefined var to represent the var created in control flow. This function
expand the loop_vars and replace original loop_vars.
1. UndefinedVar = Variable # create a variable
2. UndefinedVar = None # create a undefined var with RETURN_NO_VALUE_MAGIC_NUM
3. UndefinedVar = List(int) # create a list of variable
4. UndefinedVar = value # create a variable
"""
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, create_undefined_variable
def create_var_like(o_var):
if isinstance(o_var,
(Variable, ) + support_ret_buildin_type) or o_var is None:
return create_undefined_variable()
if isinstance(o_var, (tuple, list)):
return [create_undefined_variable() for i in range(len(o_var))]
if len(output_vars) != len(loop_vars):
raise ValueError("The length of loop_vars should be the same.")
results = []
for o_var, l_var in zip(output_vars, loop_vars):
if isinstance(l_var, UndefinedVar) or l_var is None:
results.append(create_var_like(o_var))
else:
results.append(l_var)
return results
def lod_rank_table(x, level=0): def lod_rank_table(x, level=0):
""" """
LoD Rank Table Operator. Given an input variable **x** and a level number LoD Rank Table Operator. Given an input variable **x** and a level number
...@@ -2616,6 +2650,11 @@ def change_none_to_undefinedvar(nest1, nest2): ...@@ -2616,6 +2650,11 @@ def change_none_to_undefinedvar(nest1, nest2):
def expand_undefined_var(nest1, nest2, names): def expand_undefined_var(nest1, nest2, names):
""" TODO: make this function recursively.
nest1: Var1, (UndefinedVar, [1,2,3])
nest2: Var2, ([1,2,3,4], UndefinedVar)
In this case, we should not expand recursively.
"""
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_VALUE_PREFIX from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_VALUE_PREFIX
......
...@@ -385,6 +385,7 @@ class BaseModel(fluid.dygraph.Layer): ...@@ -385,6 +385,7 @@ class BaseModel(fluid.dygraph.Layer):
dropout_implementation='upscale_in_train') dropout_implementation='upscale_in_train')
else: else:
step_input = new_hidden step_input = new_hidden
cell_outputs = self._split_batch_beams(step_input) cell_outputs = self._split_batch_beams(step_input)
cell_outputs = self.fc(cell_outputs) cell_outputs = self.fc(cell_outputs)
......
...@@ -442,13 +442,6 @@ class TestErrorInForLoop(TestTransformForLoop): ...@@ -442,13 +442,6 @@ class TestErrorInForLoop(TestTransformForLoop):
def _init_dyfunc(self): def _init_dyfunc(self):
self.dyfunc = for_loop_dyfunc_not_support self.dyfunc = for_loop_dyfunc_not_support
def test_ast_to_func(self):
with self.assertRaisesRegexp(
NotImplementedError,
"Dynamic-to-Static only supports the step value is a constant or negative constant "
):
self._run_static()
if __name__ == '__main__': if __name__ == '__main__':
with fluid.framework._test_eager_guard(): with fluid.framework._test_eager_guard():
......
...@@ -66,6 +66,9 @@ def get_source_code(func): ...@@ -66,6 +66,9 @@ def get_source_code(func):
class StaticCode1(): class StaticCode1():
def dyfunc_with_if_else(x_v, label=None): def dyfunc_with_if_else(x_v, label=None):
loss = _jst.UndefinedVar('loss')
__return_1 = _jst.UndefinedVar('__return_1')
__return_0 = _jst.UndefinedVar('__return_0')
__return_value_0 = None __return_value_0 = None
def get_args_0(): def get_args_0():
...@@ -89,9 +92,6 @@ class StaticCode1(): ...@@ -89,9 +92,6 @@ class StaticCode1():
_jst.IfElse( _jst.IfElse(
paddle.mean(x_v)[0] > 5, true_fn_0, false_fn_0, get_args_0, paddle.mean(x_v)[0] > 5, true_fn_0, false_fn_0, get_args_0,
set_args_0, ('x_v', )) set_args_0, ('x_v', ))
__return_0 = _jst.UndefinedVar('__return_0')
__return_1 = _jst.UndefinedVar('__return_1')
loss = _jst.UndefinedVar('loss')
def get_args_1(): def get_args_1():
nonlocal __return_0, __return_1, __return_value_0, loss nonlocal __return_0, __return_1, __return_value_0, loss
...@@ -123,6 +123,9 @@ class StaticCode1(): ...@@ -123,6 +123,9 @@ class StaticCode1():
class StaticCode2(): class StaticCode2():
# TODO: Transform return statement # TODO: Transform return statement
def dyfunc_with_if_else(x_v, label=None): def dyfunc_with_if_else(x_v, label=None):
loss = _jst.UndefinedVar('loss')
__return_3 = _jst.UndefinedVar('__return_3')
__return_2 = _jst.UndefinedVar('__return_2')
__return_value_1 = None __return_value_1 = None
def get_args_2(): def get_args_2():
...@@ -146,9 +149,6 @@ class StaticCode2(): ...@@ -146,9 +149,6 @@ class StaticCode2():
_jst.IfElse( _jst.IfElse(
paddle.mean(x_v)[0] > 5, true_fn_2, false_fn_2, get_args_2, paddle.mean(x_v)[0] > 5, true_fn_2, false_fn_2, get_args_2,
set_args_2, ('x_v', )) set_args_2, ('x_v', ))
__return_2 = _jst.UndefinedVar('__return_2')
__return_3 = _jst.UndefinedVar('__return_3')
loss = _jst.UndefinedVar('loss')
def get_args_3(): def get_args_3():
nonlocal __return_2, __return_3, __return_value_1, loss nonlocal __return_2, __return_3, __return_value_1, loss
......
...@@ -578,8 +578,8 @@ class TestOpNumWithTensorShapeInFor1(TestOpNumBasicWithTensorShape): ...@@ -578,8 +578,8 @@ class TestOpNumWithTensorShapeInFor1(TestOpNumBasicWithTensorShape):
self.dygraph_func = dyfunc_with_for_1 self.dygraph_func = dyfunc_with_for_1
def _set_expected_op_num(self): def _set_expected_op_num(self):
self.expected_op_num = 22 self.expected_op_num = 29
self.expected_shape_op_num = 3 self.expected_shape_op_num = 2
self.expected_slice_op_num = 3 self.expected_slice_op_num = 3
...@@ -589,7 +589,7 @@ class TestOpNumWithTensorShapeInWhile1(TestOpNumBasicWithTensorShape): ...@@ -589,7 +589,7 @@ class TestOpNumWithTensorShapeInWhile1(TestOpNumBasicWithTensorShape):
self.dygraph_func = dyfunc_with_while_1 self.dygraph_func = dyfunc_with_while_1
def _set_expected_op_num(self): def _set_expected_op_num(self):
self.expected_op_num = 22 self.expected_op_num = 21
self.expected_shape_op_num = 3 self.expected_shape_op_num = 3
self.expected_slice_op_num = 3 self.expected_slice_op_num = 3
......
...@@ -21,7 +21,7 @@ import paddle.fluid.layers as layers ...@@ -21,7 +21,7 @@ import paddle.fluid.layers as layers
from paddle.fluid.dygraph import Embedding, Layer, LayerNorm, Linear, to_variable from paddle.fluid.dygraph import Embedding, Layer, LayerNorm, Linear, to_variable
from paddle.fluid.dygraph.jit import dygraph_to_static_func from paddle.fluid.dygraph.jit import dygraph_to_static_func
from paddle.fluid.layers.utils import map_structure from paddle.fluid.layers.utils import map_structure
from paddle.fluid.layers.tensor import range as pd_range import paddle
def position_encoding_init(n_position, d_pos_vec): def position_encoding_init(n_position, d_pos_vec):
...@@ -634,7 +634,7 @@ class Transformer(Layer): ...@@ -634,7 +634,7 @@ class Transformer(Layer):
value=0), value=0),
} for i in range(self.n_layer)] } for i in range(self.n_layer)]
for i in pd_range(0, max_len, 1, dtype="int32"): for i in range(paddle.to_tensor(max_len)):
trg_pos = layers.fill_constant(shape=trg_word.shape, trg_pos = layers.fill_constant(shape=trg_word.shape,
dtype="int64", dtype="int64",
value=i) value=i)
......
...@@ -26,7 +26,8 @@ from .convert_operators import convert_pop as Pop # noqa: F401 ...@@ -26,7 +26,8 @@ from .convert_operators import convert_pop as Pop # noqa: F401
from .convert_operators import convert_print as Print # noqa: F401 from .convert_operators import convert_print as Print # noqa: F401
from .convert_operators import convert_shape as Shape # noqa: F401 from .convert_operators import convert_shape as Shape # noqa: F401
from .convert_operators import convert_while_loop as While # noqa: F401 from .convert_operators import convert_while_loop as While # noqa: F401
from .convert_operators import unpack_by_structure as Unpack # noqa: F401
from .convert_operators import indexable as Indexable # noqa: F401
from .variable_trans_func import create_bool_as_type # noqa: F401 from .variable_trans_func import create_bool_as_type # noqa: F401
from .variable_trans_func import to_static_variable # noqa: F401 from .variable_trans_func import to_static_variable # noqa: F401
from .convert_operators import convert_shape_compare # noqa: F401 from .convert_operators import convert_shape_compare # noqa: F401
......
...@@ -26,5 +26,6 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape_c ...@@ -26,5 +26,6 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape_c
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype # noqa: F401 from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape # noqa: F401 from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_loop # noqa: F401 from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_loop # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import unpack_by_structure, indexable # noqa: F401
__all__ = [] __all__ = []
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册