未验证 提交 b603dd55 编写于 作者: X xiongkun 提交者: GitHub

[Dy2static] FunctionScopeVisitor Enhance and substitute the original NameVisitor in If (#43967)

* add support for control flow block analysis

* move FunctionNameLivenessAnalysis into utils

* pass test_ifelse.py

* remove duplicate data_layer_not_check

* pass the test_ifelse.py

* fix unittest error .

* fix all ci error in first version

* temporay disable CreateVariableTransformer

* fix ci errors

* fix function name liveness analysis bugs

* modifty def cond

* fix

* fix ci error - v2

* fix by code review

* change return_name_ids -> return_name
上级 bbe99555
......@@ -96,6 +96,7 @@ class DygraphToStaticAst(BaseTransformer):
BreakContinueTransformer, # break/continue in loops
ReturnTransformer, # return in functions
LogicalTransformer, # logical and/or/not
#CreateVariableTransformer, # create undefined var for if / while / for
LoopTransformer, # for/while -> while_op
IfElseTransformer, # if/else -> cond_op
AssertTransformer, # assert statement
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle.fluid.framework import core, Variable
......@@ -21,7 +23,7 @@ from paddle.fluid.layers import assign, fill_constant, slice, reduce_all, reduce
from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn
from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_VAR_NAME
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, Dygraph2StaticException
def convert_while_loop(cond, body, getter, setter):
......@@ -41,11 +43,9 @@ def convert_while_loop(cond, body, getter, setter):
# If loop_vars is changed during cond callable, then it causes bug, but current logical_and/logical_not/... doesn't change the loop_vars.
pred = cond()
if isinstance(pred, Variable):
loop_vars = _run_paddle_while(cond, body, getter, setter)
_run_paddle_while(cond, body, getter, setter)
else:
loop_vars = _run_py_while(cond, body, getter, setter)
return loop_vars
_run_py_while(cond, body, getter, setter)
def _run_paddle_while(cond, body, getter, setter):
......@@ -61,10 +61,13 @@ def _run_paddle_while(cond, body, getter, setter):
def _run_py_while(cond, body, getter, setter):
loop_vars = getter()
while cond():
loop_vars = body()
return loop_vars
while True:
pred = cond()
if isinstance(pred, Variable):
raise Dygraph2StaticException(
"python while pred change from bool to variable.")
if not pred: break
body()
def convert_logical_and(x_func, y_func):
......@@ -231,17 +234,32 @@ def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,
def new_true_fn():
set_args(init_args)
outs = true_fn()
_check_no_undefined_var(outs, return_name_ids, 'if_body')
return outs
ret = true_fn()
# IfExpr will return a non-None return value, so we just return ret.
# We assume normal return has no return value.
if ret is None: return get_args()
else: return ret
def new_false_fn():
set_args(init_args)
outs = false_fn()
_check_no_undefined_var(outs, return_name_ids, 'else_body')
return outs
ret = false_fn()
if ret is None: return get_args()
else: return ret
cond_outs = control_flow.cond(pred, new_true_fn, new_false_fn)
try:
cond_outs = control_flow.cond(pred, new_true_fn, new_false_fn, None,
return_name_ids)
except Exception as e:
if re.search("Unsupported return type of true_fn and false_fn in cond",
str(e)):
raise Dygraph2StaticException(
"Your if/else have different return type. TODO: add link to modifty. {}"
.format(str(e)))
if re.search("Incompatible return values of", str(e)):
raise Dygraph2StaticException(
"Your if/else have different number of return value. TODO: add link to modifty. {}"
.format(str(e)))
raise e
return _recover_args_state(cond_outs, get_args, set_args, return_name_ids)
......@@ -251,8 +269,7 @@ def _run_py_ifelse(pred, true_fn, false_fn, get_args, set_args,
Evaluate python original branch function if-else.
"""
py_outs = true_fn() if pred else false_fn()
py_outs = _remove_no_value_return_var(py_outs)
return _recover_args_state(py_outs, get_args, set_args, return_name_ids)
return py_outs
def _remove_no_value_return_var(out):
......@@ -317,9 +334,10 @@ def _recover_args_state(outs, get_args, set_args, return_name_ids):
assert num_outs <= num_args
if num_args == 1:
final_outs = (outs, )
final_outs = (outs, ) if not isinstance(outs,
(list, tuple)) else tuple(outs)
else:
outs = (outs, ) if num_outs == 1 else outs
outs = (outs, ) if num_outs == 1 else tuple(outs)
final_outs = outs + init_args[num_outs:]
set_args(final_outs)
......
......@@ -27,11 +27,11 @@ from paddle.utils import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node, ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import create_assign_node
from paddle.fluid.dygraph.dygraph_to_static.utils import create_assign_node, FunctionNameLivenessAnalysis
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var
from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_node
from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_nodes
from paddle.fluid.dygraph.dygraph_to_static.utils import create_get_args_node, create_set_args_node
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
......@@ -53,7 +53,8 @@ class IfElseTransformer(BaseTransformer):
), "Type of input node should be AstNodeWrapper, but received %s ." % type(
wrapper_root)
self.root = wrapper_root.node
self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
FunctionNameLivenessAnalysis(
self.root) # name analysis of current ast tree.
def transform(self):
"""
......@@ -273,193 +274,6 @@ class NameVisitor(gast.NodeVisitor):
self.name_ids[name_id] = ctxs + self.name_ids[name_id]
def get_name_ids(nodes, after_node=None, end_node=None):
"""
Return all ast.Name.id of python variable in nodes range from
(after_node, end_node) exclusively. If after_node or end_node is None, the
range is unlimited.
"""
name_visitor = NameVisitor(after_node, end_node)
for node in nodes:
name_visitor.visit(node)
return name_visitor.name_ids
def parse_cond_args(parent_ids,
var_ids_dict,
modified_ids_dict=None,
ctx=gast.Load):
"""
Find out the ast.Name.id list of input by analyzing node's AST information.
"""
# 1. filter the var fit the ctx
arg_name_ids = [
var_id for var_id, var_ctx in six.iteritems(var_ids_dict)
if isinstance(var_ctx[0], ctx)
]
# 2. args should contain modified var ids in if-body or else-body
# case:
#
# ```
# if b < 1:
# z = y
# else:
# z = x
# ```
#
# In the above case, `z` should be in the args of cond()
if modified_ids_dict:
arg_name_ids = set(arg_name_ids) | set(modified_ids_dict)
# 3. args should not contain the vars not in parent ids
# case :
#
# ```
# x = 1
# if x > y:
# z = [v for v in range(i)]
# ```
#
# In the above case, `v` should not be in the args of cond()
arg_name_ids = set(arg_name_ids) & set(parent_ids)
return arg_name_ids
def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
after_ifelse_vars_dict):
"""
Find out the ast.Name list of output by analyzing node's AST information.
One of the following conditions should be satisfied while determining whether a variable is a return value:
1. the var in parent scope is modified in If.body or If.orelse node.
2. new var is both created in If.body and If.orelse node.
3. new var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node.
For example:
x, y = 5, 10
if x > 4:
x = x+1
z = x*x
q = 10
else:
y = y - 1
z = y*y
m = 20
n = 20
print(q)
n = 30
print(n)
The return_ids are (x, y, z, q) for `If.body` and `If.orelse`node, because
1. x is modified in If.body node,
2. y is modified in If.body node,
3. z is both created in If.body and If.orelse node,
4. q is created only in If.body, and it is used by `print(q)` as gast.Load.
Note:
After transformed, q and z are created in parent scope. For example,
x, y = 5, 10
q = paddle.jit.dy2static.UndefindVar('q')
z = paddle.jit.dy2static.UndefindVar('z')
def true_func(x, y, q):
x = x+1
z = x*x
q = 10
return x,y,z,q
def false_func(x, y, q):
y = y - 1
z = y*y
m = 20
n = 20
return x,y,z,q
x,y,z,q = fluid.layers.cond(x>4, lambda: true_func(x, y), lambda: false_func(x, y, q))
m and n are not in return_ids, because
5. m is created only in If.orelse, but it is not used after gast.If node.
6. n is created only in If.orelse, and it is used by `n = 30` and `print(n)`, but it is not used as gast.Load firstly but gast.Store .
"""
def _is_return_var(ctxs):
for ctx in ctxs:
if isinstance(ctx, (gast.Store, gast.Param)):
return True
return False
def _vars_with_store(ids_dict):
vars = []
for k, ctxs in six.iteritems(ids_dict):
if _is_return_var(ctxs):
vars.append(k)
return vars
def _modified_vars(child_dict, parent_dict):
return set(
[var for var in _vars_with_store(child_dict) if var in parent_dict])
def _vars_loaded(ids_dict):
"""
gast.Param is also a kind of `load` semantic.
"""
new_dict = defaultdict(list)
for k, ctxs in six.iteritems(ids_dict):
for ctx in ctxs:
if isinstance(ctx, (gast.Load, gast.Param)):
new_dict[k].append(ctx)
return new_dict
# modified vars
body_modified_vars = _modified_vars(if_vars_dict, parent_vars_dict)
body_modified_vars = set(
filter(lambda x: x != ARGS_NAME, body_modified_vars))
orelse_modified_vars = _modified_vars(else_vars_dict, parent_vars_dict)
orelse_modified_vars = set(
filter(lambda x: x != ARGS_NAME, orelse_modified_vars))
modified_vars = body_modified_vars | orelse_modified_vars
# new vars
# TODO(remove __args when new FunctionScopeAnalysis has been used.)
body_new_vars = set([
var for var in _vars_with_store(if_vars_dict)
if var not in parent_vars_dict and var != ARGS_NAME
])
orelse_new_vars = set([
var for var in _vars_with_store(else_vars_dict)
if var not in parent_vars_dict and var != ARGS_NAME
])
new_vars_in_body_or_orelse = body_new_vars | orelse_new_vars
new_vars_in_one_of_body_or_orelse = body_new_vars ^ orelse_new_vars
# 1. the var in parent scope is modified in If.body or If.orelse node.
modified_vars_from_parent = modified_vars - new_vars_in_body_or_orelse
# 2. new var is both created in If.body and If.orelse node.
new_vars_in_body_and_orelse = body_new_vars & orelse_new_vars
# 3. new var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node.
# TODO(zhhsplendid): the _vars_loaded can be optimized as _vars_loaded_before_store. Because if a variable is stored before load,
# the value would change by the store statement, we don't have to return to change the value. However, analysis is
# complex because if the IfElse is nested and outer IfElse store statement may not run at all. We will put this optimization
# as the future TODO
used_vars_after_ifelse = set(
[var for var in _vars_loaded(after_ifelse_vars_dict)])
new_vars_to_create = new_vars_in_one_of_body_or_orelse & used_vars_after_ifelse | new_vars_in_body_and_orelse
# 4. generate return_ids of if/else node.
return_ids = list(modified_vars_from_parent | new_vars_in_body_and_orelse
| new_vars_to_create)
return_ids.sort()
return return_ids, modified_vars_from_parent, new_vars_to_create
def _valid_nonlocal_names(return_name_ids, nonlocal_names):
"""
All var in return_name_ids should be in nonlocal_names.
......@@ -490,15 +304,8 @@ def transform_if_else(node, root):
"""
# TODO(liym27): Consider variable like `self.a` modified in if/else node.
parent_name_ids = get_name_ids([root], end_node=node)
body_name_ids = get_name_ids(node.body)
orelse_name_ids = get_name_ids(node.orelse)
# Get after_ifelse_name_ids, which means used var names after If.body and If.orelse node.
after_ifelse_name_ids = get_name_ids([root], after_node=node)
return_name_ids, modified_name_ids_from_parent, new_vars_to_create = parse_cond_return(
parent_name_ids, body_name_ids, orelse_name_ids, after_ifelse_name_ids)
new_vars_to_create = sorted(list(node.pd_scope.created_vars()))
return_name_ids = sorted(list(node.pd_scope.modified_vars()))
# NOTE: Python can create variable only in if body or only in else body, and use it out of if/else.
# E.g.
#
......@@ -513,16 +320,7 @@ def transform_if_else(node, root):
if "." not in name:
create_new_vars_in_parent_stmts.append(create_undefined_var(name))
parent_ids_set = set()
for k, ctxs in parent_name_ids.items():
if any([not isinstance(ctx, gast.Load) for ctx in ctxs]):
parent_ids_set.add(k)
true_args = parse_cond_args(parent_ids_set, body_name_ids,
modified_name_ids_from_parent)
false_args = parse_cond_args(parent_ids_set, orelse_name_ids,
modified_name_ids_from_parent)
nonlocal_names = list(true_args | false_args | new_vars_to_create)
nonlocal_names = list(return_name_ids)
nonlocal_names.sort()
# NOTE: All var in return_name_ids should be in nonlocal_names.
nonlocal_names = _valid_nonlocal_names(return_name_ids, nonlocal_names)
......@@ -531,8 +329,7 @@ def transform_if_else(node, root):
if ARGS_NAME in nonlocal_names:
nonlocal_names.remove(ARGS_NAME)
nonlocal_stmt_node = [create_nonlocal_stmt_node(nonlocal_names)
] if nonlocal_names else []
nonlocal_stmt_node = create_nonlocal_stmt_nodes(nonlocal_names)
empty_arg_node = gast.arguments(args=[],
posonlyargs=[],
......@@ -546,12 +343,12 @@ def transform_if_else(node, root):
nonlocal_stmt_node + node.body,
name=unique_name.generate(TRUE_FUNC_PREFIX),
input_args=empty_arg_node,
return_name_ids=return_name_ids)
return_name_ids=[])
false_func_node = create_funcDef_node(
nonlocal_stmt_node + node.orelse,
name=unique_name.generate(FALSE_FUNC_PREFIX),
input_args=empty_arg_node,
return_name_ids=return_name_ids)
return_name_ids=[])
get_args_node = create_get_args_node(nonlocal_names)
set_args_node = create_set_args_node(nonlocal_names)
......
......@@ -30,7 +30,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_node, create_get_args_node, create_set_args_node
from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_nodes, create_get_args_node, create_set_args_node
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import ARGS_NAME
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
......@@ -93,101 +93,6 @@ def create_while_nodes(condition_name, body_name, loop_var_names, getter_name,
return ret
class NameScope:
def __init__(self):
""" we don't analyze the read only variable
because they keep the same in control flow.
"""
self.globals = set()
self.nonlocals = set()
self.args = set()
# all vars been stored,
# may be globals or non-locals
self.w_vars = set()
def created_vars(self):
return self.w_vars - self.globals - self.nonlocals - self.args
def write_vars(self):
return self.w_vars
def global_vars(self):
return self.globals
class FunctionNameLivenessAnalysis(gast.NodeVisitor):
""" analyze the liveness of a function.
every variables stored in this scope will be collected,
in addition with global/nonlocal information.
1. global variable is stored in node.var_globals.
2. nonlocal variable is stored in node.var_nonlocals.
3. arguments is stored in node.var_args.
For example:
def func(*args, **kargs):
a = 12
global i,j
nonlocal x,y
print(a)
i = k
for m in range(10):
q = 12
After this visitor we have:
# node is the FunctionDef node with name: "func"
node.pd_scope = NameScope(
globals = ['i', 'j'],
nonlocals = ['x', 'y'],
args = ['args', 'kargs'],
wr_vars = ['a', 'i', 'q', 'm']
)
"""
def __init__(self, root_node):
self.funcdef_stack = []
self.visit(root_node)
def _current_funcdef_scope(self):
return self.funcdef_stack[-1].pd_scope
def visit_Name(self, node):
self.generic_visit(node)
write_context = (gast.Store, gast.AugStore, gast.Del)
if isinstance(node.ctx, write_context):
self._current_funcdef_scope().w_vars.add(node.id)
def visit_FunctionDef(self, node):
setattr(node, 'pd_scope', NameScope())
self.funcdef_stack.append(node)
self._current_funcdef_scope().args |= set(
self._get_argument_names(node))
self.generic_visit(node)
self.funcdef_stack.pop()
def visit_Global(self, node):
self._current_funcdef_scope().globals |= set(node.names)
def visit_Nonlocal(self, node):
self._current_funcdef_scope().nonlocals |= set(node.names)
def _get_argument_names(self, node):
""" get all arguments name in the functiondef node.
this node is local to the function and shouldn't
be created.
"""
assert isinstance(
node, gast.FunctionDef), "Input node is not function define node"
names = [a for a in node.args.args]
names.append(node.args.vararg)
names.append(node.args.kwarg)
names = [i.id for i in names if i is not None]
return names
class NameVisitor(gast.NodeVisitor):
'''
Analysis name liveness for loop transformer
......@@ -665,7 +570,7 @@ class LoopTransformer(BaseTransformer):
if ARGS_NAME in nonlocal_names:
nonlocal_names.remove(ARGS_NAME)
nonlocal_stmt_node = [create_nonlocal_stmt_node(nonlocal_names)]
nonlocal_stmt_node = create_nonlocal_stmt_nodes(nonlocal_names)
# 4. append init statements
new_stmts.extend(init_stmts)
......@@ -737,7 +642,7 @@ class LoopTransformer(BaseTransformer):
if ARGS_NAME in nonlocal_names:
nonlocal_names.remove(ARGS_NAME)
nonlocal_stmt_node = [create_nonlocal_stmt_node(nonlocal_names)]
nonlocal_stmt_node = create_nonlocal_stmt_nodes(nonlocal_names)
# Python can create variable in loop and use it out of loop, E.g.
#
......
......@@ -43,7 +43,9 @@ RETURN_VALUE_INIT_NAME = '__return_value_init'
# solve it in dy2stat, we put float64 value with this magic number at Static
# graph as a place holder to indicate the returning placeholder means no value
# should return.
RETURN_NO_VALUE_MAGIC_NUM = 1.77113e+279
# Assign not support float64, use float32 value as magic number.
RETURN_NO_VALUE_MAGIC_NUM = 1.77113e+27
RETURN_NO_VALUE_VAR_NAME = "__no_value_return_var"
......@@ -216,44 +218,17 @@ class ReturnTransformer(BaseTransformer):
ctx=gast.Load(),
annotation=None,
type_comment=None)))
init_names = [
unique_name.generate(RETURN_VALUE_INIT_NAME)
for i in range(max_return_length)
]
assign_zero_nodes = [
create_fill_constant_node(iname, 0.0) for iname in init_names
]
if len(init_names) == 1:
return_value_nodes = gast.Name(id=init_names[0],
ctx=gast.Load(),
annotation=None,
type_comment=None)
else:
# We need to initialize return value as a tuple because control
# flow requires some inputs or outputs have same structure
return_value_nodes = gast.Tuple(elts=[
gast.Name(id=iname,
ctx=gast.Load(),
annotation=None,
type_comment=None) for iname in init_names
],
ctx=gast.Load())
assign_return_value_node = gast.Assign(targets=[
gast.Name(id=value_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=return_value_nodes)
value=gast.Constant(
kind=None, value=None))
node.body.insert(0, assign_return_value_node)
node.body[:0] = assign_zero_nodes
# Prepend no value placeholders
for name in self.return_no_value_name[node]:
assign_no_value_node = create_fill_constant_node(
name, RETURN_NO_VALUE_MAGIC_NUM)
node.body.insert(0, assign_no_value_node)
self.function_def.pop()
return node
......@@ -340,64 +315,11 @@ class ReturnTransformer(BaseTransformer):
cur_func_node = self.function_def[-1]
return_length = get_return_size(return_node)
if return_length < max_return_length:
# In this case we should append RETURN_NO_VALUE placeholder
#
# max_return_length must be >= 1 here because return_length will be
# 0 at least.
if self.return_value_name[cur_func_node] is None:
self.return_value_name[cur_func_node] = unique_name.generate(
RETURN_VALUE_PREFIX)
no_value_names = [
unique_name.generate(RETURN_NO_VALUE_VAR_NAME)
for j in range(max_return_length - return_length)
]
self.return_no_value_name[cur_func_node].extend(no_value_names)
# Handle tuple/non-tuple case
if max_return_length == 1:
assign_nodes.append(
gast.Assign(targets=[
gast.Name(id=self.return_value_name[cur_func_node],
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=gast.Name(id=no_value_names[0],
ctx=gast.Load(),
annotation=None,
type_comment=None)))
else:
# max_return_length > 1 which means we should assign tuple
fill_tuple = [
gast.Name(id=n,
ctx=gast.Load(),
annotation=None,
type_comment=None) for n in no_value_names
]
if return_node.value is not None:
if isinstance(return_node.value, gast.Tuple):
fill_tuple[:0] = return_node.value.elts
else:
fill_tuple.insert(0, return_node.value)
assign_nodes.append(
gast.Assign(targets=[
gast.Name(id=self.return_value_name[cur_func_node],
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=gast.Tuple(elts=fill_tuple,
ctx=gast.Load())))
else:
# In this case we should NOT append RETURN_NO_VALUE placeholder
if return_node.value is not None:
cur_func_node = self.function_def[-1]
if self.return_value_name[cur_func_node] is None:
self.return_value_name[
cur_func_node] = unique_name.generate(
self.return_value_name[cur_func_node] = unique_name.generate(
RETURN_VALUE_PREFIX)
assign_nodes.append(
......
......@@ -30,8 +30,9 @@ import numpy as np
import paddle
from paddle.fluid import unique_name
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid import core
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers import assign
# Note(Aurelius): Do not forget the dot `.` to distinguish other
# module such as paddlenlp.
......@@ -64,6 +65,34 @@ class BaseNodeVisitor(gast.NodeVisitor):
return ret
# imp is deprecated in python3
from importlib.machinery import SourceFileLoader
dygraph_class_to_static_api = {
"CosineDecay": "cosine_decay",
"ExponentialDecay": "exponential_decay",
"InverseTimeDecay": "inverse_time_decay",
"NaturalExpDecay": "natural_exp_decay",
"NoamDecay": "noam_decay",
"PiecewiseDecay": "piecewise_decay",
"PolynomialDecay": "polynomial_decay",
}
FOR_ITER_INDEX_PREFIX = '__for_loop_var_index'
FOR_ITER_TUPLE_PREFIX = '__for_loop_iter_tuple'
FOR_ITER_TUPLE_INDEX_PREFIX = '__for_loop_iter_tuple_index'
FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len'
FOR_ITER_VAR_NAME_PREFIX = '__for_loop_iter_var'
FOR_ITER_ZIP_TO_LIST_PREFIX = '__for_loop_iter_zip'
# FullArgSpec is valid from Python3. Defined a Namedtuple to
# to make it available in Python2.
FullArgSpec = collections.namedtuple('FullArgSpec', [
'args', 'varargs', 'varkw', 'defaults', 'kwonlyargs', 'kwonlydefaults',
'annotations'
])
def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
"""
This function creates a Tensor on the global block. The created Tensor
......@@ -99,7 +128,7 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
if shape[i] is None:
shape[i] = -1
return helper.create_variable(name=name,
return helper.create_global_variable(name=name,
shape=shape,
dtype=dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
......@@ -109,32 +138,22 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
need_check_feed=False)
# imp is deprecated in python3
from importlib.machinery import SourceFileLoader
dygraph_class_to_static_api = {
"CosineDecay": "cosine_decay",
"ExponentialDecay": "exponential_decay",
"InverseTimeDecay": "inverse_time_decay",
"NaturalExpDecay": "natural_exp_decay",
"NoamDecay": "noam_decay",
"PiecewiseDecay": "piecewise_decay",
"PolynomialDecay": "polynomial_decay",
}
def create_undefined_var_like(variable):
""" create a undefined var with the same shape and dtype like varaible.
"""
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM
var = data_layer_not_check(unique_name.generate("undefined_var"),
variable.shape, variable.dtype)
assign(RETURN_NO_VALUE_MAGIC_NUM, var)
return var
FOR_ITER_INDEX_PREFIX = '__for_loop_var_index'
FOR_ITER_TUPLE_PREFIX = '__for_loop_iter_tuple'
FOR_ITER_TUPLE_INDEX_PREFIX = '__for_loop_iter_tuple_index'
FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len'
FOR_ITER_VAR_NAME_PREFIX = '__for_loop_iter_var'
FOR_ITER_ZIP_TO_LIST_PREFIX = '__for_loop_iter_zip'
# FullArgSpec is valid from Python3. Defined a Namedtuple to
# to make it available in Python2.
FullArgSpec = collections.namedtuple('FullArgSpec', [
'args', 'varargs', 'varkw', 'defaults', 'kwonlyargs', 'kwonlydefaults',
'annotations'
])
def create_undefined_variable():
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM
var = data_layer_not_check(unique_name.generate("undefined_var"), [1],
"float64")
assign(RETURN_NO_VALUE_MAGIC_NUM, var)
return var
class UndefinedVar:
......@@ -147,6 +166,12 @@ class UndefinedVar:
"local variable '{}' should be created before using it.")
class Dygraph2StaticException(Exception):
def __init__(self, message):
super().__init__(message)
def saw(x):
if isinstance(x, UndefinedVar):
return x.check()
......@@ -1600,6 +1625,209 @@ def slice_is_num(slice_node):
return False
class NameScope:
def __init__(self):
"""
A NameScope is a object which manager all the variable names.
only FunctionDef and Controlflow node will have a namescope property.
type can be "function" and "controlflow"
we don't analyze the read only variable because they don't affect the analysis.
"""
self.globals = set()
self.nonlocals = set()
self.args = set()
self.father = None # point to the nearest function name scope.
self.w_vars = set() # all qualified + normal names been stored
self.created = set(
) # useful for control flow compatibility. may be remove later
def set_father(self, father):
self.father = father
def existed_vars(self):
""" vars existing in current scope.
they must not contain qualified names.
"""
local_vars = self.w_vars - self.globals - self.nonlocals - self.args
return set(filter(lambda x: '.' not in x, local_vars))
def created_vars(self):
return self.created
def modified_vars(self):
# may be globals / non-locals / args / qualified names and created_vars
return self.w_vars
def control_flow_vars(self):
valid_names = self.w_vars
tmp = self.father.global_vars & valid_names,
return {"global": tmp, "nonlocal": self.w_vars - tmp}
def global_vars(self):
return self.globals
def merge_from(self, name_scope):
self.globals |= name_scope.globals
self.nonlocals |= name_scope.nonlocals
self.args |= name_scope.args
self.w_vars |= name_scope.w_vars
class FunctionNameLivenessAnalysis(gast.NodeVisitor):
""" analyze the liveness of a function.
every variables stored in this scope will be collected,
in addition with global/nonlocal information.
1. global variable is stored in node.var_globals.
2. nonlocal variable is stored in node.var_nonlocals.
3. arguments is stored in node.var_args.
For example:
def func(*args, **kargs):
a = 12
global i,j
nonlocal x,y
print(a)
i = k
for m in range(10):
q = 12
After this visitor we have:
# node is the FunctionDef node with name: "func"
node.pd_scope = NameScope(
globals = ['i', 'j'],
nonlocals = ['x', 'y'],
args = ['args', 'kargs'],
wr_vars = ['a', 'i', 'q', 'm']
)
"""
def __init__(self, root_node):
self.scope_node_stack = [] # controlflow, functiondef node
self.visit(root_node)
def _reset_name_scope(self, node):
# always reset the node as empty namescope.
setattr(node, "pd_scope", NameScope())
def _get_name_scope(self, node):
if not hasattr(node, "pd_scope"):
setattr(node, "pd_scope", NameScope())
return node.pd_scope
def _current_name_scope(self):
return self._get_name_scope(self.scope_node_stack[-1])
def _father_name_scope(self):
if len(self.scope_node_stack) == 1: return None
return self._get_name_scope(self.scope_node_stack[-2])
def _nearest_function_scope(self):
if len(self.scope_node_stack) == 1: return None
for node in self.scope_node_stack[-2::-1]:
if isinstance(node, gast.FunctionDef):
return self._get_name_scope(node)
def visit_Name(self, node):
self.generic_visit(node)
write_context = (gast.Store, gast.AugStore, gast.Del)
if isinstance(node.ctx, write_context):
self._current_name_scope().w_vars.add(node.id)
def visit_FunctionDef(self, node):
def pre_func():
self._current_name_scope().args |= set(
self._get_argument_names(node))
def post_func():
""" NOTE: why we need merge w_vars here ?
because we do ifelse_transformer after loop_transformer. Loops will changed into functioons. but we know this function will be called in if. so we add w_vars to father function scope.
"""
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import WHILE_CONDITION_PREFIX, WHILE_BODY_PREFIX, FOR_CONDITION_PREFIX, FOR_BODY_PREFIX
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import TRUE_FUNC_PREFIX, FALSE_FUNC_PREFIX
control_flow_function_def = [
WHILE_BODY_PREFIX, WHILE_BODY_PREFIX, FOR_CONDITION_PREFIX,
FOR_BODY_PREFIX, TRUE_FUNC_PREFIX, FALSE_FUNC_PREFIX
]
def is_control_flow_def_node():
for prefix in control_flow_function_def:
if node.name.startswith(prefix): return True
return False
if self._father_name_scope() and is_control_flow_def_node():
self._father_name_scope().w_vars |= self._current_name_scope(
).w_vars
self._visit_scope_node(node, pre_func, post_func)
def _visit_scope_node(self, node, pre_func, post_func):
""" scope node main visit logic.
pre_func and post_func is callbacks
"""
self._reset_name_scope(node)
self.scope_node_stack.append(node)
self._current_name_scope().father = self._nearest_function_scope()
if pre_func: pre_func()
self.generic_visit(node)
if post_func: post_func()
self.scope_node_stack.pop()
def _visit_controlflow_node(self, node):
def post_func():
self._father_name_scope().merge_from(self._current_name_scope())
self._current_name_scope().created = self._nearest_function_scope(
).existed_vars() - node.before_created
def pre_func():
setattr(node, "before_created",
self._nearest_function_scope().existed_vars())
self._visit_scope_node(node, pre_func, post_func)
def visit_For(self, node):
self._visit_controlflow_node(node)
def visit_While(self, node):
self._visit_controlflow_node(node)
def visit_If(self, node):
self._visit_controlflow_node(node)
def visit_Global(self, node):
self._current_name_scope().globals |= set(node.names)
def visit_Nonlocal(self, node):
self._current_name_scope().nonlocals |= set(node.names)
def visit_Attribute(self, node):
self.generic_visit(node)
write_context = (gast.Store, gast.AugStore, gast.Del)
if isinstance(node.ctx, write_context):
name = ast_to_source_code(node).strip()
self._current_name_scope().w_vars.add(name)
def _get_argument_names(self, node):
""" get all arguments name in the functiondef node.
this node is local to the function and shouldn't
be created.
"""
assert isinstance(
node, gast.FunctionDef), "Input node is not function define node"
names = [a for a in node.args.args]
names.append(node.args.vararg)
names.append(node.args.kwarg)
names = [i.id for i in names if i is not None]
return names
def create_get_args_node(names):
"""
Create get_args function as follows:
......@@ -1617,21 +1845,24 @@ def create_get_args_node(names):
return gast.parse(textwrap.dedent(func_def)).body[0]
assert isinstance(names, (list, tuple))
if not names:
return empty_node()
mapped = list(filter(lambda n: '.' not in n, names))
nonlocal_names = sorted(
mapped,
key=mapped.index) # to keep the order, we can't use set() to unique
if not names:
return empty_node()
if not nonlocal_names:
nonlocal_vars = "\n"
else:
nonlocal_vars = "nonlocal " + ",".join(nonlocal_names)
template = """
def {func_name}():
nonlocal {nonlocal_vars}
{nonlocal_vars}
return {vars},
"""
func_def = template.format(
func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX),
nonlocal_vars=','.join(nonlocal_names),
nonlocal_vars=nonlocal_vars,
vars=",".join(names))
return gast.parse(textwrap.dedent(func_def)).body[0]
......@@ -1654,32 +1885,37 @@ def create_set_args_node(names):
return gast.parse(textwrap.dedent(func_def)).body[0]
assert isinstance(names, (list, tuple))
if not names:
return empty_node()
mapped = list(filter(lambda n: '.' not in n, names))
nonlocal_names = sorted(
mapped,
key=mapped.index) # to keep the order, we can't use set() to unique
if not names:
return empty_node()
if not nonlocal_names:
nonlocal_vars = "\n"
else:
nonlocal_vars = "nonlocal " + ",".join(nonlocal_names)
template = """
def {func_name}({args}):
nonlocal {nonlocal_vars}
{nonlocal_vars}
{vars}, = {args}
"""
func_def = template.format(
func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX),
args=ARGS_NAME,
nonlocal_vars=','.join(nonlocal_names),
nonlocal_vars=nonlocal_vars,
vars=",".join(names))
return gast.parse(textwrap.dedent(func_def)).body[0]
def create_nonlocal_stmt_node(names):
def create_nonlocal_stmt_nodes(names):
assert isinstance(names, (list, tuple))
mapped = list(filter(lambda n: '.' not in n, names))
names = sorted(
mapped,
key=mapped.index) # to keep the order, we can't use set() to unique
if not names:
return []
func_code = "nonlocal {}".format(','.join(names))
return gast.parse(func_code).body[0]
return [gast.parse(func_code).body[0]]
......@@ -20,7 +20,7 @@ import textwrap
from paddle.utils import gast
from paddle.fluid import unique_name
from paddle.fluid.framework import Variable
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, data_layer_not_check
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, create_undefined_variable
__all__ = [
'create_bool_as_type',
......@@ -62,9 +62,10 @@ def to_static_variable(x):
return paddle.full(shape=[1], dtype='float64', fill_value=x)
if isinstance(x, six.integer_types):
return paddle.full(shape=[1], dtype='int64', fill_value=x)
if isinstance(x, UndefinedVar):
return data_layer_not_check(unique_name.generator("loop_undefined_var"),
[-1])
if isinstance(x, UndefinedVar) or x is None:
""" for early return case, we need a variable to represent None, current we use data_layer_not_check.
"""
return create_undefined_variable()
return x
......
......@@ -21,7 +21,7 @@ from .. import core
from ..framework import Program, Variable, Operator, _non_static_mode, static_only, _in_legacy_dygraph, in_dygraph_mode
from ..layer_helper import LayerHelper, unique_name
from .nn import logical_and, logical_not, logical_or
from .utils import assert_same_structure, map_structure, hold_mutable_vars, copy_mutable_vars
from .utils import assert_same_structure, map_structure, hold_mutable_vars, copy_mutable_vars, padding_to_same_structure, is_sequence, pack_sequence_as, flatten, to_sequence
import numpy
import warnings
import six
......@@ -107,9 +107,16 @@ def select_input(inputs, mask):
def select_input_with_buildin_type(inputs, mask):
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, create_undefined_var_like
support_ret_buildin_type = (bool, float, six.integer_types)
false_var, true_var = inputs
if isinstance(false_var, UndefinedVar) and isinstance(
true_var, UndefinedVar):
""" None -> UndefinedVar, so the real value is a [None, UndefinedVar] or [None, None], we just return None.
"""
return None
if isinstance(false_var, Variable) and isinstance(true_var, Variable):
return select_input(inputs, mask)
......@@ -132,6 +139,27 @@ def select_input_with_buildin_type(inputs, mask):
"Return results from different branches in cond are not same type: "
"false_var returned by fasle_fn is '{}' and true_var of true_fn is "
"'{}'".format(type(false_var), type(true_var)))
elif ((isinstance(false_var, UndefinedVar)
and isinstance(true_var, (Variable, ) + support_ret_buildin_type))
or (isinstance(true_var, UndefinedVar)
and isinstance(false_var,
(Variable, ) + support_ret_buildin_type))):
def create_var_if_not_undefined_var(a):
if isinstance(a, UndefinedVar): return a
return to_static_variable(a)
def create_like_if_undefined_var(a, b):
if isinstance(a, UndefinedVar): return create_undefined_var_like(b)
return a
# TODO(xiongkun): add warning here.
true_var, false_var = create_var_if_not_undefined_var(
true_var), create_var_if_not_undefined_var(false_var)
inputs = [
create_like_if_undefined_var(false_var, true_var),
create_like_if_undefined_var(true_var, false_var)
]
else:
raise TypeError(
"Unsupported return type of true_fn and false_fn in cond: false_var "
......@@ -1158,7 +1186,10 @@ def assign_skip_lod_tensor_array(input, output):
"""
Assign input to output, but skip the process of copying LoDTensorArray unless it's created in while_block.
"""
if not isinstance(input, Variable) and not isinstance(input, core.VarBase):
if not isinstance(input, (Variable, core.VarBase)):
if isinstance(output, Variable):
assign(input, output)
else:
output = input
return
......@@ -2377,7 +2408,7 @@ def copy_var_to_parent_block(var, layer_helper):
return parent_block_var
def cond(pred, true_fn=None, false_fn=None, name=None):
def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
"""
This API returns ``true_fn()`` if the predicate ``pred`` is true else
``false_fn()`` . Users could also set ``true_fn`` or ``false_fn`` to
......@@ -2423,6 +2454,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
true. The default value is ``None`` .
false_fn(callable, optional): A callable to be performed if ``pred`` is
false. The default value is ``None`` .
return_names: A list of strings to represents the name of returned vars. useful to debug.
name(str, optional): The default value is ``None`` . Normally users
don't have to set this parameter. For more information, please
refer to :ref:`api_guide_Name` .
......@@ -2536,12 +2568,30 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
"true_fn returns non-None while false_fn returns None")
# Merge ture and false output if they are not None
if return_names is None:
return_names = ["no name"] * len(to_sequence(true_output))
else:
"""
dy2static will set the return_names and expand the return values to UndefinedVar.
"""
true_output, false_output = expand_undefined_var(
true_output, false_output, return_names)
true_output, false_output = change_none_to_undefinedvar(
true_output, false_output)
if len(to_sequence(true_output)) != len(to_sequence(false_output)):
raise ValueError(
"true fn returns {} vars, but false fn returns {} vars, which is not equals"
.format(len(to_sequence(true_output)),
len(to_sequence(false_output))))
for true_out, false_out, return_name in zip(to_sequence(true_output),
to_sequence(false_output),
to_sequence(return_names)):
try:
assert_same_structure(true_output, false_output, check_types=False)
assert_same_structure(true_out, false_out, check_types=False)
except ValueError as e:
raise ValueError(
"Incompatible return values of true_fn and false_fn in cond: {}".
format(e))
"Incompatible return values of `{}` in true_fn and false_fn in cond: {}"
.format(return_name, e))
mask = cast(pred, dtype='int32')
merge_func = lambda false_var, true_var: select_input_with_buildin_type(
......@@ -2550,6 +2600,41 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
return merged_output
def change_none_to_undefinedvar(nest1, nest2):
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar
def map_fn(x):
if x is None: return UndefinedVar("padding")
return x
nest1_out = pack_sequence_as(nest1, list(map(map_fn, flatten(nest1))))
nest2_out = pack_sequence_as(nest2, list(map(map_fn, flatten(nest2))))
return nest1_out, nest2_out
def expand_undefined_var(nest1, nest2, names):
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_VALUE_PREFIX
def pack_undefined_var_as(seq):
return pack_sequence_as(seq,
[UndefinedVar("padding") for i in flatten(seq)])
def map_fn(n1, n2, name):
if not name.startswith(RETURN_VALUE_PREFIX) and (isinstance(
n1, UndefinedVar) or n1 is None):
return pack_undefined_var_as(n2)
return n1
nest1_out = list(
map(map_fn, to_sequence(nest1), to_sequence(nest2), to_sequence(names)))
nest2_out = list(
map(map_fn, to_sequence(nest2), to_sequence(nest1), to_sequence(names)))
if not is_sequence(nest1): nest1_out = nest1_out[0]
if not is_sequence(nest2): nest2_out = nest2_out[0]
return nest1_out, nest2_out
def _error_message(what, arg_name, op_name, right_value, error_value):
error_message = "{what} of '{arg_name}' in {op_name} must be " \
"{right_value}, but received: {error_value}.".format(
......
......@@ -125,6 +125,13 @@ def _yield_flat_nest(nest):
yield n
def to_sequence(nest):
if is_sequence(nest):
return nest
else:
return [nest]
def flatten(nest):
"""
:alias_main: paddle.flatten
......@@ -260,6 +267,26 @@ def _recursive_assert_same_structure(nest1, nest2, check_types):
_recursive_assert_same_structure(n1, n2, check_types)
def padding_to_same_structure(nest1, nest2, obj=None):
def _padding_to_same_structure_single(value, obj):
def change_none_to_obj(x):
if x is None: return obj
return x
if is_sequence(value):
value = pack_sequence_as(
value, [change_none_to_obj(item) for item in flatten(value)])
else:
value = change_none_to_obj(value)
return value
nest1 = _padding_to_same_structure_single(nest1, obj)
nest2 = _padding_to_same_structure_single(nest2, obj)
return nest1, nest2
def assert_same_structure(nest1, nest2, check_types=True):
"""
Confirm two nested structures with the same structure.
......
......@@ -117,7 +117,7 @@ def dyfunc_with_if_else_early_return1():
b = paddle.zeros([3, 3])
return a, b
a = paddle.zeros([2, 2]) + 1
return a
return a, None
def dyfunc_with_if_else_early_return2():
......@@ -131,7 +131,7 @@ def dyfunc_with_if_else_early_return2():
d = paddle.zeros([3, 3]) + 1
return c, d
e = paddle.zeros([2, 2]) + 3
return e
return e, None
def dyfunc_with_if_else_with_list_geneator(x):
......
......@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
from paddle.fluid import layers
......@@ -360,7 +361,7 @@ class BaseModel(fluid.dygraph.Layer):
predicted_ids = []
parent_ids = []
for step_idx in range(self.beam_max_step_num):
for step_idx in range(paddle.to_tensor(self.beam_max_step_num)):
if fluid.layers.reduce_sum(1 - beam_finished).numpy()[0] == 0:
break
step_input = self._merge_batch_beams(step_input)
......
......@@ -19,11 +19,29 @@ import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static.utils import Dygraph2StaticException
SEED = 2020
np.random.seed(SEED)
class TestDy2staticException(unittest.TestCase):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = None
self.error = "Your if/else have different number of return value."
def test_error(self):
if self.dyfunc:
with self.assertRaisesRegex(Dygraph2StaticException, self.error):
ProgramTranslator().enable(True)
self.assertTrue(declarative(self.dyfunc)(self.x))
paddle.fluid.dygraph.base._in_declarative_mode_ = False
ProgramTranslator().enable(False)
def test_continue_in_for(x):
x = fluid.dygraph.to_variable(x)
for i in range(10):
......@@ -265,10 +283,12 @@ class TestWhileLoopClassVar(TestContinueInWhile):
self.dygraph_func = while_loop_class_var
class TestOptimBreakInFor(TestContinueInWhile):
class TestOptimBreakInFor(TestDy2staticException):
def init_dygraph_func(self):
self.dygraph_func = test_optim_break_in_for
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = test_optim_break_in_for
self.error = "python while pred change from bool to variable."
class TestOptimBreakInWhile(TestContinueInWhile):
......
......@@ -17,20 +17,25 @@ from __future__ import print_function
import unittest
import paddle
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import FunctionNameLivenessAnalysis
from paddle.fluid.dygraph.dygraph_to_static.utils import FunctionNameLivenessAnalysis
from paddle.utils import gast
import inspect
class JudgeVisitor(gast.NodeVisitor):
def __init__(self, ans):
def __init__(self, ans, mod):
self.ans = ans
self.mod = mod
def visit_FunctionDef(self, node):
scope = node.pd_scope
expected = self.ans.get(node.name, set())
assert scope.created_vars() == expected, "Not Equals."
exp_mod = self.mod.get(node.name, set())
assert scope.existed_vars() == expected, "Not Equals."
assert scope.modified_vars(
) == exp_mod, "Not Equals in function:{} . expect {} , but get {}".format(
node.name, exp_mod, scope.modified_vars())
self.generic_visit(node)
......@@ -108,12 +113,31 @@ class TestClosureAnalysis(unittest.TestCase):
},
]
self.modified_var = [
{
'func': set('ki'),
'test_nonlocal': set('i')
},
{
'func': set({'i'}),
'test_global': set({"t"})
},
{
'func': set('i'),
},
{
'func': set('i'),
'test_normal_argument': set('x')
},
]
def test_main(self):
for ans, func in zip(self.answer, self.all_dygraph_funcs):
for mod, ans, func in zip(self.modified_var, self.answer,
self.all_dygraph_funcs):
test_func = inspect.getsource(func)
gast_root = gast.parse(test_func)
name_visitor = FunctionNameLivenessAnalysis(gast_root)
JudgeVisitor(ans).visit(gast_root)
JudgeVisitor(ans, mod).visit(gast_root)
def TestClosureAnalysis_Attribute_func():
......@@ -128,6 +152,10 @@ class TestClosureAnalysis_Attribute(TestClosureAnalysis):
self.all_dygraph_funcs = [TestClosureAnalysis_Attribute_func]
self.answer = [{"TestClosureAnalysis_Attribute_func": set({'i'})}]
self.modified_var = [{
"TestClosureAnalysis_Attribute_func":
set({'i', 'self.current.function'})
}]
if __name__ == '__main__':
......
......@@ -20,6 +20,7 @@ import unittest
import paddle
from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static.utils import Dygraph2StaticException
import paddle.fluid.core as core
from ifelse_simple_func import *
......@@ -32,6 +33,22 @@ else:
place = fluid.CPUPlace()
class TestDy2staticException(unittest.TestCase):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = None
self.error = "Your if/else have different number of return value."
def test_error(self):
if self.dyfunc:
with self.assertRaisesRegex(Dygraph2StaticException, self.error):
ProgramTranslator().enable(True)
self.assertTrue(declarative(self.dyfunc)(self.x))
paddle.fluid.dygraph.base._in_declarative_mode_ = False
ProgramTranslator().enable(False)
class TestDygraphIfElse(unittest.TestCase):
"""
TestCase for the transformation from control flow `if/else`
......@@ -417,16 +434,12 @@ class TestDy2StIfElseRetInt1(unittest.TestCase):
self.assertIsInstance(self.out[1], int)
class TestDy2StIfElseRetInt2(TestDy2StIfElseRetInt1):
class TestDy2StIfElseRetInt2(TestDy2staticException):
def setUp(self):
self.x = np.random.random([5]).astype('float32')
self.error = "Your if/else have different number of return value."
self.dyfunc = dyfunc_ifelse_ret_int2
self.out = self.get_dy2stat_out()
def test_ast_to_func(self):
self.assertIsInstance(self.out[0], (paddle.Tensor, core.eager.Tensor))
self.assertIsInstance(self.out[1], (paddle.Tensor, core.eager.Tensor))
class TestDy2StIfElseRetInt3(TestDy2StIfElseRetInt1):
......@@ -448,7 +461,7 @@ class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1):
def test_ast_to_func(self):
ProgramTranslator().enable(True)
with self.assertRaises(TypeError):
with self.assertRaises(Dygraph2StaticException):
static_func = paddle.jit.to_static(self.dyfunc)
out = static_func(self.x)
# Why need set `_in_declarative_mode_` here?
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,264 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import textwrap
from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import get_name_ids
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
class TestGetNameIds(unittest.TestCase):
"""
Test for parsing the ast.Name list from the ast.Nodes
"""
def setUp(self):
self.source = """
def test_fn(x):
return x+1
"""
self.all_name_ids = {'x': [gast.Param(), gast.Load()]}
def test_get_name_ids(self):
source = textwrap.dedent(self.source)
root = gast.parse(source)
all_name_ids = get_name_ids([root])
self.assertDictEqual(self.transfer_dict(self.all_name_ids),
self.transfer_dict(all_name_ids))
def transfer_dict(self, name_ids_dict):
new_dict = {}
for name, ctxs in name_ids_dict.items():
new_dict[name] = [type(ctx) for ctx in ctxs]
return new_dict
class TestGetNameIds2(TestGetNameIds):
def setUp(self):
self.source = """
def test_fn(x, y):
a = 1
x = y + a
if x > y:
z = x * x
z = z + a
else:
z = y * y
return z
"""
self.all_name_ids = {
'x':
[gast.Param(),
gast.Store(),
gast.Load(),
gast.Load(),
gast.Load()],
'a': [gast.Store(), gast.Load(),
gast.Load()],
'y': [
gast.Param(),
gast.Load(),
gast.Load(),
gast.Load(),
gast.Load(),
],
'z': [
gast.Store(),
gast.Load(),
gast.Store(),
gast.Store(),
gast.Load(),
]
}
class TestGetNameIds3(TestGetNameIds):
def setUp(self):
self.source = """
def test_fn(x, y):
z = 1
if x > y:
z = x * x
z = z + y
return z
"""
self.all_name_ids = {
'x': [
gast.Param(),
gast.Load(),
gast.Load(),
gast.Load(),
],
'y': [
gast.Param(),
gast.Load(),
gast.Load(),
],
'z': [
gast.Store(),
gast.Store(),
gast.Load(),
gast.Store(),
gast.Load(),
]
}
class TestIsControlFlowIf(unittest.TestCase):
def check_false_case(self, code):
code = textwrap.dedent(code)
node = gast.parse(code)
node_test = node.body[0].value
self.assertFalse(is_control_flow_to_transform(node_test))
def test_expr(self):
# node is not ast.Compare
self.check_false_case("a+b")
def test_expr2(self):
# x is a Tensor.
node = gast.parse("a + x.numpy()")
node_test = node.body[0].value
self.assertTrue(is_control_flow_to_transform(node_test))
def test_is_None(self):
self.check_false_case("x is None")
def test_is_None2(self):
self.check_false_case("fluid.layers.sum(x) is None")
def test_is_None3(self):
self.check_false_case("fluid.layers.sum(x).numpy() != None")
def test_is_None4(self):
node = gast.parse("fluid.layers.sum(x) and 2>1")
node_test = node.body[0].value
self.assertTrue(is_control_flow_to_transform(node_test))
def test_if(self):
node = gast.parse("x.numpy()[1] > 1")
node_test = node.body[0].value
self.assertTrue(is_control_flow_to_transform(node_test))
def test_if_with_and(self):
node = gast.parse("x and 1 < x.numpy()[1]")
node_test = node.body[0].value
self.assertTrue(is_control_flow_to_transform(node_test))
def test_if_with_or(self):
node = gast.parse("1 < fluid.layers.sum(x).numpy()[2] or x+y < 0")
node_test = node.body[0].value
self.assertTrue(is_control_flow_to_transform(node_test))
def test_shape(self):
code = """
def foo(x):
batch_size = fluid.layers.shape(x)
if batch_size[0] > 16:
x = x + 1
return x
"""
code = textwrap.dedent(code)
node = gast.parse(code)
static_analysis_visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[1].test
self.assertTrue(
is_control_flow_to_transform(test_node, static_analysis_visitor))
def test_shape_with_andOr(self):
code = """
def foo(x):
batch_size = fluid.layers.shape(x)
if x is not None and batch_size[0] > 16 or 2 > 1:
x = x + 1
return x
"""
code = textwrap.dedent(code)
node = gast.parse(code)
static_analysis_visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[1].test
self.assertTrue(
is_control_flow_to_transform(test_node, static_analysis_visitor))
def test_paddle_api(self):
code = """
def foo(x):
if fluid.layers.shape(x)[0] > 16:
x = x + 1
return x
"""
code = textwrap.dedent(code)
node = gast.parse(code)
static_analysis_visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[0].test
self.assertTrue(
is_control_flow_to_transform(test_node, static_analysis_visitor))
def test_paddle_api_with_andOr(self):
code_or = """
def foo(x):
if 2 > 1 and fluid.layers.shape(x)[0] > 16 or x is not None :
x = x + 1
return x
"""
code_and = """
def foo(x):
if 2 > 1 and fluid.layers.shape(x)[0] > 16 and x is not None :
x = x + 1
return x
"""
for code in [code_or, code_and]:
code = textwrap.dedent(code)
node = gast.parse(code)
static_analysis_visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[0].test
self.assertTrue(
is_control_flow_to_transform(test_node,
static_analysis_visitor))
def test_with_node_var_type_map(self):
node = gast.parse("x > 1")
node_test = node.body[0].value
# if x is a Tensor
var_name_to_type = {"x": {NodeVarType.TENSOR}}
self.assertTrue(
is_control_flow_to_transform(node_test,
var_name_to_type=var_name_to_type))
# if x is not a Tensor
var_name_to_type = {"x": {NodeVarType.NUMPY_NDARRAY}}
self.assertFalse(
is_control_flow_to_transform(node_test,
var_name_to_type=var_name_to_type))
def test_raise_error(self):
node = "a + b"
with self.assertRaises(Exception) as e:
self.assertRaises(TypeError, is_control_flow_to_transform(node))
self.assertTrue(
"The type of input node must be gast.AST" in str(e.exception))
if __name__ == '__main__':
unittest.main()
......@@ -66,11 +66,7 @@ def get_source_code(func):
class StaticCode1():
def dyfunc_with_if_else(x_v, label=None):
__return_value_init_0 = paddle.full(shape=[1],
dtype='float64',
fill_value=0.0,
name='__return_value_init_0')
__return_value_0 = __return_value_init_0
__return_value_0 = None
def get_args_0():
nonlocal x_v
......@@ -83,51 +79,51 @@ class StaticCode1():
def true_fn_0():
nonlocal x_v
x_v = x_v - 1
return x_v
return
def false_fn_0():
nonlocal x_v
x_v = x_v + 1
return x_v
return
_jst.IfElse(
paddle.mean(x_v)[0] > 5, true_fn_0, false_fn_0, get_args_0,
set_args_0, ('x_v', ))
__return_0 = _jst.UndefinedVar('__return_0')
__return_1 = _jst.UndefinedVar('__return_1')
loss = _jst.UndefinedVar('loss')
def get_args_1():
nonlocal __return_value_0, label, x_v
return __return_value_0, label, x_v,
nonlocal __return_0, __return_1, __return_value_0, loss
return __return_0, __return_1, __return_value_0, loss
def set_args_1(__args):
nonlocal __return_value_0, label, x_v
__return_value_0, label, x_v, = __args
nonlocal __return_0, __return_1, __return_value_0, loss
__return_0, __return_1, __return_value_0, loss = __args
def true_fn_1():
nonlocal __return_value_0, label, x_v
nonlocal __return_0, __return_1, __return_value_0, loss
loss = fluid.layers.cross_entropy(x_v, label)
__return_0 = _jst.create_bool_as_type(label is not None, True)
__return_value_0 = loss
return __return_value_0
return
def false_fn_1():
nonlocal __return_value_0, label, x_v
nonlocal __return_0, __return_1, __return_value_0, loss
__return_1 = _jst.create_bool_as_type(label is not None, True)
__return_value_0 = x_v
return __return_value_0
return
_jst.IfElse(label is not None, true_fn_1, false_fn_1, get_args_1,
set_args_1, ('__return_value_0', ))
set_args_1,
('__return_0', '__return_1', '__return_value_0', 'loss'))
return __return_value_0
class StaticCode2():
# TODO: Transform return statement
def dyfunc_with_if_else(x_v, label=None):
__return_value_init_1 = paddle.full(shape=[1],
dtype='float64',
fill_value=0.0,
name='__return_value_init_1')
__return_value_1 = __return_value_init_1
__return_value_1 = None
def get_args_2():
nonlocal x_v
......@@ -140,40 +136,44 @@ class StaticCode2():
def true_fn_2():
nonlocal x_v
x_v = x_v - 1
return x_v
return
def false_fn_2():
nonlocal x_v
x_v = x_v + 1
return x_v
return
_jst.IfElse(
paddle.mean(x_v)[0] > 5, true_fn_2, false_fn_2, get_args_2,
set_args_2, ('x_v', ))
__return_2 = _jst.UndefinedVar('__return_2')
__return_3 = _jst.UndefinedVar('__return_3')
loss = _jst.UndefinedVar('loss')
def get_args_3():
nonlocal __return_value_1, label, x_v
return __return_value_1, label, x_v,
nonlocal __return_2, __return_3, __return_value_1, loss
return __return_2, __return_3, __return_value_1, loss
def set_args_3(__args):
nonlocal __return_value_1, label, x_v
__return_value_1, label, x_v, = __args
nonlocal __return_2, __return_3, __return_value_1, loss
__return_2, __return_3, __return_value_1, loss = __args
def true_fn_3():
nonlocal __return_value_1, label, x_v
nonlocal __return_2, __return_3, __return_value_1, loss
loss = fluid.layers.cross_entropy(x_v, label)
__return_2 = _jst.create_bool_as_type(label is not None, True)
__return_value_1 = loss
return __return_value_1
return
def false_fn_3():
nonlocal __return_value_1, label, x_v
nonlocal __return_2, __return_3, __return_value_1, loss
__return_3 = _jst.create_bool_as_type(label is not None, True)
__return_value_1 = x_v
return __return_value_1
return
_jst.IfElse(label is not None, true_fn_3, false_fn_3, get_args_3,
set_args_3, ('__return_value_1', ))
set_args_3,
('__return_2', '__return_3', '__return_value_1', 'loss'))
return __return_value_1
......@@ -195,6 +195,7 @@ class TestDygraphToStaticCode(unittest.TestCase):
def test_decorator(self):
program_translator = ProgramTranslator()
code = program_translator.get_code(dyfunc_with_if_else)
#print(code)
answer = get_source_code(StaticCode1.dyfunc_with_if_else)
self.assertEqual(
answer.replace('\n', '').replace(' ', ''),
......@@ -380,13 +381,13 @@ class TestIfElseEarlyReturn(unittest.TestCase):
answer = np.zeros([2, 2]) + 1
static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return1)
out = static_func()
self.assertTrue(np.allclose(answer, out.numpy()))
self.assertTrue(np.allclose(answer, out[0].numpy()))
def test_ifelse_early_return2(self):
answer = np.zeros([2, 2]) + 3
static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return2)
out = static_func()
self.assertTrue(np.allclose(answer, out.numpy()))
self.assertTrue(np.allclose(answer, out[0].numpy()))
class TestRemoveCommentInDy2St(unittest.TestCase):
......
......@@ -19,6 +19,7 @@ import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.jit import to_static
from paddle.jit import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static.utils import Dygraph2StaticException
import unittest
import numpy as np
......@@ -245,7 +246,7 @@ class TestReturnBase(unittest.TestCase):
return res.numpy()
return res
def test_transformed_static_result(self):
def _test_value_impl(self):
dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True)
if isinstance(dygraph_res, tuple):
......@@ -264,6 +265,13 @@ class TestReturnBase(unittest.TestCase):
else:
self.assertEqual(dygraph_res, static_res)
def test_transformed_static_result(self):
if hasattr(self, "error"):
with self.assertRaisesRegex(Dygraph2StaticException, self.error):
self._test_value_impl()
else:
self._test_value_impl()
class TestInsideFuncBase(TestReturnBase):
......@@ -312,12 +320,14 @@ class TestReturnDifferentLengthIfBody(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_different_length_if_body
self.error = "Your if/else have different number of return value."
class TestReturnDifferentLengthElse(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_different_length_else
self.error = "Your if/else have different number of return value."
class TestNoReturn(TestReturnBase):
......@@ -330,12 +340,14 @@ class TestReturnNone(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_none
self.error = "Your if/else have different number of return value."
class TestReturnNoVariable(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_no_variable
self.error = "Your if/else have different number of return value."
class TestReturnListOneValue(TestReturnBase):
......
......@@ -21,6 +21,7 @@ import paddle.fluid.layers as layers
from paddle.fluid.dygraph import Embedding, Layer, LayerNorm, Linear, to_variable
from paddle.fluid.dygraph.jit import dygraph_to_static_func
from paddle.fluid.layers.utils import map_structure
from paddle.fluid.layers.tensor import range as pd_range
def position_encoding_init(n_position, d_pos_vec):
......@@ -633,7 +634,7 @@ class Transformer(Layer):
value=0),
} for i in range(self.n_layer)]
for i in range(max_len):
for i in pd_range(0, max_len, 1, dtype="int32"):
trg_pos = layers.fill_constant(shape=trg_word.shape,
dtype="int64",
value=i)
......
......@@ -25,6 +25,7 @@ import paddle.fluid.framework as framework
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Program, program_guard
from simple_nets import simple_fc_net_with_inputs, batchnorm_fc_with_inputs
import paddle
np.random.seed(123)
......@@ -41,6 +42,8 @@ class TestCondInputOutput(unittest.TestCase):
return -1
"""
paddle.enable_static()
def true_func():
return layers.fill_constant(shape=[2, 3], dtype='int32', value=2)
......@@ -73,6 +76,8 @@ class TestCondInputOutput(unittest.TestCase):
return 3, 2
"""
paddle.enable_static()
def true_func():
return layers.fill_constant(shape=[1, 2], dtype='int32',
value=1), layers.fill_constant(
......@@ -114,6 +119,8 @@ class TestCondInputOutput(unittest.TestCase):
a = a - (i - 1)
"""
paddle.enable_static()
def true_func(a, i):
a = a * (i + 1)
return a
......@@ -152,6 +159,8 @@ class TestCondInputOutput(unittest.TestCase):
pass
"""
paddle.enable_static()
def true_func():
pass
......@@ -181,6 +190,8 @@ class TestCondInputOutput(unittest.TestCase):
test returning different number of tensors cannot merge into output
"""
paddle.enable_static()
def func_return_none():
return None
......@@ -223,10 +234,11 @@ class TestCondInputOutput(unittest.TestCase):
out = layers.cond(pred, func_return_one_tensor,
func_return_two_tensors)
self.assertTrue(
"Incompatible return values of true_fn and false_fn in cond" in
str(e.exception))
"true fn returns 1 vars, but false fn returns 2 vars, which is not equals"
in str(e.exception))
def test_extremely_simple_net_with_op_in_condition(self):
paddle.enable_static()
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
......@@ -272,6 +284,8 @@ class TestCondNestedControlFlow(unittest.TestCase):
return a / a
"""
paddle.enable_static()
def less_than_branch(i, a):
return layers.cond(i >= 3.0, lambda: layers.elementwise_add(a, a),
lambda: layers.elementwise_sub(a, a))
......@@ -308,6 +322,7 @@ class TestCondNestedControlFlow(unittest.TestCase):
self.assertEqual(ret[1][0], expected_a_grad)
def test_cond_op_in_condition(self):
paddle.enable_static()
main_program = fluid.Program()
startup_program = fluid.Program()
......@@ -344,6 +359,7 @@ class TestCondBackward(unittest.TestCase):
"""
Helper function that compares calculated backward value is close to dy/dx
"""
paddle.enable_static()
main_program = Program()
main_program.random_seed = 123
startup_program = Program()
......@@ -474,6 +490,8 @@ class TestCondBackward(unittest.TestCase):
def test_cond_backward(self):
paddle.enable_static()
def cond_func(i, img, label):
predicate = ((i % 2) == 0)
return layers.cond(
......@@ -494,6 +512,7 @@ class TestCondBackward(unittest.TestCase):
use_parallel_exe)
def test_half_nested_cond_backward(self):
paddle.enable_static()
def branch(i, img, label):
return layers.cond(
......@@ -530,6 +549,7 @@ class TestCondBackward(unittest.TestCase):
use_parallel_exe)
def test_nested_cond_backward(self):
paddle.enable_static()
def branch(i, img, label, mod_two):
if mod_two:
......@@ -560,6 +580,7 @@ class TestCondBackward(unittest.TestCase):
class TestCondWithError(unittest.TestCase):
def test_input_type_error(self):
paddle.enable_static()
main_program = framework.Program()
startup_program = framework.Program()
with framework.program_guard(main_program, startup_program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册