未验证 提交 3be7f495 编写于 作者: X xiongkun 提交者: GitHub

[ Dy2Static ] transfer list into tensor array at runtime. (#45594)

* 1. make list transformer into jit form.
2. fix some bugs in tensor_array, such as append.
3. enhance the function analysis visitor to recognize push/pop.
4. add setter/getter helper to deal with 2+ name sets.

* fix ci errors:
1. add to_tensor_array logic in convert_cond
2. fix IfExpr error.
3. fix erros while return_names or push_pop_names is None
4. fix slice error in a[i]=1 where a is tensor_array
5. add pop interface in Variable
上级 31b92305
......@@ -75,7 +75,7 @@ static void CheckOutputVarStatus(const paddle::framework::Variable &src_var,
PADDLE_ENFORCE_EQ(dst_tensor.defined(),
true,
paddle::platform::errors::InvalidArgument(
"dst_tensor shall be defined."));
"dst_tensor `%s` shall be defined.", name));
if (dst_tensor.is_dense_tensor()) {
auto &src_tensor = src_var.Get<phi::DenseTensor>();
......
......@@ -93,7 +93,7 @@ class DygraphToStaticAst(BaseTransformer):
EarlyReturnTransformer,
BasicApiTransformer, # Basic Api
TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor)
ListTransformer, # List used in control flow
#ListTransformer, # List used in control flow
BreakContinueTransformer, # break/continue in loops
ReturnTransformer, # return in functions
LogicalTransformer, # logical and/or/not
......
......@@ -25,6 +25,7 @@ from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, lo
from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_VAR_NAME
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, Dygraph2StaticException
from paddle.fluid.dygraph.dygraph_to_static.utils import GetterSetterHelper
from paddle.fluid.layers.utils import copy_mutable_vars
......@@ -74,7 +75,12 @@ def _unpack_by_structure_paddle(target, structure):
return ret
def convert_while_loop(cond, body, getter, setter):
def convert_while_loop(cond,
body,
getter,
setter,
return_name_ids=None,
push_pop_names=None):
"""
A function representation of a Python ``while`` statement.
......@@ -91,21 +97,41 @@ def convert_while_loop(cond, body, getter, setter):
# If loop_vars is changed during cond callable, then it causes bug, but current logical_and/logical_not/... doesn't change the loop_vars.
pred = cond()
if isinstance(pred, Variable):
_run_paddle_while(cond, body, getter, setter)
_run_paddle_while(cond, body, getter, setter, return_name_ids,
push_pop_names)
else:
_run_py_while(cond, body, getter, setter)
def _run_paddle_while(cond, body, getter, setter):
def _convert_tensor_arrray_if_necessary(setterhelper, push_pop_names):
push_pop_vars = setterhelper.get(push_pop_names)
if push_pop_vars is None:
return
def maybe_to_tensor_array(v):
if isinstance(v, list):
return create_array("float32", initialized_list=v)
else:
return v
setterhelper.set(push_pop_names,
[maybe_to_tensor_array(v) for v in push_pop_vars])
def _run_paddle_while(cond, body, getter, setter, return_name_ids,
push_pop_names):
# NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Tensors.
helper = GetterSetterHelper(getter, setter, return_name_ids, push_pop_names)
_convert_tensor_arrray_if_necessary(helper, push_pop_names)
def new_body_fn(*args):
""" wrap the body() and add return value for `while_loop`
the args may be differ from getter().
"""
mutable_loop_vars = args
setter(mutable_loop_vars)
helper.set(return_name_ids, mutable_loop_vars)
body()
return getter()
return helper.get(return_name_ids)
def new_cond_fn(*args):
""" cond is a zero-args function, which is not
......@@ -116,12 +142,13 @@ def _run_paddle_while(cond, body, getter, setter):
# UndefinedVar will become data layer not check variable with value=NO_VALUE_MAGIC.
loop_vars = [
to_static_variable(var) if not isinstance(var, UndefinedVar) else var
for var in getter()
for var in helper.get(return_name_ids)
]
setter(loop_vars) # change the non-local var to variable
helper.set(return_name_ids,
loop_vars) # change the non-local var to variable
# variable maybe modified to inner var. change it into
loop_vars = control_flow.while_loop(new_cond_fn, new_body_fn, loop_vars)
setter(loop_vars) # change back to loop_vars
helper.set(return_name_ids, loop_vars)
return loop_vars
......@@ -263,8 +290,13 @@ def _run_py_logical_not(x):
return not x
def convert_ifelse(pred, true_fn, false_fn, get_args, set_args,
return_name_ids):
def convert_ifelse(pred,
true_fn,
false_fn,
get_args,
set_args,
return_name_ids,
push_pop_names=None):
"""
A function representation of a Python ``if/else`` statement.
......@@ -274,6 +306,7 @@ def convert_ifelse(pred, true_fn, false_fn, get_args, set_args,
false_fn(callable): A callable to be performed if ``pred`` is false.
get_args(callable): Get all arguments that needed in true_fn and false_fn.
set_args(callable): Update arguments that modified in trure_fn and false_fn.
return_name_ids(list[string]): the returned names.
Returns:
``true_fn()`` if the predicate ``pred`` is true else ``false_fn()`` .
......@@ -281,7 +314,7 @@ def convert_ifelse(pred, true_fn, false_fn, get_args, set_args,
"""
if isinstance(pred, Variable):
out = _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,
return_name_ids)
return_name_ids, push_pop_names)
else:
out = _run_py_ifelse(pred, true_fn, false_fn, get_args, set_args,
return_name_ids)
......@@ -290,27 +323,30 @@ def convert_ifelse(pred, true_fn, false_fn, get_args, set_args,
def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,
return_name_ids):
return_name_ids, push_pop_names):
"""
Paddle cond API will evaluate both ture_fn and false_fn codes.
"""
helper = GetterSetterHelper(get_args, set_args, return_name_ids,
push_pop_names)
_convert_tensor_arrray_if_necessary(helper, push_pop_names)
pred = cast_bool_if_necessary(pred)
init_args = get_args()
init_args = helper.get(return_name_ids)
def new_true_fn():
#init args may contain mutable python container like [var, 2], we copy then like in while_loop
set_args(copy_mutable_vars(init_args))
helper.set(return_name_ids, copy_mutable_vars(init_args))
ret = true_fn()
# IfExpr will return a non-None return value, so we just return ret.
# We assume normal return has no return value.
if ret is None: return get_args()
if ret is None: return helper.get(return_name_ids)
else: return ret
def new_false_fn():
#init args may contain mutable python container like [var, 2], we copy then like in while_loop
set_args(copy_mutable_vars(init_args))
helper.set(return_name_ids, copy_mutable_vars(init_args))
ret = false_fn()
if ret is None: return get_args()
if ret is None: return helper.get(return_name_ids)
else: return ret
try:
......@@ -327,6 +363,8 @@ def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,
"Your if/else have different number of return value. TODO: add link to modifty. {}"
.format(str(e)))
raise e
get_args = lambda: helper.get(return_name_ids)
set_args = lambda vs: helper.set(return_name_ids, vs)
return _recover_args_state(cond_outs, get_args, set_args, return_name_ids)
......
......@@ -35,6 +35,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_no
from paddle.fluid.dygraph.dygraph_to_static.utils import create_get_args_node, create_set_args_node
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_INDEX_PREFIX, FOR_ITER_TUPLE_PREFIX, FOR_ITER_TUPLE_INDEX_PREFIX, FOR_ITER_VAR_LEN_PREFIX, FOR_ITER_VAR_NAME_PREFIX, FOR_ITER_ZIP_TO_LIST_PREFIX, FOR_ITER_TARGET_PREFIX, FOR_ITER_ITERATOR_PREFIX
from paddle.fluid.dygraph.dygraph_to_static.utils import GetterSetterHelper, create_name_str
TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn'
......@@ -65,16 +66,16 @@ class IfElseTransformer(BaseTransformer):
def visit_If(self, node):
self.generic_visit(node)
new_vars_stmts, true_func_node, false_func_node, get_args_node, set_args_node, return_name_ids = transform_if_else(
true_func_node, false_func_node, get_args_node, set_args_node, return_name_ids, push_pop_ids = transform_if_else(
node, self.root)
new_node = create_convert_ifelse_node(return_name_ids, node.test,
true_func_node, false_func_node,
get_args_node, set_args_node)
new_node = create_convert_ifelse_node(return_name_ids, push_pop_ids,
node.test, true_func_node,
false_func_node, get_args_node,
set_args_node)
return new_vars_stmts + [
get_args_node, set_args_node, true_func_node, false_func_node
] + [new_node]
return [get_args_node, set_args_node, true_func_node, false_func_node
] + [new_node]
def visit_Call(self, node):
# Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]`
......@@ -91,7 +92,7 @@ class IfElseTransformer(BaseTransformer):
"""
self.generic_visit(node)
new_node = create_convert_ifelse_node(None, node.test, node.body,
new_node = create_convert_ifelse_node(None, None, node.test, node.body,
node.orelse, None, None, True)
# Note: A blank line will be added separately if transform gast.Expr
# into source code. Using gast.Expr.value instead to avoid syntax error
......@@ -306,16 +307,7 @@ def transform_if_else(node, root):
# TODO(liym27): Consider variable like `self.a` modified in if/else node.
return_name_ids = sorted(list(node.pd_scope.modified_vars()))
# NOTE: Python can create variable only in if body or only in else body, and use it out of if/else.
# E.g.
#
# if x > 5:
# a = 10
# print(a)
#
# Create static variable for those variables
create_new_vars_in_parent_stmts = []
push_pop_ids = sorted(list(node.pd_scope.variadic_length_vars()))
nonlocal_names = list(return_name_ids)
nonlocal_names.sort()
# NOTE: All var in return_name_ids should be in nonlocal_names.
......@@ -359,13 +351,15 @@ def transform_if_else(node, root):
input_args=empty_arg_node,
return_name_ids=[])
get_args_node = create_get_args_node(nonlocal_names)
set_args_node = create_set_args_node(nonlocal_names)
helper = GetterSetterHelper(None, None, nonlocal_names, push_pop_ids)
get_args_node = create_get_args_node(helper.union())
set_args_node = create_set_args_node(helper.union())
return create_new_vars_in_parent_stmts, true_func_node, false_func_node, get_args_node, set_args_node, return_name_ids
return true_func_node, false_func_node, get_args_node, set_args_node, return_name_ids, push_pop_ids
def create_convert_ifelse_node(return_name_ids,
push_pop_ids,
pred,
true_func,
false_func,
......@@ -377,17 +371,6 @@ def create_convert_ifelse_node(return_name_ids,
pred, true_fn, false_fn, get_args, set_args, return_name_ids)`
to replace original `python if/else` statement.
"""
def create_name_str(name_ids):
"""
Return "('x', 'y')" for [x, y]
"""
if not name_ids:
return 'None'
names_str = ["'%s'" % name for name in name_ids]
return "(%s, )" % ','.join(names_str)
if is_if_expr:
true_func_source = "lambda : {}".format(ast_to_source_code(true_func))
false_func_source = "lambda : {}".format(ast_to_source_code(false_func))
......@@ -397,7 +380,7 @@ def create_convert_ifelse_node(return_name_ids,
convert_ifelse_layer = gast.parse(
'_jst.IfElse('
'{pred}, {true_fn}, {false_fn}, {get_args}, {set_args}, {return_name_ids})'
'{pred}, {true_fn}, {false_fn}, {get_args}, {set_args}, {return_name_ids}, push_pop_names={push_pop_ids})'
.format(
pred=ast_to_source_code(pred),
true_fn=true_func_source,
......@@ -406,6 +389,7 @@ def create_convert_ifelse_node(return_name_ids,
'lambda: None', #TODO: better way to deal with this
set_args=set_args_func.name
if not is_if_expr else 'lambda args: None',
return_name_ids=create_name_str(return_name_ids))).body[0]
return_name_ids=create_name_str(return_name_ids),
push_pop_ids=create_name_str(push_pop_ids))).body[0]
return convert_ifelse_layer
......@@ -33,6 +33,7 @@ from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransfor
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import RenameTransformer
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import ForLoopTuplePreTransformer
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import GetterSetterHelper, create_name_str
__all__ = ['LoopTransformer', 'NameVisitor']
......@@ -43,8 +44,8 @@ FOR_CONDITION_PREFIX = 'for_loop_condition'
FOR_BODY_PREFIX = 'for_loop_body'
def create_while_nodes(condition_name, body_name, loop_var_names, getter_name,
setter_name):
def create_while_nodes(condition_name, body_name, loop_var_names,
push_pop_names, getter_name, setter_name):
"""
Returns a list of gast.Node which represents the calling of Paddle
controlflow while_loop.
......@@ -84,9 +85,9 @@ def create_while_nodes(condition_name, body_name, loop_var_names, getter_name,
assign_loop_var_names.append(name)
while_func_name = "_jst.While"
while_node_str = "{}({}, {}, {}, {})".format(while_func_name,
condition_name, body_name,
getter_name, setter_name)
while_node_str = "{}({}, {}, {}, {}, return_name_ids={}, push_pop_names={})".format(
while_func_name, condition_name, body_name, getter_name, setter_name,
create_name_str(loop_var_names), create_name_str(push_pop_names))
while_node = gast.parse(while_node_str).body[0]
ret = [while_node]
......@@ -539,6 +540,7 @@ class LoopTransformer(BaseTransformer):
# 2. get original loop vars
loop_var_names, create_var_names = node.pd_scope.modified_vars(
), node.pd_scope.created_vars()
push_pop_names = list(node.pd_scope.variadic_length_vars())
# TODO: Remove the bunch of code? We have the unique format `for A in B:`
# NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases,
# we need append new loop var & remove useless loop var
......@@ -607,12 +609,13 @@ class LoopTransformer(BaseTransformer):
type_comment=None)
new_stmts.append(body_func_node)
get_args_node = create_get_args_node(nonlocal_names)
set_args_node = create_set_args_node(nonlocal_names)
helper = GetterSetterHelper(None, None, nonlocal_names, push_pop_names)
get_args_node = create_get_args_node(helper.union())
set_args_node = create_set_args_node(helper.union())
# 7. create & append while loop node
while_loop_nodes = create_while_nodes(condition_func_node.name,
body_func_node.name,
nonlocal_names,
nonlocal_names, push_pop_names,
get_args_node.name,
set_args_node.name)
new_stmts.extend([get_args_node, set_args_node])
......@@ -623,6 +626,7 @@ class LoopTransformer(BaseTransformer):
def get_while_stmt_nodes(self, node):
loop_var_names, create_var_names = node.pd_scope.modified_vars(
), node.pd_scope.created_vars()
push_pop_names = list(node.pd_scope.variadic_length_vars())
new_stmts = []
# create non-local statement for body and cond.
......@@ -675,12 +679,14 @@ class LoopTransformer(BaseTransformer):
returns=None,
type_comment=None)
new_stmts.append(body_func_node)
get_args_node = create_get_args_node(nonlocal_names)
set_args_node = create_set_args_node(nonlocal_names)
helper = GetterSetterHelper(None, None, nonlocal_names, push_pop_names)
get_args_node = create_get_args_node(helper.union())
set_args_node = create_set_args_node(helper.union())
while_loop_nodes = create_while_nodes(condition_func_node.name,
body_func_node.name,
nonlocal_names,
nonlocal_names, push_pop_names,
get_args_node.name,
set_args_node.name)
new_stmts.extend([get_args_node, set_args_node])
......
......@@ -33,6 +33,8 @@ from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid import core
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers import assign
import collections
from functools import reduce
# Note(Aurelius): Do not forget the dot `.` to distinguish other
# module such as paddlenlp.
......@@ -1020,8 +1022,9 @@ class NameScope:
self.args = set()
self.father = None # point to the nearest function name scope.
self.w_vars = set() # all qualified + normal names been stored
self.created = set(
) # useful for control flow compatibility. may be remove later
self.created = set() # useful for control flow compatibility
# may be remove later.
self.push_pop_vars = set() # we call push and pop in the vars
def set_father(self, father):
self.father = father
......@@ -1040,6 +1043,9 @@ class NameScope:
# may be globals / non-locals / args / qualified names and created_vars
return self.w_vars
def variadic_length_vars(self):
return self.push_pop_vars
def control_flow_vars(self):
valid_names = self.w_vars
tmp = self.father.global_vars & valid_names,
......@@ -1053,17 +1059,25 @@ class NameScope:
self.nonlocals |= name_scope.nonlocals
self.args |= name_scope.args
self.w_vars |= name_scope.w_vars
self.push_pop_vars |= name_scope.push_pop_vars
class FunctionNameLivenessAnalysis(gast.NodeVisitor):
""" analyze the liveness of a function.
every variables stored in this scope will be collected,
in addition with global/nonlocal information.
in addition with global/nonlocal information and
push_pop information.
1. global variable is stored in node.var_globals.
2. nonlocal variable is stored in node.var_nonlocals.
3. arguments is stored in node.var_args.
4. if a variable's push and pop attribute is called,
it will be collected in push_pop_vars. They are
used for transformation to tensor_array.
NOTE: push_pop_vars **may not** in w_vars.
a.push(0) don't modify the variable a, but the content
of a.
For example:
......@@ -1073,8 +1087,12 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
nonlocal x,y
print(a)
i = k
b = []
c = [1,2,3]
for m in range(10):
q = 12
b.push(1)
c.pop()
After this visitor we have:
# node is the FunctionDef node with name: "func"
......@@ -1082,7 +1100,8 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
globals = ['i', 'j'],
nonlocals = ['x', 'y'],
args = ['args', 'kargs'],
wr_vars = ['a', 'i', 'q', 'm']
wr_vars = ['a', 'i', 'q', 'm', 'c', 'b']
push_pop_vars = ['b', 'c']
)
"""
......@@ -1137,7 +1156,7 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
self._get_argument_names(node))
def post_func():
""" NOTE: why we need merge w_vars here ?
""" NOTE: why we need merge w_vars and push_pop_vars here ?
because we do ifelse_transformer after loop_transformer. Loops will changed into functioons. but we know this function will be called in if. so we add w_vars to father function scope.
"""
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import WHILE_CONDITION_PREFIX, WHILE_BODY_PREFIX, FOR_CONDITION_PREFIX, FOR_BODY_PREFIX
......@@ -1155,6 +1174,8 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
if self._father_name_scope() and is_control_flow_def_node():
self._father_name_scope().w_vars |= self._current_name_scope(
).w_vars
self._father_name_scope(
).push_pop_vars |= self._current_name_scope().push_pop_vars
self._visit_scope_node(node, pre_func, post_func)
......@@ -1210,6 +1231,17 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
name = ast_to_source_code(node).strip()
self._current_name_scope().w_vars.add(name)
def visit_Call(self, node):
self.generic_visit(node)
if not isinstance(node.func, gast.Attribute):
return
variadic_length_method = ['append', 'pop']
if node.func.attr not in variadic_length_method:
return
# we don't treat push and pop as a write operator. such as a[i]=10 is not modify a.
name = ast_to_source_code(node.func.value).strip()
self._current_name_scope().push_pop_vars.add(name)
def _get_argument_names(self, node):
""" get all arguments name in the functiondef node.
this node is local to the function and shouldn't
......@@ -1315,3 +1347,57 @@ def create_nonlocal_stmt_nodes(names):
return []
func_code = "nonlocal {}".format(','.join(names))
return [gast.parse(func_code).body[0]]
class GetterSetterHelper:
""" we have two classes of names in setter and getter function:
w_vars(loop_vars) + push_pop_vars
To simplify the setter logic in convert_while and convert_cond,
we extract the helper class here.
"""
def __init__(self, getter_func, setter_func, *name_lists):
name_lists = map(lambda x: [] if x is None else x, name_lists)
name_sets = map(lambda x: set(x), name_lists)
self._union = list(reduce(lambda x, y: x | y, name_sets, set()))
self._union.sort()
self.getter = getter_func
self.setter = setter_func
self.name2id = {name: idx for idx, name in enumerate(self._union)}
def union(self):
return self._union
def get(self, names):
if names is None: names = []
vars = self.getter()
if vars is None: return tuple()
for n in names:
assert n in self.name2id, "the name `{}` not in name union set`{}`.".format(
n, self.name2id.keys())
return tuple(map(lambda n: vars[self.name2id[n]], names))
def set(self, names, values):
if names is None: names = []
if values is None: values = []
vars = self.getter()
if vars is None: return
for n in names:
assert n in self.name2id, "the name `{}` not in name union set`{}`.".format(
n, self.name2id.keys())
vars = list(vars)
indices = list(map(lambda n: self.name2id[n], names))
for i, v in zip(indices, values):
vars[i] = v
self.setter(vars)
def create_name_str(name_ids):
"""
Return "('x', 'y')" for [x, y]
"""
if not name_ids:
return 'None'
names_str = ["'%s'" % name for name in name_ids]
return "(%s, )" % ','.join(names_str)
......@@ -21,6 +21,7 @@ from .. import core
from ..framework import Variable, unique_name, static_only
from .layer_function_generator import OpProtoHolder
from .control_flow import array_write, array_length
from paddle.fluid.dygraph.base import in_declarative_mode
_supported_int_dtype_ = [
core.VarDesc.VarType.BOOL,
......@@ -211,16 +212,35 @@ def monkey_patch_variable():
"""
if not isinstance(var, Variable):
raise TypeError(
"Required input var should be Variable, but received {}".format(
type(var)))
if in_declarative_mode():
""" in dy2static mode, x may be tensorable values such as int, float, np.array
"""
from paddle.tensor.creation import to_tensor
var = to_tensor(var)
else:
raise TypeError(
"Required input var should be Variable, but received {}".
format(type(var)))
if self.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY:
raise TypeError(
"Only Variable with VarType.LOD_TENSOR_ARRAY support `append` method, but received type: {}"
.format(self.type))
array_write(x=var, i=array_length(self), array=self)
@static_only
def pop(self, *args):
"""
**Notes**:
**The type variable must be LoD Tensor Array.
"""
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import _run_paddle_pop
if self.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY:
raise TypeError(
"Only Variable with VarType.LOD_TENSOR_ARRAY support `append` method, but received type: {}"
.format(self.type))
return _run_paddle_pop(self, *args)
def _scalar_op_(var, scale, bias):
block = current_block(var)
out = create_new_tmp_var(block, var.dtype)
......@@ -389,6 +409,7 @@ def monkey_patch_variable():
('cpu', cpu),
('cuda', cuda),
('append', append),
('pop', pop),
('dim', lambda x: len(x.shape)),
('ndimension', lambda x: len(x.shape)),
('ndim', _ndim_),
......
......@@ -39,6 +39,19 @@ class JudgeVisitor(gast.NodeVisitor):
self.generic_visit(node)
class JudgePushPopVisitor(gast.NodeVisitor):
def __init__(self, push_pop_vars):
self.pp_var = push_pop_vars
def visit_FunctionDef(self, node):
scope = node.pd_scope
expected = self.pp_var.get(node.name, set())
assert scope.push_pop_vars == expected, "Not Equals in function:{} . expect {} , but get {}".format(
node.name, expected, scope.push_pop_vars)
self.generic_visit(node)
def test_normal_0(x):
def func():
......@@ -88,9 +101,67 @@ def test_nonlocal(x, *args, **kargs):
return x
def test_push_pop_1(x, *args, **kargs):
""" push_pop_vars in main_function is : `l`, `k`
"""
l = []
k = []
for i in range(10):
l.append(i)
k.pop(i)
return l
def test_push_pop_2(x, *args, **kargs):
""" push_pop_vars in main_function is : `k`
"""
l = []
k = []
def func():
l.append(0)
for i in range(10):
k.append(i)
return l, k
def test_push_pop_3(x, *args, **kargs):
""" push_pop_vars in main_function is : `k`
NOTE: One may expect `k` and `l` because l
is nonlocal. Name bind analysis is
not implemented yet.
"""
l = []
k = []
def func():
nonlocal l
l.append(0)
for i in range(10):
k.append(i)
return l, k
def test_push_pop_4(x, *args, **kargs):
""" push_pop_vars in main_function is : `k`
"""
l = []
k = []
for i in range(10):
for j in range(10):
if True:
l.append(j)
else:
k.pop()
return l, k
class TestClosureAnalysis(unittest.TestCase):
def setUp(self):
self.judge_type = "var and w_vars"
self.init_dygraph_func()
def init_dygraph_func(self):
......@@ -132,12 +203,20 @@ class TestClosureAnalysis(unittest.TestCase):
]
def test_main(self):
for mod, ans, func in zip(self.modified_var, self.answer,
self.all_dygraph_funcs):
test_func = inspect.getsource(func)
gast_root = gast.parse(test_func)
name_visitor = FunctionNameLivenessAnalysis(gast_root)
JudgeVisitor(ans, mod).visit(gast_root)
if self.judge_type == 'push_pop_vars':
for push_pop_vars, func in zip(self.push_pop_vars,
self.all_dygraph_funcs):
test_func = inspect.getsource(func)
gast_root = gast.parse(test_func)
name_visitor = FunctionNameLivenessAnalysis(gast_root)
JudgePushPopVisitor(push_pop_vars).visit(gast_root)
else:
for mod, ans, func in zip(self.modified_var, self.answer,
self.all_dygraph_funcs):
test_func = inspect.getsource(func)
gast_root = gast.parse(test_func)
name_visitor = FunctionNameLivenessAnalysis(gast_root)
JudgeVisitor(ans, mod).visit(gast_root)
def TestClosureAnalysis_Attribute_func():
......@@ -158,5 +237,25 @@ class TestClosureAnalysis_Attribute(TestClosureAnalysis):
}]
class TestClosureAnalysis_PushPop(TestClosureAnalysis):
def init_dygraph_func(self):
self.judge_type = "push_pop_vars"
self.all_dygraph_funcs = [
test_push_pop_1, test_push_pop_2, test_push_pop_3, test_push_pop_4
]
self.push_pop_vars = [{
"test_push_pop_1": set({'l', 'k'}),
}, {
"test_push_pop_2": set({'k'}),
"func": set("l"),
}, {
"test_push_pop_3": set({'k'}),
"func": set("l"),
}, {
"test_push_pop_4": set({'k', 'l'}),
}]
if __name__ == '__main__':
unittest.main()
......@@ -254,13 +254,13 @@ class TestListWithoutControlFlow(unittest.TestCase):
dy_res,
rtol=1e-05,
err_msg='dygraph_res is {}\nstatic_res is {}'.format(
stat_res, dy_res))
dy_res, stat_res))
class TestListInIf(TestListWithoutControlFlow):
def init_dygraph_func(self):
self.all_dygraph_funcs = [test_list_append_in_if, test_list_pop_in_if]
self.all_dygraph_funcs = [test_list_append_in_if]
class TestListInWhileLoop(TestListWithoutControlFlow):
......
......@@ -89,9 +89,12 @@ class StaticCode1():
x_v = x_v + 1
return
_jst.IfElse(
paddle.mean(x_v)[0] > 5, true_fn_0, false_fn_0, get_args_0,
set_args_0, ('x_v', ))
_jst.IfElse(paddle.mean(x_v)[0] > 5,
true_fn_0,
false_fn_0,
get_args_0,
set_args_0, ('x_v', ),
push_pop_names=None)
def get_args_1():
nonlocal __return_0, __return_1, __return_value_0, loss
......@@ -114,9 +117,13 @@ class StaticCode1():
__return_value_0 = x_v
return
_jst.IfElse(label is not None, true_fn_1, false_fn_1, get_args_1,
_jst.IfElse(label is not None,
true_fn_1,
false_fn_1,
get_args_1,
set_args_1,
('__return_0', '__return_1', '__return_value_0', 'loss'))
('__return_0', '__return_1', '__return_value_0', 'loss'),
push_pop_names=None)
return __return_value_0
......@@ -146,9 +153,12 @@ class StaticCode2():
x_v = x_v + 1
return
_jst.IfElse(
paddle.mean(x_v)[0] > 5, true_fn_2, false_fn_2, get_args_2,
set_args_2, ('x_v', ))
_jst.IfElse(paddle.mean(x_v)[0] > 5,
true_fn_2,
false_fn_2,
get_args_2,
set_args_2, ('x_v', ),
push_pop_names=None)
def get_args_3():
nonlocal __return_2, __return_3, __return_value_1, loss
......@@ -171,9 +181,13 @@ class StaticCode2():
__return_value_1 = x_v
return
_jst.IfElse(label is not None, true_fn_3, false_fn_3, get_args_3,
_jst.IfElse(label is not None,
true_fn_3,
false_fn_3,
get_args_3,
set_args_3,
('__return_2', '__return_3', '__return_value_1', 'loss'))
('__return_2', '__return_3', '__return_value_1', 'loss'),
push_pop_names=None)
return __return_value_1
......@@ -195,7 +209,7 @@ class TestDygraphToStaticCode(unittest.TestCase):
def test_decorator(self):
program_translator = ProgramTranslator()
code = program_translator.get_code(dyfunc_with_if_else)
#print(code)
print(code)
answer = get_source_code(StaticCode1.dyfunc_with_if_else)
self.assertEqual(
answer.replace('\n', '').replace(' ', ''),
......@@ -205,6 +219,7 @@ class TestDygraphToStaticCode(unittest.TestCase):
answer = get_source_code(StaticCode2.dyfunc_with_if_else)
program_translator = ProgramTranslator()
code = program_translator.get_code(dyfunc_with_if_else)
print(code)
self.assertEqual(
answer.replace('\n', '').replace(' ', ''),
code.replace('\n', '').replace(' ', ''))
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
from paddle.fluid.dygraph.dygraph_to_static.utils import GetterSetterHelper
vars = [1, 2, 3, 4, 5]
def getter():
return vars
def setter(values):
global vars
vars = values
class TestGetterSetterHelper(unittest.TestCase):
def test_1(self):
helper = GetterSetterHelper(getter, setter, ['a', 'b', 'e'],
['d', 'f', 'e'])
print(helper.union())
expect_union = ['a', 'b', 'd', 'e', 'f']
assert helper.union() == expect_union
assert helper.get(expect_union) == (1, 2, 3, 4, 5)
helper.set(['a', 'b'], [1, 1])
assert vars == [1, 1, 3, 4, 5]
helper.set(['f', 'e'], [12, 10])
assert vars == [1, 1, 3, 10, 12]
helper.set(None, None)
assert vars == [1, 1, 3, 10, 12]
assert helper.get(None) == tuple()
assert helper.get([]) == tuple()
if __name__ == '__main__':
unittest.main()
......@@ -551,8 +551,37 @@ def _getitem_impl_(var, item):
return out
def _setitem_for_tensor_array(var, item, value):
""" branches for tensor array setitem operation.
A item can be a:
(1) int/Variable, which is a simple number/variable such as [1], [-2]
(2) Slice, which is represented by bounds such as [2:-1]
(3) Tuple, which includes the above two cases such as [2:-1, 1]
If item is case (1), we perform paddle.tensor.array_write,
in other cases, we raise a NotImplementedError.
"""
from ..framework import LayerHelper, core, _non_static_mode
from .framework import Variable
assert not _non_static_mode(
), "setitem for tensor_array must be called in static graph mode."
if isinstance(item, (Variable, int)):
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle import cast
from paddle.tensor import array_write
item = paddle.cast(to_static_variable(item), dtype='int64')
value = to_static_variable(value)
array_write(x=value, i=item, array=var)
else:
raise NotImplementedError(
"Only support __setitem__ by Int/Variable in tensor_array, but gets {}"
.format(type(item)))
def _setitem_impl_(var, item, value):
from .framework import default_main_program, Variable
from paddle.fluid import core
if var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
return _setitem_for_tensor_array(var, item, value)
inputs = {'Input': var}
if isinstance(item, list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册